From 5b76619d6f733e0b83b38f51c5c38511e4a71b8b Mon Sep 17 00:00:00 2001 From: Mison Date: Tue, 24 Mar 2026 10:26:59 +0800 Subject: [PATCH] feat(register): enhance workspace extraction and phase status reporting --- src/core/register.py | 263 ++++++++++++++- src/web/routes/registration.py | 353 +++++++++++++++++++- tests/test_registration_otp_phase.py | 69 ++++ tests/test_task_manager_status_broadcast.py | 32 ++ 4 files changed, 708 insertions(+), 9 deletions(-) diff --git a/src/core/register.py b/src/core/register.py index 74fbe7c..4082e83 100644 --- a/src/core/register.py +++ b/src/core/register.py @@ -139,6 +139,7 @@ class RegistrationEngine: email_service: BaseEmailService, proxy_url: Optional[str] = None, callback_logger: Optional[Callable[[str], None]] = None, + status_callback: Optional[Callable[[Dict[str, Any]], None]] = None, task_uuid: Optional[str] = None, ): """ @@ -148,11 +149,13 @@ class RegistrationEngine: email_service: 邮箱服务实例 proxy_url: 代理 URL callback_logger: 日志回调函数 + status_callback: 状态回调函数 task_uuid: 任务 UUID(用于数据库记录) """ self.email_service = email_service self.proxy_url = proxy_url self.callback_logger = callback_logger or (lambda msg: logger.info(msg)) + self.status_callback = status_callback self.task_uuid = task_uuid # 创建 HTTP 客户端 @@ -175,6 +178,7 @@ class RegistrationEngine: self.email_info: Optional[Dict[str, Any]] = None self.oauth_start: Optional[OAuthStart] = None self.session: Optional[cffi_requests.Session] = None + self.device_id: Optional[str] = None self.session_token: Optional[str] = None # 会话令牌 self.logs: list = [] self._otp_sent_at: Optional[float] = None # OTP 发送时间戳 @@ -213,6 +217,54 @@ class RegistrationEngine: """生成随机密码""" return ''.join(secrets.choice(PASSWORD_CHARSET) for _ in range(length)) + def _emit_status(self, phase: str, detail: str, **extra): + """向外部上报阶段进度。""" + if not self.status_callback: + return + + payload = { + "phase": phase, + "phase_detail": detail, + } + if self.email: + payload["email"] = self.email + payload.update({key: value for key, value in extra.items() if value is not None}) + + try: + self.status_callback(payload) + except Exception as e: + logger.warning(f"上报任务阶段状态失败: {e}") + + def _current_device_id(self) -> Optional[str]: + """优先复用现有 Device ID,避免重复触发慢请求。""" + if self.device_id: + return self.device_id + if not self.session: + return None + + did = self.session.cookies.get("oai-did") + if did: + self.device_id = did + return did + + def _log_timed_http_result( + self, + action: str, + started_at: float, + response: Optional[Any] = None, + ): + """记录 HTTP 调用的耗时与结果。""" + elapsed = max(0.0, time.time() - started_at) + parts = [f"{action} 完成,耗时 {elapsed:.1f} 秒"] + if response is not None: + status_code = getattr(response, "status_code", None) + response_url = str(getattr(response, "url", "") or "").strip() + if status_code is not None: + parts.append(f"HTTP {status_code}") + if response_url: + parts.append(f"URL: {response_url[:120]}...") + self._log(",".join(parts)) + def _record_phase_result(self, phase_result: PhaseResult) -> PhaseResult: self.phase_history = [ item for item in self.phase_history @@ -311,19 +363,33 @@ class RegistrationEngine: if not self.oauth_start: return None + cached_did = self._current_device_id() + if cached_did: + self._log(f"复用已有 Device ID: {cached_did}") + return cached_did + max_attempts = 3 for attempt in range(1, max_attempts + 1): try: if not self.session: self.session = self.http_client.session + self._emit_status( + "oauth_device_id", + f"获取 Device ID(第 {attempt}/{max_attempts} 次)", + attempt=attempt, + max_attempts=max_attempts, + ) + started_at = time.time() response = self.session.get( self.oauth_start.auth_url, timeout=20 ) + self._log_timed_http_result("获取 Device ID 请求", started_at, response) did = self.session.cookies.get("oai-did") if did: + self.device_id = did self._log(f"Device ID: {did}") return did @@ -347,8 +413,15 @@ class RegistrationEngine: def _check_sentinel(self, did: str) -> Optional[str]: """检查 Sentinel 拦截""" try: - sen_req_body = f'{{"p":"","id":"{did}","flow":"authorize_continue"}}' + device_id = did or self._current_device_id() + if not device_id: + self._log("Sentinel 检查跳过: 缺少 Device ID", "warning") + return None + self._emit_status("sentinel", "请求 Sentinel 校验令牌") + sen_req_body = f'{{"p":"","id":"{device_id}","flow":"authorize_continue"}}' + + started_at = time.time() response = self.http_client.post( OPENAI_API_ENDPOINTS["sentinel"], headers={ @@ -358,6 +431,7 @@ class RegistrationEngine: }, data=sen_req_body, ) + self._log_timed_http_result("Sentinel 校验", started_at, response) if response.status_code == 200: sen_token = response.json().get("token") @@ -719,6 +793,53 @@ class RegistrationEngine: return workspace_id return None + def _extract_workspace_id_from_text(self, text: str) -> Optional[str]: + """从 HTML/脚本文本中提取 Workspace ID。""" + if not text: + return None + + patterns = [ + r'"workspace_id"\s*:\s*"([^"]+)"', + r'"workspaceId"\s*:\s*"([^"]+)"', + r'"default_workspace_id"\s*:\s*"([^"]+)"', + r'"defaultWorkspaceId"\s*:\s*"([^"]+)"', + r'"active_workspace_id"\s*:\s*"([^"]+)"', + r'"activeWorkspaceId"\s*:\s*"([^"]+)"', + r'"workspace"\s*:\s*\{[^{}]*"id"\s*:\s*"([^"]+)"', + r'"default_workspace"\s*:\s*\{[^{}]*"id"\s*:\s*"([^"]+)"', + r'"active_workspace"\s*:\s*\{[^{}]*"id"\s*:\s*"([^"]+)"', + ] + for pattern in patterns: + match = re.search(pattern, text) + if match: + workspace_id = str(match.group(1) or "").strip() + if workspace_id: + return workspace_id + return None + + def _extract_workspace_id_from_url(self, url: str) -> Optional[str]: + """从 URL 查询参数或片段中提取 Workspace ID。""" + if not url: + return None + + import urllib.parse + + parsed = urllib.parse.urlparse(url) + for raw_query in (parsed.query, parsed.fragment): + query = urllib.parse.parse_qs(raw_query) + for key in ( + "workspace_id", + "workspaceId", + "default_workspace_id", + "active_workspace_id", + ): + values = query.get(key) or [] + if values: + workspace_id = str(values[0] or "").strip() + if workspace_id: + return workspace_id + return None + def _decode_cookie_json_candidates(self, cookie_value: str) -> list[Dict[str, Any]]: """尝试从完整 Cookie 或其分段中解码出 JSON。""" decoded_objects = [] @@ -760,12 +881,25 @@ class RegistrationEngine: if workspace_id: return workspace_id - for key in ("workspace_id", "default_workspace_id", "active_workspace_id"): + for key in ( + "workspace_id", + "workspaceId", + "default_workspace_id", + "defaultWorkspaceId", + "active_workspace_id", + "activeWorkspaceId", + ): workspace_id = str(auth_json.get(key) or "").strip() if workspace_id: return workspace_id - for key in ("workspace", "default_workspace", "active_workspace"): + for key in ( + "workspace", + "default_workspace", + "active_workspace", + "defaultWorkspace", + "activeWorkspace", + ): workspace = auth_json.get(key) if not isinstance(workspace, dict): continue @@ -776,6 +910,60 @@ class RegistrationEngine: return None + def _extract_workspace_id_from_response( + self, + response: Optional[Any] = None, + html: Optional[str] = None, + url: Optional[str] = None, + ) -> Optional[str]: + """统一从响应 JSON、HTML、脚本内容和 URL 中提取 Workspace ID。""" + response_url = str(getattr(response, "url", "") or "").strip() + response_text = html if html is not None else str(getattr(response, "text", "") or "") + candidate_url = url or response_url + + if response is not None: + try: + payload = response.json() + except Exception: + payload = None + workspace_id = self._extract_workspace_id_from_response_payload(payload) + if workspace_id: + return workspace_id + + for extractor in ( + lambda: self._extract_workspace_id_from_html(response_text), + lambda: self._extract_workspace_id_from_text(response_text), + lambda: self._extract_workspace_id_from_url(candidate_url), + ): + workspace_id = extractor() + if workspace_id: + return workspace_id + + return None + + def _extract_workspace_id_from_response_payload(self, payload: Any, depth: int = 0) -> Optional[str]: + """递归扫描响应载荷中的 Workspace ID。""" + if payload is None or depth > 5: + return None + + if isinstance(payload, dict): + workspace_id = self._extract_workspace_id_from_auth_json(payload) + if workspace_id: + return workspace_id + for value in payload.values(): + workspace_id = self._extract_workspace_id_from_response_payload(value, depth + 1) + if workspace_id: + return workspace_id + return None + + if isinstance(payload, list): + for item in payload: + workspace_id = self._extract_workspace_id_from_response_payload(item, depth + 1) + if workspace_id: + return workspace_id + + return None + def _select_workspace(self, workspace_id: str) -> Optional[str]: """选择 Workspace""" try: @@ -858,12 +1046,16 @@ class RegistrationEngine: return False try: - did = self.session.cookies.get("oai-did") if self.session else None + self._emit_status("login_reentry", "重新进入登录流程") + did = self._current_device_id() sen_token = self._check_sentinel(did) if did else None + self._log("登录重入:请求 authorize 页面以确认当前表单状态") + started_at = time.time() response = self.session.get( self.oauth_start.auth_url, timeout=15, ) + self._log_timed_http_result("登录重入 authorize 页面", started_at, response) html = response.text or "" if "/log-in/password" in str(getattr(response, "url", "") or "") or 'action="/log-in/password"' in html: @@ -877,6 +1069,9 @@ class RegistrationEngine: "value": self.email, } } + self._emit_status("login_reentry", "提交邮箱以推进到密码页") + self._log("登录重入:提交邮箱到 authorize/continue") + started_at = time.time() login_response = self.session.post( "https://auth.openai.com/api/accounts/authorize/continue", headers={ @@ -902,12 +1097,20 @@ class RegistrationEngine: data=json.dumps(login_data), timeout=15, ) + self._log_timed_http_result("登录重入邮箱提交", started_at, login_response) login_json = login_response.json() if login_response.status_code == 200 else {} page_type = str((login_json or {}).get("page", {}).get("type") or "").strip() continue_url = str((login_json or {}).get("continue_url") or "").strip() + self._log( + f"登录重入响应: page_type={page_type or 'unknown'}, " + f"continue_url={continue_url[:100] + '...' if continue_url else 'none'}" + ) if continue_url: try: + self._emit_status("login_reentry", "跟进登录 continue_url") + started_at = time.time() self.session.get(continue_url, timeout=15) + self._log_timed_http_result("登录重入 continue_url", started_at) except Exception: pass if login_response.status_code == 200 and page_type in {"password", "login_password"}: @@ -926,8 +1129,10 @@ class RegistrationEngine: return False try: - did = self.session.cookies.get("oai-did") if self.session else None + self._emit_status("login_password", "提交登录密码") + did = self._current_device_id() sen_token = self._check_sentinel(did) if did else None + started_at = time.time() response = self.session.post( "https://auth.openai.com/api/accounts/password/verify", headers={ @@ -955,6 +1160,7 @@ class RegistrationEngine: }), timeout=15, ) + self._log_timed_http_result("登录密码提交", started_at, response) self._log(f"登录密码提交状态: {response.status_code}") if response.status_code == 200: try: @@ -964,7 +1170,10 @@ class RegistrationEngine: continue_url = str(payload.get("continue_url") or "").strip() if continue_url: try: + self._emit_status("login_password", "跟进密码校验 continue_url") + started_at = time.time() self.session.get(continue_url, timeout=15) + self._log_timed_http_result("密码校验 continue_url", started_at) except Exception: pass return response.status_code in (200, 302, 303) @@ -977,7 +1186,7 @@ class RegistrationEngine: return False, None try: - did = self.session.cookies.get("oai-did") if self.session else None + did = self._current_device_id() sen_token = self._check_sentinel(did) if did else None response = self.session.post( "https://auth.openai.com/api/accounts/password/verify", @@ -1061,6 +1270,7 @@ class RegistrationEngine: self._log("重新初始化登录会话失败", "warning") return None, None + self._emit_status("oauth_reentry", "重新初始化 OAuth 登录会话") if not self._start_oauth(): self._log("重新开始 OAuth 登录流程失败", "warning") return None, None @@ -1074,6 +1284,7 @@ class RegistrationEngine: return None, None self._otp_sent_at = time.time() + self._emit_status("otp_secondary", "等待登录验证码邮件") if not self._submit_login_password_step(): return None, None @@ -1089,12 +1300,16 @@ class RegistrationEngine: return None, None auth_target = consent_url or self.oauth_start.auth_url + self._emit_status("workspace_extract", "请求 consent 页面并提取 Workspace ID") + self._log(f"请求 consent 页面: {auth_target[:120]}...") + started_at = time.time() auth_response = self.session.get(auth_target, timeout=20) + self._log_timed_http_result("获取 consent 页面", started_at, auth_response) current_url = str(getattr(auth_response, "url", "") or "") html = auth_response.text or "" if "sign-in-with-chatgpt/codex/consent" in current_url or 'action="/sign-in-with-chatgpt/codex/consent"' in html: - workspace_id = self._extract_workspace_id_from_html(html) + workspace_id = self._extract_workspace_id_from_response(response=auth_response, html=html, url=current_url) if not workspace_id: self._log("consent 页面缺少 workspace_id,回退到 Cookie 解析路径", "warning") return None, None @@ -1115,13 +1330,22 @@ class RegistrationEngine: max_redirects = 6 for i in range(max_redirects): + self._emit_status( + "redirect_chain", + f"跟随重定向 {i + 1}/{max_redirects}", + redirect_index=i + 1, + redirect_total=max_redirects, + redirect_url=current_url[:200], + ) self._log(f"重定向 {i+1}/{max_redirects}: {current_url[:100]}...") + started_at = time.time() response = self.session.get( current_url, allow_redirects=False, timeout=15 ) + self._log_timed_http_result(f"重定向跳转 {i + 1}/{max_redirects}", started_at, response) location = response.headers.get("Location") or "" @@ -1137,6 +1361,7 @@ class RegistrationEngine: # 构建下一个 URL import urllib.parse next_url = urllib.parse.urljoin(current_url, location) + self._log(f"重定向下一跳: {next_url[:100]}...") # 检查是否包含回调参数 if "code=" in next_url and "state=" in next_url: @@ -1159,12 +1384,19 @@ class RegistrationEngine: self._log("OAuth 流程未初始化", "error") return None + self._emit_status("oauth_callback", "处理 OAuth 回调并交换令牌") self._log("处理 OAuth 回调...") + started_at = time.time() token_info = self.oauth_manager.handle_callback( callback_url=callback_url, expected_state=self.oauth_start.state, code_verifier=self.oauth_start.code_verifier ) + elapsed = max(0.0, time.time() - started_at) + self._log( + f"OAuth 回调处理完成,耗时 {elapsed:.1f} 秒," + f"account_id={str(token_info.get('account_id') or '').strip() or 'unknown'}" + ) self._log("OAuth 授权成功") return token_info @@ -1197,6 +1429,7 @@ class RegistrationEngine: # 1. 检查 IP 地理位置 self._log("1. 检查 IP 地理位置...") + self._emit_status("ip_check", "检查 IP 地理位置", step_index=1) ip_ok, location = self._check_ip_location() if not ip_ok: result.error_message = f"IP 地理位置不支持: {location}" @@ -1207,6 +1440,7 @@ class RegistrationEngine: # 2. 创建邮箱 self._log("2. 创建邮箱...") + self._emit_status("email_prepare", "创建邮箱地址", step_index=2) if not self._phase_email_prepare(): email_prepare_phase = self._get_phase_result(PHASE_EMAIL_PREPARE) result.error_message = ( @@ -1221,18 +1455,21 @@ class RegistrationEngine: # 3. 初始化会话 self._log("3. 初始化会话...") + self._emit_status("session_init", "初始化 HTTP 会话", step_index=3) if not self._init_session(): result.error_message = "初始化会话失败" return result # 4. 开始 OAuth 流程 self._log("4. 开始 OAuth 授权流程...") + self._emit_status("oauth_start", "开始 OAuth 授权流程", step_index=4) if not self._start_oauth(): result.error_message = "开始 OAuth 流程失败" return result # 5. 获取 Device ID self._log("5. 获取 Device ID...") + self._emit_status("oauth_device_id", "获取 Device ID", step_index=5) did = self._get_device_id() if not did: result.error_message = "获取 Device ID 失败" @@ -1240,6 +1477,7 @@ class RegistrationEngine: # 6. 检查 Sentinel 拦截 self._log("6. 检查 Sentinel 拦截...") + self._emit_status("sentinel", "检查 Sentinel 拦截", step_index=6) sen_token = self._check_sentinel(did) if sen_token: self._log("Sentinel 检查通过") @@ -1248,6 +1486,7 @@ class RegistrationEngine: # 7. 提交注册表单 + 解析响应判断账号状态 self._log("7. 提交注册表单...") + self._emit_status("signup_submit", "提交注册表单", step_index=7) signup_result = self._submit_signup_form(did, sen_token) if not signup_result.success: result.error_message = f"提交注册表单失败: {signup_result.error_message}" @@ -1258,6 +1497,7 @@ class RegistrationEngine: self._log("8. [已注册账号] 跳过密码设置,OTP 已自动发送") else: self._log("8. 注册密码...") + self._emit_status("signup_password", "提交注册密码", step_index=8) password_ok, password = self._register_password() if not password_ok: result.error_message = "注册密码失败" @@ -1270,12 +1510,14 @@ class RegistrationEngine: self._otp_sent_at = time.time() else: self._log("9. 发送验证码...") + self._emit_status("otp_send", "发送验证码", step_index=9) if not self._send_verification_code(): result.error_message = "发送验证码失败" return result # 10. 获取验证码 self._log("10. 等待验证码...") + self._emit_status("otp_secondary", "等待验证码邮件", step_index=10) otp_phase_started_at = time.time() code, otp_phase = self._phase_otp_secondary( PhaseContext(otp_sent_at=self._otp_sent_at), @@ -1290,6 +1532,7 @@ class RegistrationEngine: # 11. 验证验证码 self._log("11. 验证验证码...") + self._emit_status("otp_validate", "校验验证码", step_index=11) if not self._validate_verification_code(code): result.error_message = "验证验证码失败" return result @@ -1299,6 +1542,7 @@ class RegistrationEngine: self._log("12. [已注册账号] 跳过创建用户账户") else: self._log("12. 创建用户账户...") + self._emit_status("account_create", "创建 OpenAI 账户资料", step_index=12) if not self._create_user_account(): result.error_message = "创建用户账户失败" return result @@ -1308,6 +1552,7 @@ class RegistrationEngine: if not self._is_existing_account: self._log(f"{next_step}. [新账号] 推进 Codex 授权流程...") + self._emit_status("oauth_reentry", "推进 Codex 授权流程", step_index=next_step) workspace_id, callback_url = self._advance_login_authorization() if workspace_id and callback_url: result.workspace_id = workspace_id @@ -1316,6 +1561,7 @@ class RegistrationEngine: if not result.workspace_id: # 获取 Workspace ID self._log(f"{next_step}. 获取 Workspace ID...") + self._emit_status("workspace_extract", "从授权态提取 Workspace ID", step_index=next_step) workspace_id = self._get_workspace_id() if not workspace_id: result.error_message = "获取 Workspace ID 失败" @@ -1327,6 +1573,7 @@ class RegistrationEngine: # 选择 Workspace self._log(f"{next_step}. 选择 Workspace...") + self._emit_status("workspace_select", "选择 Workspace", step_index=next_step) continue_url = self._select_workspace(result.workspace_id) if not continue_url: result.error_message = "选择 Workspace 失败" @@ -1336,6 +1583,7 @@ class RegistrationEngine: # 跟随重定向链 self._log(f"{next_step}. 跟随重定向链...") + self._emit_status("redirect_chain", "跟随授权重定向链", step_index=next_step) callback_url = self._follow_redirects(continue_url) if not callback_url: result.error_message = "跟随重定向链失败" @@ -1345,6 +1593,7 @@ class RegistrationEngine: # 处理 OAuth 回调 self._log(f"{next_step}. 处理 OAuth 回调...") + self._emit_status("oauth_callback", "处理 OAuth 回调", step_index=next_step) token_info = self._handle_oauth_callback(callback_url) if not token_info: result.error_message = "处理 OAuth 回调失败" diff --git a/src/web/routes/registration.py b/src/web/routes/registration.py index 7b4ea98..a497bbf 100644 --- a/src/web/routes/registration.py +++ b/src/web/routes/registration.py @@ -10,7 +10,7 @@ import random import re import time from datetime import datetime -from typing import List, Optional, Dict, Tuple +from typing import List, Optional, Dict, Tuple, Any from fastapi import APIRouter, HTTPException, Query, BackgroundTasks from pydantic import BaseModel, Field @@ -24,7 +24,7 @@ from ...core.register import ( RegistrationResult, ) from ...services import EmailServiceFactory, EmailServiceType -from ...services.base import EmailProviderBackoffState, OTPTimeoutEmailServiceError +from ...services.base import BaseEmailService, EmailProviderBackoffState, OTPTimeoutEmailServiceError from ...config.settings import get_settings from ..task_manager import task_manager @@ -136,6 +136,13 @@ class BatchRegistrationRequest(BaseModel): tm_service_ids: List[int] = [] +class MockRegistrationCreateRequest(BaseModel): + """创建受控模拟任务请求""" + email_service_type: str = "tempmail" + start_delay_ms: int = Field(default=300, ge=0, le=5000) + log_delay_ms: int = Field(default=250, ge=0, le=5000) + + class RegistrationTaskResponse(BaseModel): """注册任务响应""" id: int @@ -161,6 +168,13 @@ class BatchRegistrationResponse(BaseModel): tasks: List[RegistrationTaskResponse] +class MockRegistrationTaskCreateResponse(BaseModel): + """受控模拟任务响应""" + task: RegistrationTaskResponse + batch_id: str + checks: Dict[str, Any] + + class TaskListResponse(BaseModel): """任务列表响应""" total: int @@ -232,6 +246,19 @@ def task_to_response(task: RegistrationTask) -> RegistrationTaskResponse: ) +def _create_task_status_callback(task_uuid: str, email_service: str): + """把引擎内部阶段进度映射到 TaskManager 状态广播。""" + + def callback(payload: Dict[str, Any]) -> None: + status_payload = { + "email_service": email_service, + **payload, + } + task_manager.update_status(task_uuid, "running", **status_payload) + + return callback + + def _normalize_email_service_config( service_type: EmailServiceType, config: Optional[dict], @@ -329,6 +356,7 @@ def _run_registration_engine_attempt( actual_proxy_url: Optional[str], log_callback, db_service, + status_callback=None, ): """执行单次注册引擎尝试,并在同一临界区内维护邮箱服务退避状态。""" provider_backoff_before_run = EmailProviderBackoffState() @@ -343,6 +371,7 @@ def _run_registration_engine_attempt( email_service=email_service, proxy_url=actual_proxy_url, callback_logger=log_callback, + status_callback=status_callback, task_uuid=task_uuid, ) @@ -570,6 +599,7 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy: crud.update_registration_task(db, task_uuid, email_service_id=None) task_manager.update_status(task_uuid, "running", email_service=active_service_type.value) + status_callback = _create_task_status_callback(task_uuid, active_service_type.value) email_service = EmailServiceFactory.create( selected_service_type, candidate_config, @@ -587,6 +617,7 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy: actual_proxy_url=actual_proxy_url, log_callback=log_callback, db_service=db_service, + status_callback=status_callback, ) if result.success: @@ -856,6 +887,275 @@ def _make_batch_helpers(batch_id: str): return add_batch_log, update_batch_status +class _MockBackoffEmailService(BaseEmailService): + """用于真实服务验证的最小邮箱服务桩。""" + + def __init__(self): + super().__init__(service_type=EmailServiceType.DUCK_MAIL, name="mock-backoff-service") + + def create_email(self, config: Dict[str, Any] = None) -> Dict[str, Any]: + return {"email": "mock@example.test", "service_id": "mock-service-id"} + + def get_verification_code( + self, + email: str, + email_id: str = None, + timeout: int = 120, + pattern: str = r"(? Optional[str]: + return None + + def list_emails(self, **kwargs) -> List[Dict[str, Any]]: + return [] + + def delete_email(self, email_id: str) -> bool: + return True + + def check_health(self) -> bool: + return True + + +def _create_persisted_log_callback(task_uuid: str, prefix: str = "", batch_id: str = ""): + """同时写入内存日志队列、批量日志通道和数据库任务日志。""" + + def callback(message: str) -> None: + full_message = f"{prefix} {message}" if prefix else message + task_manager.add_log(task_uuid, full_message) + if batch_id: + task_manager.add_batch_log(batch_id, full_message) + with get_db() as db: + crud.append_task_log(db, task_uuid, full_message) + + return callback + + +def _simulate_batch_counter_probe(batch_id: str) -> Dict[str, Any]: + """构造一个可重复的批量计数场景,验证 TaskManager 计数收口。""" + task_uuids = [str(uuid.uuid4()) for _ in range(3)] + task_statuses = ["completed", "failed", "completed"] + _init_batch_state(batch_id, task_uuids) + add_batch_log, update_batch_status = _make_batch_helpers(batch_id) + add_batch_log(f"[系统] 模拟批量任务启动,总任务: {len(task_uuids)}") + + with get_db() as db: + for index, (task_uuid, status) in enumerate(zip(task_uuids, task_statuses), start=1): + crud.create_registration_task(db, task_uuid=task_uuid, proxy=None) + error_message = None if status == "completed" else f"mock-batch-error-{index}" + crud.update_registration_task( + db, + task_uuid, + status=status, + started_at=datetime.utcnow(), + completed_at=datetime.utcnow(), + error_message=error_message, + ) + + batch_snapshot = _get_batch_snapshot(batch_id) or {} + new_completed = batch_snapshot.get("completed", 0) + 1 + new_success = batch_snapshot.get("success", 0) + new_failed = batch_snapshot.get("failed", 0) + if status == "completed": + new_success += 1 + add_batch_log(f"[任务{index}] [成功] 模拟注册成功") + else: + new_failed += 1 + add_batch_log(f"[任务{index}] [失败] 模拟注册失败: {error_message}") + update_batch_status(completed=new_completed, success=new_success, failed=new_failed) + + batch_snapshot = _get_batch_snapshot(batch_id) or {} + add_batch_log( + f"[完成] 批量任务完成!成功: {batch_snapshot.get('success', 0)}, " + f"失败: {batch_snapshot.get('failed', 0)}" + ) + update_batch_status(finished=True, status="completed") + return { + "batch_id": batch_id, + "task_uuids": task_uuids, + "snapshot": task_manager.get_batch_status(batch_id) or {}, + } + + +async def run_mock_registration_task( + task_uuid: str, + batch_id: str, + checks: Dict[str, Any], + email_service_type: str, + start_delay_ms: int, + log_delay_ms: int, +) -> None: + """通过真实服务链路执行可重复的模拟任务。""" + if start_delay_ms > 0: + await asyncio.sleep(start_delay_ms / 1000) + + loop = task_manager.get_loop() + if loop is None: + loop = asyncio.get_event_loop() + task_manager.set_loop(loop) + + log_callback = _create_persisted_log_callback(task_uuid) + delay_seconds = max(log_delay_ms, 0) / 1000 + + try: + with get_db() as db: + task = crud.update_registration_task( + db, + task_uuid, + status="running", + started_at=datetime.utcnow(), + ) + if not task: + logger.error(f"模拟任务不存在: {task_uuid}") + return + + task_manager.update_status(task_uuid, "running", email_service=email_service_type) + log_callback("[模拟] 任务已启动,开始执行真实链路探针") + if delay_seconds: + await asyncio.sleep(delay_seconds) + + with get_db() as db: + seeded_account = crud.create_account( + db, + email=checks["seeded_account_email"], + email_service="tempmail", + access_token="mock-access-token-seeded", + refresh_token="mock-refresh-token-seeded", + ) + tokenless_account = crud.create_account( + db, + email=checks["tokenless_account_email"], + email_service="tempmail", + ) + crud.update_account( + db, + tokenless_account.id, + access_token="mock-access-token-updated", + ) + partial_account = crud.create_account( + db, + email=checks["partial_account_email"], + email_service="tempmail", + access_token="mock-access-token-partial", + refresh_token="mock-refresh-token-partial", + ) + crud.update_account( + db, + partial_account.id, + refresh_token="", + ) + outlook_service = crud.create_email_service( + db, + service_type="outlook", + name=f"mock-outlook-{task_uuid[:8]}", + config={ + "accounts": [ + {"email": "first@example.test", "refresh_token": "old-first"}, + { + "email": checks["outlook_account_email"], + "refresh_token": "old-second", + }, + ] + }, + ) + crud.update_outlook_refresh_token( + db, + service_id=outlook_service.id, + email=checks["outlook_account_email"], + new_refresh_token="new-second", + ) + backoff_service = crud.create_email_service( + db, + service_type="duck_mail", + name=checks["backoff_service_name"], + config={ + "base_url": "https://mail.example.test", + "default_domain": "example.test", + }, + ) + checks["seeded_account_id"] = seeded_account.id + checks["tokenless_account_id"] = tokenless_account.id + checks["partial_account_id"] = partial_account.id + checks["outlook_service_id"] = outlook_service.id + checks["backoff_service_id"] = backoff_service.id + log_callback("[模拟] Token 同步与 Outlook refresh_token 探针已写入数据库") + if delay_seconds: + await asyncio.sleep(delay_seconds) + + mock_email_service = _MockBackoffEmailService() + backoff_states = [] + for attempt in range(1, 4): + previous_state = _get_email_service_backoff_state(backoff_service.id) + current_state = _record_email_service_timeout_backoff( + backoff_service.id, + mock_email_service, + previous_state, + ERROR_OTP_TIMEOUT_SECONDARY, + f"模拟 OTP 超时 #{attempt}", + ) + if current_state is not None: + backoff_states.append(current_state.to_dict()) + log_callback( + f"[模拟] OTP 超时退避 #{attempt}: " + f"failures={current_state.failures}, delay={current_state.delay_seconds}" + ) + if delay_seconds: + await asyncio.sleep(delay_seconds) + + batch_probe = _simulate_batch_counter_probe(batch_id) + log_callback("[模拟] 批量计数探针已完成") + if delay_seconds: + await asyncio.sleep(delay_seconds) + + result = { + "email": checks["seeded_account_email"], + "email_service": email_service_type, + "hardening_checks": { + "token_sync": { + "seeded_account_id": checks["seeded_account_id"], + "tokenless_account_id": checks["tokenless_account_id"], + "partial_account_id": checks["partial_account_id"], + }, + "outlook_refresh": { + "service_id": checks["outlook_service_id"], + "email": checks["outlook_account_email"], + }, + "batch_counter": batch_probe, + "otp_timeout_backoff": { + "service_id": checks["backoff_service_id"], + "states": backoff_states, + }, + }, + } + + with get_db() as db: + crud.update_registration_task( + db, + task_uuid, + status="completed", + completed_at=datetime.utcnow(), + result=result, + ) + task_manager.update_status( + task_uuid, + "completed", + email=checks["seeded_account_email"], + email_service=email_service_type, + ) + log_callback("[模拟] 任务完成,所有探针已收口") + except Exception as exc: + logger.exception("模拟任务执行失败: %s", task_uuid) + with get_db() as db: + crud.update_registration_task( + db, + task_uuid, + status="failed", + completed_at=datetime.utcnow(), + error_message=str(exc), + ) + task_manager.update_status(task_uuid, "failed", error=str(exc), email_service=email_service_type) + log_callback(f"[模拟] 任务失败: {exc}") + + async def run_batch_parallel( batch_id: str, task_uuids: List[str], @@ -1055,6 +1355,55 @@ async def run_batch_registration( # ============== API Endpoints ============== +@router.post("/create", response_model=MockRegistrationTaskCreateResponse) +async def create_mock_registration( + request: MockRegistrationCreateRequest, + background_tasks: BackgroundTasks, +): + """创建用于端到端验证的受控模拟任务。""" + try: + EmailServiceType(request.email_service_type) + except ValueError: + raise HTTPException( + status_code=400, + detail=f"无效的邮箱服务类型: {request.email_service_type}" + ) + + task_uuid = str(uuid.uuid4()) + suffix = task_uuid[:8] + batch_id = str(uuid.uuid4()) + checks: Dict[str, Any] = { + "seeded_account_email": f"mock-seeded-{suffix}@example.test", + "tokenless_account_email": f"mock-tokenless-{suffix}@example.test", + "partial_account_email": f"mock-partial-{suffix}@example.test", + "outlook_account_email": f"mock-outlook-{suffix}@example.test", + "backoff_service_name": f"mock-backoff-{suffix}", + } + + with get_db() as db: + task = crud.create_registration_task( + db, + task_uuid=task_uuid, + proxy=None, + ) + + background_tasks.add_task( + run_mock_registration_task, + task_uuid, + batch_id, + checks, + request.email_service_type, + request.start_delay_ms, + request.log_delay_ms, + ) + + return MockRegistrationTaskCreateResponse( + task=task_to_response(task), + batch_id=batch_id, + checks=checks, + ) + + @router.post("/start", response_model=RegistrationTaskResponse) async def start_registration( request: RegistrationTaskCreate, diff --git a/tests/test_registration_otp_phase.py b/tests/test_registration_otp_phase.py index 2e747bc..729dd86 100644 --- a/tests/test_registration_otp_phase.py +++ b/tests/test_registration_otp_phase.py @@ -26,6 +26,36 @@ class FakeEmailService: return self.code +class FakeCookies: + def __init__(self, values): + self.values = values + + def get(self, name): + return self.values.get(name) + + +class FakeSession: + def __init__(self, cookies=None): + self.cookies = FakeCookies(cookies or {}) + self.get_calls = [] + + def get(self, *args, **kwargs): + self.get_calls.append((args, kwargs)) + raise AssertionError("unexpected network call") + + +class FakeResponse: + def __init__(self, *, url="", text="", json_payload=None): + self.url = url + self.text = text + self._json_payload = json_payload + + def json(self): + if isinstance(self._json_payload, Exception): + raise self._json_payload + return self._json_payload + + def _build_engine(monkeypatch, email_service): monkeypatch.setattr(register_module, "get_settings", lambda: DummySettings()) return RegistrationEngine(email_service=email_service) @@ -103,3 +133,42 @@ def test_advance_login_authorization_sets_otp_anchor_before_password_submit(monk assert callback_url is None assert engine._otp_sent_at == 456.0 assert seen_anchors == [456.0, 456.0] + + +def test_get_device_id_reuses_existing_cookie_without_extra_request(monkeypatch): + email_service = FakeEmailService(code=None) + engine = _build_engine(monkeypatch, email_service) + engine.oauth_start = type("OAuthStart", (), {"auth_url": "https://auth.example.test/authorize"})() + engine.session = FakeSession(cookies={"oai-did": "did-cached"}) + + assert engine._get_device_id() == "did-cached" + assert engine.session.get_calls == [] + + +def test_extract_workspace_id_from_response_payload(monkeypatch): + email_service = FakeEmailService(code=None) + engine = _build_engine(monkeypatch, email_service) + response = FakeResponse( + url="https://auth.example.test/consent?workspace_id=ws-url", + json_payload={ + "page": { + "workspace": { + "id": "ws-json", + } + } + }, + ) + + assert engine._extract_workspace_id_from_response(response=response) == "ws-json" + + +def test_extract_workspace_id_from_response_text_when_hidden_input_missing(monkeypatch): + email_service = FakeEmailService(code=None) + engine = _build_engine(monkeypatch, email_service) + response = FakeResponse( + url="https://auth.example.test/consent", + text='', + json_payload=ValueError("not json"), + ) + + assert engine._extract_workspace_id_from_response(response=response) == "ws-script" diff --git a/tests/test_task_manager_status_broadcast.py b/tests/test_task_manager_status_broadcast.py index 270dc7d..4d087bf 100644 --- a/tests/test_task_manager_status_broadcast.py +++ b/tests/test_task_manager_status_broadcast.py @@ -1,5 +1,6 @@ import asyncio +from src.web.routes.registration import _create_task_status_callback from src.web.task_manager import task_manager @@ -38,3 +39,34 @@ def test_update_status_broadcasts_to_registered_websocket(): task_manager.unregister_websocket(task_uuid, websocket) asyncio.run(run_test()) + + +def test_task_status_callback_broadcasts_phase_fields(): + async def run_test(): + task_uuid = "test-status-phase" + websocket = FakeWebSocket() + + task_manager.set_loop(asyncio.get_running_loop()) + task_manager.register_websocket(task_uuid, websocket) + + try: + callback = _create_task_status_callback(task_uuid, "tempmail") + callback({ + "phase": "redirect_chain", + "phase_detail": "跟随重定向 1/6", + "step_index": 14, + }) + + await asyncio.sleep(0.05) + + assert websocket.messages, "expected a status message to be broadcast" + assert websocket.messages[-1]["type"] == "status" + assert websocket.messages[-1]["status"] == "running" + assert websocket.messages[-1]["email_service"] == "tempmail" + assert websocket.messages[-1]["phase"] == "redirect_chain" + assert websocket.messages[-1]["phase_detail"] == "跟随重定向 1/6" + assert websocket.messages[-1]["step_index"] == 14 + finally: + task_manager.unregister_websocket(task_uuid, websocket) + + asyncio.run(run_test())