mirror of
https://github.com/cnlimiter/codex-register.git
synced 2026-05-06 20:02:51 +08:00
feat(register): enhance workspace extraction and phase status reporting
This commit is contained in:
@@ -139,6 +139,7 @@ class RegistrationEngine:
|
||||
email_service: BaseEmailService,
|
||||
proxy_url: Optional[str] = None,
|
||||
callback_logger: Optional[Callable[[str], None]] = None,
|
||||
status_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
|
||||
task_uuid: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
@@ -148,11 +149,13 @@ class RegistrationEngine:
|
||||
email_service: 邮箱服务实例
|
||||
proxy_url: 代理 URL
|
||||
callback_logger: 日志回调函数
|
||||
status_callback: 状态回调函数
|
||||
task_uuid: 任务 UUID(用于数据库记录)
|
||||
"""
|
||||
self.email_service = email_service
|
||||
self.proxy_url = proxy_url
|
||||
self.callback_logger = callback_logger or (lambda msg: logger.info(msg))
|
||||
self.status_callback = status_callback
|
||||
self.task_uuid = task_uuid
|
||||
|
||||
# 创建 HTTP 客户端
|
||||
@@ -175,6 +178,7 @@ class RegistrationEngine:
|
||||
self.email_info: Optional[Dict[str, Any]] = None
|
||||
self.oauth_start: Optional[OAuthStart] = None
|
||||
self.session: Optional[cffi_requests.Session] = None
|
||||
self.device_id: Optional[str] = None
|
||||
self.session_token: Optional[str] = None # 会话令牌
|
||||
self.logs: list = []
|
||||
self._otp_sent_at: Optional[float] = None # OTP 发送时间戳
|
||||
@@ -213,6 +217,54 @@ class RegistrationEngine:
|
||||
"""生成随机密码"""
|
||||
return ''.join(secrets.choice(PASSWORD_CHARSET) for _ in range(length))
|
||||
|
||||
def _emit_status(self, phase: str, detail: str, **extra):
|
||||
"""向外部上报阶段进度。"""
|
||||
if not self.status_callback:
|
||||
return
|
||||
|
||||
payload = {
|
||||
"phase": phase,
|
||||
"phase_detail": detail,
|
||||
}
|
||||
if self.email:
|
||||
payload["email"] = self.email
|
||||
payload.update({key: value for key, value in extra.items() if value is not None})
|
||||
|
||||
try:
|
||||
self.status_callback(payload)
|
||||
except Exception as e:
|
||||
logger.warning(f"上报任务阶段状态失败: {e}")
|
||||
|
||||
def _current_device_id(self) -> Optional[str]:
|
||||
"""优先复用现有 Device ID,避免重复触发慢请求。"""
|
||||
if self.device_id:
|
||||
return self.device_id
|
||||
if not self.session:
|
||||
return None
|
||||
|
||||
did = self.session.cookies.get("oai-did")
|
||||
if did:
|
||||
self.device_id = did
|
||||
return did
|
||||
|
||||
def _log_timed_http_result(
|
||||
self,
|
||||
action: str,
|
||||
started_at: float,
|
||||
response: Optional[Any] = None,
|
||||
):
|
||||
"""记录 HTTP 调用的耗时与结果。"""
|
||||
elapsed = max(0.0, time.time() - started_at)
|
||||
parts = [f"{action} 完成,耗时 {elapsed:.1f} 秒"]
|
||||
if response is not None:
|
||||
status_code = getattr(response, "status_code", None)
|
||||
response_url = str(getattr(response, "url", "") or "").strip()
|
||||
if status_code is not None:
|
||||
parts.append(f"HTTP {status_code}")
|
||||
if response_url:
|
||||
parts.append(f"URL: {response_url[:120]}...")
|
||||
self._log(",".join(parts))
|
||||
|
||||
def _record_phase_result(self, phase_result: PhaseResult) -> PhaseResult:
|
||||
self.phase_history = [
|
||||
item for item in self.phase_history
|
||||
@@ -311,19 +363,33 @@ class RegistrationEngine:
|
||||
if not self.oauth_start:
|
||||
return None
|
||||
|
||||
cached_did = self._current_device_id()
|
||||
if cached_did:
|
||||
self._log(f"复用已有 Device ID: {cached_did}")
|
||||
return cached_did
|
||||
|
||||
max_attempts = 3
|
||||
for attempt in range(1, max_attempts + 1):
|
||||
try:
|
||||
if not self.session:
|
||||
self.session = self.http_client.session
|
||||
|
||||
self._emit_status(
|
||||
"oauth_device_id",
|
||||
f"获取 Device ID(第 {attempt}/{max_attempts} 次)",
|
||||
attempt=attempt,
|
||||
max_attempts=max_attempts,
|
||||
)
|
||||
started_at = time.time()
|
||||
response = self.session.get(
|
||||
self.oauth_start.auth_url,
|
||||
timeout=20
|
||||
)
|
||||
self._log_timed_http_result("获取 Device ID 请求", started_at, response)
|
||||
did = self.session.cookies.get("oai-did")
|
||||
|
||||
if did:
|
||||
self.device_id = did
|
||||
self._log(f"Device ID: {did}")
|
||||
return did
|
||||
|
||||
@@ -347,8 +413,15 @@ class RegistrationEngine:
|
||||
def _check_sentinel(self, did: str) -> Optional[str]:
|
||||
"""检查 Sentinel 拦截"""
|
||||
try:
|
||||
sen_req_body = f'{{"p":"","id":"{did}","flow":"authorize_continue"}}'
|
||||
device_id = did or self._current_device_id()
|
||||
if not device_id:
|
||||
self._log("Sentinel 检查跳过: 缺少 Device ID", "warning")
|
||||
return None
|
||||
|
||||
self._emit_status("sentinel", "请求 Sentinel 校验令牌")
|
||||
sen_req_body = f'{{"p":"","id":"{device_id}","flow":"authorize_continue"}}'
|
||||
|
||||
started_at = time.time()
|
||||
response = self.http_client.post(
|
||||
OPENAI_API_ENDPOINTS["sentinel"],
|
||||
headers={
|
||||
@@ -358,6 +431,7 @@ class RegistrationEngine:
|
||||
},
|
||||
data=sen_req_body,
|
||||
)
|
||||
self._log_timed_http_result("Sentinel 校验", started_at, response)
|
||||
|
||||
if response.status_code == 200:
|
||||
sen_token = response.json().get("token")
|
||||
@@ -719,6 +793,53 @@ class RegistrationEngine:
|
||||
return workspace_id
|
||||
return None
|
||||
|
||||
def _extract_workspace_id_from_text(self, text: str) -> Optional[str]:
|
||||
"""从 HTML/脚本文本中提取 Workspace ID。"""
|
||||
if not text:
|
||||
return None
|
||||
|
||||
patterns = [
|
||||
r'"workspace_id"\s*:\s*"([^"]+)"',
|
||||
r'"workspaceId"\s*:\s*"([^"]+)"',
|
||||
r'"default_workspace_id"\s*:\s*"([^"]+)"',
|
||||
r'"defaultWorkspaceId"\s*:\s*"([^"]+)"',
|
||||
r'"active_workspace_id"\s*:\s*"([^"]+)"',
|
||||
r'"activeWorkspaceId"\s*:\s*"([^"]+)"',
|
||||
r'"workspace"\s*:\s*\{[^{}]*"id"\s*:\s*"([^"]+)"',
|
||||
r'"default_workspace"\s*:\s*\{[^{}]*"id"\s*:\s*"([^"]+)"',
|
||||
r'"active_workspace"\s*:\s*\{[^{}]*"id"\s*:\s*"([^"]+)"',
|
||||
]
|
||||
for pattern in patterns:
|
||||
match = re.search(pattern, text)
|
||||
if match:
|
||||
workspace_id = str(match.group(1) or "").strip()
|
||||
if workspace_id:
|
||||
return workspace_id
|
||||
return None
|
||||
|
||||
def _extract_workspace_id_from_url(self, url: str) -> Optional[str]:
|
||||
"""从 URL 查询参数或片段中提取 Workspace ID。"""
|
||||
if not url:
|
||||
return None
|
||||
|
||||
import urllib.parse
|
||||
|
||||
parsed = urllib.parse.urlparse(url)
|
||||
for raw_query in (parsed.query, parsed.fragment):
|
||||
query = urllib.parse.parse_qs(raw_query)
|
||||
for key in (
|
||||
"workspace_id",
|
||||
"workspaceId",
|
||||
"default_workspace_id",
|
||||
"active_workspace_id",
|
||||
):
|
||||
values = query.get(key) or []
|
||||
if values:
|
||||
workspace_id = str(values[0] or "").strip()
|
||||
if workspace_id:
|
||||
return workspace_id
|
||||
return None
|
||||
|
||||
def _decode_cookie_json_candidates(self, cookie_value: str) -> list[Dict[str, Any]]:
|
||||
"""尝试从完整 Cookie 或其分段中解码出 JSON。"""
|
||||
decoded_objects = []
|
||||
@@ -760,12 +881,25 @@ class RegistrationEngine:
|
||||
if workspace_id:
|
||||
return workspace_id
|
||||
|
||||
for key in ("workspace_id", "default_workspace_id", "active_workspace_id"):
|
||||
for key in (
|
||||
"workspace_id",
|
||||
"workspaceId",
|
||||
"default_workspace_id",
|
||||
"defaultWorkspaceId",
|
||||
"active_workspace_id",
|
||||
"activeWorkspaceId",
|
||||
):
|
||||
workspace_id = str(auth_json.get(key) or "").strip()
|
||||
if workspace_id:
|
||||
return workspace_id
|
||||
|
||||
for key in ("workspace", "default_workspace", "active_workspace"):
|
||||
for key in (
|
||||
"workspace",
|
||||
"default_workspace",
|
||||
"active_workspace",
|
||||
"defaultWorkspace",
|
||||
"activeWorkspace",
|
||||
):
|
||||
workspace = auth_json.get(key)
|
||||
if not isinstance(workspace, dict):
|
||||
continue
|
||||
@@ -776,6 +910,60 @@ class RegistrationEngine:
|
||||
|
||||
return None
|
||||
|
||||
def _extract_workspace_id_from_response(
|
||||
self,
|
||||
response: Optional[Any] = None,
|
||||
html: Optional[str] = None,
|
||||
url: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
"""统一从响应 JSON、HTML、脚本内容和 URL 中提取 Workspace ID。"""
|
||||
response_url = str(getattr(response, "url", "") or "").strip()
|
||||
response_text = html if html is not None else str(getattr(response, "text", "") or "")
|
||||
candidate_url = url or response_url
|
||||
|
||||
if response is not None:
|
||||
try:
|
||||
payload = response.json()
|
||||
except Exception:
|
||||
payload = None
|
||||
workspace_id = self._extract_workspace_id_from_response_payload(payload)
|
||||
if workspace_id:
|
||||
return workspace_id
|
||||
|
||||
for extractor in (
|
||||
lambda: self._extract_workspace_id_from_html(response_text),
|
||||
lambda: self._extract_workspace_id_from_text(response_text),
|
||||
lambda: self._extract_workspace_id_from_url(candidate_url),
|
||||
):
|
||||
workspace_id = extractor()
|
||||
if workspace_id:
|
||||
return workspace_id
|
||||
|
||||
return None
|
||||
|
||||
def _extract_workspace_id_from_response_payload(self, payload: Any, depth: int = 0) -> Optional[str]:
|
||||
"""递归扫描响应载荷中的 Workspace ID。"""
|
||||
if payload is None or depth > 5:
|
||||
return None
|
||||
|
||||
if isinstance(payload, dict):
|
||||
workspace_id = self._extract_workspace_id_from_auth_json(payload)
|
||||
if workspace_id:
|
||||
return workspace_id
|
||||
for value in payload.values():
|
||||
workspace_id = self._extract_workspace_id_from_response_payload(value, depth + 1)
|
||||
if workspace_id:
|
||||
return workspace_id
|
||||
return None
|
||||
|
||||
if isinstance(payload, list):
|
||||
for item in payload:
|
||||
workspace_id = self._extract_workspace_id_from_response_payload(item, depth + 1)
|
||||
if workspace_id:
|
||||
return workspace_id
|
||||
|
||||
return None
|
||||
|
||||
def _select_workspace(self, workspace_id: str) -> Optional[str]:
|
||||
"""选择 Workspace"""
|
||||
try:
|
||||
@@ -858,12 +1046,16 @@ class RegistrationEngine:
|
||||
return False
|
||||
|
||||
try:
|
||||
did = self.session.cookies.get("oai-did") if self.session else None
|
||||
self._emit_status("login_reentry", "重新进入登录流程")
|
||||
did = self._current_device_id()
|
||||
sen_token = self._check_sentinel(did) if did else None
|
||||
self._log("登录重入:请求 authorize 页面以确认当前表单状态")
|
||||
started_at = time.time()
|
||||
response = self.session.get(
|
||||
self.oauth_start.auth_url,
|
||||
timeout=15,
|
||||
)
|
||||
self._log_timed_http_result("登录重入 authorize 页面", started_at, response)
|
||||
html = response.text or ""
|
||||
|
||||
if "/log-in/password" in str(getattr(response, "url", "") or "") or 'action="/log-in/password"' in html:
|
||||
@@ -877,6 +1069,9 @@ class RegistrationEngine:
|
||||
"value": self.email,
|
||||
}
|
||||
}
|
||||
self._emit_status("login_reentry", "提交邮箱以推进到密码页")
|
||||
self._log("登录重入:提交邮箱到 authorize/continue")
|
||||
started_at = time.time()
|
||||
login_response = self.session.post(
|
||||
"https://auth.openai.com/api/accounts/authorize/continue",
|
||||
headers={
|
||||
@@ -902,12 +1097,20 @@ class RegistrationEngine:
|
||||
data=json.dumps(login_data),
|
||||
timeout=15,
|
||||
)
|
||||
self._log_timed_http_result("登录重入邮箱提交", started_at, login_response)
|
||||
login_json = login_response.json() if login_response.status_code == 200 else {}
|
||||
page_type = str((login_json or {}).get("page", {}).get("type") or "").strip()
|
||||
continue_url = str((login_json or {}).get("continue_url") or "").strip()
|
||||
self._log(
|
||||
f"登录重入响应: page_type={page_type or 'unknown'}, "
|
||||
f"continue_url={continue_url[:100] + '...' if continue_url else 'none'}"
|
||||
)
|
||||
if continue_url:
|
||||
try:
|
||||
self._emit_status("login_reentry", "跟进登录 continue_url")
|
||||
started_at = time.time()
|
||||
self.session.get(continue_url, timeout=15)
|
||||
self._log_timed_http_result("登录重入 continue_url", started_at)
|
||||
except Exception:
|
||||
pass
|
||||
if login_response.status_code == 200 and page_type in {"password", "login_password"}:
|
||||
@@ -926,8 +1129,10 @@ class RegistrationEngine:
|
||||
return False
|
||||
|
||||
try:
|
||||
did = self.session.cookies.get("oai-did") if self.session else None
|
||||
self._emit_status("login_password", "提交登录密码")
|
||||
did = self._current_device_id()
|
||||
sen_token = self._check_sentinel(did) if did else None
|
||||
started_at = time.time()
|
||||
response = self.session.post(
|
||||
"https://auth.openai.com/api/accounts/password/verify",
|
||||
headers={
|
||||
@@ -955,6 +1160,7 @@ class RegistrationEngine:
|
||||
}),
|
||||
timeout=15,
|
||||
)
|
||||
self._log_timed_http_result("登录密码提交", started_at, response)
|
||||
self._log(f"登录密码提交状态: {response.status_code}")
|
||||
if response.status_code == 200:
|
||||
try:
|
||||
@@ -964,7 +1170,10 @@ class RegistrationEngine:
|
||||
continue_url = str(payload.get("continue_url") or "").strip()
|
||||
if continue_url:
|
||||
try:
|
||||
self._emit_status("login_password", "跟进密码校验 continue_url")
|
||||
started_at = time.time()
|
||||
self.session.get(continue_url, timeout=15)
|
||||
self._log_timed_http_result("密码校验 continue_url", started_at)
|
||||
except Exception:
|
||||
pass
|
||||
return response.status_code in (200, 302, 303)
|
||||
@@ -977,7 +1186,7 @@ class RegistrationEngine:
|
||||
return False, None
|
||||
|
||||
try:
|
||||
did = self.session.cookies.get("oai-did") if self.session else None
|
||||
did = self._current_device_id()
|
||||
sen_token = self._check_sentinel(did) if did else None
|
||||
response = self.session.post(
|
||||
"https://auth.openai.com/api/accounts/password/verify",
|
||||
@@ -1061,6 +1270,7 @@ class RegistrationEngine:
|
||||
self._log("重新初始化登录会话失败", "warning")
|
||||
return None, None
|
||||
|
||||
self._emit_status("oauth_reentry", "重新初始化 OAuth 登录会话")
|
||||
if not self._start_oauth():
|
||||
self._log("重新开始 OAuth 登录流程失败", "warning")
|
||||
return None, None
|
||||
@@ -1074,6 +1284,7 @@ class RegistrationEngine:
|
||||
return None, None
|
||||
|
||||
self._otp_sent_at = time.time()
|
||||
self._emit_status("otp_secondary", "等待登录验证码邮件")
|
||||
|
||||
if not self._submit_login_password_step():
|
||||
return None, None
|
||||
@@ -1089,12 +1300,16 @@ class RegistrationEngine:
|
||||
return None, None
|
||||
|
||||
auth_target = consent_url or self.oauth_start.auth_url
|
||||
self._emit_status("workspace_extract", "请求 consent 页面并提取 Workspace ID")
|
||||
self._log(f"请求 consent 页面: {auth_target[:120]}...")
|
||||
started_at = time.time()
|
||||
auth_response = self.session.get(auth_target, timeout=20)
|
||||
self._log_timed_http_result("获取 consent 页面", started_at, auth_response)
|
||||
current_url = str(getattr(auth_response, "url", "") or "")
|
||||
html = auth_response.text or ""
|
||||
|
||||
if "sign-in-with-chatgpt/codex/consent" in current_url or 'action="/sign-in-with-chatgpt/codex/consent"' in html:
|
||||
workspace_id = self._extract_workspace_id_from_html(html)
|
||||
workspace_id = self._extract_workspace_id_from_response(response=auth_response, html=html, url=current_url)
|
||||
if not workspace_id:
|
||||
self._log("consent 页面缺少 workspace_id,回退到 Cookie 解析路径", "warning")
|
||||
return None, None
|
||||
@@ -1115,13 +1330,22 @@ class RegistrationEngine:
|
||||
max_redirects = 6
|
||||
|
||||
for i in range(max_redirects):
|
||||
self._emit_status(
|
||||
"redirect_chain",
|
||||
f"跟随重定向 {i + 1}/{max_redirects}",
|
||||
redirect_index=i + 1,
|
||||
redirect_total=max_redirects,
|
||||
redirect_url=current_url[:200],
|
||||
)
|
||||
self._log(f"重定向 {i+1}/{max_redirects}: {current_url[:100]}...")
|
||||
|
||||
started_at = time.time()
|
||||
response = self.session.get(
|
||||
current_url,
|
||||
allow_redirects=False,
|
||||
timeout=15
|
||||
)
|
||||
self._log_timed_http_result(f"重定向跳转 {i + 1}/{max_redirects}", started_at, response)
|
||||
|
||||
location = response.headers.get("Location") or ""
|
||||
|
||||
@@ -1137,6 +1361,7 @@ class RegistrationEngine:
|
||||
# 构建下一个 URL
|
||||
import urllib.parse
|
||||
next_url = urllib.parse.urljoin(current_url, location)
|
||||
self._log(f"重定向下一跳: {next_url[:100]}...")
|
||||
|
||||
# 检查是否包含回调参数
|
||||
if "code=" in next_url and "state=" in next_url:
|
||||
@@ -1159,12 +1384,19 @@ class RegistrationEngine:
|
||||
self._log("OAuth 流程未初始化", "error")
|
||||
return None
|
||||
|
||||
self._emit_status("oauth_callback", "处理 OAuth 回调并交换令牌")
|
||||
self._log("处理 OAuth 回调...")
|
||||
started_at = time.time()
|
||||
token_info = self.oauth_manager.handle_callback(
|
||||
callback_url=callback_url,
|
||||
expected_state=self.oauth_start.state,
|
||||
code_verifier=self.oauth_start.code_verifier
|
||||
)
|
||||
elapsed = max(0.0, time.time() - started_at)
|
||||
self._log(
|
||||
f"OAuth 回调处理完成,耗时 {elapsed:.1f} 秒,"
|
||||
f"account_id={str(token_info.get('account_id') or '').strip() or 'unknown'}"
|
||||
)
|
||||
|
||||
self._log("OAuth 授权成功")
|
||||
return token_info
|
||||
@@ -1197,6 +1429,7 @@ class RegistrationEngine:
|
||||
|
||||
# 1. 检查 IP 地理位置
|
||||
self._log("1. 检查 IP 地理位置...")
|
||||
self._emit_status("ip_check", "检查 IP 地理位置", step_index=1)
|
||||
ip_ok, location = self._check_ip_location()
|
||||
if not ip_ok:
|
||||
result.error_message = f"IP 地理位置不支持: {location}"
|
||||
@@ -1207,6 +1440,7 @@ class RegistrationEngine:
|
||||
|
||||
# 2. 创建邮箱
|
||||
self._log("2. 创建邮箱...")
|
||||
self._emit_status("email_prepare", "创建邮箱地址", step_index=2)
|
||||
if not self._phase_email_prepare():
|
||||
email_prepare_phase = self._get_phase_result(PHASE_EMAIL_PREPARE)
|
||||
result.error_message = (
|
||||
@@ -1221,18 +1455,21 @@ class RegistrationEngine:
|
||||
|
||||
# 3. 初始化会话
|
||||
self._log("3. 初始化会话...")
|
||||
self._emit_status("session_init", "初始化 HTTP 会话", step_index=3)
|
||||
if not self._init_session():
|
||||
result.error_message = "初始化会话失败"
|
||||
return result
|
||||
|
||||
# 4. 开始 OAuth 流程
|
||||
self._log("4. 开始 OAuth 授权流程...")
|
||||
self._emit_status("oauth_start", "开始 OAuth 授权流程", step_index=4)
|
||||
if not self._start_oauth():
|
||||
result.error_message = "开始 OAuth 流程失败"
|
||||
return result
|
||||
|
||||
# 5. 获取 Device ID
|
||||
self._log("5. 获取 Device ID...")
|
||||
self._emit_status("oauth_device_id", "获取 Device ID", step_index=5)
|
||||
did = self._get_device_id()
|
||||
if not did:
|
||||
result.error_message = "获取 Device ID 失败"
|
||||
@@ -1240,6 +1477,7 @@ class RegistrationEngine:
|
||||
|
||||
# 6. 检查 Sentinel 拦截
|
||||
self._log("6. 检查 Sentinel 拦截...")
|
||||
self._emit_status("sentinel", "检查 Sentinel 拦截", step_index=6)
|
||||
sen_token = self._check_sentinel(did)
|
||||
if sen_token:
|
||||
self._log("Sentinel 检查通过")
|
||||
@@ -1248,6 +1486,7 @@ class RegistrationEngine:
|
||||
|
||||
# 7. 提交注册表单 + 解析响应判断账号状态
|
||||
self._log("7. 提交注册表单...")
|
||||
self._emit_status("signup_submit", "提交注册表单", step_index=7)
|
||||
signup_result = self._submit_signup_form(did, sen_token)
|
||||
if not signup_result.success:
|
||||
result.error_message = f"提交注册表单失败: {signup_result.error_message}"
|
||||
@@ -1258,6 +1497,7 @@ class RegistrationEngine:
|
||||
self._log("8. [已注册账号] 跳过密码设置,OTP 已自动发送")
|
||||
else:
|
||||
self._log("8. 注册密码...")
|
||||
self._emit_status("signup_password", "提交注册密码", step_index=8)
|
||||
password_ok, password = self._register_password()
|
||||
if not password_ok:
|
||||
result.error_message = "注册密码失败"
|
||||
@@ -1270,12 +1510,14 @@ class RegistrationEngine:
|
||||
self._otp_sent_at = time.time()
|
||||
else:
|
||||
self._log("9. 发送验证码...")
|
||||
self._emit_status("otp_send", "发送验证码", step_index=9)
|
||||
if not self._send_verification_code():
|
||||
result.error_message = "发送验证码失败"
|
||||
return result
|
||||
|
||||
# 10. 获取验证码
|
||||
self._log("10. 等待验证码...")
|
||||
self._emit_status("otp_secondary", "等待验证码邮件", step_index=10)
|
||||
otp_phase_started_at = time.time()
|
||||
code, otp_phase = self._phase_otp_secondary(
|
||||
PhaseContext(otp_sent_at=self._otp_sent_at),
|
||||
@@ -1290,6 +1532,7 @@ class RegistrationEngine:
|
||||
|
||||
# 11. 验证验证码
|
||||
self._log("11. 验证验证码...")
|
||||
self._emit_status("otp_validate", "校验验证码", step_index=11)
|
||||
if not self._validate_verification_code(code):
|
||||
result.error_message = "验证验证码失败"
|
||||
return result
|
||||
@@ -1299,6 +1542,7 @@ class RegistrationEngine:
|
||||
self._log("12. [已注册账号] 跳过创建用户账户")
|
||||
else:
|
||||
self._log("12. 创建用户账户...")
|
||||
self._emit_status("account_create", "创建 OpenAI 账户资料", step_index=12)
|
||||
if not self._create_user_account():
|
||||
result.error_message = "创建用户账户失败"
|
||||
return result
|
||||
@@ -1308,6 +1552,7 @@ class RegistrationEngine:
|
||||
|
||||
if not self._is_existing_account:
|
||||
self._log(f"{next_step}. [新账号] 推进 Codex 授权流程...")
|
||||
self._emit_status("oauth_reentry", "推进 Codex 授权流程", step_index=next_step)
|
||||
workspace_id, callback_url = self._advance_login_authorization()
|
||||
if workspace_id and callback_url:
|
||||
result.workspace_id = workspace_id
|
||||
@@ -1316,6 +1561,7 @@ class RegistrationEngine:
|
||||
if not result.workspace_id:
|
||||
# 获取 Workspace ID
|
||||
self._log(f"{next_step}. 获取 Workspace ID...")
|
||||
self._emit_status("workspace_extract", "从授权态提取 Workspace ID", step_index=next_step)
|
||||
workspace_id = self._get_workspace_id()
|
||||
if not workspace_id:
|
||||
result.error_message = "获取 Workspace ID 失败"
|
||||
@@ -1327,6 +1573,7 @@ class RegistrationEngine:
|
||||
|
||||
# 选择 Workspace
|
||||
self._log(f"{next_step}. 选择 Workspace...")
|
||||
self._emit_status("workspace_select", "选择 Workspace", step_index=next_step)
|
||||
continue_url = self._select_workspace(result.workspace_id)
|
||||
if not continue_url:
|
||||
result.error_message = "选择 Workspace 失败"
|
||||
@@ -1336,6 +1583,7 @@ class RegistrationEngine:
|
||||
|
||||
# 跟随重定向链
|
||||
self._log(f"{next_step}. 跟随重定向链...")
|
||||
self._emit_status("redirect_chain", "跟随授权重定向链", step_index=next_step)
|
||||
callback_url = self._follow_redirects(continue_url)
|
||||
if not callback_url:
|
||||
result.error_message = "跟随重定向链失败"
|
||||
@@ -1345,6 +1593,7 @@ class RegistrationEngine:
|
||||
|
||||
# 处理 OAuth 回调
|
||||
self._log(f"{next_step}. 处理 OAuth 回调...")
|
||||
self._emit_status("oauth_callback", "处理 OAuth 回调", step_index=next_step)
|
||||
token_info = self._handle_oauth_callback(callback_url)
|
||||
if not token_info:
|
||||
result.error_message = "处理 OAuth 回调失败"
|
||||
|
||||
@@ -10,7 +10,7 @@ import random
|
||||
import re
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Dict, Tuple
|
||||
from typing import List, Optional, Dict, Tuple, Any
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query, BackgroundTasks
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -24,7 +24,7 @@ from ...core.register import (
|
||||
RegistrationResult,
|
||||
)
|
||||
from ...services import EmailServiceFactory, EmailServiceType
|
||||
from ...services.base import EmailProviderBackoffState, OTPTimeoutEmailServiceError
|
||||
from ...services.base import BaseEmailService, EmailProviderBackoffState, OTPTimeoutEmailServiceError
|
||||
from ...config.settings import get_settings
|
||||
from ..task_manager import task_manager
|
||||
|
||||
@@ -136,6 +136,13 @@ class BatchRegistrationRequest(BaseModel):
|
||||
tm_service_ids: List[int] = []
|
||||
|
||||
|
||||
class MockRegistrationCreateRequest(BaseModel):
|
||||
"""创建受控模拟任务请求"""
|
||||
email_service_type: str = "tempmail"
|
||||
start_delay_ms: int = Field(default=300, ge=0, le=5000)
|
||||
log_delay_ms: int = Field(default=250, ge=0, le=5000)
|
||||
|
||||
|
||||
class RegistrationTaskResponse(BaseModel):
|
||||
"""注册任务响应"""
|
||||
id: int
|
||||
@@ -161,6 +168,13 @@ class BatchRegistrationResponse(BaseModel):
|
||||
tasks: List[RegistrationTaskResponse]
|
||||
|
||||
|
||||
class MockRegistrationTaskCreateResponse(BaseModel):
|
||||
"""受控模拟任务响应"""
|
||||
task: RegistrationTaskResponse
|
||||
batch_id: str
|
||||
checks: Dict[str, Any]
|
||||
|
||||
|
||||
class TaskListResponse(BaseModel):
|
||||
"""任务列表响应"""
|
||||
total: int
|
||||
@@ -232,6 +246,19 @@ def task_to_response(task: RegistrationTask) -> RegistrationTaskResponse:
|
||||
)
|
||||
|
||||
|
||||
def _create_task_status_callback(task_uuid: str, email_service: str):
|
||||
"""把引擎内部阶段进度映射到 TaskManager 状态广播。"""
|
||||
|
||||
def callback(payload: Dict[str, Any]) -> None:
|
||||
status_payload = {
|
||||
"email_service": email_service,
|
||||
**payload,
|
||||
}
|
||||
task_manager.update_status(task_uuid, "running", **status_payload)
|
||||
|
||||
return callback
|
||||
|
||||
|
||||
def _normalize_email_service_config(
|
||||
service_type: EmailServiceType,
|
||||
config: Optional[dict],
|
||||
@@ -329,6 +356,7 @@ def _run_registration_engine_attempt(
|
||||
actual_proxy_url: Optional[str],
|
||||
log_callback,
|
||||
db_service,
|
||||
status_callback=None,
|
||||
):
|
||||
"""执行单次注册引擎尝试,并在同一临界区内维护邮箱服务退避状态。"""
|
||||
provider_backoff_before_run = EmailProviderBackoffState()
|
||||
@@ -343,6 +371,7 @@ def _run_registration_engine_attempt(
|
||||
email_service=email_service,
|
||||
proxy_url=actual_proxy_url,
|
||||
callback_logger=log_callback,
|
||||
status_callback=status_callback,
|
||||
task_uuid=task_uuid,
|
||||
)
|
||||
|
||||
@@ -570,6 +599,7 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy:
|
||||
crud.update_registration_task(db, task_uuid, email_service_id=None)
|
||||
|
||||
task_manager.update_status(task_uuid, "running", email_service=active_service_type.value)
|
||||
status_callback = _create_task_status_callback(task_uuid, active_service_type.value)
|
||||
email_service = EmailServiceFactory.create(
|
||||
selected_service_type,
|
||||
candidate_config,
|
||||
@@ -587,6 +617,7 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy:
|
||||
actual_proxy_url=actual_proxy_url,
|
||||
log_callback=log_callback,
|
||||
db_service=db_service,
|
||||
status_callback=status_callback,
|
||||
)
|
||||
|
||||
if result.success:
|
||||
@@ -856,6 +887,275 @@ def _make_batch_helpers(batch_id: str):
|
||||
return add_batch_log, update_batch_status
|
||||
|
||||
|
||||
class _MockBackoffEmailService(BaseEmailService):
|
||||
"""用于真实服务验证的最小邮箱服务桩。"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(service_type=EmailServiceType.DUCK_MAIL, name="mock-backoff-service")
|
||||
|
||||
def create_email(self, config: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
return {"email": "mock@example.test", "service_id": "mock-service-id"}
|
||||
|
||||
def get_verification_code(
|
||||
self,
|
||||
email: str,
|
||||
email_id: str = None,
|
||||
timeout: int = 120,
|
||||
pattern: str = r"(?<!\d)(\d{6})(?!\d)",
|
||||
otp_sent_at: Optional[float] = None,
|
||||
) -> Optional[str]:
|
||||
return None
|
||||
|
||||
def list_emails(self, **kwargs) -> List[Dict[str, Any]]:
|
||||
return []
|
||||
|
||||
def delete_email(self, email_id: str) -> bool:
|
||||
return True
|
||||
|
||||
def check_health(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def _create_persisted_log_callback(task_uuid: str, prefix: str = "", batch_id: str = ""):
|
||||
"""同时写入内存日志队列、批量日志通道和数据库任务日志。"""
|
||||
|
||||
def callback(message: str) -> None:
|
||||
full_message = f"{prefix} {message}" if prefix else message
|
||||
task_manager.add_log(task_uuid, full_message)
|
||||
if batch_id:
|
||||
task_manager.add_batch_log(batch_id, full_message)
|
||||
with get_db() as db:
|
||||
crud.append_task_log(db, task_uuid, full_message)
|
||||
|
||||
return callback
|
||||
|
||||
|
||||
def _simulate_batch_counter_probe(batch_id: str) -> Dict[str, Any]:
|
||||
"""构造一个可重复的批量计数场景,验证 TaskManager 计数收口。"""
|
||||
task_uuids = [str(uuid.uuid4()) for _ in range(3)]
|
||||
task_statuses = ["completed", "failed", "completed"]
|
||||
_init_batch_state(batch_id, task_uuids)
|
||||
add_batch_log, update_batch_status = _make_batch_helpers(batch_id)
|
||||
add_batch_log(f"[系统] 模拟批量任务启动,总任务: {len(task_uuids)}")
|
||||
|
||||
with get_db() as db:
|
||||
for index, (task_uuid, status) in enumerate(zip(task_uuids, task_statuses), start=1):
|
||||
crud.create_registration_task(db, task_uuid=task_uuid, proxy=None)
|
||||
error_message = None if status == "completed" else f"mock-batch-error-{index}"
|
||||
crud.update_registration_task(
|
||||
db,
|
||||
task_uuid,
|
||||
status=status,
|
||||
started_at=datetime.utcnow(),
|
||||
completed_at=datetime.utcnow(),
|
||||
error_message=error_message,
|
||||
)
|
||||
|
||||
batch_snapshot = _get_batch_snapshot(batch_id) or {}
|
||||
new_completed = batch_snapshot.get("completed", 0) + 1
|
||||
new_success = batch_snapshot.get("success", 0)
|
||||
new_failed = batch_snapshot.get("failed", 0)
|
||||
if status == "completed":
|
||||
new_success += 1
|
||||
add_batch_log(f"[任务{index}] [成功] 模拟注册成功")
|
||||
else:
|
||||
new_failed += 1
|
||||
add_batch_log(f"[任务{index}] [失败] 模拟注册失败: {error_message}")
|
||||
update_batch_status(completed=new_completed, success=new_success, failed=new_failed)
|
||||
|
||||
batch_snapshot = _get_batch_snapshot(batch_id) or {}
|
||||
add_batch_log(
|
||||
f"[完成] 批量任务完成!成功: {batch_snapshot.get('success', 0)}, "
|
||||
f"失败: {batch_snapshot.get('failed', 0)}"
|
||||
)
|
||||
update_batch_status(finished=True, status="completed")
|
||||
return {
|
||||
"batch_id": batch_id,
|
||||
"task_uuids": task_uuids,
|
||||
"snapshot": task_manager.get_batch_status(batch_id) or {},
|
||||
}
|
||||
|
||||
|
||||
async def run_mock_registration_task(
|
||||
task_uuid: str,
|
||||
batch_id: str,
|
||||
checks: Dict[str, Any],
|
||||
email_service_type: str,
|
||||
start_delay_ms: int,
|
||||
log_delay_ms: int,
|
||||
) -> None:
|
||||
"""通过真实服务链路执行可重复的模拟任务。"""
|
||||
if start_delay_ms > 0:
|
||||
await asyncio.sleep(start_delay_ms / 1000)
|
||||
|
||||
loop = task_manager.get_loop()
|
||||
if loop is None:
|
||||
loop = asyncio.get_event_loop()
|
||||
task_manager.set_loop(loop)
|
||||
|
||||
log_callback = _create_persisted_log_callback(task_uuid)
|
||||
delay_seconds = max(log_delay_ms, 0) / 1000
|
||||
|
||||
try:
|
||||
with get_db() as db:
|
||||
task = crud.update_registration_task(
|
||||
db,
|
||||
task_uuid,
|
||||
status="running",
|
||||
started_at=datetime.utcnow(),
|
||||
)
|
||||
if not task:
|
||||
logger.error(f"模拟任务不存在: {task_uuid}")
|
||||
return
|
||||
|
||||
task_manager.update_status(task_uuid, "running", email_service=email_service_type)
|
||||
log_callback("[模拟] 任务已启动,开始执行真实链路探针")
|
||||
if delay_seconds:
|
||||
await asyncio.sleep(delay_seconds)
|
||||
|
||||
with get_db() as db:
|
||||
seeded_account = crud.create_account(
|
||||
db,
|
||||
email=checks["seeded_account_email"],
|
||||
email_service="tempmail",
|
||||
access_token="mock-access-token-seeded",
|
||||
refresh_token="mock-refresh-token-seeded",
|
||||
)
|
||||
tokenless_account = crud.create_account(
|
||||
db,
|
||||
email=checks["tokenless_account_email"],
|
||||
email_service="tempmail",
|
||||
)
|
||||
crud.update_account(
|
||||
db,
|
||||
tokenless_account.id,
|
||||
access_token="mock-access-token-updated",
|
||||
)
|
||||
partial_account = crud.create_account(
|
||||
db,
|
||||
email=checks["partial_account_email"],
|
||||
email_service="tempmail",
|
||||
access_token="mock-access-token-partial",
|
||||
refresh_token="mock-refresh-token-partial",
|
||||
)
|
||||
crud.update_account(
|
||||
db,
|
||||
partial_account.id,
|
||||
refresh_token="",
|
||||
)
|
||||
outlook_service = crud.create_email_service(
|
||||
db,
|
||||
service_type="outlook",
|
||||
name=f"mock-outlook-{task_uuid[:8]}",
|
||||
config={
|
||||
"accounts": [
|
||||
{"email": "first@example.test", "refresh_token": "old-first"},
|
||||
{
|
||||
"email": checks["outlook_account_email"],
|
||||
"refresh_token": "old-second",
|
||||
},
|
||||
]
|
||||
},
|
||||
)
|
||||
crud.update_outlook_refresh_token(
|
||||
db,
|
||||
service_id=outlook_service.id,
|
||||
email=checks["outlook_account_email"],
|
||||
new_refresh_token="new-second",
|
||||
)
|
||||
backoff_service = crud.create_email_service(
|
||||
db,
|
||||
service_type="duck_mail",
|
||||
name=checks["backoff_service_name"],
|
||||
config={
|
||||
"base_url": "https://mail.example.test",
|
||||
"default_domain": "example.test",
|
||||
},
|
||||
)
|
||||
checks["seeded_account_id"] = seeded_account.id
|
||||
checks["tokenless_account_id"] = tokenless_account.id
|
||||
checks["partial_account_id"] = partial_account.id
|
||||
checks["outlook_service_id"] = outlook_service.id
|
||||
checks["backoff_service_id"] = backoff_service.id
|
||||
log_callback("[模拟] Token 同步与 Outlook refresh_token 探针已写入数据库")
|
||||
if delay_seconds:
|
||||
await asyncio.sleep(delay_seconds)
|
||||
|
||||
mock_email_service = _MockBackoffEmailService()
|
||||
backoff_states = []
|
||||
for attempt in range(1, 4):
|
||||
previous_state = _get_email_service_backoff_state(backoff_service.id)
|
||||
current_state = _record_email_service_timeout_backoff(
|
||||
backoff_service.id,
|
||||
mock_email_service,
|
||||
previous_state,
|
||||
ERROR_OTP_TIMEOUT_SECONDARY,
|
||||
f"模拟 OTP 超时 #{attempt}",
|
||||
)
|
||||
if current_state is not None:
|
||||
backoff_states.append(current_state.to_dict())
|
||||
log_callback(
|
||||
f"[模拟] OTP 超时退避 #{attempt}: "
|
||||
f"failures={current_state.failures}, delay={current_state.delay_seconds}"
|
||||
)
|
||||
if delay_seconds:
|
||||
await asyncio.sleep(delay_seconds)
|
||||
|
||||
batch_probe = _simulate_batch_counter_probe(batch_id)
|
||||
log_callback("[模拟] 批量计数探针已完成")
|
||||
if delay_seconds:
|
||||
await asyncio.sleep(delay_seconds)
|
||||
|
||||
result = {
|
||||
"email": checks["seeded_account_email"],
|
||||
"email_service": email_service_type,
|
||||
"hardening_checks": {
|
||||
"token_sync": {
|
||||
"seeded_account_id": checks["seeded_account_id"],
|
||||
"tokenless_account_id": checks["tokenless_account_id"],
|
||||
"partial_account_id": checks["partial_account_id"],
|
||||
},
|
||||
"outlook_refresh": {
|
||||
"service_id": checks["outlook_service_id"],
|
||||
"email": checks["outlook_account_email"],
|
||||
},
|
||||
"batch_counter": batch_probe,
|
||||
"otp_timeout_backoff": {
|
||||
"service_id": checks["backoff_service_id"],
|
||||
"states": backoff_states,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
with get_db() as db:
|
||||
crud.update_registration_task(
|
||||
db,
|
||||
task_uuid,
|
||||
status="completed",
|
||||
completed_at=datetime.utcnow(),
|
||||
result=result,
|
||||
)
|
||||
task_manager.update_status(
|
||||
task_uuid,
|
||||
"completed",
|
||||
email=checks["seeded_account_email"],
|
||||
email_service=email_service_type,
|
||||
)
|
||||
log_callback("[模拟] 任务完成,所有探针已收口")
|
||||
except Exception as exc:
|
||||
logger.exception("模拟任务执行失败: %s", task_uuid)
|
||||
with get_db() as db:
|
||||
crud.update_registration_task(
|
||||
db,
|
||||
task_uuid,
|
||||
status="failed",
|
||||
completed_at=datetime.utcnow(),
|
||||
error_message=str(exc),
|
||||
)
|
||||
task_manager.update_status(task_uuid, "failed", error=str(exc), email_service=email_service_type)
|
||||
log_callback(f"[模拟] 任务失败: {exc}")
|
||||
|
||||
|
||||
async def run_batch_parallel(
|
||||
batch_id: str,
|
||||
task_uuids: List[str],
|
||||
@@ -1055,6 +1355,55 @@ async def run_batch_registration(
|
||||
|
||||
# ============== API Endpoints ==============
|
||||
|
||||
@router.post("/create", response_model=MockRegistrationTaskCreateResponse)
|
||||
async def create_mock_registration(
|
||||
request: MockRegistrationCreateRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
):
|
||||
"""创建用于端到端验证的受控模拟任务。"""
|
||||
try:
|
||||
EmailServiceType(request.email_service_type)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"无效的邮箱服务类型: {request.email_service_type}"
|
||||
)
|
||||
|
||||
task_uuid = str(uuid.uuid4())
|
||||
suffix = task_uuid[:8]
|
||||
batch_id = str(uuid.uuid4())
|
||||
checks: Dict[str, Any] = {
|
||||
"seeded_account_email": f"mock-seeded-{suffix}@example.test",
|
||||
"tokenless_account_email": f"mock-tokenless-{suffix}@example.test",
|
||||
"partial_account_email": f"mock-partial-{suffix}@example.test",
|
||||
"outlook_account_email": f"mock-outlook-{suffix}@example.test",
|
||||
"backoff_service_name": f"mock-backoff-{suffix}",
|
||||
}
|
||||
|
||||
with get_db() as db:
|
||||
task = crud.create_registration_task(
|
||||
db,
|
||||
task_uuid=task_uuid,
|
||||
proxy=None,
|
||||
)
|
||||
|
||||
background_tasks.add_task(
|
||||
run_mock_registration_task,
|
||||
task_uuid,
|
||||
batch_id,
|
||||
checks,
|
||||
request.email_service_type,
|
||||
request.start_delay_ms,
|
||||
request.log_delay_ms,
|
||||
)
|
||||
|
||||
return MockRegistrationTaskCreateResponse(
|
||||
task=task_to_response(task),
|
||||
batch_id=batch_id,
|
||||
checks=checks,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/start", response_model=RegistrationTaskResponse)
|
||||
async def start_registration(
|
||||
request: RegistrationTaskCreate,
|
||||
|
||||
@@ -26,6 +26,36 @@ class FakeEmailService:
|
||||
return self.code
|
||||
|
||||
|
||||
class FakeCookies:
|
||||
def __init__(self, values):
|
||||
self.values = values
|
||||
|
||||
def get(self, name):
|
||||
return self.values.get(name)
|
||||
|
||||
|
||||
class FakeSession:
|
||||
def __init__(self, cookies=None):
|
||||
self.cookies = FakeCookies(cookies or {})
|
||||
self.get_calls = []
|
||||
|
||||
def get(self, *args, **kwargs):
|
||||
self.get_calls.append((args, kwargs))
|
||||
raise AssertionError("unexpected network call")
|
||||
|
||||
|
||||
class FakeResponse:
|
||||
def __init__(self, *, url="", text="", json_payload=None):
|
||||
self.url = url
|
||||
self.text = text
|
||||
self._json_payload = json_payload
|
||||
|
||||
def json(self):
|
||||
if isinstance(self._json_payload, Exception):
|
||||
raise self._json_payload
|
||||
return self._json_payload
|
||||
|
||||
|
||||
def _build_engine(monkeypatch, email_service):
|
||||
monkeypatch.setattr(register_module, "get_settings", lambda: DummySettings())
|
||||
return RegistrationEngine(email_service=email_service)
|
||||
@@ -103,3 +133,42 @@ def test_advance_login_authorization_sets_otp_anchor_before_password_submit(monk
|
||||
assert callback_url is None
|
||||
assert engine._otp_sent_at == 456.0
|
||||
assert seen_anchors == [456.0, 456.0]
|
||||
|
||||
|
||||
def test_get_device_id_reuses_existing_cookie_without_extra_request(monkeypatch):
|
||||
email_service = FakeEmailService(code=None)
|
||||
engine = _build_engine(monkeypatch, email_service)
|
||||
engine.oauth_start = type("OAuthStart", (), {"auth_url": "https://auth.example.test/authorize"})()
|
||||
engine.session = FakeSession(cookies={"oai-did": "did-cached"})
|
||||
|
||||
assert engine._get_device_id() == "did-cached"
|
||||
assert engine.session.get_calls == []
|
||||
|
||||
|
||||
def test_extract_workspace_id_from_response_payload(monkeypatch):
|
||||
email_service = FakeEmailService(code=None)
|
||||
engine = _build_engine(monkeypatch, email_service)
|
||||
response = FakeResponse(
|
||||
url="https://auth.example.test/consent?workspace_id=ws-url",
|
||||
json_payload={
|
||||
"page": {
|
||||
"workspace": {
|
||||
"id": "ws-json",
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
assert engine._extract_workspace_id_from_response(response=response) == "ws-json"
|
||||
|
||||
|
||||
def test_extract_workspace_id_from_response_text_when_hidden_input_missing(monkeypatch):
|
||||
email_service = FakeEmailService(code=None)
|
||||
engine = _build_engine(monkeypatch, email_service)
|
||||
response = FakeResponse(
|
||||
url="https://auth.example.test/consent",
|
||||
text='<script>window.__NEXT_DATA__={"activeWorkspaceId":"ws-script"}</script>',
|
||||
json_payload=ValueError("not json"),
|
||||
)
|
||||
|
||||
assert engine._extract_workspace_id_from_response(response=response) == "ws-script"
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
|
||||
from src.web.routes.registration import _create_task_status_callback
|
||||
from src.web.task_manager import task_manager
|
||||
|
||||
|
||||
@@ -38,3 +39,34 @@ def test_update_status_broadcasts_to_registered_websocket():
|
||||
task_manager.unregister_websocket(task_uuid, websocket)
|
||||
|
||||
asyncio.run(run_test())
|
||||
|
||||
|
||||
def test_task_status_callback_broadcasts_phase_fields():
|
||||
async def run_test():
|
||||
task_uuid = "test-status-phase"
|
||||
websocket = FakeWebSocket()
|
||||
|
||||
task_manager.set_loop(asyncio.get_running_loop())
|
||||
task_manager.register_websocket(task_uuid, websocket)
|
||||
|
||||
try:
|
||||
callback = _create_task_status_callback(task_uuid, "tempmail")
|
||||
callback({
|
||||
"phase": "redirect_chain",
|
||||
"phase_detail": "跟随重定向 1/6",
|
||||
"step_index": 14,
|
||||
})
|
||||
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
assert websocket.messages, "expected a status message to be broadcast"
|
||||
assert websocket.messages[-1]["type"] == "status"
|
||||
assert websocket.messages[-1]["status"] == "running"
|
||||
assert websocket.messages[-1]["email_service"] == "tempmail"
|
||||
assert websocket.messages[-1]["phase"] == "redirect_chain"
|
||||
assert websocket.messages[-1]["phase_detail"] == "跟随重定向 1/6"
|
||||
assert websocket.messages[-1]["step_index"] == 14
|
||||
finally:
|
||||
task_manager.unregister_websocket(task_uuid, websocket)
|
||||
|
||||
asyncio.run(run_test())
|
||||
|
||||
Reference in New Issue
Block a user