From 0933cd3881ddb2c6cb7bb9df911015609ab3f58a Mon Sep 17 00:00:00 2001 From: Mison Date: Mon, 23 Mar 2026 12:12:15 +0800 Subject: [PATCH] fix(login): add workspace backoff and cleanup guard --- src/core/login.py | 78 +++++++++++++++++++++----------------- src/core/register.py | 22 +++++++++++ tests/test_login_engine.py | 57 ++++++++++++++++++++++++++++ 3 files changed, 123 insertions(+), 34 deletions(-) create mode 100644 tests/test_login_engine.py diff --git a/src/core/login.py b/src/core/login.py index 1520f77..5ab780d 100644 --- a/src/core/login.py +++ b/src/core/login.py @@ -6,6 +6,7 @@ import urllib.parse import base64 import json as json_module +import time from datetime import datetime from typing import Optional, Dict, Any @@ -100,7 +101,6 @@ class LoginEngine(RegistrationEngine): def _send_verification_code_passwordless(self) -> bool: """发送验证码""" try: - import time # 记录发送时间戳 self._otp_sent_at = time.time() response = self.session.post( @@ -118,46 +118,54 @@ class LoginEngine(RegistrationEngine): self._log(f"发送验证码失败: {e}", "error") return False + def _decode_workspace_id(self, auth_cookie: str) -> str: + """从授权 Cookie 中解析 Workspace ID""" + segments = auth_cookie.split(".") + if len(segments) < 1: + raise ValueError("授权 Cookie 格式错误") + + payload = segments[0] + pad = "=" * ((4 - (len(payload) % 4)) % 4) + decoded = base64.urlsafe_b64decode((payload + pad).encode("ascii")) + auth_json = json_module.loads(decoded.decode("utf-8")) + + workspaces = auth_json.get("workspaces") or [] + if not workspaces: + raise ValueError("授权 Cookie 里没有 workspace 信息") + + workspace_id = str((workspaces[0] or {}).get("id") or "").strip() + if not workspace_id: + raise ValueError("无法解析 workspace_id") + + return workspace_id + def _get_workspace_id(self) -> Optional[str]: """获取 Workspace ID""" - try: - auth_cookie = self.session.cookies.get("oai-client-auth-session") - if not auth_cookie: - self._log("未能获取到授权 Cookie", "error") - return None + backoff_seconds = (1, 2, 4) + max_attempts = len(backoff_seconds) + 1 + for attempt in range(1, max_attempts + 1): try: - segments = auth_cookie.split(".") - if len(segments) < 1: - self._log("授权 Cookie 格式错误", "error") - return None - - # 解码第一个 segment - payload = segments[0] - pad = "=" * ((4 - (len(payload) % 4)) % 4) - decoded = base64.urlsafe_b64decode((payload + pad).encode("ascii")) - auth_json = json_module.loads(decoded.decode("utf-8")) - - workspaces = auth_json.get("workspaces") or [] - if not workspaces: - self._log("授权 Cookie 里没有 workspace 信息", "error") - return None - - workspace_id = str((workspaces[0] or {}).get("id") or "").strip() - if not workspace_id: - self._log("无法解析 workspace_id", "error") - return None - - self._log(f"Workspace ID: {workspace_id}") - return workspace_id + auth_cookie = self.session.cookies.get("oai-client-auth-session") + if auth_cookie: + workspace_id = self._decode_workspace_id(auth_cookie) + self._log(f"Workspace ID: {workspace_id}") + return workspace_id + raise ValueError("未能获取到授权 Cookie") except Exception as e: - self._log(f"解析授权 Cookie 失败: {e}", "error") - return None + level = "warning" if attempt < max_attempts else "error" + self._log( + f"获取 Workspace ID 失败: {e} (第 {attempt}/{max_attempts} 次)", + level, + ) - except Exception as e: - self._log(f"获取 Workspace ID 失败: {e}", "error") - return None + if attempt < max_attempts: + wait_seconds = backoff_seconds[attempt - 1] + self._log(f"等待 {wait_seconds} 秒后重试 Workspace ID", "warning") + time.sleep(wait_seconds) + + return None def _select_workspace(self, workspace_id: str) -> Optional[str]: """选择 Workspace""" @@ -464,3 +472,5 @@ class LoginEngine(RegistrationEngine): self._log(f"注册过程中发生未预期错误: {e}", "error") result.error_message = str(e) return result + finally: + self.close() diff --git a/src/core/register.py b/src/core/register.py index 431e093..9f7455e 100644 --- a/src/core/register.py +++ b/src/core/register.py @@ -211,6 +211,28 @@ class RegistrationEngine: self._log(f"初始化会话失败: {e}", 'error') return False + def close(self): + """关闭注册流程占用的资源""" + if self.session: + try: + self.session.close() + except Exception as e: + self._log(f"关闭注册会话失败: {e}", "warning") + finally: + self.session = None + + try: + self.http_client.close() + except Exception as e: + self._log(f"关闭 HTTP 客户端失败: {e}", "warning") + + close_email_service = getattr(self.email_service, "close", None) + if callable(close_email_service): + try: + close_email_service() + except Exception as e: + self._log(f"关闭邮箱服务失败: {e}", "warning") + def _get_device_id(self) -> Optional[str]: """获取 Device ID""" if not self.oauth_start: diff --git a/tests/test_login_engine.py b/tests/test_login_engine.py new file mode 100644 index 0000000..904c8ad --- /dev/null +++ b/tests/test_login_engine.py @@ -0,0 +1,57 @@ +import base64 +import json +from types import SimpleNamespace + +from src.core.login import LoginEngine + + +def _build_auth_cookie(workspace_id: str) -> str: + payload = base64.urlsafe_b64encode( + json.dumps({"workspaces": [{"id": workspace_id}]}).encode("utf-8") + ).decode("ascii").rstrip("=") + return f"{payload}.signature" + + +def test_get_workspace_id_retries_with_exponential_backoff(monkeypatch): + engine = LoginEngine.__new__(LoginEngine) + engine.logs = [] + engine._log = lambda message, level="info": engine.logs.append((level, message)) + + auth_cookie = _build_auth_cookie("ws-123") + cookies = SimpleNamespace() + calls = {"count": 0} + + def fake_get(name): + assert name == "oai-client-auth-session" + calls["count"] += 1 + if calls["count"] < 4: + return None + return auth_cookie + + cookies.get = fake_get + engine.session = SimpleNamespace(cookies=cookies) + + sleeps = [] + monkeypatch.setattr("src.core.login.time.sleep", lambda seconds: sleeps.append(seconds)) + + workspace_id = engine._get_workspace_id() + + assert workspace_id == "ws-123" + assert calls["count"] == 4 + assert sleeps == [1, 2, 4] + + +def test_run_always_closes_resources_on_early_return(): + engine = LoginEngine.__new__(LoginEngine) + engine.logs = [] + engine._log = lambda message, level="info": None + engine.close_called = False + engine.close = lambda: setattr(engine, "close_called", True) + + engine._check_ip_location = lambda: (False, "blocked") + + result = engine.run() + + assert result.success is False + assert result.error_message == "IP 地理位置不支持: blocked" + assert engine.close_called is True