mirror of
https://github.com/cnlimiter/codex-register.git
synced 2026-05-06 20:02:51 +08:00
fix(login): add workspace backoff and cleanup guard
This commit is contained in:
@@ -6,6 +6,7 @@
|
||||
import urllib.parse
|
||||
import base64
|
||||
import json as json_module
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
@@ -100,7 +101,6 @@ class LoginEngine(RegistrationEngine):
|
||||
def _send_verification_code_passwordless(self) -> bool:
|
||||
"""发送验证码"""
|
||||
try:
|
||||
import time
|
||||
# 记录发送时间戳
|
||||
self._otp_sent_at = time.time()
|
||||
response = self.session.post(
|
||||
@@ -118,21 +118,12 @@ class LoginEngine(RegistrationEngine):
|
||||
self._log(f"发送验证码失败: {e}", "error")
|
||||
return False
|
||||
|
||||
def _get_workspace_id(self) -> Optional[str]:
|
||||
"""获取 Workspace ID"""
|
||||
try:
|
||||
auth_cookie = self.session.cookies.get("oai-client-auth-session")
|
||||
if not auth_cookie:
|
||||
self._log("未能获取到授权 Cookie", "error")
|
||||
return None
|
||||
|
||||
try:
|
||||
def _decode_workspace_id(self, auth_cookie: str) -> str:
|
||||
"""从授权 Cookie 中解析 Workspace ID"""
|
||||
segments = auth_cookie.split(".")
|
||||
if len(segments) < 1:
|
||||
self._log("授权 Cookie 格式错误", "error")
|
||||
return None
|
||||
raise ValueError("授权 Cookie 格式错误")
|
||||
|
||||
# 解码第一个 segment
|
||||
payload = segments[0]
|
||||
pad = "=" * ((4 - (len(payload) % 4)) % 4)
|
||||
decoded = base64.urlsafe_b64decode((payload + pad).encode("ascii"))
|
||||
@@ -140,23 +131,40 @@ class LoginEngine(RegistrationEngine):
|
||||
|
||||
workspaces = auth_json.get("workspaces") or []
|
||||
if not workspaces:
|
||||
self._log("授权 Cookie 里没有 workspace 信息", "error")
|
||||
return None
|
||||
raise ValueError("授权 Cookie 里没有 workspace 信息")
|
||||
|
||||
workspace_id = str((workspaces[0] or {}).get("id") or "").strip()
|
||||
if not workspace_id:
|
||||
self._log("无法解析 workspace_id", "error")
|
||||
return None
|
||||
raise ValueError("无法解析 workspace_id")
|
||||
|
||||
return workspace_id
|
||||
|
||||
def _get_workspace_id(self) -> Optional[str]:
|
||||
"""获取 Workspace ID"""
|
||||
backoff_seconds = (1, 2, 4)
|
||||
max_attempts = len(backoff_seconds) + 1
|
||||
|
||||
for attempt in range(1, max_attempts + 1):
|
||||
try:
|
||||
auth_cookie = self.session.cookies.get("oai-client-auth-session")
|
||||
if auth_cookie:
|
||||
workspace_id = self._decode_workspace_id(auth_cookie)
|
||||
self._log(f"Workspace ID: {workspace_id}")
|
||||
return workspace_id
|
||||
|
||||
raise ValueError("未能获取到授权 Cookie")
|
||||
except Exception as e:
|
||||
self._log(f"解析授权 Cookie 失败: {e}", "error")
|
||||
return None
|
||||
level = "warning" if attempt < max_attempts else "error"
|
||||
self._log(
|
||||
f"获取 Workspace ID 失败: {e} (第 {attempt}/{max_attempts} 次)",
|
||||
level,
|
||||
)
|
||||
|
||||
if attempt < max_attempts:
|
||||
wait_seconds = backoff_seconds[attempt - 1]
|
||||
self._log(f"等待 {wait_seconds} 秒后重试 Workspace ID", "warning")
|
||||
time.sleep(wait_seconds)
|
||||
|
||||
except Exception as e:
|
||||
self._log(f"获取 Workspace ID 失败: {e}", "error")
|
||||
return None
|
||||
|
||||
def _select_workspace(self, workspace_id: str) -> Optional[str]:
|
||||
@@ -464,3 +472,5 @@ class LoginEngine(RegistrationEngine):
|
||||
self._log(f"注册过程中发生未预期错误: {e}", "error")
|
||||
result.error_message = str(e)
|
||||
return result
|
||||
finally:
|
||||
self.close()
|
||||
|
||||
@@ -211,6 +211,28 @@ class RegistrationEngine:
|
||||
self._log(f"初始化会话失败: {e}", 'error')
|
||||
return False
|
||||
|
||||
def close(self):
|
||||
"""关闭注册流程占用的资源"""
|
||||
if self.session:
|
||||
try:
|
||||
self.session.close()
|
||||
except Exception as e:
|
||||
self._log(f"关闭注册会话失败: {e}", "warning")
|
||||
finally:
|
||||
self.session = None
|
||||
|
||||
try:
|
||||
self.http_client.close()
|
||||
except Exception as e:
|
||||
self._log(f"关闭 HTTP 客户端失败: {e}", "warning")
|
||||
|
||||
close_email_service = getattr(self.email_service, "close", None)
|
||||
if callable(close_email_service):
|
||||
try:
|
||||
close_email_service()
|
||||
except Exception as e:
|
||||
self._log(f"关闭邮箱服务失败: {e}", "warning")
|
||||
|
||||
def _get_device_id(self) -> Optional[str]:
|
||||
"""获取 Device ID"""
|
||||
if not self.oauth_start:
|
||||
|
||||
57
tests/test_login_engine.py
Normal file
57
tests/test_login_engine.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import base64
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
|
||||
from src.core.login import LoginEngine
|
||||
|
||||
|
||||
def _build_auth_cookie(workspace_id: str) -> str:
|
||||
payload = base64.urlsafe_b64encode(
|
||||
json.dumps({"workspaces": [{"id": workspace_id}]}).encode("utf-8")
|
||||
).decode("ascii").rstrip("=")
|
||||
return f"{payload}.signature"
|
||||
|
||||
|
||||
def test_get_workspace_id_retries_with_exponential_backoff(monkeypatch):
|
||||
engine = LoginEngine.__new__(LoginEngine)
|
||||
engine.logs = []
|
||||
engine._log = lambda message, level="info": engine.logs.append((level, message))
|
||||
|
||||
auth_cookie = _build_auth_cookie("ws-123")
|
||||
cookies = SimpleNamespace()
|
||||
calls = {"count": 0}
|
||||
|
||||
def fake_get(name):
|
||||
assert name == "oai-client-auth-session"
|
||||
calls["count"] += 1
|
||||
if calls["count"] < 4:
|
||||
return None
|
||||
return auth_cookie
|
||||
|
||||
cookies.get = fake_get
|
||||
engine.session = SimpleNamespace(cookies=cookies)
|
||||
|
||||
sleeps = []
|
||||
monkeypatch.setattr("src.core.login.time.sleep", lambda seconds: sleeps.append(seconds))
|
||||
|
||||
workspace_id = engine._get_workspace_id()
|
||||
|
||||
assert workspace_id == "ws-123"
|
||||
assert calls["count"] == 4
|
||||
assert sleeps == [1, 2, 4]
|
||||
|
||||
|
||||
def test_run_always_closes_resources_on_early_return():
|
||||
engine = LoginEngine.__new__(LoginEngine)
|
||||
engine.logs = []
|
||||
engine._log = lambda message, level="info": None
|
||||
engine.close_called = False
|
||||
engine.close = lambda: setattr(engine, "close_called", True)
|
||||
|
||||
engine._check_ip_location = lambda: (False, "blocked")
|
||||
|
||||
result = engine.run()
|
||||
|
||||
assert result.success is False
|
||||
assert result.error_message == "IP 地理位置不支持: blocked"
|
||||
assert engine.close_called is True
|
||||
Reference in New Issue
Block a user