feat(register): enhance workspace extraction and phase status reporting

This commit is contained in:
Mison
2026-03-24 10:26:59 +08:00
parent 67a446aca0
commit 5b76619d6f
4 changed files with 708 additions and 9 deletions

View File

@@ -139,6 +139,7 @@ class RegistrationEngine:
email_service: BaseEmailService,
proxy_url: Optional[str] = None,
callback_logger: Optional[Callable[[str], None]] = None,
status_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
task_uuid: Optional[str] = None,
):
"""
@@ -148,11 +149,13 @@ class RegistrationEngine:
email_service: 邮箱服务实例
proxy_url: 代理 URL
callback_logger: 日志回调函数
status_callback: 状态回调函数
task_uuid: 任务 UUID用于数据库记录
"""
self.email_service = email_service
self.proxy_url = proxy_url
self.callback_logger = callback_logger or (lambda msg: logger.info(msg))
self.status_callback = status_callback
self.task_uuid = task_uuid
# 创建 HTTP 客户端
@@ -175,6 +178,7 @@ class RegistrationEngine:
self.email_info: Optional[Dict[str, Any]] = None
self.oauth_start: Optional[OAuthStart] = None
self.session: Optional[cffi_requests.Session] = None
self.device_id: Optional[str] = None
self.session_token: Optional[str] = None # 会话令牌
self.logs: list = []
self._otp_sent_at: Optional[float] = None # OTP 发送时间戳
@@ -213,6 +217,54 @@ class RegistrationEngine:
"""生成随机密码"""
return ''.join(secrets.choice(PASSWORD_CHARSET) for _ in range(length))
def _emit_status(self, phase: str, detail: str, **extra):
"""向外部上报阶段进度。"""
if not self.status_callback:
return
payload = {
"phase": phase,
"phase_detail": detail,
}
if self.email:
payload["email"] = self.email
payload.update({key: value for key, value in extra.items() if value is not None})
try:
self.status_callback(payload)
except Exception as e:
logger.warning(f"上报任务阶段状态失败: {e}")
def _current_device_id(self) -> Optional[str]:
"""优先复用现有 Device ID避免重复触发慢请求。"""
if self.device_id:
return self.device_id
if not self.session:
return None
did = self.session.cookies.get("oai-did")
if did:
self.device_id = did
return did
def _log_timed_http_result(
self,
action: str,
started_at: float,
response: Optional[Any] = None,
):
"""记录 HTTP 调用的耗时与结果。"""
elapsed = max(0.0, time.time() - started_at)
parts = [f"{action} 完成,耗时 {elapsed:.1f}"]
if response is not None:
status_code = getattr(response, "status_code", None)
response_url = str(getattr(response, "url", "") or "").strip()
if status_code is not None:
parts.append(f"HTTP {status_code}")
if response_url:
parts.append(f"URL: {response_url[:120]}...")
self._log("".join(parts))
def _record_phase_result(self, phase_result: PhaseResult) -> PhaseResult:
self.phase_history = [
item for item in self.phase_history
@@ -311,19 +363,33 @@ class RegistrationEngine:
if not self.oauth_start:
return None
cached_did = self._current_device_id()
if cached_did:
self._log(f"复用已有 Device ID: {cached_did}")
return cached_did
max_attempts = 3
for attempt in range(1, max_attempts + 1):
try:
if not self.session:
self.session = self.http_client.session
self._emit_status(
"oauth_device_id",
f"获取 Device ID{attempt}/{max_attempts} 次)",
attempt=attempt,
max_attempts=max_attempts,
)
started_at = time.time()
response = self.session.get(
self.oauth_start.auth_url,
timeout=20
)
self._log_timed_http_result("获取 Device ID 请求", started_at, response)
did = self.session.cookies.get("oai-did")
if did:
self.device_id = did
self._log(f"Device ID: {did}")
return did
@@ -347,8 +413,15 @@ class RegistrationEngine:
def _check_sentinel(self, did: str) -> Optional[str]:
"""检查 Sentinel 拦截"""
try:
sen_req_body = f'{{"p":"","id":"{did}","flow":"authorize_continue"}}'
device_id = did or self._current_device_id()
if not device_id:
self._log("Sentinel 检查跳过: 缺少 Device ID", "warning")
return None
self._emit_status("sentinel", "请求 Sentinel 校验令牌")
sen_req_body = f'{{"p":"","id":"{device_id}","flow":"authorize_continue"}}'
started_at = time.time()
response = self.http_client.post(
OPENAI_API_ENDPOINTS["sentinel"],
headers={
@@ -358,6 +431,7 @@ class RegistrationEngine:
},
data=sen_req_body,
)
self._log_timed_http_result("Sentinel 校验", started_at, response)
if response.status_code == 200:
sen_token = response.json().get("token")
@@ -719,6 +793,53 @@ class RegistrationEngine:
return workspace_id
return None
def _extract_workspace_id_from_text(self, text: str) -> Optional[str]:
"""从 HTML/脚本文本中提取 Workspace ID。"""
if not text:
return None
patterns = [
r'"workspace_id"\s*:\s*"([^"]+)"',
r'"workspaceId"\s*:\s*"([^"]+)"',
r'"default_workspace_id"\s*:\s*"([^"]+)"',
r'"defaultWorkspaceId"\s*:\s*"([^"]+)"',
r'"active_workspace_id"\s*:\s*"([^"]+)"',
r'"activeWorkspaceId"\s*:\s*"([^"]+)"',
r'"workspace"\s*:\s*\{[^{}]*"id"\s*:\s*"([^"]+)"',
r'"default_workspace"\s*:\s*\{[^{}]*"id"\s*:\s*"([^"]+)"',
r'"active_workspace"\s*:\s*\{[^{}]*"id"\s*:\s*"([^"]+)"',
]
for pattern in patterns:
match = re.search(pattern, text)
if match:
workspace_id = str(match.group(1) or "").strip()
if workspace_id:
return workspace_id
return None
def _extract_workspace_id_from_url(self, url: str) -> Optional[str]:
"""从 URL 查询参数或片段中提取 Workspace ID。"""
if not url:
return None
import urllib.parse
parsed = urllib.parse.urlparse(url)
for raw_query in (parsed.query, parsed.fragment):
query = urllib.parse.parse_qs(raw_query)
for key in (
"workspace_id",
"workspaceId",
"default_workspace_id",
"active_workspace_id",
):
values = query.get(key) or []
if values:
workspace_id = str(values[0] or "").strip()
if workspace_id:
return workspace_id
return None
def _decode_cookie_json_candidates(self, cookie_value: str) -> list[Dict[str, Any]]:
"""尝试从完整 Cookie 或其分段中解码出 JSON。"""
decoded_objects = []
@@ -760,12 +881,25 @@ class RegistrationEngine:
if workspace_id:
return workspace_id
for key in ("workspace_id", "default_workspace_id", "active_workspace_id"):
for key in (
"workspace_id",
"workspaceId",
"default_workspace_id",
"defaultWorkspaceId",
"active_workspace_id",
"activeWorkspaceId",
):
workspace_id = str(auth_json.get(key) or "").strip()
if workspace_id:
return workspace_id
for key in ("workspace", "default_workspace", "active_workspace"):
for key in (
"workspace",
"default_workspace",
"active_workspace",
"defaultWorkspace",
"activeWorkspace",
):
workspace = auth_json.get(key)
if not isinstance(workspace, dict):
continue
@@ -776,6 +910,60 @@ class RegistrationEngine:
return None
def _extract_workspace_id_from_response(
self,
response: Optional[Any] = None,
html: Optional[str] = None,
url: Optional[str] = None,
) -> Optional[str]:
"""统一从响应 JSON、HTML、脚本内容和 URL 中提取 Workspace ID。"""
response_url = str(getattr(response, "url", "") or "").strip()
response_text = html if html is not None else str(getattr(response, "text", "") or "")
candidate_url = url or response_url
if response is not None:
try:
payload = response.json()
except Exception:
payload = None
workspace_id = self._extract_workspace_id_from_response_payload(payload)
if workspace_id:
return workspace_id
for extractor in (
lambda: self._extract_workspace_id_from_html(response_text),
lambda: self._extract_workspace_id_from_text(response_text),
lambda: self._extract_workspace_id_from_url(candidate_url),
):
workspace_id = extractor()
if workspace_id:
return workspace_id
return None
def _extract_workspace_id_from_response_payload(self, payload: Any, depth: int = 0) -> Optional[str]:
"""递归扫描响应载荷中的 Workspace ID。"""
if payload is None or depth > 5:
return None
if isinstance(payload, dict):
workspace_id = self._extract_workspace_id_from_auth_json(payload)
if workspace_id:
return workspace_id
for value in payload.values():
workspace_id = self._extract_workspace_id_from_response_payload(value, depth + 1)
if workspace_id:
return workspace_id
return None
if isinstance(payload, list):
for item in payload:
workspace_id = self._extract_workspace_id_from_response_payload(item, depth + 1)
if workspace_id:
return workspace_id
return None
def _select_workspace(self, workspace_id: str) -> Optional[str]:
"""选择 Workspace"""
try:
@@ -858,12 +1046,16 @@ class RegistrationEngine:
return False
try:
did = self.session.cookies.get("oai-did") if self.session else None
self._emit_status("login_reentry", "重新进入登录流程")
did = self._current_device_id()
sen_token = self._check_sentinel(did) if did else None
self._log("登录重入:请求 authorize 页面以确认当前表单状态")
started_at = time.time()
response = self.session.get(
self.oauth_start.auth_url,
timeout=15,
)
self._log_timed_http_result("登录重入 authorize 页面", started_at, response)
html = response.text or ""
if "/log-in/password" in str(getattr(response, "url", "") or "") or 'action="/log-in/password"' in html:
@@ -877,6 +1069,9 @@ class RegistrationEngine:
"value": self.email,
}
}
self._emit_status("login_reentry", "提交邮箱以推进到密码页")
self._log("登录重入:提交邮箱到 authorize/continue")
started_at = time.time()
login_response = self.session.post(
"https://auth.openai.com/api/accounts/authorize/continue",
headers={
@@ -902,12 +1097,20 @@ class RegistrationEngine:
data=json.dumps(login_data),
timeout=15,
)
self._log_timed_http_result("登录重入邮箱提交", started_at, login_response)
login_json = login_response.json() if login_response.status_code == 200 else {}
page_type = str((login_json or {}).get("page", {}).get("type") or "").strip()
continue_url = str((login_json or {}).get("continue_url") or "").strip()
self._log(
f"登录重入响应: page_type={page_type or 'unknown'}, "
f"continue_url={continue_url[:100] + '...' if continue_url else 'none'}"
)
if continue_url:
try:
self._emit_status("login_reentry", "跟进登录 continue_url")
started_at = time.time()
self.session.get(continue_url, timeout=15)
self._log_timed_http_result("登录重入 continue_url", started_at)
except Exception:
pass
if login_response.status_code == 200 and page_type in {"password", "login_password"}:
@@ -926,8 +1129,10 @@ class RegistrationEngine:
return False
try:
did = self.session.cookies.get("oai-did") if self.session else None
self._emit_status("login_password", "提交登录密码")
did = self._current_device_id()
sen_token = self._check_sentinel(did) if did else None
started_at = time.time()
response = self.session.post(
"https://auth.openai.com/api/accounts/password/verify",
headers={
@@ -955,6 +1160,7 @@ class RegistrationEngine:
}),
timeout=15,
)
self._log_timed_http_result("登录密码提交", started_at, response)
self._log(f"登录密码提交状态: {response.status_code}")
if response.status_code == 200:
try:
@@ -964,7 +1170,10 @@ class RegistrationEngine:
continue_url = str(payload.get("continue_url") or "").strip()
if continue_url:
try:
self._emit_status("login_password", "跟进密码校验 continue_url")
started_at = time.time()
self.session.get(continue_url, timeout=15)
self._log_timed_http_result("密码校验 continue_url", started_at)
except Exception:
pass
return response.status_code in (200, 302, 303)
@@ -977,7 +1186,7 @@ class RegistrationEngine:
return False, None
try:
did = self.session.cookies.get("oai-did") if self.session else None
did = self._current_device_id()
sen_token = self._check_sentinel(did) if did else None
response = self.session.post(
"https://auth.openai.com/api/accounts/password/verify",
@@ -1061,6 +1270,7 @@ class RegistrationEngine:
self._log("重新初始化登录会话失败", "warning")
return None, None
self._emit_status("oauth_reentry", "重新初始化 OAuth 登录会话")
if not self._start_oauth():
self._log("重新开始 OAuth 登录流程失败", "warning")
return None, None
@@ -1074,6 +1284,7 @@ class RegistrationEngine:
return None, None
self._otp_sent_at = time.time()
self._emit_status("otp_secondary", "等待登录验证码邮件")
if not self._submit_login_password_step():
return None, None
@@ -1089,12 +1300,16 @@ class RegistrationEngine:
return None, None
auth_target = consent_url or self.oauth_start.auth_url
self._emit_status("workspace_extract", "请求 consent 页面并提取 Workspace ID")
self._log(f"请求 consent 页面: {auth_target[:120]}...")
started_at = time.time()
auth_response = self.session.get(auth_target, timeout=20)
self._log_timed_http_result("获取 consent 页面", started_at, auth_response)
current_url = str(getattr(auth_response, "url", "") or "")
html = auth_response.text or ""
if "sign-in-with-chatgpt/codex/consent" in current_url or 'action="/sign-in-with-chatgpt/codex/consent"' in html:
workspace_id = self._extract_workspace_id_from_html(html)
workspace_id = self._extract_workspace_id_from_response(response=auth_response, html=html, url=current_url)
if not workspace_id:
self._log("consent 页面缺少 workspace_id回退到 Cookie 解析路径", "warning")
return None, None
@@ -1115,13 +1330,22 @@ class RegistrationEngine:
max_redirects = 6
for i in range(max_redirects):
self._emit_status(
"redirect_chain",
f"跟随重定向 {i + 1}/{max_redirects}",
redirect_index=i + 1,
redirect_total=max_redirects,
redirect_url=current_url[:200],
)
self._log(f"重定向 {i+1}/{max_redirects}: {current_url[:100]}...")
started_at = time.time()
response = self.session.get(
current_url,
allow_redirects=False,
timeout=15
)
self._log_timed_http_result(f"重定向跳转 {i + 1}/{max_redirects}", started_at, response)
location = response.headers.get("Location") or ""
@@ -1137,6 +1361,7 @@ class RegistrationEngine:
# 构建下一个 URL
import urllib.parse
next_url = urllib.parse.urljoin(current_url, location)
self._log(f"重定向下一跳: {next_url[:100]}...")
# 检查是否包含回调参数
if "code=" in next_url and "state=" in next_url:
@@ -1159,12 +1384,19 @@ class RegistrationEngine:
self._log("OAuth 流程未初始化", "error")
return None
self._emit_status("oauth_callback", "处理 OAuth 回调并交换令牌")
self._log("处理 OAuth 回调...")
started_at = time.time()
token_info = self.oauth_manager.handle_callback(
callback_url=callback_url,
expected_state=self.oauth_start.state,
code_verifier=self.oauth_start.code_verifier
)
elapsed = max(0.0, time.time() - started_at)
self._log(
f"OAuth 回调处理完成,耗时 {elapsed:.1f} 秒,"
f"account_id={str(token_info.get('account_id') or '').strip() or 'unknown'}"
)
self._log("OAuth 授权成功")
return token_info
@@ -1197,6 +1429,7 @@ class RegistrationEngine:
# 1. 检查 IP 地理位置
self._log("1. 检查 IP 地理位置...")
self._emit_status("ip_check", "检查 IP 地理位置", step_index=1)
ip_ok, location = self._check_ip_location()
if not ip_ok:
result.error_message = f"IP 地理位置不支持: {location}"
@@ -1207,6 +1440,7 @@ class RegistrationEngine:
# 2. 创建邮箱
self._log("2. 创建邮箱...")
self._emit_status("email_prepare", "创建邮箱地址", step_index=2)
if not self._phase_email_prepare():
email_prepare_phase = self._get_phase_result(PHASE_EMAIL_PREPARE)
result.error_message = (
@@ -1221,18 +1455,21 @@ class RegistrationEngine:
# 3. 初始化会话
self._log("3. 初始化会话...")
self._emit_status("session_init", "初始化 HTTP 会话", step_index=3)
if not self._init_session():
result.error_message = "初始化会话失败"
return result
# 4. 开始 OAuth 流程
self._log("4. 开始 OAuth 授权流程...")
self._emit_status("oauth_start", "开始 OAuth 授权流程", step_index=4)
if not self._start_oauth():
result.error_message = "开始 OAuth 流程失败"
return result
# 5. 获取 Device ID
self._log("5. 获取 Device ID...")
self._emit_status("oauth_device_id", "获取 Device ID", step_index=5)
did = self._get_device_id()
if not did:
result.error_message = "获取 Device ID 失败"
@@ -1240,6 +1477,7 @@ class RegistrationEngine:
# 6. 检查 Sentinel 拦截
self._log("6. 检查 Sentinel 拦截...")
self._emit_status("sentinel", "检查 Sentinel 拦截", step_index=6)
sen_token = self._check_sentinel(did)
if sen_token:
self._log("Sentinel 检查通过")
@@ -1248,6 +1486,7 @@ class RegistrationEngine:
# 7. 提交注册表单 + 解析响应判断账号状态
self._log("7. 提交注册表单...")
self._emit_status("signup_submit", "提交注册表单", step_index=7)
signup_result = self._submit_signup_form(did, sen_token)
if not signup_result.success:
result.error_message = f"提交注册表单失败: {signup_result.error_message}"
@@ -1258,6 +1497,7 @@ class RegistrationEngine:
self._log("8. [已注册账号] 跳过密码设置OTP 已自动发送")
else:
self._log("8. 注册密码...")
self._emit_status("signup_password", "提交注册密码", step_index=8)
password_ok, password = self._register_password()
if not password_ok:
result.error_message = "注册密码失败"
@@ -1270,12 +1510,14 @@ class RegistrationEngine:
self._otp_sent_at = time.time()
else:
self._log("9. 发送验证码...")
self._emit_status("otp_send", "发送验证码", step_index=9)
if not self._send_verification_code():
result.error_message = "发送验证码失败"
return result
# 10. 获取验证码
self._log("10. 等待验证码...")
self._emit_status("otp_secondary", "等待验证码邮件", step_index=10)
otp_phase_started_at = time.time()
code, otp_phase = self._phase_otp_secondary(
PhaseContext(otp_sent_at=self._otp_sent_at),
@@ -1290,6 +1532,7 @@ class RegistrationEngine:
# 11. 验证验证码
self._log("11. 验证验证码...")
self._emit_status("otp_validate", "校验验证码", step_index=11)
if not self._validate_verification_code(code):
result.error_message = "验证验证码失败"
return result
@@ -1299,6 +1542,7 @@ class RegistrationEngine:
self._log("12. [已注册账号] 跳过创建用户账户")
else:
self._log("12. 创建用户账户...")
self._emit_status("account_create", "创建 OpenAI 账户资料", step_index=12)
if not self._create_user_account():
result.error_message = "创建用户账户失败"
return result
@@ -1308,6 +1552,7 @@ class RegistrationEngine:
if not self._is_existing_account:
self._log(f"{next_step}. [新账号] 推进 Codex 授权流程...")
self._emit_status("oauth_reentry", "推进 Codex 授权流程", step_index=next_step)
workspace_id, callback_url = self._advance_login_authorization()
if workspace_id and callback_url:
result.workspace_id = workspace_id
@@ -1316,6 +1561,7 @@ class RegistrationEngine:
if not result.workspace_id:
# 获取 Workspace ID
self._log(f"{next_step}. 获取 Workspace ID...")
self._emit_status("workspace_extract", "从授权态提取 Workspace ID", step_index=next_step)
workspace_id = self._get_workspace_id()
if not workspace_id:
result.error_message = "获取 Workspace ID 失败"
@@ -1327,6 +1573,7 @@ class RegistrationEngine:
# 选择 Workspace
self._log(f"{next_step}. 选择 Workspace...")
self._emit_status("workspace_select", "选择 Workspace", step_index=next_step)
continue_url = self._select_workspace(result.workspace_id)
if not continue_url:
result.error_message = "选择 Workspace 失败"
@@ -1336,6 +1583,7 @@ class RegistrationEngine:
# 跟随重定向链
self._log(f"{next_step}. 跟随重定向链...")
self._emit_status("redirect_chain", "跟随授权重定向链", step_index=next_step)
callback_url = self._follow_redirects(continue_url)
if not callback_url:
result.error_message = "跟随重定向链失败"
@@ -1345,6 +1593,7 @@ class RegistrationEngine:
# 处理 OAuth 回调
self._log(f"{next_step}. 处理 OAuth 回调...")
self._emit_status("oauth_callback", "处理 OAuth 回调", step_index=next_step)
token_info = self._handle_oauth_callback(callback_url)
if not token_info:
result.error_message = "处理 OAuth 回调失败"

View File

@@ -10,7 +10,7 @@ import random
import re
import time
from datetime import datetime
from typing import List, Optional, Dict, Tuple
from typing import List, Optional, Dict, Tuple, Any
from fastapi import APIRouter, HTTPException, Query, BackgroundTasks
from pydantic import BaseModel, Field
@@ -24,7 +24,7 @@ from ...core.register import (
RegistrationResult,
)
from ...services import EmailServiceFactory, EmailServiceType
from ...services.base import EmailProviderBackoffState, OTPTimeoutEmailServiceError
from ...services.base import BaseEmailService, EmailProviderBackoffState, OTPTimeoutEmailServiceError
from ...config.settings import get_settings
from ..task_manager import task_manager
@@ -136,6 +136,13 @@ class BatchRegistrationRequest(BaseModel):
tm_service_ids: List[int] = []
class MockRegistrationCreateRequest(BaseModel):
"""创建受控模拟任务请求"""
email_service_type: str = "tempmail"
start_delay_ms: int = Field(default=300, ge=0, le=5000)
log_delay_ms: int = Field(default=250, ge=0, le=5000)
class RegistrationTaskResponse(BaseModel):
"""注册任务响应"""
id: int
@@ -161,6 +168,13 @@ class BatchRegistrationResponse(BaseModel):
tasks: List[RegistrationTaskResponse]
class MockRegistrationTaskCreateResponse(BaseModel):
"""受控模拟任务响应"""
task: RegistrationTaskResponse
batch_id: str
checks: Dict[str, Any]
class TaskListResponse(BaseModel):
"""任务列表响应"""
total: int
@@ -232,6 +246,19 @@ def task_to_response(task: RegistrationTask) -> RegistrationTaskResponse:
)
def _create_task_status_callback(task_uuid: str, email_service: str):
"""把引擎内部阶段进度映射到 TaskManager 状态广播。"""
def callback(payload: Dict[str, Any]) -> None:
status_payload = {
"email_service": email_service,
**payload,
}
task_manager.update_status(task_uuid, "running", **status_payload)
return callback
def _normalize_email_service_config(
service_type: EmailServiceType,
config: Optional[dict],
@@ -329,6 +356,7 @@ def _run_registration_engine_attempt(
actual_proxy_url: Optional[str],
log_callback,
db_service,
status_callback=None,
):
"""执行单次注册引擎尝试,并在同一临界区内维护邮箱服务退避状态。"""
provider_backoff_before_run = EmailProviderBackoffState()
@@ -343,6 +371,7 @@ def _run_registration_engine_attempt(
email_service=email_service,
proxy_url=actual_proxy_url,
callback_logger=log_callback,
status_callback=status_callback,
task_uuid=task_uuid,
)
@@ -570,6 +599,7 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy:
crud.update_registration_task(db, task_uuid, email_service_id=None)
task_manager.update_status(task_uuid, "running", email_service=active_service_type.value)
status_callback = _create_task_status_callback(task_uuid, active_service_type.value)
email_service = EmailServiceFactory.create(
selected_service_type,
candidate_config,
@@ -587,6 +617,7 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy:
actual_proxy_url=actual_proxy_url,
log_callback=log_callback,
db_service=db_service,
status_callback=status_callback,
)
if result.success:
@@ -856,6 +887,275 @@ def _make_batch_helpers(batch_id: str):
return add_batch_log, update_batch_status
class _MockBackoffEmailService(BaseEmailService):
"""用于真实服务验证的最小邮箱服务桩。"""
def __init__(self):
super().__init__(service_type=EmailServiceType.DUCK_MAIL, name="mock-backoff-service")
def create_email(self, config: Dict[str, Any] = None) -> Dict[str, Any]:
return {"email": "mock@example.test", "service_id": "mock-service-id"}
def get_verification_code(
self,
email: str,
email_id: str = None,
timeout: int = 120,
pattern: str = r"(?<!\d)(\d{6})(?!\d)",
otp_sent_at: Optional[float] = None,
) -> Optional[str]:
return None
def list_emails(self, **kwargs) -> List[Dict[str, Any]]:
return []
def delete_email(self, email_id: str) -> bool:
return True
def check_health(self) -> bool:
return True
def _create_persisted_log_callback(task_uuid: str, prefix: str = "", batch_id: str = ""):
"""同时写入内存日志队列、批量日志通道和数据库任务日志。"""
def callback(message: str) -> None:
full_message = f"{prefix} {message}" if prefix else message
task_manager.add_log(task_uuid, full_message)
if batch_id:
task_manager.add_batch_log(batch_id, full_message)
with get_db() as db:
crud.append_task_log(db, task_uuid, full_message)
return callback
def _simulate_batch_counter_probe(batch_id: str) -> Dict[str, Any]:
"""构造一个可重复的批量计数场景,验证 TaskManager 计数收口。"""
task_uuids = [str(uuid.uuid4()) for _ in range(3)]
task_statuses = ["completed", "failed", "completed"]
_init_batch_state(batch_id, task_uuids)
add_batch_log, update_batch_status = _make_batch_helpers(batch_id)
add_batch_log(f"[系统] 模拟批量任务启动,总任务: {len(task_uuids)}")
with get_db() as db:
for index, (task_uuid, status) in enumerate(zip(task_uuids, task_statuses), start=1):
crud.create_registration_task(db, task_uuid=task_uuid, proxy=None)
error_message = None if status == "completed" else f"mock-batch-error-{index}"
crud.update_registration_task(
db,
task_uuid,
status=status,
started_at=datetime.utcnow(),
completed_at=datetime.utcnow(),
error_message=error_message,
)
batch_snapshot = _get_batch_snapshot(batch_id) or {}
new_completed = batch_snapshot.get("completed", 0) + 1
new_success = batch_snapshot.get("success", 0)
new_failed = batch_snapshot.get("failed", 0)
if status == "completed":
new_success += 1
add_batch_log(f"[任务{index}] [成功] 模拟注册成功")
else:
new_failed += 1
add_batch_log(f"[任务{index}] [失败] 模拟注册失败: {error_message}")
update_batch_status(completed=new_completed, success=new_success, failed=new_failed)
batch_snapshot = _get_batch_snapshot(batch_id) or {}
add_batch_log(
f"[完成] 批量任务完成!成功: {batch_snapshot.get('success', 0)}, "
f"失败: {batch_snapshot.get('failed', 0)}"
)
update_batch_status(finished=True, status="completed")
return {
"batch_id": batch_id,
"task_uuids": task_uuids,
"snapshot": task_manager.get_batch_status(batch_id) or {},
}
async def run_mock_registration_task(
task_uuid: str,
batch_id: str,
checks: Dict[str, Any],
email_service_type: str,
start_delay_ms: int,
log_delay_ms: int,
) -> None:
"""通过真实服务链路执行可重复的模拟任务。"""
if start_delay_ms > 0:
await asyncio.sleep(start_delay_ms / 1000)
loop = task_manager.get_loop()
if loop is None:
loop = asyncio.get_event_loop()
task_manager.set_loop(loop)
log_callback = _create_persisted_log_callback(task_uuid)
delay_seconds = max(log_delay_ms, 0) / 1000
try:
with get_db() as db:
task = crud.update_registration_task(
db,
task_uuid,
status="running",
started_at=datetime.utcnow(),
)
if not task:
logger.error(f"模拟任务不存在: {task_uuid}")
return
task_manager.update_status(task_uuid, "running", email_service=email_service_type)
log_callback("[模拟] 任务已启动,开始执行真实链路探针")
if delay_seconds:
await asyncio.sleep(delay_seconds)
with get_db() as db:
seeded_account = crud.create_account(
db,
email=checks["seeded_account_email"],
email_service="tempmail",
access_token="mock-access-token-seeded",
refresh_token="mock-refresh-token-seeded",
)
tokenless_account = crud.create_account(
db,
email=checks["tokenless_account_email"],
email_service="tempmail",
)
crud.update_account(
db,
tokenless_account.id,
access_token="mock-access-token-updated",
)
partial_account = crud.create_account(
db,
email=checks["partial_account_email"],
email_service="tempmail",
access_token="mock-access-token-partial",
refresh_token="mock-refresh-token-partial",
)
crud.update_account(
db,
partial_account.id,
refresh_token="",
)
outlook_service = crud.create_email_service(
db,
service_type="outlook",
name=f"mock-outlook-{task_uuid[:8]}",
config={
"accounts": [
{"email": "first@example.test", "refresh_token": "old-first"},
{
"email": checks["outlook_account_email"],
"refresh_token": "old-second",
},
]
},
)
crud.update_outlook_refresh_token(
db,
service_id=outlook_service.id,
email=checks["outlook_account_email"],
new_refresh_token="new-second",
)
backoff_service = crud.create_email_service(
db,
service_type="duck_mail",
name=checks["backoff_service_name"],
config={
"base_url": "https://mail.example.test",
"default_domain": "example.test",
},
)
checks["seeded_account_id"] = seeded_account.id
checks["tokenless_account_id"] = tokenless_account.id
checks["partial_account_id"] = partial_account.id
checks["outlook_service_id"] = outlook_service.id
checks["backoff_service_id"] = backoff_service.id
log_callback("[模拟] Token 同步与 Outlook refresh_token 探针已写入数据库")
if delay_seconds:
await asyncio.sleep(delay_seconds)
mock_email_service = _MockBackoffEmailService()
backoff_states = []
for attempt in range(1, 4):
previous_state = _get_email_service_backoff_state(backoff_service.id)
current_state = _record_email_service_timeout_backoff(
backoff_service.id,
mock_email_service,
previous_state,
ERROR_OTP_TIMEOUT_SECONDARY,
f"模拟 OTP 超时 #{attempt}",
)
if current_state is not None:
backoff_states.append(current_state.to_dict())
log_callback(
f"[模拟] OTP 超时退避 #{attempt}: "
f"failures={current_state.failures}, delay={current_state.delay_seconds}"
)
if delay_seconds:
await asyncio.sleep(delay_seconds)
batch_probe = _simulate_batch_counter_probe(batch_id)
log_callback("[模拟] 批量计数探针已完成")
if delay_seconds:
await asyncio.sleep(delay_seconds)
result = {
"email": checks["seeded_account_email"],
"email_service": email_service_type,
"hardening_checks": {
"token_sync": {
"seeded_account_id": checks["seeded_account_id"],
"tokenless_account_id": checks["tokenless_account_id"],
"partial_account_id": checks["partial_account_id"],
},
"outlook_refresh": {
"service_id": checks["outlook_service_id"],
"email": checks["outlook_account_email"],
},
"batch_counter": batch_probe,
"otp_timeout_backoff": {
"service_id": checks["backoff_service_id"],
"states": backoff_states,
},
},
}
with get_db() as db:
crud.update_registration_task(
db,
task_uuid,
status="completed",
completed_at=datetime.utcnow(),
result=result,
)
task_manager.update_status(
task_uuid,
"completed",
email=checks["seeded_account_email"],
email_service=email_service_type,
)
log_callback("[模拟] 任务完成,所有探针已收口")
except Exception as exc:
logger.exception("模拟任务执行失败: %s", task_uuid)
with get_db() as db:
crud.update_registration_task(
db,
task_uuid,
status="failed",
completed_at=datetime.utcnow(),
error_message=str(exc),
)
task_manager.update_status(task_uuid, "failed", error=str(exc), email_service=email_service_type)
log_callback(f"[模拟] 任务失败: {exc}")
async def run_batch_parallel(
batch_id: str,
task_uuids: List[str],
@@ -1055,6 +1355,55 @@ async def run_batch_registration(
# ============== API Endpoints ==============
@router.post("/create", response_model=MockRegistrationTaskCreateResponse)
async def create_mock_registration(
request: MockRegistrationCreateRequest,
background_tasks: BackgroundTasks,
):
"""创建用于端到端验证的受控模拟任务。"""
try:
EmailServiceType(request.email_service_type)
except ValueError:
raise HTTPException(
status_code=400,
detail=f"无效的邮箱服务类型: {request.email_service_type}"
)
task_uuid = str(uuid.uuid4())
suffix = task_uuid[:8]
batch_id = str(uuid.uuid4())
checks: Dict[str, Any] = {
"seeded_account_email": f"mock-seeded-{suffix}@example.test",
"tokenless_account_email": f"mock-tokenless-{suffix}@example.test",
"partial_account_email": f"mock-partial-{suffix}@example.test",
"outlook_account_email": f"mock-outlook-{suffix}@example.test",
"backoff_service_name": f"mock-backoff-{suffix}",
}
with get_db() as db:
task = crud.create_registration_task(
db,
task_uuid=task_uuid,
proxy=None,
)
background_tasks.add_task(
run_mock_registration_task,
task_uuid,
batch_id,
checks,
request.email_service_type,
request.start_delay_ms,
request.log_delay_ms,
)
return MockRegistrationTaskCreateResponse(
task=task_to_response(task),
batch_id=batch_id,
checks=checks,
)
@router.post("/start", response_model=RegistrationTaskResponse)
async def start_registration(
request: RegistrationTaskCreate,

View File

@@ -26,6 +26,36 @@ class FakeEmailService:
return self.code
class FakeCookies:
def __init__(self, values):
self.values = values
def get(self, name):
return self.values.get(name)
class FakeSession:
def __init__(self, cookies=None):
self.cookies = FakeCookies(cookies or {})
self.get_calls = []
def get(self, *args, **kwargs):
self.get_calls.append((args, kwargs))
raise AssertionError("unexpected network call")
class FakeResponse:
def __init__(self, *, url="", text="", json_payload=None):
self.url = url
self.text = text
self._json_payload = json_payload
def json(self):
if isinstance(self._json_payload, Exception):
raise self._json_payload
return self._json_payload
def _build_engine(monkeypatch, email_service):
monkeypatch.setattr(register_module, "get_settings", lambda: DummySettings())
return RegistrationEngine(email_service=email_service)
@@ -103,3 +133,42 @@ def test_advance_login_authorization_sets_otp_anchor_before_password_submit(monk
assert callback_url is None
assert engine._otp_sent_at == 456.0
assert seen_anchors == [456.0, 456.0]
def test_get_device_id_reuses_existing_cookie_without_extra_request(monkeypatch):
email_service = FakeEmailService(code=None)
engine = _build_engine(monkeypatch, email_service)
engine.oauth_start = type("OAuthStart", (), {"auth_url": "https://auth.example.test/authorize"})()
engine.session = FakeSession(cookies={"oai-did": "did-cached"})
assert engine._get_device_id() == "did-cached"
assert engine.session.get_calls == []
def test_extract_workspace_id_from_response_payload(monkeypatch):
email_service = FakeEmailService(code=None)
engine = _build_engine(monkeypatch, email_service)
response = FakeResponse(
url="https://auth.example.test/consent?workspace_id=ws-url",
json_payload={
"page": {
"workspace": {
"id": "ws-json",
}
}
},
)
assert engine._extract_workspace_id_from_response(response=response) == "ws-json"
def test_extract_workspace_id_from_response_text_when_hidden_input_missing(monkeypatch):
email_service = FakeEmailService(code=None)
engine = _build_engine(monkeypatch, email_service)
response = FakeResponse(
url="https://auth.example.test/consent",
text='<script>window.__NEXT_DATA__={"activeWorkspaceId":"ws-script"}</script>',
json_payload=ValueError("not json"),
)
assert engine._extract_workspace_id_from_response(response=response) == "ws-script"

View File

@@ -1,5 +1,6 @@
import asyncio
from src.web.routes.registration import _create_task_status_callback
from src.web.task_manager import task_manager
@@ -38,3 +39,34 @@ def test_update_status_broadcasts_to_registered_websocket():
task_manager.unregister_websocket(task_uuid, websocket)
asyncio.run(run_test())
def test_task_status_callback_broadcasts_phase_fields():
async def run_test():
task_uuid = "test-status-phase"
websocket = FakeWebSocket()
task_manager.set_loop(asyncio.get_running_loop())
task_manager.register_websocket(task_uuid, websocket)
try:
callback = _create_task_status_callback(task_uuid, "tempmail")
callback({
"phase": "redirect_chain",
"phase_detail": "跟随重定向 1/6",
"step_index": 14,
})
await asyncio.sleep(0.05)
assert websocket.messages, "expected a status message to be broadcast"
assert websocket.messages[-1]["type"] == "status"
assert websocket.messages[-1]["status"] == "running"
assert websocket.messages[-1]["email_service"] == "tempmail"
assert websocket.messages[-1]["phase"] == "redirect_chain"
assert websocket.messages[-1]["phase_detail"] == "跟随重定向 1/6"
assert websocket.messages[-1]["step_index"] == 14
finally:
task_manager.unregister_websocket(task_uuid, websocket)
asyncio.run(run_test())