diff --git a/src/config/constants.py b/src/config/constants.py index 1ba4d0e..46d887b 100644 --- a/src/config/constants.py +++ b/src/config/constants.py @@ -68,9 +68,10 @@ OPENAI_API_ENDPOINTS = { "passwordless_send_otp": "https://auth.openai.com/api/accounts/passwordless/send-otp", "validate_otp": "https://auth.openai.com/api/accounts/email-otp/validate", "create_account": "https://auth.openai.com/api/accounts/create_account", - "add_phone" : "https://auth.openai.com/add-phone", + "add_phone": "https://auth.openai.com/add-phone", "select_workspace": "https://auth.openai.com/api/accounts/workspace/select", - "password_verify" : "https://auth.openai.com/api/accounts/password/verify" + "send_passwordless_otp": "https://auth.openai.com/api/accounts/passwordless/send-otp", + "password_verify": "https://auth.openai.com/api/accounts/password/verify", } # OpenAI 页面类型(用于判断账号状态) diff --git a/src/core/http_client.py b/src/core/http_client.py index f3dd876..d279f20 100644 --- a/src/core/http_client.py +++ b/src/core/http_client.py @@ -288,7 +288,7 @@ class OpenAIHTTPClient(HTTPClient): except Exception as e: logger.error(f"检查 IP 地理位置失败: {e}") - return False, None + return False, str(e) def send_openai_request( self, @@ -417,4 +417,4 @@ def create_openai_client( Returns: OpenAIHTTPClient 实例 """ - return OpenAIHTTPClient(proxy_url, config) \ No newline at end of file + return OpenAIHTTPClient(proxy_url, config) diff --git a/src/core/register.py b/src/core/register.py index b63adaf..278a381 100644 --- a/src/core/register.py +++ b/src/core/register.py @@ -6,6 +6,7 @@ import re import json import time +import hashlib import logging import secrets import string @@ -34,6 +35,18 @@ from ..config.settings import get_settings logger = logging.getLogger(__name__) +WORKSPACE_PROBE_BACKOFF_DELAYS = (1.0, 2.0, 4.0) +WORKSPACE_CONTINUE_URL_MAX_ATTEMPTS = len(WORKSPACE_PROBE_BACKOFF_DELAYS) +PROVIDER_VERIFICATION_TIMEOUTS = { + EmailServiceType.OUTLOOK: 180, + EmailServiceType.IMAP_MAIL: 150, + EmailServiceType.MOE_MAIL: 150, + EmailServiceType.TEMPMAIL: 120, + EmailServiceType.TEMP_MAIL: 120, + EmailServiceType.DUCK_MAIL: 90, + EmailServiceType.FREEMAIL: 120, +} + @dataclass class RegistrationResult: @@ -81,6 +94,15 @@ class SignupFormResult: error_message: str = "" +@dataclass +class OTPValidationResult: + """OTP 校验结果""" + success: bool + continue_url: str = "" + response_data: Dict[str, Any] = None + error_message: str = "" + + class RegistrationEngine: """ 注册引擎 @@ -132,6 +154,22 @@ class RegistrationEngine: self.logs: list = [] self._otp_sent_at: Optional[float] = None # OTP 发送时间戳 self._is_existing_account: bool = False # 是否为已注册账号(用于自动登录) + self._last_continue_url: Optional[str] = None # 最近一次 OTP/Workspace 返回的 continue_url + self._last_cookie_probe: Optional[Dict[str, Any]] = None # 最近一次 Cookie 探针快照 + + def close(self): + """关闭底层 HTTP 资源,避免批量任务长期运行时连接泄漏。""" + try: + self._close_http_client() + except Exception as e: + logger.warning(f"关闭注册引擎 HTTP 客户端失败: {e}") + + def _close_http_client(self): + """关闭底层 HTTP 资源并清理关联会话引用。""" + try: + self.http_client.close() + finally: + self.session = None def _log(self, message: str, level: str = "info"): """记录日志""" @@ -266,7 +304,7 @@ class RegistrationEngine: if attempt < max_attempts: time.sleep(attempt) - self.http_client.close() + self._close_http_client() self.session = self.http_client.session return None @@ -466,51 +504,111 @@ class RegistrationEngine: self._log(f"发送验证码失败: {e}", 'error') return False - def _get_verification_code(self) -> Optional[str]: + def _get_verification_code(self, otp_sent_at: Optional[float] = None) -> Optional[str]: """获取验证码""" try: - self._log(f"正在等待邮箱 {self.email} 的验证码...") + effective_otp_sent_at = otp_sent_at if otp_sent_at is not None else self._otp_sent_at + verification_timeout = self._get_provider_verification_timeout() + if effective_otp_sent_at: + self._log( + f"正在等待邮箱 {self.email} 的验证码... " + f"(provider={self.email_service.service_type.value}, timeout={verification_timeout}s, " + f"otp_sent_at={effective_otp_sent_at:.3f})" + ) + else: + self._log( + f"正在等待邮箱 {self.email} 的验证码... " + f"(provider={self.email_service.service_type.value}, timeout={verification_timeout}s)" + ) email_id = self.email_info.get("service_id") if self.email_info else None code = self.email_service.get_verification_code( email=self.email, email_id=email_id, - timeout=120, + timeout=verification_timeout, pattern=OTP_CODE_PATTERN, - otp_sent_at=self._otp_sent_at, + otp_sent_at=effective_otp_sent_at, ) if code: self._log(f"成功获取验证码: {code}") return code - else: - self._log("等待验证码超时", 'error') + + self._log("首次等待验证码超时,尝试主动刷新邮箱会话后再重试一次...", "warning") + if not self._refresh_verification_session(): + self._log("邮箱会话刷新失败,停止验证码重试", "error") return None + retry_code = self.email_service.get_verification_code( + email=self.email, + email_id=email_id, + timeout=verification_timeout, + pattern=OTP_CODE_PATTERN, + otp_sent_at=effective_otp_sent_at, + ) + if retry_code: + self._log(f"刷新会话后成功获取验证码: {retry_code}") + return retry_code + + self._log("等待验证码超时", "error") + return None + except Exception as e: self._log(f"获取验证码失败: {e}", 'error') return None - def _validate_verification_code(self, code: str) -> bool: - """验证验证码""" + def _get_provider_verification_timeout(self) -> int: + """根据邮箱 provider 推导验证码等待超时。""" + settings = get_settings() + base_timeout = max(int(getattr(settings, "email_code_timeout", 120) or 120), 30) + provider_timeout = PROVIDER_VERIFICATION_TIMEOUTS.get(self.email_service.service_type, base_timeout) + return max(base_timeout, provider_timeout) + + def _refresh_verification_session(self) -> bool: + """主动刷新邮箱服务会话,给验证码二次探测提供新连接状态。""" try: - code_body = f'{{"code":"{code}"}}' + refresh_handler = getattr(self.email_service, "refresh_session", None) + if callable(refresh_handler): + refresh_handler() + self._log("邮箱服务已执行 refresh_session()") + return True - response = self.session.post( - OPENAI_API_ENDPOINTS["validate_otp"], - headers={ - "referer": "https://auth.openai.com/email-verification", - "accept": "application/json", - "content-type": "application/json", - }, - data=code_body, - ) + refreshed = False + http_client = getattr(self.email_service, "http_client", None) + if http_client: + client_cls = http_client.__class__ + proxy_url = getattr(http_client, "proxy_url", None) + config = getattr(http_client, "config", None) + try: + http_client.close() + except Exception: + pass + setattr(self.email_service, "http_client", client_cls(proxy_url=proxy_url, config=config)) + refreshed = True - self._log(f"验证码校验状态: {response.status_code}") - return response.status_code == 200 + providers = getattr(self.email_service, "_providers", None) + provider_lock = getattr(self.email_service, "_provider_lock", None) + if isinstance(providers, dict): + if provider_lock: + with provider_lock: + providers.clear() + else: + providers.clear() + refreshed = True + reset_provider_health = getattr(self.email_service, "reset_provider_health", None) + if callable(reset_provider_health): + reset_provider_health() + refreshed = True + + if refreshed: + self._log("邮箱服务底层会话已主动刷新") + return True + + self._log("当前邮箱服务未暴露可刷新的会话句柄", "warning") + return False except Exception as e: - self._log(f"验证验证码失败: {e}", 'error') + self._log(f"刷新邮箱会话失败: {e}", "warning") return False def _create_user_account(self) -> bool: @@ -542,18 +640,756 @@ class RegistrationEngine: self._log(f"创建账户失败: {e}", 'error') return False - def _add_phone(self) -> bool: - """获取 手机验证码""" - phone_body = f'{{"code":"{code}"}}' - response = self.session.post( - OPENAI_API_ENDPOINTS["add_phone"], - headers={ - "referer": "https://auth.openai.com/api/accounts/create_account", + def _fingerprint_cookie_value(self, value: Optional[str]) -> str: + """生成 Cookie 指纹,便于审计变化而不直接输出完整值""" + normalized = str(value or "").strip() + if not normalized: + return "-" + return hashlib.sha256(normalized.encode("utf-8")).hexdigest()[:12] + + def _build_cookie_probe(self) -> Dict[str, str]: + """构建关键 Cookie 状态快照""" + cookies = self.session.cookies + auth_cookie = cookies.get("oai-client-auth-session") + session_cookie = cookies.get("__Secure-next-auth.session-token") + device_cookie = cookies.get("oai-did") + auth_segments = len((auth_cookie or "").split(".")) if auth_cookie else 0 + workspace_id = None + workspace_source = "auth_cookie_missing" + + if auth_cookie: + workspace_id, workspace_source = self._decode_workspace_id_from_auth_cookie(auth_cookie) + + return { + "did": "Y" if device_cookie else "N", + "did_fp": self._fingerprint_cookie_value(device_cookie), + "auth": "Y" if auth_cookie else "N", + "auth_segments": str(auth_segments), + "auth_jwt_state": "complete_jwt" if auth_segments == 3 else "partial_or_empty", + "auth_fp": self._fingerprint_cookie_value(auth_cookie), + "next_auth": "Y" if session_cookie else "N", + "next_auth_fp": self._fingerprint_cookie_value(session_cookie), + "workspace": workspace_id or "-", + "workspace_source": workspace_source, + } + + def _log_cookie_state(self, stage: str, reset_baseline: bool = False): + """记录关键 Cookie 状态与变化,用于观测状态机是否闭环""" + if not self.session: + self._log(f"{stage}: Session 未初始化", "warning") + return + + probe = self._build_cookie_probe() + + self._log( + f"{stage}: Cookie 探针 did={probe['did']}({probe['did_fp']}), " + f"auth={probe['auth']}(segments={probe['auth_segments']},state={probe['auth_jwt_state']},fp={probe['auth_fp']}), " + f"next_auth={probe['next_auth']}({probe['next_auth_fp']}), " + f"workspace={probe['workspace']}(source={probe['workspace_source']})" + ) + + previous_probe = None if reset_baseline else self._last_cookie_probe + if previous_probe: + changes = [] + tracked_fields = ( + ("did", "did"), + ("did_fp", "did_fp"), + ("auth", "auth"), + ("auth_segments", "auth_segments"), + ("auth_jwt_state", "auth_jwt_state"), + ("auth_fp", "auth_fp"), + ("next_auth", "next_auth"), + ("next_auth_fp", "next_auth_fp"), + ("workspace", "workspace"), + ("workspace_source", "workspace_source"), + ) + + for key, label in tracked_fields: + old_value = previous_probe.get(key) + new_value = probe.get(key) + if old_value != new_value: + changes.append(f"{label}:{old_value}->{new_value}") + + if changes: + self._log(f"{stage}: Cookie 变化 {'; '.join(changes)}") + else: + self._log(f"{stage}: Cookie 无变化") + else: + baseline_action = "重置" if reset_baseline else "建立" + self._log(f"{stage}: Cookie 探针基线已{baseline_action}") + + self._last_cookie_probe = probe + + def _decode_workspace_id_from_auth_cookie(self, auth_cookie: str) -> Tuple[Optional[str], str]: + """从授权 Cookie 中解析 Workspace ID""" + import base64 + import json as json_module + + cookie_value = str(auth_cookie or "").strip() + if not cookie_value: + return None, "empty_auth_cookie" + + segments = cookie_value.split(".") + decoded_payloads = [] + + for index, payload in enumerate(segments[:3]): + raw_segment = str(payload or "").strip() + if not raw_segment: + continue + + try: + pad = "=" * ((4 - (len(raw_segment) % 4)) % 4) + decoded = base64.urlsafe_b64decode((raw_segment + pad).encode("ascii")) + parsed = json_module.loads(decoded.decode("utf-8")) + if isinstance(parsed, dict): + decoded_payloads.append((index, parsed)) + except Exception: + continue + + if not decoded_payloads: + return None, f"cookie_decode_failed:segments={len(segments)}" + + for index, payload in decoded_payloads: + workspaces = payload.get("workspaces") or [] + if not workspaces: + continue + + workspace_id = str((workspaces[0] or {}).get("id") or "").strip() + if workspace_id: + return workspace_id, f"cookie_segment_{index}" + + return None, f"workspace_missing:segments={len(segments)}" + + def _try_upgrade_cookie_with_continue_url(self, continue_url: str, stage: str) -> Optional[str]: + """访问 continue_url,驱动授权 Cookie 升级后再次解析 workspace""" + normalized_url = str(continue_url or "").strip() + if not normalized_url: + self._log(f"{stage}: continue_url 为空,无法升级 Cookie", "warning") + return None + + self._last_continue_url = normalized_url + + for attempt, probe_delay in enumerate(WORKSPACE_PROBE_BACKOFF_DELAYS, start=1): + try: + self._log( + f"{stage}: 第 {attempt}/{WORKSPACE_CONTINUE_URL_MAX_ATTEMPTS} 次访问 continue_url 以升级授权 Cookie" + ) + response = self.session.get( + normalized_url, + allow_redirects=True, + timeout=15, + ) + self._log(f"{stage}: Continue URL 响应状态 {response.status_code}") + self._log_cookie_state(f"{stage} 后") + + workspace_id = self._get_workspace_id(log_missing=False) + if workspace_id: + self._log(f"{stage}: Continue URL 升级后获取到 Workspace ID") + return workspace_id + + self._log( + f"{stage}: 首次探测仍未拿到 Workspace,等待 {probe_delay:.1f}s 后执行二次探测..." + ) + time.sleep(probe_delay) + workspace_id = self._get_workspace_id(log_missing=False) + if workspace_id: + self._log(f"{stage}: 二次探测后获取到 Workspace ID") + return workspace_id + except Exception as e: + self._log(f"{stage}: Continue URL 访问异常: {e}", "warning") + + if attempt < WORKSPACE_CONTINUE_URL_MAX_ATTEMPTS: + self._log( + f"{stage}: 第 {attempt} 次二次探测仍未拿到 Workspace,继续下一轮 Continue URL 补偿..." + ) + + self._log(f"{stage}: Continue URL 已访问但 Workspace 仍不可用", "warning") + return None + + def _reset_oauth_session(self) -> bool: + """重置 OAuth 会话,用于降级登录流程""" + try: + self._log("正在重置 HTTP 会话以开始降级登录流程...") + if self.session: + self._log_cookie_state("降级流程会话重置前") + # 关闭旧客户端 + self._close_http_client() + # 创建新客户端和会话 + self.http_client = OpenAIHTTPClient(proxy_url=self.proxy_url) + self.session = self.http_client.session + self.session_token = None + self._last_continue_url = None + self._last_cookie_probe = None + + # 生成新的 OAuth URL + self.oauth_start = self.oauth_manager.start_oauth() + self._log(f"新 OAuth URL 已生成,准备开始登录握手") + self._log_cookie_state("降级流程会话重置后", reset_baseline=True) + return True + except Exception as e: + self._log(f"重置会话失败: {e}", "error") + return False + + def _submit_login_form(self, did: str, sen_token: Optional[str]) -> SignupFormResult: + """提交登录表单(screen_hint=login)""" + try: + login_body = f'{{"username":{{"value":"{self.email}","kind":"email"}},"screen_hint":"login"}}' + + headers = { + "referer": "https://auth.openai.com/login", "accept": "application/json", "content-type": "application/json", - }, - data=create_account_body, + } + + if sen_token: + sentinel = f'{{"p": "", "t": "", "c": "{sen_token}", "id": "{did}", "flow": "authorize_continue"}}' + headers["openai-sentinel-token"] = sentinel + + response = self.session.post( + OPENAI_API_ENDPOINTS["signup"], + headers=headers, + data=login_body, + ) + + self._log(f"登录表单提交状态: {response.status_code}") + self._log_cookie_state("降级登录提交邮箱后") + if response.status_code != 200: + return SignupFormResult(success=False, error_message=f"登录表单提交失败: HTTP {response.status_code}") + + response_data: Dict[str, Any] = {} + page_type = "" + try: + parsed = response.json() + if isinstance(parsed, dict): + response_data = parsed + page_type = str(parsed.get("page", {}).get("type", "") or "").strip() + else: + self._log(f"登录表单响应不是对象: {type(parsed).__name__},按 HTTP 200 继续无密码 OTP", "warning") + except Exception as parse_error: + self._log(f"解析登录表单响应失败: {parse_error},按 HTTP 200 继续无密码 OTP", "warning") + + if page_type: + self._log(f"登录表单响应页面类型: {page_type}") + else: + self._log("登录表单响应未返回 page.type,按 HTTP 200 继续无密码 OTP", "warning") + + if page_type == "login_password": + self._log("降级登录进入 login_password,按 Issue #62 继续触发无密码 OTP") + elif page_type == OPENAI_PAGE_TYPES["EMAIL_OTP_VERIFICATION"]: + self._log("降级登录进入 email_otp_verification,继续执行无密码 OTP 闭环") + elif page_type: + self._log(f"降级登录进入 {page_type},按 Issue #62 不做页面拦截,继续触发无密码 OTP") + + return SignupFormResult( + success=True, + page_type=page_type, + response_data=response_data, + ) + except Exception as e: + self._log(f"提交登录表单异常: {e}", "error") + return SignupFormResult(success=False, error_message=str(e)) + + def _send_passwordless_otp(self) -> bool: + """发送无密码 OTP 验证码""" + try: + self._log("正在触发无密码登录 OTP...") + otp_sent_at = time.time() + self._log_cookie_state("无密码 OTP 发送前") + + response = self.session.post( + OPENAI_API_ENDPOINTS["send_passwordless_otp"], + headers={ + "referer": "https://auth.openai.com/login/password", + "accept": "application/json", + "content-type": "application/json", + }, + data="" + ) + + self._log(f"无密码 OTP 发送状态: {response.status_code}") + self._log_cookie_state("无密码 OTP 发送后") + if response.status_code != 200: + self._log(f"无密码 OTP 发送失败响应: {response.text[:200]}", "warning") + else: + self._otp_sent_at = otp_sent_at + self._log(f"无密码 OTP 时间戳已更新: {self._otp_sent_at:.3f}") + return response.status_code == 200 + except Exception as e: + self._log(f"发送无密码 OTP 失败: {e}", "error") + return False + + def _validate_verification_code(self, code: str) -> OTPValidationResult: + """验证验证码并返回 continue_url 等状态信息""" + try: + code_body = f'{{"code":"{code}"}}' + + response = self.session.post( + OPENAI_API_ENDPOINTS["validate_otp"], + headers={ + "referer": "https://auth.openai.com/email-verification", + "accept": "application/json", + "content-type": "application/json", + }, + data=code_body, + ) + + self._log(f"验证码校验状态: {response.status_code}") + self._log_cookie_state("OTP 校验响应后") + if response.status_code != 200: + return OTPValidationResult( + success=False, + error_message=f"HTTP {response.status_code}: {response.text[:200]}" + ) + + try: + parsed = response.json() + if isinstance(parsed, dict): + data = parsed + else: + self._log( + f"验证码校验响应不是对象: {type(parsed).__name__},继续依据 Cookie 状态判断", + "warning" + ) + data = {} + except Exception as parse_error: + self._log( + f"解析验证码校验响应失败: {parse_error},继续依据 Cookie 状态判断", + "warning" + ) + data = {} + + continue_url = str(data.get("continue_url") or "").strip() + if continue_url: + self._last_continue_url = continue_url + self._log(f"验证码校验返回 Continue URL: {continue_url[:60]}...") + else: + self._log("验证码校验成功,但响应中未包含 continue_url", "warning") + + return OTPValidationResult( + success=True, + continue_url=continue_url, + response_data=data, + ) + except Exception as e: + self._log(f"验证验证码失败: {e}", "error") + return OTPValidationResult(success=False, error_message=str(e)) + + def _resolve_workspace_id(self) -> Optional[str]: + """按主闭环优先、降级补偿兜底的顺序解析 Workspace ID""" + workspace_id = self._get_workspace_id(log_missing=False) + if workspace_id: + return workspace_id + + if self._last_continue_url: + self._log("主流程首次探测未获取到 Workspace,尝试复用最近的 continue_url 升级 Cookie...") + workspace_id = self._try_upgrade_cookie_with_continue_url( + self._last_continue_url, + "主流程 Continue URL 补偿" + ) + if workspace_id: + return workspace_id + + self._log("主流程 Workspace 探测仍失败,切换到降级登录补偿流程...", "warning") + return self._fallback_to_login_flow() + + def _fallback_to_login_flow(self) -> Optional[str]: + """ + [核心修复] 降级到登录流程以获取 Workspace ID + 参考 Issue #62 解决方案 + """ + self._log("-" * 40) + self._log("检测到 Cookie 缺失 Workspace,启动降级登录补偿流程...") + + # 1. 重置会话 + if not self._reset_oauth_session(): + return None + + # 2. 获取 Device ID + did = self._get_device_id() + if not did: + self._log("降级登录未能获取 Device ID", "error") + return None + + # 3. Sentinel 检查 + sen_token = self._check_sentinel(did) + + # 4. 提交登录表单 + login_res = self._submit_login_form(did, sen_token) + if not login_res.success: + self._log(f"降级登录表单失败: {login_res.error_message}", "error") + return None + + # 5. 发送无密码 OTP + if not self._send_passwordless_otp(): + self._log("降级登录未能触发无密码 OTP", "error") + return None + + login_otp_sent_at = self._otp_sent_at + if not login_otp_sent_at: + login_otp_sent_at = time.time() + self._otp_sent_at = login_otp_sent_at + self._log("降级登录未记录到无密码 OTP 时间戳,使用当前时间兜底", "warning") + + self._log(f"降级登录 OTP 时间线已校准: {login_otp_sent_at:.3f}") + + # 6. 获取邮件验证码 + self._log("正在等待新的登录验证码...") + code = self._get_verification_code(otp_sent_at=login_otp_sent_at) + if not code: + self._log("降级登录未获取到新的 OTP 验证码", "error") + return None + + # 7. 验证 OTP -> GET continue_url -> 解析 Workspace + otp_result = self._validate_verification_code(code) + if not otp_result.success: + self._log(f"OTP 校验失败: {otp_result.error_message}", "error") + return None + + self._log_cookie_state("降级登录 OTP 校验后") + + workspace_id = self._get_workspace_id(log_missing=False) + if workspace_id: + self._log("降级登录 OTP 校验后已直接获得 Workspace,无需继续访问 continue_url") + return workspace_id + + continue_url = str(otp_result.continue_url or "").strip() + if not continue_url: + self._log("OTP 校验成功但缺少 continue_url,无法执行 Cookie 升级", "error") + return None + + self._log("降级登录 OTP 校验成功,开始 GET continue_url 升级 Cookie") + workspace_id = self._try_upgrade_cookie_with_continue_url( + continue_url, + "降级登录 Continue URL" ) + if not workspace_id: + self._log("降级登录闭环完成,但仍未拿到 Workspace", "error") + return workspace_id + + def _get_workspace_id(self, log_missing: bool = True) -> Optional[str]: + """获取 Workspace ID""" + try: + self._log_cookie_state("解析 Workspace 前") + auth_cookie = self.session.cookies.get("oai-client-auth-session") + if not auth_cookie: + if log_missing: + self._log("未能获取到授权 Cookie", "error") + return None + + try: + workspace_id, source = self._decode_workspace_id_from_auth_cookie(auth_cookie) + if not workspace_id: + if log_missing: + self._log(f"授权 Cookie 中未解析到 workspace: {source}", "warning") + return None + + self._log(f"Workspace ID 解析成功: {workspace_id} (source={source})") + return workspace_id + + except Exception as e: + if log_missing: + self._log(f"解析授权 Cookie 异常: {e}", "warning") + return None + + except Exception as e: + self._log(f"获取 Workspace ID 失败: {e}", "error") + return None + + def _select_workspace(self, workspace_id: str) -> Optional[str]: + """选择 Workspace""" + try: + select_body = f'{{"workspace_id":"{workspace_id}"}}' + + response = self.session.post( + OPENAI_API_ENDPOINTS["select_workspace"], + headers={ + "referer": "https://auth.openai.com/sign-in-with-chatgpt/codex/consent", + "content-type": "application/json", + }, + data=select_body, + ) + + if response.status_code != 200: + self._log(f"选择 workspace 失败: {response.status_code}", "error") + self._log(f"响应: {response.text[:200]}", "warning") + return None + + continue_url = str((response.json() or {}).get("continue_url") or "").strip() + if not continue_url: + self._log("workspace/select 响应里缺少 continue_url", "error") + return None + + self._log(f"Continue URL: {continue_url[:100]}...") + return continue_url + + except Exception as e: + self._log(f"选择 Workspace 失败: {e}", "error") + return None + + def _follow_redirects(self, start_url: str) -> Optional[str]: + """跟随重定向链,寻找回调 URL""" + try: + current_url = start_url + max_redirects = 6 + + for i in range(max_redirects): + self._log(f"重定向 {i+1}/{max_redirects}: {current_url[:100]}...") + + response = self.session.get( + current_url, + allow_redirects=False, + timeout=15 + ) + + location = response.headers.get("Location") or "" + + # 如果不是重定向状态码,停止 + if response.status_code not in [301, 302, 303, 307, 308]: + self._log(f"非重定向状态码: {response.status_code}") + break + + if not location: + self._log("重定向响应缺少 Location 头") + break + + # 构建下一个 URL + import urllib.parse + next_url = urllib.parse.urljoin(current_url, location) + + # 检查是否包含回调参数 + if "code=" in next_url and "state=" in next_url: + self._log(f"找到回调 URL: {next_url[:100]}...") + return next_url + + current_url = next_url + + self._log("未能在重定向链中找到回调 URL", "error") + return None + + except Exception as e: + self._log(f"跟随重定向失败: {e}", "error") + return None + + def _handle_oauth_callback(self, callback_url: str) -> Optional[Dict[str, Any]]: + """处理 OAuth 回调""" + try: + if not self.oauth_start: + self._log("OAuth 流程未初始化", "error") + return None + + self._log("处理 OAuth 回调...") + token_info = self.oauth_manager.handle_callback( + callback_url=callback_url, + expected_state=self.oauth_start.state, + code_verifier=self.oauth_start.code_verifier + ) + + self._log("OAuth 授权成功") + return token_info + + except Exception as e: + self._log(f"处理 OAuth 回调失败: {e}", "error") + return None + + def run(self) -> RegistrationResult: + """ + 执行完整的注册流程 + + 支持已注册账号自动登录: + - 如果检测到邮箱已注册,自动切换到登录流程 + - 已注册账号跳过:设置密码、发送验证码、创建用户账户 + - 共用步骤:获取验证码、验证验证码、Workspace 和 OAuth 回调 + + Returns: + RegistrationResult: 注册结果 + """ + result = RegistrationResult(success=False, logs=self.logs) + + try: + self._log("=" * 60) + self._log("开始注册流程") + self._log("=" * 60) + + # 1. 检查 IP 地理位置 + self._log("1. 检查 IP 地理位置...") + ip_ok, location = self._check_ip_location() + if not ip_ok: + result.error_message = f"IP 地理位置不支持: {location}" + self._log(f"IP 检查失败: {location}", "error") + return result + + self._log(f"IP 位置: {location}") + + # 2. 创建邮箱 + self._log("2. 创建邮箱...") + if not self._create_email(): + result.error_message = "创建邮箱失败" + return result + + result.email = self.email + + # 3. 初始化会话 + self._log("3. 初始化会话...") + if not self._init_session(): + result.error_message = "初始化会话失败" + return result + + # 4. 开始 OAuth 流程 + self._log("4. 开始 OAuth 授权流程...") + if not self._start_oauth(): + result.error_message = "开始 OAuth 流程失败" + return result + + # 5. 获取 Device ID + self._log("5. 获取 Device ID...") + did = self._get_device_id() + if not did: + result.error_message = "获取 Device ID 失败" + return result + + # 6. 检查 Sentinel 拦截 + self._log("6. 检查 Sentinel 拦截...") + sen_token = self._check_sentinel(did) + if sen_token: + self._log("Sentinel 检查通过") + else: + self._log("Sentinel 检查失败或未启用", "warning") + + # 7. 提交注册表单 + 解析响应判断账号状态 + self._log("7. 提交注册表单...") + signup_result = self._submit_signup_form(did, sen_token) + if not signup_result.success: + result.error_message = f"提交注册表单失败: {signup_result.error_message}" + return result + + # 8. [已注册账号跳过] 注册密码 + if self._is_existing_account: + self._log("8. [已注册账号] 跳过密码设置,OTP 已自动发送") + else: + self._log("8. 注册密码...") + password_ok, password = self._register_password() + if not password_ok: + result.error_message = "注册密码失败" + return result + + # 9. [已注册账号跳过] 发送验证码 + if self._is_existing_account: + self._log("9. [已注册账号] 跳过发送验证码,使用自动发送的 OTP") + # 已注册账号的 OTP 在提交表单时已自动发送,记录时间戳 + self._otp_sent_at = time.time() + else: + self._log("9. 发送验证码...") + if not self._send_verification_code(): + result.error_message = "发送验证码失败" + return result + + # 10. 获取验证码 + self._log("10. 等待验证码...") + code = self._get_verification_code() + if not code: + result.error_message = "获取验证码失败" + return result + + # 11. 验证验证码 + self._log("11. 验证验证码...") + otp_result = self._validate_verification_code(code) + if not otp_result.success: + result.error_message = f"验证验证码失败: {otp_result.error_message}" + return result + + if otp_result.continue_url: + self._try_upgrade_cookie_with_continue_url( + otp_result.continue_url, + "主流程 OTP 校验" + ) + + # 12. [已注册账号跳过] 创建用户账户 + if self._is_existing_account: + self._log("12. [已注册账号] 跳过创建用户账户") + else: + self._log("12. 创建用户账户...") + if not self._create_user_account(): + result.error_message = "创建用户账户失败" + return result + + # 13. 获取 Workspace ID + self._log("13. 获取 Workspace ID...") + workspace_id = self._resolve_workspace_id() + + if not workspace_id: + result.error_message = "获取 Workspace ID 失败 (含降级补偿)" + return result + + result.workspace_id = workspace_id + + # 14. 选择 Workspace + self._log("14. 选择 Workspace...") + continue_url = self._select_workspace(workspace_id) + if not continue_url: + result.error_message = "选择 Workspace 失败" + return result + + # 15. 跟随重定向链 + self._log("15. 跟随重定向链...") + callback_url = self._follow_redirects(continue_url) + if not callback_url: + result.error_message = "跟随重定向链失败" + return result + + # 16. 处理 OAuth 回调 + self._log("16. 处理 OAuth 回调...") + token_info = self._handle_oauth_callback(callback_url) + if not token_info: + result.error_message = "处理 OAuth 回调失败" + return result + + # 提取账户信息 + result.account_id = token_info.get("account_id", "") + result.access_token = token_info.get("access_token", "") + result.refresh_token = token_info.get("refresh_token", "") + result.id_token = token_info.get("id_token", "") + result.password = self.password or "" # 保存密码(已注册账号为空) + + # 设置来源标记 + result.source = "login" if self._is_existing_account else "register" + + # 尝试获取 session_token 从 cookie + session_cookie = self.session.cookies.get("__Secure-next-auth.session-token") + if session_cookie: + self.session_token = session_cookie + result.session_token = session_cookie + self._log(f"获取到 Session Token") + + # 17. 完成 + self._log("=" * 60) + if self._is_existing_account: + self._log("登录成功! (已注册账号)") + else: + self._log("注册成功!") + self._log(f"邮箱: {result.email}") + self._log(f"Account ID: {result.account_id}") + self._log(f"Workspace ID: {result.workspace_id}") + self._log("=" * 60) + + result.success = True + result.metadata = { + "email_service": self.email_service.service_type.value, + "proxy_used": self.proxy_url, + "registered_at": datetime.now().isoformat(), + "is_existing_account": self._is_existing_account, + } + + return result + + except Exception as e: + self._log(f"注册过程中发生未预期错误: {e}", "error") + result.error_message = str(e) + return result + finally: + try: + self._close_http_client() + except Exception as e: + logger.warning(f"关闭注册引擎 HTTP 客户端失败: {e}") def save_to_database(self, result: RegistrationResult) -> bool: """ diff --git a/src/database/crud.py b/src/database/crud.py index e42da90..b47d09c 100644 --- a/src/database/crud.py +++ b/src/database/crud.py @@ -2,7 +2,7 @@ 数据库 CRUD 操作 """ -from typing import List, Optional, Dict, Any, Union +from typing import List, Optional, Dict, Any, Union, Iterable, Set from datetime import datetime, timedelta from sqlalchemy.orm import Session from sqlalchemy import and_, or_, desc, asc, func @@ -464,9 +464,13 @@ def get_proxies( return query.all() -def get_enabled_proxies(db: Session) -> List[Proxy]: +def get_enabled_proxies(db: Session, exclude_ids: Optional[Iterable[int]] = None) -> List[Proxy]: """获取所有启用的代理""" - return db.query(Proxy).filter(Proxy.enabled == True).all() + query = db.query(Proxy).filter(Proxy.enabled == True) + excluded: Set[int] = {int(proxy_id) for proxy_id in (exclude_ids or [])} + if excluded: + query = query.filter(~Proxy.id.in_(excluded)) + return query.all() def update_proxy( @@ -517,14 +521,18 @@ def update_proxy_last_used(db: Session, proxy_id: int) -> bool: return True -def get_random_proxy(db: Session) -> Optional[Proxy]: +def get_random_proxy(db: Session, exclude_ids: Optional[Iterable[int]] = None) -> Optional[Proxy]: """随机获取一个启用的代理,优先返回 is_default=True 的代理""" import random + excluded: Set[int] = {int(proxy_id) for proxy_id in (exclude_ids or [])} # 优先返回默认代理 - default_proxy = db.query(Proxy).filter(Proxy.enabled == True, Proxy.is_default == True).first() + default_query = db.query(Proxy).filter(Proxy.enabled == True, Proxy.is_default == True) + if excluded: + default_query = default_query.filter(~Proxy.id.in_(excluded)) + default_proxy = default_query.first() if default_proxy: return default_proxy - proxies = get_enabled_proxies(db) + proxies = get_enabled_proxies(db, exclude_ids=excluded) if not proxies: return None return random.choice(proxies) diff --git a/src/web/routes/registration.py b/src/web/routes/registration.py index 99807f5..4a303d4 100644 --- a/src/web/routes/registration.py +++ b/src/web/routes/registration.py @@ -6,6 +6,7 @@ import asyncio import logging import uuid import random +import re from datetime import datetime from typing import List, Optional, Dict, Tuple, Any @@ -32,7 +33,16 @@ batch_tasks: Dict[str, dict] = {} # ============== Proxy Helper Functions ============== -def get_proxy_for_registration(db) -> Tuple[Optional[str], Optional[int]]: +RETRYABLE_PROXY_ERROR_PATTERN = re.compile( + r"(?:curl(?:[^0-9]{0,8})?(35|56)\b|curl:\s*\((35|56)\))", + re.IGNORECASE, +) + + +def get_proxy_for_registration( + db, + exclude_proxy_ids: Optional[List[int]] = None, +) -> Tuple[Optional[str], Optional[int]]: """ 获取用于注册的代理 @@ -45,7 +55,7 @@ def get_proxy_for_registration(db) -> Tuple[Optional[str], Optional[int]]: Tuple[proxy_url, proxy_id]: 代理 URL 和代理 ID(如果来自代理列表) """ # 先尝试从代理列表中获取 - proxy = crud.get_random_proxy(db) + proxy = crud.get_random_proxy(db, exclude_ids=exclude_proxy_ids) if proxy: return proxy.proxy_url, proxy.id @@ -64,6 +74,27 @@ def update_proxy_usage(db, proxy_id: Optional[int]): crud.update_proxy_last_used(db, proxy_id) +def is_retryable_proxy_error(error_message: Optional[str]) -> bool: + """判断是否属于可通过切换代理自愈的 curl 网络错误。""" + message = str(error_message or "").strip() + if not message: + return False + return RETRYABLE_PROXY_ERROR_PATTERN.search(message) is not None + + +def disable_proxy_for_network_error(db, proxy_id: Optional[int], reason: str) -> bool: + """将当前数据库代理标记为失效,避免后续再次被选中。""" + if not proxy_id: + return False + + proxy = crud.update_proxy(db, proxy_id, enabled=False) + if not proxy: + return False + + logger.warning(f"代理 {proxy_id} 因网络错误已自动禁用: {reason}") + return True + + # ============== Pydantic Models ============== class RegistrationTaskCreate(BaseModel): @@ -246,51 +277,44 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy: logger.error(f"任务不存在: {task_uuid}") return - # 确定使用的代理 - # 如果前端传入了代理参数,使用传入的 - # 否则从代理列表或系统设置中获取 - actual_proxy_url = proxy - proxy_id = None + resolved_email_service_id = email_service_id or task.email_service_id - if not actual_proxy_url: - actual_proxy_url, proxy_id = get_proxy_for_registration(db) - if actual_proxy_url: - logger.info(f"任务 {task_uuid} 使用代理: {actual_proxy_url[:50]}...") - - # 更新任务的代理记录 - crud.update_registration_task(db, task_uuid, proxy=actual_proxy_url) - - # 创建邮箱服务 - service_type = EmailServiceType(email_service_type) + # 更新 TaskManager 状态 + task_manager.update_status(task_uuid, "running") settings = get_settings() + log_callback = task_manager.create_log_callback(task_uuid, prefix=log_prefix, batch_id=batch_id) - # 优先使用数据库中配置的邮箱服务 - if email_service_id: - from ...database.models import EmailService as EmailServiceModel - db_service = db.query(EmailServiceModel).filter( - EmailServiceModel.id == email_service_id, - EmailServiceModel.enabled == True - ).first() + def build_email_service(active_proxy_url: Optional[str]): + requested_service_type = EmailServiceType(email_service_type) - if db_service: - service_type = EmailServiceType(db_service.service_type) - config = _normalize_email_service_config(service_type, db_service.config, actual_proxy_url) - # 更新任务关联的邮箱服务 - crud.update_registration_task(db, task_uuid, email_service_id=db_service.id) - logger.info(f"使用数据库邮箱服务: {db_service.name} (ID: {db_service.id}, 类型: {service_type.value})") - else: - raise ValueError(f"邮箱服务不存在或已禁用: {email_service_id}") - else: - # 使用默认配置或传入的配置 + if resolved_email_service_id: + from ...database.models import EmailService as EmailServiceModel + db_service = db.query(EmailServiceModel).filter( + EmailServiceModel.id == resolved_email_service_id, + EmailServiceModel.enabled == True + ).first() + + if db_service: + selected_service_type = EmailServiceType(db_service.service_type) + config = _normalize_email_service_config(selected_service_type, db_service.config, active_proxy_url) + crud.update_registration_task(db, task_uuid, email_service_id=db_service.id) + logger.info( + f"使用数据库邮箱服务: {db_service.name} " + f"(ID: {db_service.id}, 类型: {selected_service_type.value})" + ) + email_service = EmailServiceFactory.create(selected_service_type, config) + return email_service, selected_service_type + raise ValueError(f"邮箱服务不存在或已禁用: {resolved_email_service_id}") + + service_type = requested_service_type if service_type == EmailServiceType.TEMPMAIL: config = { "base_url": settings.tempmail_base_url, "timeout": settings.tempmail_timeout, "max_retries": settings.tempmail_max_retries, - "proxy_url": actual_proxy_url, + "proxy_url": active_proxy_url, } elif service_type == EmailServiceType.MOE_MAIL: - # 检查数据库中是否有可用的自定义域名服务 from ...database.models import EmailService as EmailServiceModel db_service = db.query(EmailServiceModel).filter( EmailServiceModel.service_type == "moe_mail", @@ -298,21 +322,19 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy: ).order_by(EmailServiceModel.priority.asc()).first() if db_service and db_service.config: - config = _normalize_email_service_config(service_type, db_service.config, actual_proxy_url) + config = _normalize_email_service_config(service_type, db_service.config, active_proxy_url) crud.update_registration_task(db, task_uuid, email_service_id=db_service.id) logger.info(f"使用数据库自定义域名服务: {db_service.name}") elif settings.custom_domain_base_url and settings.custom_domain_api_key: config = { "base_url": settings.custom_domain_base_url, "api_key": settings.custom_domain_api_key.get_secret_value() if settings.custom_domain_api_key else "", - "proxy_url": actual_proxy_url, + "proxy_url": active_proxy_url, } else: raise ValueError("没有可用的自定义域名邮箱服务,请先在设置中配置") elif service_type == EmailServiceType.OUTLOOK: - # 检查数据库中是否有可用的 Outlook 账户 from ...database.models import EmailService as EmailServiceModel, Account - # 获取所有启用的 Outlook 服务 outlook_services = db.query(EmailServiceModel).filter( EmailServiceModel.service_type == "outlook", EmailServiceModel.enabled == True @@ -327,14 +349,12 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy: email = svc.config.get("email") if svc.config else None if not email: continue - # 检查是否已在 accounts 表中注册 existing = db.query(Account).filter(Account.email == email).first() if not existing: selected_service = svc logger.info(f"选择未注册的 Outlook 账户: {email}") break - else: - logger.info(f"跳过已注册的 Outlook 账户: {email}") + logger.info(f"跳过已注册的 Outlook 账户: {email}") if selected_service and selected_service.config: config = selected_service.config.copy() @@ -352,7 +372,7 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy: ).order_by(EmailServiceModel.priority.asc()).first() if db_service and db_service.config: - config = _normalize_email_service_config(service_type, db_service.config, actual_proxy_url) + config = _normalize_email_service_config(service_type, db_service.config, active_proxy_url) crud.update_registration_task(db, task_uuid, email_service_id=db_service.id) logger.info(f"使用数据库 DuckMail 服务: {db_service.name}") else: @@ -366,7 +386,7 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy: ).order_by(EmailServiceModel.priority.asc()).first() if db_service and db_service.config: - config = _normalize_email_service_config(service_type, db_service.config, actual_proxy_url) + config = _normalize_email_service_config(service_type, db_service.config, active_proxy_url) crud.update_registration_task(db, task_uuid, email_service_id=db_service.id) logger.info(f"使用数据库 Freemail 服务: {db_service.name}") else: @@ -380,7 +400,7 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy: ).order_by(EmailServiceModel.priority.asc()).first() if db_service and db_service.config: - config = _normalize_email_service_config(service_type, db_service.config, actual_proxy_url) + config = _normalize_email_service_config(service_type, db_service.config, active_proxy_url) crud.update_registration_task(db, task_uuid, email_service_id=db_service.id) logger.info(f"使用数据库 IMAP 邮箱服务: {db_service.name}") else: @@ -388,23 +408,56 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy: else: config = email_service_config or {} - email_service = EmailServiceFactory.create(service_type, config) + email_service = EmailServiceFactory.create(service_type, config) + return email_service, service_type - # 在 WebSocket 状态里附带邮箱服务类型,前端可同步更新任务卡片 - task_manager.update_status(task_uuid, "running", email_service=service_type.value) + requested_proxy = proxy + exhausted_proxy_ids = set() + result = None + active_service_type = EmailServiceType(email_service_type) - # 创建注册引擎 - 使用 TaskManager 的日志回调 - log_callback = task_manager.create_log_callback(task_uuid, prefix=log_prefix, batch_id=batch_id) + while True: + actual_proxy_url = requested_proxy + proxy_id = None - engine = LoginEngine( - email_service=email_service, - proxy_url=actual_proxy_url, - callback_logger=log_callback, - task_uuid=task_uuid - ) + if not actual_proxy_url: + actual_proxy_url, proxy_id = get_proxy_for_registration( + db, + exclude_proxy_ids=list(exhausted_proxy_ids), + ) + if actual_proxy_url: + logger.info(f"任务 {task_uuid} 使用代理: {actual_proxy_url[:50]}...") - # 执行注册 - result = engine.run() + crud.update_registration_task(db, task_uuid, proxy=actual_proxy_url) + email_service, active_service_type = build_email_service(actual_proxy_url) + task_manager.update_status(task_uuid, "running", email_service=active_service_type.value) + engine = LoginEngine( + email_service=email_service, + proxy_url=actual_proxy_url, + callback_logger=log_callback, + task_uuid=task_uuid + ) + + result = engine.run() + if result.success: + break + + if is_retryable_proxy_error(result.error_message): + log_callback(f"[代理] 检测到可重试网络错误: {result.error_message}") + if proxy_id and disable_proxy_for_network_error(db, proxy_id, result.error_message): + exhausted_proxy_ids.add(proxy_id) + log_callback(f"[代理] 当前代理已标记失效并从代理池移除: {proxy_id}") + + next_proxy_url, next_proxy_id = get_proxy_for_registration( + db, + exclude_proxy_ids=list(exhausted_proxy_ids), + ) + if next_proxy_url and (next_proxy_url != actual_proxy_url or next_proxy_id != proxy_id): + requested_proxy = None + log_callback(f"[代理] 切换到新代理后重试注册: {next_proxy_url[:50]}...") + continue + + break if result.success: # 更新代理使用时间 @@ -506,7 +559,7 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy: completed_at=datetime.utcnow(), result={ **result.to_dict(), - "email_service": service_type.value, + "email_service": active_service_type.value, } ) @@ -515,7 +568,7 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy: task_uuid, "completed", email=result.email, - email_service=service_type.value, + email_service=active_service_type.value, ) logger.info(f"注册任务完成: {task_uuid}, 邮箱: {result.email}") @@ -533,7 +586,7 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy: task_uuid, "failed", error=result.error_message, - email_service=service_type.value, + email_service=active_service_type.value, ) logger.warning(f"注册任务失败: {task_uuid}, 原因: {result.error_message}") diff --git a/tests/test_register_fallback_flow.py b/tests/test_register_fallback_flow.py new file mode 100644 index 0000000..1a3027d --- /dev/null +++ b/tests/test_register_fallback_flow.py @@ -0,0 +1,439 @@ +from types import SimpleNamespace + +import pytest + +from src.config.constants import EmailServiceType, OPENAI_API_ENDPOINTS, OPENAI_PAGE_TYPES +from src.core import register +from src.services.base import BaseEmailService + + +class DummyEmailService(BaseEmailService): + def __init__(self): + super().__init__(EmailServiceType.TEMPMAIL, name="dummy") + + def create_email(self, config=None): + return {"email": "tester@example.com", "service_id": "svc-1"} + + def get_verification_code(self, email, email_id=None, timeout=120, pattern=None, otp_sent_at=None): + return "123456" + + def list_emails(self, **kwargs): + return [] + + def delete_email(self, email_id): + return True + + def check_health(self): + return True + + def refresh_session(self): + return None + + +class FakeResponse: + def __init__(self, status_code=200, payload=None, text=""): + self.status_code = status_code + self._payload = payload if payload is not None else {} + self.text = text + + def json(self): + return self._payload + + +class BrokenJSONResponse(FakeResponse): + def json(self): + raise ValueError("bad json") + + +class FakeSession: + def __init__(self, post_handler=None, get_handler=None, cookies=None): + self.post_handler = post_handler + self.get_handler = get_handler + self.cookies = cookies or {} + self.post_calls = [] + self.get_calls = [] + + def post(self, url, **kwargs): + self.post_calls.append({"url": url, "kwargs": kwargs}) + if self.post_handler is None: + raise AssertionError("unexpected post call") + return self.post_handler(url, **kwargs) + + def get(self, url, **kwargs): + self.get_calls.append({"url": url, "kwargs": kwargs}) + if self.get_handler is None: + raise AssertionError("unexpected get call") + return self.get_handler(url, **kwargs) + + +class DummyHTTPClient: + def __init__(self, proxy_url=None): + self.proxy_url = proxy_url + self.session = FakeSession() + self.closed = False + + def close(self): + self.closed = True + + def post(self, url, **kwargs): + raise AssertionError("unexpected http client post") + + +class DummyOAuthManager: + def __init__(self, **kwargs): + self.kwargs = kwargs + + def start_oauth(self): + return SimpleNamespace( + auth_url="https://auth.example/start", + state="state-1", + code_verifier="verifier-1", + redirect_uri="http://localhost/callback", + ) + + +def make_engine(monkeypatch, email_service=None): + monkeypatch.setattr( + register, + "get_settings", + lambda: SimpleNamespace( + openai_client_id="client-id", + openai_auth_url="https://auth.example/authorize", + openai_token_url="https://auth.example/token", + openai_redirect_uri="http://localhost/callback", + openai_scope="openid email profile offline_access", + ), + ) + monkeypatch.setattr(register, "OpenAIHTTPClient", DummyHTTPClient) + monkeypatch.setattr(register, "OAuthManager", DummyOAuthManager) + + engine = register.RegistrationEngine(email_service or DummyEmailService()) + engine.email = "tester@example.com" + engine.email_info = {"email": "tester@example.com", "service_id": "svc-1"} + return engine + + +@pytest.mark.parametrize( + "page_type", + [ + "login_password", + OPENAI_PAGE_TYPES["EMAIL_OTP_VERIFICATION"], + "consent_required", + "some_other_page", + ], +) +def test_submit_login_form_accepts_any_http_200_page(monkeypatch, page_type): + engine = make_engine(monkeypatch) + engine.session = FakeSession( + post_handler=lambda url, **kwargs: FakeResponse( + status_code=200, + payload={"page": {"type": page_type}}, + ) + ) + + result = engine._submit_login_form("did-1", "sen-1") + + assert result.success is True + assert result.page_type == page_type + assert result.error_message == "" + + +def test_submit_login_form_accepts_http_200_even_when_json_is_invalid(monkeypatch): + engine = make_engine(monkeypatch) + engine.session = FakeSession( + post_handler=lambda url, **kwargs: BrokenJSONResponse(status_code=200) + ) + + result = engine._submit_login_form("did-1", "sen-1") + + assert result.success is True + assert result.page_type == "" + assert result.response_data == {} + assert result.error_message == "" + + +def test_send_passwordless_otp_posts_empty_body(monkeypatch): + engine = make_engine(monkeypatch) + engine.session = FakeSession( + post_handler=lambda url, **kwargs: FakeResponse(status_code=200) + ) + + success = engine._send_passwordless_otp() + + assert success is True + assert len(engine.session.post_calls) == 1 + call = engine.session.post_calls[0] + assert call["url"] == OPENAI_API_ENDPOINTS["send_passwordless_otp"] + assert call["kwargs"]["data"] == "" + assert engine._otp_sent_at is not None + + +def test_send_passwordless_otp_does_not_update_timestamp_on_failure(monkeypatch): + engine = make_engine(monkeypatch) + engine._otp_sent_at = 1234.5 + engine.session = FakeSession( + post_handler=lambda url, **kwargs: FakeResponse(status_code=500, text="server error") + ) + + success = engine._send_passwordless_otp() + + assert success is False + assert engine._otp_sent_at == 1234.5 + + +def test_get_verification_code_passes_explicit_otp_timestamp(monkeypatch): + captured = {} + + class RecordingEmailService(DummyEmailService): + def get_verification_code(self, email, email_id=None, timeout=120, pattern=None, otp_sent_at=None): + captured["email"] = email + captured["email_id"] = email_id + captured["timeout"] = timeout + captured["pattern"] = pattern + captured["otp_sent_at"] = otp_sent_at + return "654321" + + engine = make_engine(monkeypatch, email_service=RecordingEmailService()) + + code = engine._get_verification_code(otp_sent_at=1234.5) + + assert code == "654321" + assert captured["email"] == "tester@example.com" + assert captured["email_id"] == "svc-1" + assert captured["timeout"] == 120 + assert captured["otp_sent_at"] == 1234.5 + + +def test_validate_verification_code_accepts_http_200_even_when_json_is_invalid(monkeypatch): + engine = make_engine(monkeypatch) + engine.session = FakeSession( + post_handler=lambda url, **kwargs: BrokenJSONResponse(status_code=200) + ) + + result = engine._validate_verification_code("123456") + + assert result.success is True + assert result.continue_url == "" + assert result.response_data == {} + + +def test_run_closes_http_client_on_early_failure(monkeypatch): + engine = make_engine(monkeypatch) + tracking_client = DummyHTTPClient() + engine.http_client = tracking_client + + monkeypatch.setattr(engine, "_check_ip_location", lambda: (False, None)) + + result = engine.run() + + assert result.success is False + assert tracking_client.closed is True + assert engine.session is None + + +def test_fallback_to_login_flow_forces_otp_and_continue_url(monkeypatch): + engine = make_engine(monkeypatch) + steps = [] + captured = {} + + monkeypatch.setattr(engine, "_reset_oauth_session", lambda: steps.append("reset_session") or True) + monkeypatch.setattr(engine, "_get_device_id", lambda: steps.append("get_device_id") or "did-1") + monkeypatch.setattr(engine, "_check_sentinel", lambda did: steps.append("check_sentinel") or "sen-1") + monkeypatch.setattr( + engine, + "_submit_login_form", + lambda did, sen: steps.append("submit_login_form") + or register.SignupFormResult(success=True, page_type="login_password"), + ) + + def fake_send_passwordless_otp(): + steps.append("send_passwordless_otp") + engine._otp_sent_at = 4567.89 + return True + + def fake_get_verification_code(otp_sent_at=None): + steps.append("get_verification_code") + captured["otp_sent_at"] = otp_sent_at + return "123456" + + monkeypatch.setattr(engine, "_send_passwordless_otp", fake_send_passwordless_otp) + monkeypatch.setattr(engine, "_get_verification_code", fake_get_verification_code) + monkeypatch.setattr( + engine, + "_validate_verification_code", + lambda code: steps.append("validate_verification_code") + or register.OTPValidationResult(success=True, continue_url="https://auth.example/continue"), + ) + + def fake_try_upgrade(continue_url, stage): + steps.append("get_continue_url_and_parse_workspace") + captured["continue_url"] = continue_url + captured["stage"] = stage + return "ws-123" + + monkeypatch.setattr(engine, "_try_upgrade_cookie_with_continue_url", fake_try_upgrade) + + workspace_id = engine._fallback_to_login_flow() + + assert workspace_id == "ws-123" + assert steps == [ + "reset_session", + "get_device_id", + "check_sentinel", + "submit_login_form", + "send_passwordless_otp", + "get_verification_code", + "validate_verification_code", + "get_continue_url_and_parse_workspace", + ] + assert captured["continue_url"] == "https://auth.example/continue" + assert captured["stage"] == "降级登录 Continue URL" + assert captured["otp_sent_at"] == 4567.89 + + +def test_fallback_to_login_flow_requires_continue_url(monkeypatch): + engine = make_engine(monkeypatch) + + monkeypatch.setattr(engine, "_reset_oauth_session", lambda: True) + monkeypatch.setattr(engine, "_get_device_id", lambda: "did-1") + monkeypatch.setattr(engine, "_check_sentinel", lambda did: "sen-1") + monkeypatch.setattr( + engine, + "_submit_login_form", + lambda did, sen: register.SignupFormResult(success=True, page_type="login_password"), + ) + + def fake_send_passwordless_otp(): + engine._otp_sent_at = 9876.5 + return True + + monkeypatch.setattr(engine, "_send_passwordless_otp", fake_send_passwordless_otp) + monkeypatch.setattr(engine, "_get_verification_code", lambda otp_sent_at=None: "123456") + monkeypatch.setattr( + engine, + "_validate_verification_code", + lambda code: register.OTPValidationResult(success=True, continue_url=""), + ) + + def fail_if_called(*args, **kwargs): + raise AssertionError("continue_url 缺失时不应尝试升级 Cookie") + + monkeypatch.setattr(engine, "_try_upgrade_cookie_with_continue_url", fail_if_called) + + assert engine._fallback_to_login_flow() is None + + +def test_fallback_to_login_flow_accepts_workspace_without_continue_url(monkeypatch): + engine = make_engine(monkeypatch) + steps = [] + + monkeypatch.setattr(engine, "_reset_oauth_session", lambda: steps.append("reset_session") or True) + monkeypatch.setattr(engine, "_get_device_id", lambda: steps.append("get_device_id") or "did-1") + monkeypatch.setattr(engine, "_check_sentinel", lambda did: steps.append("check_sentinel") or "sen-1") + monkeypatch.setattr( + engine, + "_submit_login_form", + lambda did, sen: steps.append("submit_login_form") + or register.SignupFormResult(success=True, page_type="login_password"), + ) + + def fake_send_passwordless_otp(): + steps.append("send_passwordless_otp") + engine._otp_sent_at = 4567.89 + return True + + monkeypatch.setattr(engine, "_send_passwordless_otp", fake_send_passwordless_otp) + monkeypatch.setattr( + engine, + "_get_verification_code", + lambda otp_sent_at=None: steps.append("get_verification_code") or "123456", + ) + monkeypatch.setattr( + engine, + "_validate_verification_code", + lambda code: steps.append("validate_verification_code") + or register.OTPValidationResult(success=True, continue_url=""), + ) + monkeypatch.setattr( + engine, + "_get_workspace_id", + lambda log_missing=True: steps.append("get_workspace_id") or "ws-cookie", + ) + + def fail_if_called(*args, **kwargs): + raise AssertionError("已有 workspace 时不应继续访问 continue_url") + + monkeypatch.setattr(engine, "_try_upgrade_cookie_with_continue_url", fail_if_called) + + workspace_id = engine._fallback_to_login_flow() + + assert workspace_id == "ws-cookie" + assert steps == [ + "reset_session", + "get_device_id", + "check_sentinel", + "submit_login_form", + "send_passwordless_otp", + "get_verification_code", + "validate_verification_code", + "get_workspace_id", + ] + + +def test_get_verification_code_uses_provider_timeout_and_refreshes_once(monkeypatch): + captured = {"calls": [], "refresh_count": 0} + + class RefreshableOutlookService(DummyEmailService): + def __init__(self): + super().__init__() + self.service_type = EmailServiceType.OUTLOOK + + def get_verification_code(self, email, email_id=None, timeout=120, pattern=None, otp_sent_at=None): + captured["calls"].append( + { + "timeout": timeout, + "otp_sent_at": otp_sent_at, + "email": email, + "email_id": email_id, + } + ) + if len(captured["calls"]) == 1: + return None + return "987654" + + def refresh_session(self): + captured["refresh_count"] += 1 + + engine = make_engine(monkeypatch, email_service=RefreshableOutlookService()) + + code = engine._get_verification_code(otp_sent_at=2468.0) + + assert code == "987654" + assert captured["refresh_count"] == 1 + assert len(captured["calls"]) == 2 + assert captured["calls"][0]["timeout"] == 180 + assert captured["calls"][1]["timeout"] == 180 + assert captured["calls"][0]["otp_sent_at"] == 2468.0 + + +def test_try_upgrade_cookie_with_continue_url_retries_with_second_probe(monkeypatch): + engine = make_engine(monkeypatch) + sleep_calls = [] + workspace_results = iter([None, None, None, None, None, "ws-delayed"]) + engine.session = FakeSession( + get_handler=lambda url, **kwargs: FakeResponse(status_code=302), + cookies={}, + ) + + monkeypatch.setattr(engine, "_log_cookie_state", lambda *args, **kwargs: None) + monkeypatch.setattr(engine, "_get_workspace_id", lambda log_missing=False: next(workspace_results)) + monkeypatch.setattr(register.time, "sleep", lambda seconds: sleep_calls.append(seconds)) + + workspace_id = engine._try_upgrade_cookie_with_continue_url( + "https://auth.example/continue", + "降级登录 Continue URL", + ) + + assert workspace_id == "ws-delayed" + assert len(engine.session.get_calls) == 3 + assert sleep_calls == [1.0, 2.0, 4.0] diff --git a/tests/test_registration_proxy_failover.py b/tests/test_registration_proxy_failover.py new file mode 100644 index 0000000..d482d2b --- /dev/null +++ b/tests/test_registration_proxy_failover.py @@ -0,0 +1,102 @@ +from types import SimpleNamespace + +from src.database import crud +from src.database.session import DatabaseSessionManager +from src.web.routes import registration +from src.core.register import RegistrationResult + + +def test_run_sync_registration_task_disables_bad_proxy_and_retries(monkeypatch, tmp_path): + manager = DatabaseSessionManager(f"sqlite:///{tmp_path}/test.db") + manager.create_tables() + manager.migrate_tables() + + with manager.session_scope() as session: + primary_proxy = crud.create_proxy( + session, + name="primary", + type="http", + host="127.0.0.1", + port=8001, + ) + crud.update_proxy(session, primary_proxy.id, is_default=True) + backup_proxy = crud.create_proxy( + session, + name="backup", + type="http", + host="127.0.0.1", + port=8002, + ) + email_service = crud.create_email_service( + session, + service_type="tempmail", + name="tempmail-db", + config={"base_url": "https://mail.example/api"}, + ) + crud.create_registration_task(session, task_uuid="task-proxy-failover") + primary_proxy_id = primary_proxy.id + backup_proxy_id = backup_proxy.id + email_service_id = email_service.id + + monkeypatch.setattr(registration, "get_db", manager.session_scope) + monkeypatch.setattr( + registration, + "EmailServiceFactory", + SimpleNamespace(create=lambda service_type, config: SimpleNamespace(service_type=service_type, config=config)), + ) + + attempted_proxies = [] + saved_results = [] + + class FakeRegistrationEngine: + def __init__(self, email_service, proxy_url=None, callback_logger=None, task_uuid=None): + self.proxy_url = proxy_url + + def run(self): + attempted_proxies.append(self.proxy_url) + if self.proxy_url.endswith(":8001"): + return RegistrationResult( + success=False, + email="proxy@example.com", + error_message="OpenAI 请求失败: curl: (35) TLS handshake failed", + ) + + return RegistrationResult( + success=True, + email="proxy@example.com", + access_token="access-token", + workspace_id="ws-123", + ) + + def save_to_database(self, result): + saved_results.append(result.email) + return True + + monkeypatch.setattr(registration, "RegistrationEngine", FakeRegistrationEngine) + + registration._run_sync_registration_task( + task_uuid="task-proxy-failover", + email_service_type="tempmail", + proxy=None, + email_service_config=None, + email_service_id=email_service_id, + ) + + assert attempted_proxies == [ + "http://127.0.0.1:8001", + "http://127.0.0.1:8002", + ] + assert saved_results == ["proxy@example.com"] + + with manager.session_scope() as session: + disabled_primary = crud.get_proxy_by_id(session, primary_proxy_id) + active_backup = crud.get_proxy_by_id(session, backup_proxy_id) + task = crud.get_registration_task_by_uuid(session, "task-proxy-failover") + + assert disabled_primary is not None + assert disabled_primary.enabled is False + assert active_backup is not None + assert active_backup.enabled is True + assert task is not None + assert task.status == "completed" + assert task.proxy == "http://127.0.0.1:8002"