fix(harden): isolate resource cleanup and self-healing flow

This commit is contained in:
Mison
2026-03-23 11:25:51 +08:00
parent cf571d37c1
commit 43149ff079
7 changed files with 1542 additions and 103 deletions

View File

@@ -68,9 +68,10 @@ OPENAI_API_ENDPOINTS = {
"passwordless_send_otp": "https://auth.openai.com/api/accounts/passwordless/send-otp",
"validate_otp": "https://auth.openai.com/api/accounts/email-otp/validate",
"create_account": "https://auth.openai.com/api/accounts/create_account",
"add_phone" : "https://auth.openai.com/add-phone",
"add_phone": "https://auth.openai.com/add-phone",
"select_workspace": "https://auth.openai.com/api/accounts/workspace/select",
"password_verify" : "https://auth.openai.com/api/accounts/password/verify"
"send_passwordless_otp": "https://auth.openai.com/api/accounts/passwordless/send-otp",
"password_verify": "https://auth.openai.com/api/accounts/password/verify",
}
# OpenAI 页面类型(用于判断账号状态)

View File

@@ -288,7 +288,7 @@ class OpenAIHTTPClient(HTTPClient):
except Exception as e:
logger.error(f"检查 IP 地理位置失败: {e}")
return False, None
return False, str(e)
def send_openai_request(
self,
@@ -417,4 +417,4 @@ def create_openai_client(
Returns:
OpenAIHTTPClient 实例
"""
return OpenAIHTTPClient(proxy_url, config)
return OpenAIHTTPClient(proxy_url, config)

View File

@@ -6,6 +6,7 @@
import re
import json
import time
import hashlib
import logging
import secrets
import string
@@ -34,6 +35,18 @@ from ..config.settings import get_settings
logger = logging.getLogger(__name__)
WORKSPACE_PROBE_BACKOFF_DELAYS = (1.0, 2.0, 4.0)
WORKSPACE_CONTINUE_URL_MAX_ATTEMPTS = len(WORKSPACE_PROBE_BACKOFF_DELAYS)
PROVIDER_VERIFICATION_TIMEOUTS = {
EmailServiceType.OUTLOOK: 180,
EmailServiceType.IMAP_MAIL: 150,
EmailServiceType.MOE_MAIL: 150,
EmailServiceType.TEMPMAIL: 120,
EmailServiceType.TEMP_MAIL: 120,
EmailServiceType.DUCK_MAIL: 90,
EmailServiceType.FREEMAIL: 120,
}
@dataclass
class RegistrationResult:
@@ -81,6 +94,15 @@ class SignupFormResult:
error_message: str = ""
@dataclass
class OTPValidationResult:
"""OTP 校验结果"""
success: bool
continue_url: str = ""
response_data: Dict[str, Any] = None
error_message: str = ""
class RegistrationEngine:
"""
注册引擎
@@ -132,6 +154,22 @@ class RegistrationEngine:
self.logs: list = []
self._otp_sent_at: Optional[float] = None # OTP 发送时间戳
self._is_existing_account: bool = False # 是否为已注册账号(用于自动登录)
self._last_continue_url: Optional[str] = None # 最近一次 OTP/Workspace 返回的 continue_url
self._last_cookie_probe: Optional[Dict[str, Any]] = None # 最近一次 Cookie 探针快照
def close(self):
"""关闭底层 HTTP 资源,避免批量任务长期运行时连接泄漏。"""
try:
self._close_http_client()
except Exception as e:
logger.warning(f"关闭注册引擎 HTTP 客户端失败: {e}")
def _close_http_client(self):
"""关闭底层 HTTP 资源并清理关联会话引用。"""
try:
self.http_client.close()
finally:
self.session = None
def _log(self, message: str, level: str = "info"):
"""记录日志"""
@@ -266,7 +304,7 @@ class RegistrationEngine:
if attempt < max_attempts:
time.sleep(attempt)
self.http_client.close()
self._close_http_client()
self.session = self.http_client.session
return None
@@ -466,51 +504,111 @@ class RegistrationEngine:
self._log(f"发送验证码失败: {e}", 'error')
return False
def _get_verification_code(self) -> Optional[str]:
def _get_verification_code(self, otp_sent_at: Optional[float] = None) -> Optional[str]:
"""获取验证码"""
try:
self._log(f"正在等待邮箱 {self.email} 的验证码...")
effective_otp_sent_at = otp_sent_at if otp_sent_at is not None else self._otp_sent_at
verification_timeout = self._get_provider_verification_timeout()
if effective_otp_sent_at:
self._log(
f"正在等待邮箱 {self.email} 的验证码... "
f"(provider={self.email_service.service_type.value}, timeout={verification_timeout}s, "
f"otp_sent_at={effective_otp_sent_at:.3f})"
)
else:
self._log(
f"正在等待邮箱 {self.email} 的验证码... "
f"(provider={self.email_service.service_type.value}, timeout={verification_timeout}s)"
)
email_id = self.email_info.get("service_id") if self.email_info else None
code = self.email_service.get_verification_code(
email=self.email,
email_id=email_id,
timeout=120,
timeout=verification_timeout,
pattern=OTP_CODE_PATTERN,
otp_sent_at=self._otp_sent_at,
otp_sent_at=effective_otp_sent_at,
)
if code:
self._log(f"成功获取验证码: {code}")
return code
else:
self._log("等待验证码超时", 'error')
self._log("首次等待验证码超时,尝试主动刷新邮箱会话后再重试一次...", "warning")
if not self._refresh_verification_session():
self._log("邮箱会话刷新失败,停止验证码重试", "error")
return None
retry_code = self.email_service.get_verification_code(
email=self.email,
email_id=email_id,
timeout=verification_timeout,
pattern=OTP_CODE_PATTERN,
otp_sent_at=effective_otp_sent_at,
)
if retry_code:
self._log(f"刷新会话后成功获取验证码: {retry_code}")
return retry_code
self._log("等待验证码超时", "error")
return None
except Exception as e:
self._log(f"获取验证码失败: {e}", 'error')
return None
def _validate_verification_code(self, code: str) -> bool:
"""验证验证码"""
def _get_provider_verification_timeout(self) -> int:
"""根据邮箱 provider 推导验证码等待超时。"""
settings = get_settings()
base_timeout = max(int(getattr(settings, "email_code_timeout", 120) or 120), 30)
provider_timeout = PROVIDER_VERIFICATION_TIMEOUTS.get(self.email_service.service_type, base_timeout)
return max(base_timeout, provider_timeout)
def _refresh_verification_session(self) -> bool:
"""主动刷新邮箱服务会话,给验证码二次探测提供新连接状态。"""
try:
code_body = f'{{"code":"{code}"}}'
refresh_handler = getattr(self.email_service, "refresh_session", None)
if callable(refresh_handler):
refresh_handler()
self._log("邮箱服务已执行 refresh_session()")
return True
response = self.session.post(
OPENAI_API_ENDPOINTS["validate_otp"],
headers={
"referer": "https://auth.openai.com/email-verification",
"accept": "application/json",
"content-type": "application/json",
},
data=code_body,
)
refreshed = False
http_client = getattr(self.email_service, "http_client", None)
if http_client:
client_cls = http_client.__class__
proxy_url = getattr(http_client, "proxy_url", None)
config = getattr(http_client, "config", None)
try:
http_client.close()
except Exception:
pass
setattr(self.email_service, "http_client", client_cls(proxy_url=proxy_url, config=config))
refreshed = True
self._log(f"验证码校验状态: {response.status_code}")
return response.status_code == 200
providers = getattr(self.email_service, "_providers", None)
provider_lock = getattr(self.email_service, "_provider_lock", None)
if isinstance(providers, dict):
if provider_lock:
with provider_lock:
providers.clear()
else:
providers.clear()
refreshed = True
reset_provider_health = getattr(self.email_service, "reset_provider_health", None)
if callable(reset_provider_health):
reset_provider_health()
refreshed = True
if refreshed:
self._log("邮箱服务底层会话已主动刷新")
return True
self._log("当前邮箱服务未暴露可刷新的会话句柄", "warning")
return False
except Exception as e:
self._log(f"验证验证码失败: {e}", 'error')
self._log(f"刷新邮箱会话失败: {e}", "warning")
return False
def _create_user_account(self) -> bool:
@@ -542,18 +640,756 @@ class RegistrationEngine:
self._log(f"创建账户失败: {e}", 'error')
return False
def _add_phone(self) -> bool:
"""获取 手机验证码"""
phone_body = f'{{"code":"{code}"}}'
response = self.session.post(
OPENAI_API_ENDPOINTS["add_phone"],
headers={
"referer": "https://auth.openai.com/api/accounts/create_account",
def _fingerprint_cookie_value(self, value: Optional[str]) -> str:
"""生成 Cookie 指纹,便于审计变化而不直接输出完整值"""
normalized = str(value or "").strip()
if not normalized:
return "-"
return hashlib.sha256(normalized.encode("utf-8")).hexdigest()[:12]
def _build_cookie_probe(self) -> Dict[str, str]:
"""构建关键 Cookie 状态快照"""
cookies = self.session.cookies
auth_cookie = cookies.get("oai-client-auth-session")
session_cookie = cookies.get("__Secure-next-auth.session-token")
device_cookie = cookies.get("oai-did")
auth_segments = len((auth_cookie or "").split(".")) if auth_cookie else 0
workspace_id = None
workspace_source = "auth_cookie_missing"
if auth_cookie:
workspace_id, workspace_source = self._decode_workspace_id_from_auth_cookie(auth_cookie)
return {
"did": "Y" if device_cookie else "N",
"did_fp": self._fingerprint_cookie_value(device_cookie),
"auth": "Y" if auth_cookie else "N",
"auth_segments": str(auth_segments),
"auth_jwt_state": "complete_jwt" if auth_segments == 3 else "partial_or_empty",
"auth_fp": self._fingerprint_cookie_value(auth_cookie),
"next_auth": "Y" if session_cookie else "N",
"next_auth_fp": self._fingerprint_cookie_value(session_cookie),
"workspace": workspace_id or "-",
"workspace_source": workspace_source,
}
def _log_cookie_state(self, stage: str, reset_baseline: bool = False):
"""记录关键 Cookie 状态与变化,用于观测状态机是否闭环"""
if not self.session:
self._log(f"{stage}: Session 未初始化", "warning")
return
probe = self._build_cookie_probe()
self._log(
f"{stage}: Cookie 探针 did={probe['did']}({probe['did_fp']}), "
f"auth={probe['auth']}(segments={probe['auth_segments']},state={probe['auth_jwt_state']},fp={probe['auth_fp']}), "
f"next_auth={probe['next_auth']}({probe['next_auth_fp']}), "
f"workspace={probe['workspace']}(source={probe['workspace_source']})"
)
previous_probe = None if reset_baseline else self._last_cookie_probe
if previous_probe:
changes = []
tracked_fields = (
("did", "did"),
("did_fp", "did_fp"),
("auth", "auth"),
("auth_segments", "auth_segments"),
("auth_jwt_state", "auth_jwt_state"),
("auth_fp", "auth_fp"),
("next_auth", "next_auth"),
("next_auth_fp", "next_auth_fp"),
("workspace", "workspace"),
("workspace_source", "workspace_source"),
)
for key, label in tracked_fields:
old_value = previous_probe.get(key)
new_value = probe.get(key)
if old_value != new_value:
changes.append(f"{label}:{old_value}->{new_value}")
if changes:
self._log(f"{stage}: Cookie 变化 {'; '.join(changes)}")
else:
self._log(f"{stage}: Cookie 无变化")
else:
baseline_action = "重置" if reset_baseline else "建立"
self._log(f"{stage}: Cookie 探针基线已{baseline_action}")
self._last_cookie_probe = probe
def _decode_workspace_id_from_auth_cookie(self, auth_cookie: str) -> Tuple[Optional[str], str]:
"""从授权 Cookie 中解析 Workspace ID"""
import base64
import json as json_module
cookie_value = str(auth_cookie or "").strip()
if not cookie_value:
return None, "empty_auth_cookie"
segments = cookie_value.split(".")
decoded_payloads = []
for index, payload in enumerate(segments[:3]):
raw_segment = str(payload or "").strip()
if not raw_segment:
continue
try:
pad = "=" * ((4 - (len(raw_segment) % 4)) % 4)
decoded = base64.urlsafe_b64decode((raw_segment + pad).encode("ascii"))
parsed = json_module.loads(decoded.decode("utf-8"))
if isinstance(parsed, dict):
decoded_payloads.append((index, parsed))
except Exception:
continue
if not decoded_payloads:
return None, f"cookie_decode_failed:segments={len(segments)}"
for index, payload in decoded_payloads:
workspaces = payload.get("workspaces") or []
if not workspaces:
continue
workspace_id = str((workspaces[0] or {}).get("id") or "").strip()
if workspace_id:
return workspace_id, f"cookie_segment_{index}"
return None, f"workspace_missing:segments={len(segments)}"
def _try_upgrade_cookie_with_continue_url(self, continue_url: str, stage: str) -> Optional[str]:
"""访问 continue_url驱动授权 Cookie 升级后再次解析 workspace"""
normalized_url = str(continue_url or "").strip()
if not normalized_url:
self._log(f"{stage}: continue_url 为空,无法升级 Cookie", "warning")
return None
self._last_continue_url = normalized_url
for attempt, probe_delay in enumerate(WORKSPACE_PROBE_BACKOFF_DELAYS, start=1):
try:
self._log(
f"{stage}: 第 {attempt}/{WORKSPACE_CONTINUE_URL_MAX_ATTEMPTS} 次访问 continue_url 以升级授权 Cookie"
)
response = self.session.get(
normalized_url,
allow_redirects=True,
timeout=15,
)
self._log(f"{stage}: Continue URL 响应状态 {response.status_code}")
self._log_cookie_state(f"{stage}")
workspace_id = self._get_workspace_id(log_missing=False)
if workspace_id:
self._log(f"{stage}: Continue URL 升级后获取到 Workspace ID")
return workspace_id
self._log(
f"{stage}: 首次探测仍未拿到 Workspace等待 {probe_delay:.1f}s 后执行二次探测..."
)
time.sleep(probe_delay)
workspace_id = self._get_workspace_id(log_missing=False)
if workspace_id:
self._log(f"{stage}: 二次探测后获取到 Workspace ID")
return workspace_id
except Exception as e:
self._log(f"{stage}: Continue URL 访问异常: {e}", "warning")
if attempt < WORKSPACE_CONTINUE_URL_MAX_ATTEMPTS:
self._log(
f"{stage}: 第 {attempt} 次二次探测仍未拿到 Workspace继续下一轮 Continue URL 补偿..."
)
self._log(f"{stage}: Continue URL 已访问但 Workspace 仍不可用", "warning")
return None
def _reset_oauth_session(self) -> bool:
"""重置 OAuth 会话,用于降级登录流程"""
try:
self._log("正在重置 HTTP 会话以开始降级登录流程...")
if self.session:
self._log_cookie_state("降级流程会话重置前")
# 关闭旧客户端
self._close_http_client()
# 创建新客户端和会话
self.http_client = OpenAIHTTPClient(proxy_url=self.proxy_url)
self.session = self.http_client.session
self.session_token = None
self._last_continue_url = None
self._last_cookie_probe = None
# 生成新的 OAuth URL
self.oauth_start = self.oauth_manager.start_oauth()
self._log(f"新 OAuth URL 已生成,准备开始登录握手")
self._log_cookie_state("降级流程会话重置后", reset_baseline=True)
return True
except Exception as e:
self._log(f"重置会话失败: {e}", "error")
return False
def _submit_login_form(self, did: str, sen_token: Optional[str]) -> SignupFormResult:
"""提交登录表单screen_hint=login"""
try:
login_body = f'{{"username":{{"value":"{self.email}","kind":"email"}},"screen_hint":"login"}}'
headers = {
"referer": "https://auth.openai.com/login",
"accept": "application/json",
"content-type": "application/json",
},
data=create_account_body,
}
if sen_token:
sentinel = f'{{"p": "", "t": "", "c": "{sen_token}", "id": "{did}", "flow": "authorize_continue"}}'
headers["openai-sentinel-token"] = sentinel
response = self.session.post(
OPENAI_API_ENDPOINTS["signup"],
headers=headers,
data=login_body,
)
self._log(f"登录表单提交状态: {response.status_code}")
self._log_cookie_state("降级登录提交邮箱后")
if response.status_code != 200:
return SignupFormResult(success=False, error_message=f"登录表单提交失败: HTTP {response.status_code}")
response_data: Dict[str, Any] = {}
page_type = ""
try:
parsed = response.json()
if isinstance(parsed, dict):
response_data = parsed
page_type = str(parsed.get("page", {}).get("type", "") or "").strip()
else:
self._log(f"登录表单响应不是对象: {type(parsed).__name__},按 HTTP 200 继续无密码 OTP", "warning")
except Exception as parse_error:
self._log(f"解析登录表单响应失败: {parse_error},按 HTTP 200 继续无密码 OTP", "warning")
if page_type:
self._log(f"登录表单响应页面类型: {page_type}")
else:
self._log("登录表单响应未返回 page.type按 HTTP 200 继续无密码 OTP", "warning")
if page_type == "login_password":
self._log("降级登录进入 login_password按 Issue #62 继续触发无密码 OTP")
elif page_type == OPENAI_PAGE_TYPES["EMAIL_OTP_VERIFICATION"]:
self._log("降级登录进入 email_otp_verification继续执行无密码 OTP 闭环")
elif page_type:
self._log(f"降级登录进入 {page_type},按 Issue #62 不做页面拦截,继续触发无密码 OTP")
return SignupFormResult(
success=True,
page_type=page_type,
response_data=response_data,
)
except Exception as e:
self._log(f"提交登录表单异常: {e}", "error")
return SignupFormResult(success=False, error_message=str(e))
def _send_passwordless_otp(self) -> bool:
"""发送无密码 OTP 验证码"""
try:
self._log("正在触发无密码登录 OTP...")
otp_sent_at = time.time()
self._log_cookie_state("无密码 OTP 发送前")
response = self.session.post(
OPENAI_API_ENDPOINTS["send_passwordless_otp"],
headers={
"referer": "https://auth.openai.com/login/password",
"accept": "application/json",
"content-type": "application/json",
},
data=""
)
self._log(f"无密码 OTP 发送状态: {response.status_code}")
self._log_cookie_state("无密码 OTP 发送后")
if response.status_code != 200:
self._log(f"无密码 OTP 发送失败响应: {response.text[:200]}", "warning")
else:
self._otp_sent_at = otp_sent_at
self._log(f"无密码 OTP 时间戳已更新: {self._otp_sent_at:.3f}")
return response.status_code == 200
except Exception as e:
self._log(f"发送无密码 OTP 失败: {e}", "error")
return False
def _validate_verification_code(self, code: str) -> OTPValidationResult:
"""验证验证码并返回 continue_url 等状态信息"""
try:
code_body = f'{{"code":"{code}"}}'
response = self.session.post(
OPENAI_API_ENDPOINTS["validate_otp"],
headers={
"referer": "https://auth.openai.com/email-verification",
"accept": "application/json",
"content-type": "application/json",
},
data=code_body,
)
self._log(f"验证码校验状态: {response.status_code}")
self._log_cookie_state("OTP 校验响应后")
if response.status_code != 200:
return OTPValidationResult(
success=False,
error_message=f"HTTP {response.status_code}: {response.text[:200]}"
)
try:
parsed = response.json()
if isinstance(parsed, dict):
data = parsed
else:
self._log(
f"验证码校验响应不是对象: {type(parsed).__name__},继续依据 Cookie 状态判断",
"warning"
)
data = {}
except Exception as parse_error:
self._log(
f"解析验证码校验响应失败: {parse_error},继续依据 Cookie 状态判断",
"warning"
)
data = {}
continue_url = str(data.get("continue_url") or "").strip()
if continue_url:
self._last_continue_url = continue_url
self._log(f"验证码校验返回 Continue URL: {continue_url[:60]}...")
else:
self._log("验证码校验成功,但响应中未包含 continue_url", "warning")
return OTPValidationResult(
success=True,
continue_url=continue_url,
response_data=data,
)
except Exception as e:
self._log(f"验证验证码失败: {e}", "error")
return OTPValidationResult(success=False, error_message=str(e))
def _resolve_workspace_id(self) -> Optional[str]:
"""按主闭环优先、降级补偿兜底的顺序解析 Workspace ID"""
workspace_id = self._get_workspace_id(log_missing=False)
if workspace_id:
return workspace_id
if self._last_continue_url:
self._log("主流程首次探测未获取到 Workspace尝试复用最近的 continue_url 升级 Cookie...")
workspace_id = self._try_upgrade_cookie_with_continue_url(
self._last_continue_url,
"主流程 Continue URL 补偿"
)
if workspace_id:
return workspace_id
self._log("主流程 Workspace 探测仍失败,切换到降级登录补偿流程...", "warning")
return self._fallback_to_login_flow()
def _fallback_to_login_flow(self) -> Optional[str]:
"""
[核心修复] 降级到登录流程以获取 Workspace ID
参考 Issue #62 解决方案
"""
self._log("-" * 40)
self._log("检测到 Cookie 缺失 Workspace启动降级登录补偿流程...")
# 1. 重置会话
if not self._reset_oauth_session():
return None
# 2. 获取 Device ID
did = self._get_device_id()
if not did:
self._log("降级登录未能获取 Device ID", "error")
return None
# 3. Sentinel 检查
sen_token = self._check_sentinel(did)
# 4. 提交登录表单
login_res = self._submit_login_form(did, sen_token)
if not login_res.success:
self._log(f"降级登录表单失败: {login_res.error_message}", "error")
return None
# 5. 发送无密码 OTP
if not self._send_passwordless_otp():
self._log("降级登录未能触发无密码 OTP", "error")
return None
login_otp_sent_at = self._otp_sent_at
if not login_otp_sent_at:
login_otp_sent_at = time.time()
self._otp_sent_at = login_otp_sent_at
self._log("降级登录未记录到无密码 OTP 时间戳,使用当前时间兜底", "warning")
self._log(f"降级登录 OTP 时间线已校准: {login_otp_sent_at:.3f}")
# 6. 获取邮件验证码
self._log("正在等待新的登录验证码...")
code = self._get_verification_code(otp_sent_at=login_otp_sent_at)
if not code:
self._log("降级登录未获取到新的 OTP 验证码", "error")
return None
# 7. 验证 OTP -> GET continue_url -> 解析 Workspace
otp_result = self._validate_verification_code(code)
if not otp_result.success:
self._log(f"OTP 校验失败: {otp_result.error_message}", "error")
return None
self._log_cookie_state("降级登录 OTP 校验后")
workspace_id = self._get_workspace_id(log_missing=False)
if workspace_id:
self._log("降级登录 OTP 校验后已直接获得 Workspace无需继续访问 continue_url")
return workspace_id
continue_url = str(otp_result.continue_url or "").strip()
if not continue_url:
self._log("OTP 校验成功但缺少 continue_url无法执行 Cookie 升级", "error")
return None
self._log("降级登录 OTP 校验成功,开始 GET continue_url 升级 Cookie")
workspace_id = self._try_upgrade_cookie_with_continue_url(
continue_url,
"降级登录 Continue URL"
)
if not workspace_id:
self._log("降级登录闭环完成,但仍未拿到 Workspace", "error")
return workspace_id
def _get_workspace_id(self, log_missing: bool = True) -> Optional[str]:
"""获取 Workspace ID"""
try:
self._log_cookie_state("解析 Workspace 前")
auth_cookie = self.session.cookies.get("oai-client-auth-session")
if not auth_cookie:
if log_missing:
self._log("未能获取到授权 Cookie", "error")
return None
try:
workspace_id, source = self._decode_workspace_id_from_auth_cookie(auth_cookie)
if not workspace_id:
if log_missing:
self._log(f"授权 Cookie 中未解析到 workspace: {source}", "warning")
return None
self._log(f"Workspace ID 解析成功: {workspace_id} (source={source})")
return workspace_id
except Exception as e:
if log_missing:
self._log(f"解析授权 Cookie 异常: {e}", "warning")
return None
except Exception as e:
self._log(f"获取 Workspace ID 失败: {e}", "error")
return None
def _select_workspace(self, workspace_id: str) -> Optional[str]:
"""选择 Workspace"""
try:
select_body = f'{{"workspace_id":"{workspace_id}"}}'
response = self.session.post(
OPENAI_API_ENDPOINTS["select_workspace"],
headers={
"referer": "https://auth.openai.com/sign-in-with-chatgpt/codex/consent",
"content-type": "application/json",
},
data=select_body,
)
if response.status_code != 200:
self._log(f"选择 workspace 失败: {response.status_code}", "error")
self._log(f"响应: {response.text[:200]}", "warning")
return None
continue_url = str((response.json() or {}).get("continue_url") or "").strip()
if not continue_url:
self._log("workspace/select 响应里缺少 continue_url", "error")
return None
self._log(f"Continue URL: {continue_url[:100]}...")
return continue_url
except Exception as e:
self._log(f"选择 Workspace 失败: {e}", "error")
return None
def _follow_redirects(self, start_url: str) -> Optional[str]:
"""跟随重定向链,寻找回调 URL"""
try:
current_url = start_url
max_redirects = 6
for i in range(max_redirects):
self._log(f"重定向 {i+1}/{max_redirects}: {current_url[:100]}...")
response = self.session.get(
current_url,
allow_redirects=False,
timeout=15
)
location = response.headers.get("Location") or ""
# 如果不是重定向状态码,停止
if response.status_code not in [301, 302, 303, 307, 308]:
self._log(f"非重定向状态码: {response.status_code}")
break
if not location:
self._log("重定向响应缺少 Location 头")
break
# 构建下一个 URL
import urllib.parse
next_url = urllib.parse.urljoin(current_url, location)
# 检查是否包含回调参数
if "code=" in next_url and "state=" in next_url:
self._log(f"找到回调 URL: {next_url[:100]}...")
return next_url
current_url = next_url
self._log("未能在重定向链中找到回调 URL", "error")
return None
except Exception as e:
self._log(f"跟随重定向失败: {e}", "error")
return None
def _handle_oauth_callback(self, callback_url: str) -> Optional[Dict[str, Any]]:
"""处理 OAuth 回调"""
try:
if not self.oauth_start:
self._log("OAuth 流程未初始化", "error")
return None
self._log("处理 OAuth 回调...")
token_info = self.oauth_manager.handle_callback(
callback_url=callback_url,
expected_state=self.oauth_start.state,
code_verifier=self.oauth_start.code_verifier
)
self._log("OAuth 授权成功")
return token_info
except Exception as e:
self._log(f"处理 OAuth 回调失败: {e}", "error")
return None
def run(self) -> RegistrationResult:
"""
执行完整的注册流程
支持已注册账号自动登录:
- 如果检测到邮箱已注册,自动切换到登录流程
- 已注册账号跳过:设置密码、发送验证码、创建用户账户
- 共用步骤获取验证码、验证验证码、Workspace 和 OAuth 回调
Returns:
RegistrationResult: 注册结果
"""
result = RegistrationResult(success=False, logs=self.logs)
try:
self._log("=" * 60)
self._log("开始注册流程")
self._log("=" * 60)
# 1. 检查 IP 地理位置
self._log("1. 检查 IP 地理位置...")
ip_ok, location = self._check_ip_location()
if not ip_ok:
result.error_message = f"IP 地理位置不支持: {location}"
self._log(f"IP 检查失败: {location}", "error")
return result
self._log(f"IP 位置: {location}")
# 2. 创建邮箱
self._log("2. 创建邮箱...")
if not self._create_email():
result.error_message = "创建邮箱失败"
return result
result.email = self.email
# 3. 初始化会话
self._log("3. 初始化会话...")
if not self._init_session():
result.error_message = "初始化会话失败"
return result
# 4. 开始 OAuth 流程
self._log("4. 开始 OAuth 授权流程...")
if not self._start_oauth():
result.error_message = "开始 OAuth 流程失败"
return result
# 5. 获取 Device ID
self._log("5. 获取 Device ID...")
did = self._get_device_id()
if not did:
result.error_message = "获取 Device ID 失败"
return result
# 6. 检查 Sentinel 拦截
self._log("6. 检查 Sentinel 拦截...")
sen_token = self._check_sentinel(did)
if sen_token:
self._log("Sentinel 检查通过")
else:
self._log("Sentinel 检查失败或未启用", "warning")
# 7. 提交注册表单 + 解析响应判断账号状态
self._log("7. 提交注册表单...")
signup_result = self._submit_signup_form(did, sen_token)
if not signup_result.success:
result.error_message = f"提交注册表单失败: {signup_result.error_message}"
return result
# 8. [已注册账号跳过] 注册密码
if self._is_existing_account:
self._log("8. [已注册账号] 跳过密码设置OTP 已自动发送")
else:
self._log("8. 注册密码...")
password_ok, password = self._register_password()
if not password_ok:
result.error_message = "注册密码失败"
return result
# 9. [已注册账号跳过] 发送验证码
if self._is_existing_account:
self._log("9. [已注册账号] 跳过发送验证码,使用自动发送的 OTP")
# 已注册账号的 OTP 在提交表单时已自动发送,记录时间戳
self._otp_sent_at = time.time()
else:
self._log("9. 发送验证码...")
if not self._send_verification_code():
result.error_message = "发送验证码失败"
return result
# 10. 获取验证码
self._log("10. 等待验证码...")
code = self._get_verification_code()
if not code:
result.error_message = "获取验证码失败"
return result
# 11. 验证验证码
self._log("11. 验证验证码...")
otp_result = self._validate_verification_code(code)
if not otp_result.success:
result.error_message = f"验证验证码失败: {otp_result.error_message}"
return result
if otp_result.continue_url:
self._try_upgrade_cookie_with_continue_url(
otp_result.continue_url,
"主流程 OTP 校验"
)
# 12. [已注册账号跳过] 创建用户账户
if self._is_existing_account:
self._log("12. [已注册账号] 跳过创建用户账户")
else:
self._log("12. 创建用户账户...")
if not self._create_user_account():
result.error_message = "创建用户账户失败"
return result
# 13. 获取 Workspace ID
self._log("13. 获取 Workspace ID...")
workspace_id = self._resolve_workspace_id()
if not workspace_id:
result.error_message = "获取 Workspace ID 失败 (含降级补偿)"
return result
result.workspace_id = workspace_id
# 14. 选择 Workspace
self._log("14. 选择 Workspace...")
continue_url = self._select_workspace(workspace_id)
if not continue_url:
result.error_message = "选择 Workspace 失败"
return result
# 15. 跟随重定向链
self._log("15. 跟随重定向链...")
callback_url = self._follow_redirects(continue_url)
if not callback_url:
result.error_message = "跟随重定向链失败"
return result
# 16. 处理 OAuth 回调
self._log("16. 处理 OAuth 回调...")
token_info = self._handle_oauth_callback(callback_url)
if not token_info:
result.error_message = "处理 OAuth 回调失败"
return result
# 提取账户信息
result.account_id = token_info.get("account_id", "")
result.access_token = token_info.get("access_token", "")
result.refresh_token = token_info.get("refresh_token", "")
result.id_token = token_info.get("id_token", "")
result.password = self.password or "" # 保存密码(已注册账号为空)
# 设置来源标记
result.source = "login" if self._is_existing_account else "register"
# 尝试获取 session_token 从 cookie
session_cookie = self.session.cookies.get("__Secure-next-auth.session-token")
if session_cookie:
self.session_token = session_cookie
result.session_token = session_cookie
self._log(f"获取到 Session Token")
# 17. 完成
self._log("=" * 60)
if self._is_existing_account:
self._log("登录成功! (已注册账号)")
else:
self._log("注册成功!")
self._log(f"邮箱: {result.email}")
self._log(f"Account ID: {result.account_id}")
self._log(f"Workspace ID: {result.workspace_id}")
self._log("=" * 60)
result.success = True
result.metadata = {
"email_service": self.email_service.service_type.value,
"proxy_used": self.proxy_url,
"registered_at": datetime.now().isoformat(),
"is_existing_account": self._is_existing_account,
}
return result
except Exception as e:
self._log(f"注册过程中发生未预期错误: {e}", "error")
result.error_message = str(e)
return result
finally:
try:
self._close_http_client()
except Exception as e:
logger.warning(f"关闭注册引擎 HTTP 客户端失败: {e}")
def save_to_database(self, result: RegistrationResult) -> bool:
"""

View File

@@ -2,7 +2,7 @@
数据库 CRUD 操作
"""
from typing import List, Optional, Dict, Any, Union
from typing import List, Optional, Dict, Any, Union, Iterable, Set
from datetime import datetime, timedelta
from sqlalchemy.orm import Session
from sqlalchemy import and_, or_, desc, asc, func
@@ -464,9 +464,13 @@ def get_proxies(
return query.all()
def get_enabled_proxies(db: Session) -> List[Proxy]:
def get_enabled_proxies(db: Session, exclude_ids: Optional[Iterable[int]] = None) -> List[Proxy]:
"""获取所有启用的代理"""
return db.query(Proxy).filter(Proxy.enabled == True).all()
query = db.query(Proxy).filter(Proxy.enabled == True)
excluded: Set[int] = {int(proxy_id) for proxy_id in (exclude_ids or [])}
if excluded:
query = query.filter(~Proxy.id.in_(excluded))
return query.all()
def update_proxy(
@@ -517,14 +521,18 @@ def update_proxy_last_used(db: Session, proxy_id: int) -> bool:
return True
def get_random_proxy(db: Session) -> Optional[Proxy]:
def get_random_proxy(db: Session, exclude_ids: Optional[Iterable[int]] = None) -> Optional[Proxy]:
"""随机获取一个启用的代理,优先返回 is_default=True 的代理"""
import random
excluded: Set[int] = {int(proxy_id) for proxy_id in (exclude_ids or [])}
# 优先返回默认代理
default_proxy = db.query(Proxy).filter(Proxy.enabled == True, Proxy.is_default == True).first()
default_query = db.query(Proxy).filter(Proxy.enabled == True, Proxy.is_default == True)
if excluded:
default_query = default_query.filter(~Proxy.id.in_(excluded))
default_proxy = default_query.first()
if default_proxy:
return default_proxy
proxies = get_enabled_proxies(db)
proxies = get_enabled_proxies(db, exclude_ids=excluded)
if not proxies:
return None
return random.choice(proxies)

View File

@@ -6,6 +6,7 @@ import asyncio
import logging
import uuid
import random
import re
from datetime import datetime
from typing import List, Optional, Dict, Tuple, Any
@@ -32,7 +33,16 @@ batch_tasks: Dict[str, dict] = {}
# ============== Proxy Helper Functions ==============
def get_proxy_for_registration(db) -> Tuple[Optional[str], Optional[int]]:
RETRYABLE_PROXY_ERROR_PATTERN = re.compile(
r"(?:curl(?:[^0-9]{0,8})?(35|56)\b|curl:\s*\((35|56)\))",
re.IGNORECASE,
)
def get_proxy_for_registration(
db,
exclude_proxy_ids: Optional[List[int]] = None,
) -> Tuple[Optional[str], Optional[int]]:
"""
获取用于注册的代理
@@ -45,7 +55,7 @@ def get_proxy_for_registration(db) -> Tuple[Optional[str], Optional[int]]:
Tuple[proxy_url, proxy_id]: 代理 URL 和代理 ID如果来自代理列表
"""
# 先尝试从代理列表中获取
proxy = crud.get_random_proxy(db)
proxy = crud.get_random_proxy(db, exclude_ids=exclude_proxy_ids)
if proxy:
return proxy.proxy_url, proxy.id
@@ -64,6 +74,27 @@ def update_proxy_usage(db, proxy_id: Optional[int]):
crud.update_proxy_last_used(db, proxy_id)
def is_retryable_proxy_error(error_message: Optional[str]) -> bool:
"""判断是否属于可通过切换代理自愈的 curl 网络错误。"""
message = str(error_message or "").strip()
if not message:
return False
return RETRYABLE_PROXY_ERROR_PATTERN.search(message) is not None
def disable_proxy_for_network_error(db, proxy_id: Optional[int], reason: str) -> bool:
"""将当前数据库代理标记为失效,避免后续再次被选中。"""
if not proxy_id:
return False
proxy = crud.update_proxy(db, proxy_id, enabled=False)
if not proxy:
return False
logger.warning(f"代理 {proxy_id} 因网络错误已自动禁用: {reason}")
return True
# ============== Pydantic Models ==============
class RegistrationTaskCreate(BaseModel):
@@ -246,51 +277,44 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy:
logger.error(f"任务不存在: {task_uuid}")
return
# 确定使用的代理
# 如果前端传入了代理参数,使用传入的
# 否则从代理列表或系统设置中获取
actual_proxy_url = proxy
proxy_id = None
resolved_email_service_id = email_service_id or task.email_service_id
if not actual_proxy_url:
actual_proxy_url, proxy_id = get_proxy_for_registration(db)
if actual_proxy_url:
logger.info(f"任务 {task_uuid} 使用代理: {actual_proxy_url[:50]}...")
# 更新任务的代理记录
crud.update_registration_task(db, task_uuid, proxy=actual_proxy_url)
# 创建邮箱服务
service_type = EmailServiceType(email_service_type)
# 更新 TaskManager 状态
task_manager.update_status(task_uuid, "running")
settings = get_settings()
log_callback = task_manager.create_log_callback(task_uuid, prefix=log_prefix, batch_id=batch_id)
# 优先使用数据库中配置的邮箱服务
if email_service_id:
from ...database.models import EmailService as EmailServiceModel
db_service = db.query(EmailServiceModel).filter(
EmailServiceModel.id == email_service_id,
EmailServiceModel.enabled == True
).first()
def build_email_service(active_proxy_url: Optional[str]):
requested_service_type = EmailServiceType(email_service_type)
if db_service:
service_type = EmailServiceType(db_service.service_type)
config = _normalize_email_service_config(service_type, db_service.config, actual_proxy_url)
# 更新任务关联的邮箱服务
crud.update_registration_task(db, task_uuid, email_service_id=db_service.id)
logger.info(f"使用数据库邮箱服务: {db_service.name} (ID: {db_service.id}, 类型: {service_type.value})")
else:
raise ValueError(f"邮箱服务不存在或已禁用: {email_service_id}")
else:
# 使用默认配置或传入的配置
if resolved_email_service_id:
from ...database.models import EmailService as EmailServiceModel
db_service = db.query(EmailServiceModel).filter(
EmailServiceModel.id == resolved_email_service_id,
EmailServiceModel.enabled == True
).first()
if db_service:
selected_service_type = EmailServiceType(db_service.service_type)
config = _normalize_email_service_config(selected_service_type, db_service.config, active_proxy_url)
crud.update_registration_task(db, task_uuid, email_service_id=db_service.id)
logger.info(
f"使用数据库邮箱服务: {db_service.name} "
f"(ID: {db_service.id}, 类型: {selected_service_type.value})"
)
email_service = EmailServiceFactory.create(selected_service_type, config)
return email_service, selected_service_type
raise ValueError(f"邮箱服务不存在或已禁用: {resolved_email_service_id}")
service_type = requested_service_type
if service_type == EmailServiceType.TEMPMAIL:
config = {
"base_url": settings.tempmail_base_url,
"timeout": settings.tempmail_timeout,
"max_retries": settings.tempmail_max_retries,
"proxy_url": actual_proxy_url,
"proxy_url": active_proxy_url,
}
elif service_type == EmailServiceType.MOE_MAIL:
# 检查数据库中是否有可用的自定义域名服务
from ...database.models import EmailService as EmailServiceModel
db_service = db.query(EmailServiceModel).filter(
EmailServiceModel.service_type == "moe_mail",
@@ -298,21 +322,19 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy:
).order_by(EmailServiceModel.priority.asc()).first()
if db_service and db_service.config:
config = _normalize_email_service_config(service_type, db_service.config, actual_proxy_url)
config = _normalize_email_service_config(service_type, db_service.config, active_proxy_url)
crud.update_registration_task(db, task_uuid, email_service_id=db_service.id)
logger.info(f"使用数据库自定义域名服务: {db_service.name}")
elif settings.custom_domain_base_url and settings.custom_domain_api_key:
config = {
"base_url": settings.custom_domain_base_url,
"api_key": settings.custom_domain_api_key.get_secret_value() if settings.custom_domain_api_key else "",
"proxy_url": actual_proxy_url,
"proxy_url": active_proxy_url,
}
else:
raise ValueError("没有可用的自定义域名邮箱服务,请先在设置中配置")
elif service_type == EmailServiceType.OUTLOOK:
# 检查数据库中是否有可用的 Outlook 账户
from ...database.models import EmailService as EmailServiceModel, Account
# 获取所有启用的 Outlook 服务
outlook_services = db.query(EmailServiceModel).filter(
EmailServiceModel.service_type == "outlook",
EmailServiceModel.enabled == True
@@ -327,14 +349,12 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy:
email = svc.config.get("email") if svc.config else None
if not email:
continue
# 检查是否已在 accounts 表中注册
existing = db.query(Account).filter(Account.email == email).first()
if not existing:
selected_service = svc
logger.info(f"选择未注册的 Outlook 账户: {email}")
break
else:
logger.info(f"跳过已注册的 Outlook 账户: {email}")
logger.info(f"跳过已注册的 Outlook 账户: {email}")
if selected_service and selected_service.config:
config = selected_service.config.copy()
@@ -352,7 +372,7 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy:
).order_by(EmailServiceModel.priority.asc()).first()
if db_service and db_service.config:
config = _normalize_email_service_config(service_type, db_service.config, actual_proxy_url)
config = _normalize_email_service_config(service_type, db_service.config, active_proxy_url)
crud.update_registration_task(db, task_uuid, email_service_id=db_service.id)
logger.info(f"使用数据库 DuckMail 服务: {db_service.name}")
else:
@@ -366,7 +386,7 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy:
).order_by(EmailServiceModel.priority.asc()).first()
if db_service and db_service.config:
config = _normalize_email_service_config(service_type, db_service.config, actual_proxy_url)
config = _normalize_email_service_config(service_type, db_service.config, active_proxy_url)
crud.update_registration_task(db, task_uuid, email_service_id=db_service.id)
logger.info(f"使用数据库 Freemail 服务: {db_service.name}")
else:
@@ -380,7 +400,7 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy:
).order_by(EmailServiceModel.priority.asc()).first()
if db_service and db_service.config:
config = _normalize_email_service_config(service_type, db_service.config, actual_proxy_url)
config = _normalize_email_service_config(service_type, db_service.config, active_proxy_url)
crud.update_registration_task(db, task_uuid, email_service_id=db_service.id)
logger.info(f"使用数据库 IMAP 邮箱服务: {db_service.name}")
else:
@@ -388,23 +408,56 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy:
else:
config = email_service_config or {}
email_service = EmailServiceFactory.create(service_type, config)
email_service = EmailServiceFactory.create(service_type, config)
return email_service, service_type
# 在 WebSocket 状态里附带邮箱服务类型,前端可同步更新任务卡片
task_manager.update_status(task_uuid, "running", email_service=service_type.value)
requested_proxy = proxy
exhausted_proxy_ids = set()
result = None
active_service_type = EmailServiceType(email_service_type)
# 创建注册引擎 - 使用 TaskManager 的日志回调
log_callback = task_manager.create_log_callback(task_uuid, prefix=log_prefix, batch_id=batch_id)
while True:
actual_proxy_url = requested_proxy
proxy_id = None
engine = LoginEngine(
email_service=email_service,
proxy_url=actual_proxy_url,
callback_logger=log_callback,
task_uuid=task_uuid
)
if not actual_proxy_url:
actual_proxy_url, proxy_id = get_proxy_for_registration(
db,
exclude_proxy_ids=list(exhausted_proxy_ids),
)
if actual_proxy_url:
logger.info(f"任务 {task_uuid} 使用代理: {actual_proxy_url[:50]}...")
# 执行注册
result = engine.run()
crud.update_registration_task(db, task_uuid, proxy=actual_proxy_url)
email_service, active_service_type = build_email_service(actual_proxy_url)
task_manager.update_status(task_uuid, "running", email_service=active_service_type.value)
engine = LoginEngine(
email_service=email_service,
proxy_url=actual_proxy_url,
callback_logger=log_callback,
task_uuid=task_uuid
)
result = engine.run()
if result.success:
break
if is_retryable_proxy_error(result.error_message):
log_callback(f"[代理] 检测到可重试网络错误: {result.error_message}")
if proxy_id and disable_proxy_for_network_error(db, proxy_id, result.error_message):
exhausted_proxy_ids.add(proxy_id)
log_callback(f"[代理] 当前代理已标记失效并从代理池移除: {proxy_id}")
next_proxy_url, next_proxy_id = get_proxy_for_registration(
db,
exclude_proxy_ids=list(exhausted_proxy_ids),
)
if next_proxy_url and (next_proxy_url != actual_proxy_url or next_proxy_id != proxy_id):
requested_proxy = None
log_callback(f"[代理] 切换到新代理后重试注册: {next_proxy_url[:50]}...")
continue
break
if result.success:
# 更新代理使用时间
@@ -506,7 +559,7 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy:
completed_at=datetime.utcnow(),
result={
**result.to_dict(),
"email_service": service_type.value,
"email_service": active_service_type.value,
}
)
@@ -515,7 +568,7 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy:
task_uuid,
"completed",
email=result.email,
email_service=service_type.value,
email_service=active_service_type.value,
)
logger.info(f"注册任务完成: {task_uuid}, 邮箱: {result.email}")
@@ -533,7 +586,7 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy:
task_uuid,
"failed",
error=result.error_message,
email_service=service_type.value,
email_service=active_service_type.value,
)
logger.warning(f"注册任务失败: {task_uuid}, 原因: {result.error_message}")

View File

@@ -0,0 +1,439 @@
from types import SimpleNamespace
import pytest
from src.config.constants import EmailServiceType, OPENAI_API_ENDPOINTS, OPENAI_PAGE_TYPES
from src.core import register
from src.services.base import BaseEmailService
class DummyEmailService(BaseEmailService):
def __init__(self):
super().__init__(EmailServiceType.TEMPMAIL, name="dummy")
def create_email(self, config=None):
return {"email": "tester@example.com", "service_id": "svc-1"}
def get_verification_code(self, email, email_id=None, timeout=120, pattern=None, otp_sent_at=None):
return "123456"
def list_emails(self, **kwargs):
return []
def delete_email(self, email_id):
return True
def check_health(self):
return True
def refresh_session(self):
return None
class FakeResponse:
def __init__(self, status_code=200, payload=None, text=""):
self.status_code = status_code
self._payload = payload if payload is not None else {}
self.text = text
def json(self):
return self._payload
class BrokenJSONResponse(FakeResponse):
def json(self):
raise ValueError("bad json")
class FakeSession:
def __init__(self, post_handler=None, get_handler=None, cookies=None):
self.post_handler = post_handler
self.get_handler = get_handler
self.cookies = cookies or {}
self.post_calls = []
self.get_calls = []
def post(self, url, **kwargs):
self.post_calls.append({"url": url, "kwargs": kwargs})
if self.post_handler is None:
raise AssertionError("unexpected post call")
return self.post_handler(url, **kwargs)
def get(self, url, **kwargs):
self.get_calls.append({"url": url, "kwargs": kwargs})
if self.get_handler is None:
raise AssertionError("unexpected get call")
return self.get_handler(url, **kwargs)
class DummyHTTPClient:
def __init__(self, proxy_url=None):
self.proxy_url = proxy_url
self.session = FakeSession()
self.closed = False
def close(self):
self.closed = True
def post(self, url, **kwargs):
raise AssertionError("unexpected http client post")
class DummyOAuthManager:
def __init__(self, **kwargs):
self.kwargs = kwargs
def start_oauth(self):
return SimpleNamespace(
auth_url="https://auth.example/start",
state="state-1",
code_verifier="verifier-1",
redirect_uri="http://localhost/callback",
)
def make_engine(monkeypatch, email_service=None):
monkeypatch.setattr(
register,
"get_settings",
lambda: SimpleNamespace(
openai_client_id="client-id",
openai_auth_url="https://auth.example/authorize",
openai_token_url="https://auth.example/token",
openai_redirect_uri="http://localhost/callback",
openai_scope="openid email profile offline_access",
),
)
monkeypatch.setattr(register, "OpenAIHTTPClient", DummyHTTPClient)
monkeypatch.setattr(register, "OAuthManager", DummyOAuthManager)
engine = register.RegistrationEngine(email_service or DummyEmailService())
engine.email = "tester@example.com"
engine.email_info = {"email": "tester@example.com", "service_id": "svc-1"}
return engine
@pytest.mark.parametrize(
"page_type",
[
"login_password",
OPENAI_PAGE_TYPES["EMAIL_OTP_VERIFICATION"],
"consent_required",
"some_other_page",
],
)
def test_submit_login_form_accepts_any_http_200_page(monkeypatch, page_type):
engine = make_engine(monkeypatch)
engine.session = FakeSession(
post_handler=lambda url, **kwargs: FakeResponse(
status_code=200,
payload={"page": {"type": page_type}},
)
)
result = engine._submit_login_form("did-1", "sen-1")
assert result.success is True
assert result.page_type == page_type
assert result.error_message == ""
def test_submit_login_form_accepts_http_200_even_when_json_is_invalid(monkeypatch):
engine = make_engine(monkeypatch)
engine.session = FakeSession(
post_handler=lambda url, **kwargs: BrokenJSONResponse(status_code=200)
)
result = engine._submit_login_form("did-1", "sen-1")
assert result.success is True
assert result.page_type == ""
assert result.response_data == {}
assert result.error_message == ""
def test_send_passwordless_otp_posts_empty_body(monkeypatch):
engine = make_engine(monkeypatch)
engine.session = FakeSession(
post_handler=lambda url, **kwargs: FakeResponse(status_code=200)
)
success = engine._send_passwordless_otp()
assert success is True
assert len(engine.session.post_calls) == 1
call = engine.session.post_calls[0]
assert call["url"] == OPENAI_API_ENDPOINTS["send_passwordless_otp"]
assert call["kwargs"]["data"] == ""
assert engine._otp_sent_at is not None
def test_send_passwordless_otp_does_not_update_timestamp_on_failure(monkeypatch):
engine = make_engine(monkeypatch)
engine._otp_sent_at = 1234.5
engine.session = FakeSession(
post_handler=lambda url, **kwargs: FakeResponse(status_code=500, text="server error")
)
success = engine._send_passwordless_otp()
assert success is False
assert engine._otp_sent_at == 1234.5
def test_get_verification_code_passes_explicit_otp_timestamp(monkeypatch):
captured = {}
class RecordingEmailService(DummyEmailService):
def get_verification_code(self, email, email_id=None, timeout=120, pattern=None, otp_sent_at=None):
captured["email"] = email
captured["email_id"] = email_id
captured["timeout"] = timeout
captured["pattern"] = pattern
captured["otp_sent_at"] = otp_sent_at
return "654321"
engine = make_engine(monkeypatch, email_service=RecordingEmailService())
code = engine._get_verification_code(otp_sent_at=1234.5)
assert code == "654321"
assert captured["email"] == "tester@example.com"
assert captured["email_id"] == "svc-1"
assert captured["timeout"] == 120
assert captured["otp_sent_at"] == 1234.5
def test_validate_verification_code_accepts_http_200_even_when_json_is_invalid(monkeypatch):
engine = make_engine(monkeypatch)
engine.session = FakeSession(
post_handler=lambda url, **kwargs: BrokenJSONResponse(status_code=200)
)
result = engine._validate_verification_code("123456")
assert result.success is True
assert result.continue_url == ""
assert result.response_data == {}
def test_run_closes_http_client_on_early_failure(monkeypatch):
engine = make_engine(monkeypatch)
tracking_client = DummyHTTPClient()
engine.http_client = tracking_client
monkeypatch.setattr(engine, "_check_ip_location", lambda: (False, None))
result = engine.run()
assert result.success is False
assert tracking_client.closed is True
assert engine.session is None
def test_fallback_to_login_flow_forces_otp_and_continue_url(monkeypatch):
engine = make_engine(monkeypatch)
steps = []
captured = {}
monkeypatch.setattr(engine, "_reset_oauth_session", lambda: steps.append("reset_session") or True)
monkeypatch.setattr(engine, "_get_device_id", lambda: steps.append("get_device_id") or "did-1")
monkeypatch.setattr(engine, "_check_sentinel", lambda did: steps.append("check_sentinel") or "sen-1")
monkeypatch.setattr(
engine,
"_submit_login_form",
lambda did, sen: steps.append("submit_login_form")
or register.SignupFormResult(success=True, page_type="login_password"),
)
def fake_send_passwordless_otp():
steps.append("send_passwordless_otp")
engine._otp_sent_at = 4567.89
return True
def fake_get_verification_code(otp_sent_at=None):
steps.append("get_verification_code")
captured["otp_sent_at"] = otp_sent_at
return "123456"
monkeypatch.setattr(engine, "_send_passwordless_otp", fake_send_passwordless_otp)
monkeypatch.setattr(engine, "_get_verification_code", fake_get_verification_code)
monkeypatch.setattr(
engine,
"_validate_verification_code",
lambda code: steps.append("validate_verification_code")
or register.OTPValidationResult(success=True, continue_url="https://auth.example/continue"),
)
def fake_try_upgrade(continue_url, stage):
steps.append("get_continue_url_and_parse_workspace")
captured["continue_url"] = continue_url
captured["stage"] = stage
return "ws-123"
monkeypatch.setattr(engine, "_try_upgrade_cookie_with_continue_url", fake_try_upgrade)
workspace_id = engine._fallback_to_login_flow()
assert workspace_id == "ws-123"
assert steps == [
"reset_session",
"get_device_id",
"check_sentinel",
"submit_login_form",
"send_passwordless_otp",
"get_verification_code",
"validate_verification_code",
"get_continue_url_and_parse_workspace",
]
assert captured["continue_url"] == "https://auth.example/continue"
assert captured["stage"] == "降级登录 Continue URL"
assert captured["otp_sent_at"] == 4567.89
def test_fallback_to_login_flow_requires_continue_url(monkeypatch):
engine = make_engine(monkeypatch)
monkeypatch.setattr(engine, "_reset_oauth_session", lambda: True)
monkeypatch.setattr(engine, "_get_device_id", lambda: "did-1")
monkeypatch.setattr(engine, "_check_sentinel", lambda did: "sen-1")
monkeypatch.setattr(
engine,
"_submit_login_form",
lambda did, sen: register.SignupFormResult(success=True, page_type="login_password"),
)
def fake_send_passwordless_otp():
engine._otp_sent_at = 9876.5
return True
monkeypatch.setattr(engine, "_send_passwordless_otp", fake_send_passwordless_otp)
monkeypatch.setattr(engine, "_get_verification_code", lambda otp_sent_at=None: "123456")
monkeypatch.setattr(
engine,
"_validate_verification_code",
lambda code: register.OTPValidationResult(success=True, continue_url=""),
)
def fail_if_called(*args, **kwargs):
raise AssertionError("continue_url 缺失时不应尝试升级 Cookie")
monkeypatch.setattr(engine, "_try_upgrade_cookie_with_continue_url", fail_if_called)
assert engine._fallback_to_login_flow() is None
def test_fallback_to_login_flow_accepts_workspace_without_continue_url(monkeypatch):
engine = make_engine(monkeypatch)
steps = []
monkeypatch.setattr(engine, "_reset_oauth_session", lambda: steps.append("reset_session") or True)
monkeypatch.setattr(engine, "_get_device_id", lambda: steps.append("get_device_id") or "did-1")
monkeypatch.setattr(engine, "_check_sentinel", lambda did: steps.append("check_sentinel") or "sen-1")
monkeypatch.setattr(
engine,
"_submit_login_form",
lambda did, sen: steps.append("submit_login_form")
or register.SignupFormResult(success=True, page_type="login_password"),
)
def fake_send_passwordless_otp():
steps.append("send_passwordless_otp")
engine._otp_sent_at = 4567.89
return True
monkeypatch.setattr(engine, "_send_passwordless_otp", fake_send_passwordless_otp)
monkeypatch.setattr(
engine,
"_get_verification_code",
lambda otp_sent_at=None: steps.append("get_verification_code") or "123456",
)
monkeypatch.setattr(
engine,
"_validate_verification_code",
lambda code: steps.append("validate_verification_code")
or register.OTPValidationResult(success=True, continue_url=""),
)
monkeypatch.setattr(
engine,
"_get_workspace_id",
lambda log_missing=True: steps.append("get_workspace_id") or "ws-cookie",
)
def fail_if_called(*args, **kwargs):
raise AssertionError("已有 workspace 时不应继续访问 continue_url")
monkeypatch.setattr(engine, "_try_upgrade_cookie_with_continue_url", fail_if_called)
workspace_id = engine._fallback_to_login_flow()
assert workspace_id == "ws-cookie"
assert steps == [
"reset_session",
"get_device_id",
"check_sentinel",
"submit_login_form",
"send_passwordless_otp",
"get_verification_code",
"validate_verification_code",
"get_workspace_id",
]
def test_get_verification_code_uses_provider_timeout_and_refreshes_once(monkeypatch):
captured = {"calls": [], "refresh_count": 0}
class RefreshableOutlookService(DummyEmailService):
def __init__(self):
super().__init__()
self.service_type = EmailServiceType.OUTLOOK
def get_verification_code(self, email, email_id=None, timeout=120, pattern=None, otp_sent_at=None):
captured["calls"].append(
{
"timeout": timeout,
"otp_sent_at": otp_sent_at,
"email": email,
"email_id": email_id,
}
)
if len(captured["calls"]) == 1:
return None
return "987654"
def refresh_session(self):
captured["refresh_count"] += 1
engine = make_engine(monkeypatch, email_service=RefreshableOutlookService())
code = engine._get_verification_code(otp_sent_at=2468.0)
assert code == "987654"
assert captured["refresh_count"] == 1
assert len(captured["calls"]) == 2
assert captured["calls"][0]["timeout"] == 180
assert captured["calls"][1]["timeout"] == 180
assert captured["calls"][0]["otp_sent_at"] == 2468.0
def test_try_upgrade_cookie_with_continue_url_retries_with_second_probe(monkeypatch):
engine = make_engine(monkeypatch)
sleep_calls = []
workspace_results = iter([None, None, None, None, None, "ws-delayed"])
engine.session = FakeSession(
get_handler=lambda url, **kwargs: FakeResponse(status_code=302),
cookies={},
)
monkeypatch.setattr(engine, "_log_cookie_state", lambda *args, **kwargs: None)
monkeypatch.setattr(engine, "_get_workspace_id", lambda log_missing=False: next(workspace_results))
monkeypatch.setattr(register.time, "sleep", lambda seconds: sleep_calls.append(seconds))
workspace_id = engine._try_upgrade_cookie_with_continue_url(
"https://auth.example/continue",
"降级登录 Continue URL",
)
assert workspace_id == "ws-delayed"
assert len(engine.session.get_calls) == 3
assert sleep_calls == [1.0, 2.0, 4.0]

View File

@@ -0,0 +1,102 @@
from types import SimpleNamespace
from src.database import crud
from src.database.session import DatabaseSessionManager
from src.web.routes import registration
from src.core.register import RegistrationResult
def test_run_sync_registration_task_disables_bad_proxy_and_retries(monkeypatch, tmp_path):
manager = DatabaseSessionManager(f"sqlite:///{tmp_path}/test.db")
manager.create_tables()
manager.migrate_tables()
with manager.session_scope() as session:
primary_proxy = crud.create_proxy(
session,
name="primary",
type="http",
host="127.0.0.1",
port=8001,
)
crud.update_proxy(session, primary_proxy.id, is_default=True)
backup_proxy = crud.create_proxy(
session,
name="backup",
type="http",
host="127.0.0.1",
port=8002,
)
email_service = crud.create_email_service(
session,
service_type="tempmail",
name="tempmail-db",
config={"base_url": "https://mail.example/api"},
)
crud.create_registration_task(session, task_uuid="task-proxy-failover")
primary_proxy_id = primary_proxy.id
backup_proxy_id = backup_proxy.id
email_service_id = email_service.id
monkeypatch.setattr(registration, "get_db", manager.session_scope)
monkeypatch.setattr(
registration,
"EmailServiceFactory",
SimpleNamespace(create=lambda service_type, config: SimpleNamespace(service_type=service_type, config=config)),
)
attempted_proxies = []
saved_results = []
class FakeRegistrationEngine:
def __init__(self, email_service, proxy_url=None, callback_logger=None, task_uuid=None):
self.proxy_url = proxy_url
def run(self):
attempted_proxies.append(self.proxy_url)
if self.proxy_url.endswith(":8001"):
return RegistrationResult(
success=False,
email="proxy@example.com",
error_message="OpenAI 请求失败: curl: (35) TLS handshake failed",
)
return RegistrationResult(
success=True,
email="proxy@example.com",
access_token="access-token",
workspace_id="ws-123",
)
def save_to_database(self, result):
saved_results.append(result.email)
return True
monkeypatch.setattr(registration, "RegistrationEngine", FakeRegistrationEngine)
registration._run_sync_registration_task(
task_uuid="task-proxy-failover",
email_service_type="tempmail",
proxy=None,
email_service_config=None,
email_service_id=email_service_id,
)
assert attempted_proxies == [
"http://127.0.0.1:8001",
"http://127.0.0.1:8002",
]
assert saved_results == ["proxy@example.com"]
with manager.session_scope() as session:
disabled_primary = crud.get_proxy_by_id(session, primary_proxy_id)
active_backup = crud.get_proxy_by_id(session, backup_proxy_id)
task = crud.get_registration_task_by_uuid(session, "task-proxy-failover")
assert disabled_primary is not None
assert disabled_primary.enabled is False
assert active_backup is not None
assert active_backup.enabled is True
assert task is not None
assert task.status == "completed"
assert task.proxy == "http://127.0.0.1:8002"