mirror of
https://github.com/cnlimiter/codex-register.git
synced 2026-06-25 17:24:06 +08:00
fix(harden): isolate resource cleanup and self-healing flow
This commit is contained in:
@@ -68,9 +68,10 @@ OPENAI_API_ENDPOINTS = {
|
||||
"passwordless_send_otp": "https://auth.openai.com/api/accounts/passwordless/send-otp",
|
||||
"validate_otp": "https://auth.openai.com/api/accounts/email-otp/validate",
|
||||
"create_account": "https://auth.openai.com/api/accounts/create_account",
|
||||
"add_phone" : "https://auth.openai.com/add-phone",
|
||||
"add_phone": "https://auth.openai.com/add-phone",
|
||||
"select_workspace": "https://auth.openai.com/api/accounts/workspace/select",
|
||||
"password_verify" : "https://auth.openai.com/api/accounts/password/verify"
|
||||
"send_passwordless_otp": "https://auth.openai.com/api/accounts/passwordless/send-otp",
|
||||
"password_verify": "https://auth.openai.com/api/accounts/password/verify",
|
||||
}
|
||||
|
||||
# OpenAI 页面类型(用于判断账号状态)
|
||||
|
||||
@@ -288,7 +288,7 @@ class OpenAIHTTPClient(HTTPClient):
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"检查 IP 地理位置失败: {e}")
|
||||
return False, None
|
||||
return False, str(e)
|
||||
|
||||
def send_openai_request(
|
||||
self,
|
||||
@@ -417,4 +417,4 @@ def create_openai_client(
|
||||
Returns:
|
||||
OpenAIHTTPClient 实例
|
||||
"""
|
||||
return OpenAIHTTPClient(proxy_url, config)
|
||||
return OpenAIHTTPClient(proxy_url, config)
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
import re
|
||||
import json
|
||||
import time
|
||||
import hashlib
|
||||
import logging
|
||||
import secrets
|
||||
import string
|
||||
@@ -34,6 +35,18 @@ from ..config.settings import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
WORKSPACE_PROBE_BACKOFF_DELAYS = (1.0, 2.0, 4.0)
|
||||
WORKSPACE_CONTINUE_URL_MAX_ATTEMPTS = len(WORKSPACE_PROBE_BACKOFF_DELAYS)
|
||||
PROVIDER_VERIFICATION_TIMEOUTS = {
|
||||
EmailServiceType.OUTLOOK: 180,
|
||||
EmailServiceType.IMAP_MAIL: 150,
|
||||
EmailServiceType.MOE_MAIL: 150,
|
||||
EmailServiceType.TEMPMAIL: 120,
|
||||
EmailServiceType.TEMP_MAIL: 120,
|
||||
EmailServiceType.DUCK_MAIL: 90,
|
||||
EmailServiceType.FREEMAIL: 120,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class RegistrationResult:
|
||||
@@ -81,6 +94,15 @@ class SignupFormResult:
|
||||
error_message: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class OTPValidationResult:
|
||||
"""OTP 校验结果"""
|
||||
success: bool
|
||||
continue_url: str = ""
|
||||
response_data: Dict[str, Any] = None
|
||||
error_message: str = ""
|
||||
|
||||
|
||||
class RegistrationEngine:
|
||||
"""
|
||||
注册引擎
|
||||
@@ -132,6 +154,22 @@ class RegistrationEngine:
|
||||
self.logs: list = []
|
||||
self._otp_sent_at: Optional[float] = None # OTP 发送时间戳
|
||||
self._is_existing_account: bool = False # 是否为已注册账号(用于自动登录)
|
||||
self._last_continue_url: Optional[str] = None # 最近一次 OTP/Workspace 返回的 continue_url
|
||||
self._last_cookie_probe: Optional[Dict[str, Any]] = None # 最近一次 Cookie 探针快照
|
||||
|
||||
def close(self):
|
||||
"""关闭底层 HTTP 资源,避免批量任务长期运行时连接泄漏。"""
|
||||
try:
|
||||
self._close_http_client()
|
||||
except Exception as e:
|
||||
logger.warning(f"关闭注册引擎 HTTP 客户端失败: {e}")
|
||||
|
||||
def _close_http_client(self):
|
||||
"""关闭底层 HTTP 资源并清理关联会话引用。"""
|
||||
try:
|
||||
self.http_client.close()
|
||||
finally:
|
||||
self.session = None
|
||||
|
||||
def _log(self, message: str, level: str = "info"):
|
||||
"""记录日志"""
|
||||
@@ -266,7 +304,7 @@ class RegistrationEngine:
|
||||
|
||||
if attempt < max_attempts:
|
||||
time.sleep(attempt)
|
||||
self.http_client.close()
|
||||
self._close_http_client()
|
||||
self.session = self.http_client.session
|
||||
|
||||
return None
|
||||
@@ -466,51 +504,111 @@ class RegistrationEngine:
|
||||
self._log(f"发送验证码失败: {e}", 'error')
|
||||
return False
|
||||
|
||||
def _get_verification_code(self) -> Optional[str]:
|
||||
def _get_verification_code(self, otp_sent_at: Optional[float] = None) -> Optional[str]:
|
||||
"""获取验证码"""
|
||||
try:
|
||||
self._log(f"正在等待邮箱 {self.email} 的验证码...")
|
||||
effective_otp_sent_at = otp_sent_at if otp_sent_at is not None else self._otp_sent_at
|
||||
verification_timeout = self._get_provider_verification_timeout()
|
||||
if effective_otp_sent_at:
|
||||
self._log(
|
||||
f"正在等待邮箱 {self.email} 的验证码... "
|
||||
f"(provider={self.email_service.service_type.value}, timeout={verification_timeout}s, "
|
||||
f"otp_sent_at={effective_otp_sent_at:.3f})"
|
||||
)
|
||||
else:
|
||||
self._log(
|
||||
f"正在等待邮箱 {self.email} 的验证码... "
|
||||
f"(provider={self.email_service.service_type.value}, timeout={verification_timeout}s)"
|
||||
)
|
||||
|
||||
email_id = self.email_info.get("service_id") if self.email_info else None
|
||||
code = self.email_service.get_verification_code(
|
||||
email=self.email,
|
||||
email_id=email_id,
|
||||
timeout=120,
|
||||
timeout=verification_timeout,
|
||||
pattern=OTP_CODE_PATTERN,
|
||||
otp_sent_at=self._otp_sent_at,
|
||||
otp_sent_at=effective_otp_sent_at,
|
||||
)
|
||||
|
||||
if code:
|
||||
self._log(f"成功获取验证码: {code}")
|
||||
return code
|
||||
else:
|
||||
self._log("等待验证码超时", 'error')
|
||||
|
||||
self._log("首次等待验证码超时,尝试主动刷新邮箱会话后再重试一次...", "warning")
|
||||
if not self._refresh_verification_session():
|
||||
self._log("邮箱会话刷新失败,停止验证码重试", "error")
|
||||
return None
|
||||
|
||||
retry_code = self.email_service.get_verification_code(
|
||||
email=self.email,
|
||||
email_id=email_id,
|
||||
timeout=verification_timeout,
|
||||
pattern=OTP_CODE_PATTERN,
|
||||
otp_sent_at=effective_otp_sent_at,
|
||||
)
|
||||
if retry_code:
|
||||
self._log(f"刷新会话后成功获取验证码: {retry_code}")
|
||||
return retry_code
|
||||
|
||||
self._log("等待验证码超时", "error")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
self._log(f"获取验证码失败: {e}", 'error')
|
||||
return None
|
||||
|
||||
def _validate_verification_code(self, code: str) -> bool:
|
||||
"""验证验证码"""
|
||||
def _get_provider_verification_timeout(self) -> int:
|
||||
"""根据邮箱 provider 推导验证码等待超时。"""
|
||||
settings = get_settings()
|
||||
base_timeout = max(int(getattr(settings, "email_code_timeout", 120) or 120), 30)
|
||||
provider_timeout = PROVIDER_VERIFICATION_TIMEOUTS.get(self.email_service.service_type, base_timeout)
|
||||
return max(base_timeout, provider_timeout)
|
||||
|
||||
def _refresh_verification_session(self) -> bool:
|
||||
"""主动刷新邮箱服务会话,给验证码二次探测提供新连接状态。"""
|
||||
try:
|
||||
code_body = f'{{"code":"{code}"}}'
|
||||
refresh_handler = getattr(self.email_service, "refresh_session", None)
|
||||
if callable(refresh_handler):
|
||||
refresh_handler()
|
||||
self._log("邮箱服务已执行 refresh_session()")
|
||||
return True
|
||||
|
||||
response = self.session.post(
|
||||
OPENAI_API_ENDPOINTS["validate_otp"],
|
||||
headers={
|
||||
"referer": "https://auth.openai.com/email-verification",
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
},
|
||||
data=code_body,
|
||||
)
|
||||
refreshed = False
|
||||
http_client = getattr(self.email_service, "http_client", None)
|
||||
if http_client:
|
||||
client_cls = http_client.__class__
|
||||
proxy_url = getattr(http_client, "proxy_url", None)
|
||||
config = getattr(http_client, "config", None)
|
||||
try:
|
||||
http_client.close()
|
||||
except Exception:
|
||||
pass
|
||||
setattr(self.email_service, "http_client", client_cls(proxy_url=proxy_url, config=config))
|
||||
refreshed = True
|
||||
|
||||
self._log(f"验证码校验状态: {response.status_code}")
|
||||
return response.status_code == 200
|
||||
providers = getattr(self.email_service, "_providers", None)
|
||||
provider_lock = getattr(self.email_service, "_provider_lock", None)
|
||||
if isinstance(providers, dict):
|
||||
if provider_lock:
|
||||
with provider_lock:
|
||||
providers.clear()
|
||||
else:
|
||||
providers.clear()
|
||||
refreshed = True
|
||||
|
||||
reset_provider_health = getattr(self.email_service, "reset_provider_health", None)
|
||||
if callable(reset_provider_health):
|
||||
reset_provider_health()
|
||||
refreshed = True
|
||||
|
||||
if refreshed:
|
||||
self._log("邮箱服务底层会话已主动刷新")
|
||||
return True
|
||||
|
||||
self._log("当前邮箱服务未暴露可刷新的会话句柄", "warning")
|
||||
return False
|
||||
except Exception as e:
|
||||
self._log(f"验证验证码失败: {e}", 'error')
|
||||
self._log(f"刷新邮箱会话失败: {e}", "warning")
|
||||
return False
|
||||
|
||||
def _create_user_account(self) -> bool:
|
||||
@@ -542,18 +640,756 @@ class RegistrationEngine:
|
||||
self._log(f"创建账户失败: {e}", 'error')
|
||||
return False
|
||||
|
||||
def _add_phone(self) -> bool:
|
||||
"""获取 手机验证码"""
|
||||
phone_body = f'{{"code":"{code}"}}'
|
||||
response = self.session.post(
|
||||
OPENAI_API_ENDPOINTS["add_phone"],
|
||||
headers={
|
||||
"referer": "https://auth.openai.com/api/accounts/create_account",
|
||||
def _fingerprint_cookie_value(self, value: Optional[str]) -> str:
|
||||
"""生成 Cookie 指纹,便于审计变化而不直接输出完整值"""
|
||||
normalized = str(value or "").strip()
|
||||
if not normalized:
|
||||
return "-"
|
||||
return hashlib.sha256(normalized.encode("utf-8")).hexdigest()[:12]
|
||||
|
||||
def _build_cookie_probe(self) -> Dict[str, str]:
|
||||
"""构建关键 Cookie 状态快照"""
|
||||
cookies = self.session.cookies
|
||||
auth_cookie = cookies.get("oai-client-auth-session")
|
||||
session_cookie = cookies.get("__Secure-next-auth.session-token")
|
||||
device_cookie = cookies.get("oai-did")
|
||||
auth_segments = len((auth_cookie or "").split(".")) if auth_cookie else 0
|
||||
workspace_id = None
|
||||
workspace_source = "auth_cookie_missing"
|
||||
|
||||
if auth_cookie:
|
||||
workspace_id, workspace_source = self._decode_workspace_id_from_auth_cookie(auth_cookie)
|
||||
|
||||
return {
|
||||
"did": "Y" if device_cookie else "N",
|
||||
"did_fp": self._fingerprint_cookie_value(device_cookie),
|
||||
"auth": "Y" if auth_cookie else "N",
|
||||
"auth_segments": str(auth_segments),
|
||||
"auth_jwt_state": "complete_jwt" if auth_segments == 3 else "partial_or_empty",
|
||||
"auth_fp": self._fingerprint_cookie_value(auth_cookie),
|
||||
"next_auth": "Y" if session_cookie else "N",
|
||||
"next_auth_fp": self._fingerprint_cookie_value(session_cookie),
|
||||
"workspace": workspace_id or "-",
|
||||
"workspace_source": workspace_source,
|
||||
}
|
||||
|
||||
def _log_cookie_state(self, stage: str, reset_baseline: bool = False):
|
||||
"""记录关键 Cookie 状态与变化,用于观测状态机是否闭环"""
|
||||
if not self.session:
|
||||
self._log(f"{stage}: Session 未初始化", "warning")
|
||||
return
|
||||
|
||||
probe = self._build_cookie_probe()
|
||||
|
||||
self._log(
|
||||
f"{stage}: Cookie 探针 did={probe['did']}({probe['did_fp']}), "
|
||||
f"auth={probe['auth']}(segments={probe['auth_segments']},state={probe['auth_jwt_state']},fp={probe['auth_fp']}), "
|
||||
f"next_auth={probe['next_auth']}({probe['next_auth_fp']}), "
|
||||
f"workspace={probe['workspace']}(source={probe['workspace_source']})"
|
||||
)
|
||||
|
||||
previous_probe = None if reset_baseline else self._last_cookie_probe
|
||||
if previous_probe:
|
||||
changes = []
|
||||
tracked_fields = (
|
||||
("did", "did"),
|
||||
("did_fp", "did_fp"),
|
||||
("auth", "auth"),
|
||||
("auth_segments", "auth_segments"),
|
||||
("auth_jwt_state", "auth_jwt_state"),
|
||||
("auth_fp", "auth_fp"),
|
||||
("next_auth", "next_auth"),
|
||||
("next_auth_fp", "next_auth_fp"),
|
||||
("workspace", "workspace"),
|
||||
("workspace_source", "workspace_source"),
|
||||
)
|
||||
|
||||
for key, label in tracked_fields:
|
||||
old_value = previous_probe.get(key)
|
||||
new_value = probe.get(key)
|
||||
if old_value != new_value:
|
||||
changes.append(f"{label}:{old_value}->{new_value}")
|
||||
|
||||
if changes:
|
||||
self._log(f"{stage}: Cookie 变化 {'; '.join(changes)}")
|
||||
else:
|
||||
self._log(f"{stage}: Cookie 无变化")
|
||||
else:
|
||||
baseline_action = "重置" if reset_baseline else "建立"
|
||||
self._log(f"{stage}: Cookie 探针基线已{baseline_action}")
|
||||
|
||||
self._last_cookie_probe = probe
|
||||
|
||||
def _decode_workspace_id_from_auth_cookie(self, auth_cookie: str) -> Tuple[Optional[str], str]:
|
||||
"""从授权 Cookie 中解析 Workspace ID"""
|
||||
import base64
|
||||
import json as json_module
|
||||
|
||||
cookie_value = str(auth_cookie or "").strip()
|
||||
if not cookie_value:
|
||||
return None, "empty_auth_cookie"
|
||||
|
||||
segments = cookie_value.split(".")
|
||||
decoded_payloads = []
|
||||
|
||||
for index, payload in enumerate(segments[:3]):
|
||||
raw_segment = str(payload or "").strip()
|
||||
if not raw_segment:
|
||||
continue
|
||||
|
||||
try:
|
||||
pad = "=" * ((4 - (len(raw_segment) % 4)) % 4)
|
||||
decoded = base64.urlsafe_b64decode((raw_segment + pad).encode("ascii"))
|
||||
parsed = json_module.loads(decoded.decode("utf-8"))
|
||||
if isinstance(parsed, dict):
|
||||
decoded_payloads.append((index, parsed))
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if not decoded_payloads:
|
||||
return None, f"cookie_decode_failed:segments={len(segments)}"
|
||||
|
||||
for index, payload in decoded_payloads:
|
||||
workspaces = payload.get("workspaces") or []
|
||||
if not workspaces:
|
||||
continue
|
||||
|
||||
workspace_id = str((workspaces[0] or {}).get("id") or "").strip()
|
||||
if workspace_id:
|
||||
return workspace_id, f"cookie_segment_{index}"
|
||||
|
||||
return None, f"workspace_missing:segments={len(segments)}"
|
||||
|
||||
def _try_upgrade_cookie_with_continue_url(self, continue_url: str, stage: str) -> Optional[str]:
|
||||
"""访问 continue_url,驱动授权 Cookie 升级后再次解析 workspace"""
|
||||
normalized_url = str(continue_url or "").strip()
|
||||
if not normalized_url:
|
||||
self._log(f"{stage}: continue_url 为空,无法升级 Cookie", "warning")
|
||||
return None
|
||||
|
||||
self._last_continue_url = normalized_url
|
||||
|
||||
for attempt, probe_delay in enumerate(WORKSPACE_PROBE_BACKOFF_DELAYS, start=1):
|
||||
try:
|
||||
self._log(
|
||||
f"{stage}: 第 {attempt}/{WORKSPACE_CONTINUE_URL_MAX_ATTEMPTS} 次访问 continue_url 以升级授权 Cookie"
|
||||
)
|
||||
response = self.session.get(
|
||||
normalized_url,
|
||||
allow_redirects=True,
|
||||
timeout=15,
|
||||
)
|
||||
self._log(f"{stage}: Continue URL 响应状态 {response.status_code}")
|
||||
self._log_cookie_state(f"{stage} 后")
|
||||
|
||||
workspace_id = self._get_workspace_id(log_missing=False)
|
||||
if workspace_id:
|
||||
self._log(f"{stage}: Continue URL 升级后获取到 Workspace ID")
|
||||
return workspace_id
|
||||
|
||||
self._log(
|
||||
f"{stage}: 首次探测仍未拿到 Workspace,等待 {probe_delay:.1f}s 后执行二次探测..."
|
||||
)
|
||||
time.sleep(probe_delay)
|
||||
workspace_id = self._get_workspace_id(log_missing=False)
|
||||
if workspace_id:
|
||||
self._log(f"{stage}: 二次探测后获取到 Workspace ID")
|
||||
return workspace_id
|
||||
except Exception as e:
|
||||
self._log(f"{stage}: Continue URL 访问异常: {e}", "warning")
|
||||
|
||||
if attempt < WORKSPACE_CONTINUE_URL_MAX_ATTEMPTS:
|
||||
self._log(
|
||||
f"{stage}: 第 {attempt} 次二次探测仍未拿到 Workspace,继续下一轮 Continue URL 补偿..."
|
||||
)
|
||||
|
||||
self._log(f"{stage}: Continue URL 已访问但 Workspace 仍不可用", "warning")
|
||||
return None
|
||||
|
||||
def _reset_oauth_session(self) -> bool:
|
||||
"""重置 OAuth 会话,用于降级登录流程"""
|
||||
try:
|
||||
self._log("正在重置 HTTP 会话以开始降级登录流程...")
|
||||
if self.session:
|
||||
self._log_cookie_state("降级流程会话重置前")
|
||||
# 关闭旧客户端
|
||||
self._close_http_client()
|
||||
# 创建新客户端和会话
|
||||
self.http_client = OpenAIHTTPClient(proxy_url=self.proxy_url)
|
||||
self.session = self.http_client.session
|
||||
self.session_token = None
|
||||
self._last_continue_url = None
|
||||
self._last_cookie_probe = None
|
||||
|
||||
# 生成新的 OAuth URL
|
||||
self.oauth_start = self.oauth_manager.start_oauth()
|
||||
self._log(f"新 OAuth URL 已生成,准备开始登录握手")
|
||||
self._log_cookie_state("降级流程会话重置后", reset_baseline=True)
|
||||
return True
|
||||
except Exception as e:
|
||||
self._log(f"重置会话失败: {e}", "error")
|
||||
return False
|
||||
|
||||
def _submit_login_form(self, did: str, sen_token: Optional[str]) -> SignupFormResult:
|
||||
"""提交登录表单(screen_hint=login)"""
|
||||
try:
|
||||
login_body = f'{{"username":{{"value":"{self.email}","kind":"email"}},"screen_hint":"login"}}'
|
||||
|
||||
headers = {
|
||||
"referer": "https://auth.openai.com/login",
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
},
|
||||
data=create_account_body,
|
||||
}
|
||||
|
||||
if sen_token:
|
||||
sentinel = f'{{"p": "", "t": "", "c": "{sen_token}", "id": "{did}", "flow": "authorize_continue"}}'
|
||||
headers["openai-sentinel-token"] = sentinel
|
||||
|
||||
response = self.session.post(
|
||||
OPENAI_API_ENDPOINTS["signup"],
|
||||
headers=headers,
|
||||
data=login_body,
|
||||
)
|
||||
|
||||
self._log(f"登录表单提交状态: {response.status_code}")
|
||||
self._log_cookie_state("降级登录提交邮箱后")
|
||||
if response.status_code != 200:
|
||||
return SignupFormResult(success=False, error_message=f"登录表单提交失败: HTTP {response.status_code}")
|
||||
|
||||
response_data: Dict[str, Any] = {}
|
||||
page_type = ""
|
||||
try:
|
||||
parsed = response.json()
|
||||
if isinstance(parsed, dict):
|
||||
response_data = parsed
|
||||
page_type = str(parsed.get("page", {}).get("type", "") or "").strip()
|
||||
else:
|
||||
self._log(f"登录表单响应不是对象: {type(parsed).__name__},按 HTTP 200 继续无密码 OTP", "warning")
|
||||
except Exception as parse_error:
|
||||
self._log(f"解析登录表单响应失败: {parse_error},按 HTTP 200 继续无密码 OTP", "warning")
|
||||
|
||||
if page_type:
|
||||
self._log(f"登录表单响应页面类型: {page_type}")
|
||||
else:
|
||||
self._log("登录表单响应未返回 page.type,按 HTTP 200 继续无密码 OTP", "warning")
|
||||
|
||||
if page_type == "login_password":
|
||||
self._log("降级登录进入 login_password,按 Issue #62 继续触发无密码 OTP")
|
||||
elif page_type == OPENAI_PAGE_TYPES["EMAIL_OTP_VERIFICATION"]:
|
||||
self._log("降级登录进入 email_otp_verification,继续执行无密码 OTP 闭环")
|
||||
elif page_type:
|
||||
self._log(f"降级登录进入 {page_type},按 Issue #62 不做页面拦截,继续触发无密码 OTP")
|
||||
|
||||
return SignupFormResult(
|
||||
success=True,
|
||||
page_type=page_type,
|
||||
response_data=response_data,
|
||||
)
|
||||
except Exception as e:
|
||||
self._log(f"提交登录表单异常: {e}", "error")
|
||||
return SignupFormResult(success=False, error_message=str(e))
|
||||
|
||||
def _send_passwordless_otp(self) -> bool:
|
||||
"""发送无密码 OTP 验证码"""
|
||||
try:
|
||||
self._log("正在触发无密码登录 OTP...")
|
||||
otp_sent_at = time.time()
|
||||
self._log_cookie_state("无密码 OTP 发送前")
|
||||
|
||||
response = self.session.post(
|
||||
OPENAI_API_ENDPOINTS["send_passwordless_otp"],
|
||||
headers={
|
||||
"referer": "https://auth.openai.com/login/password",
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
},
|
||||
data=""
|
||||
)
|
||||
|
||||
self._log(f"无密码 OTP 发送状态: {response.status_code}")
|
||||
self._log_cookie_state("无密码 OTP 发送后")
|
||||
if response.status_code != 200:
|
||||
self._log(f"无密码 OTP 发送失败响应: {response.text[:200]}", "warning")
|
||||
else:
|
||||
self._otp_sent_at = otp_sent_at
|
||||
self._log(f"无密码 OTP 时间戳已更新: {self._otp_sent_at:.3f}")
|
||||
return response.status_code == 200
|
||||
except Exception as e:
|
||||
self._log(f"发送无密码 OTP 失败: {e}", "error")
|
||||
return False
|
||||
|
||||
def _validate_verification_code(self, code: str) -> OTPValidationResult:
|
||||
"""验证验证码并返回 continue_url 等状态信息"""
|
||||
try:
|
||||
code_body = f'{{"code":"{code}"}}'
|
||||
|
||||
response = self.session.post(
|
||||
OPENAI_API_ENDPOINTS["validate_otp"],
|
||||
headers={
|
||||
"referer": "https://auth.openai.com/email-verification",
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
},
|
||||
data=code_body,
|
||||
)
|
||||
|
||||
self._log(f"验证码校验状态: {response.status_code}")
|
||||
self._log_cookie_state("OTP 校验响应后")
|
||||
if response.status_code != 200:
|
||||
return OTPValidationResult(
|
||||
success=False,
|
||||
error_message=f"HTTP {response.status_code}: {response.text[:200]}"
|
||||
)
|
||||
|
||||
try:
|
||||
parsed = response.json()
|
||||
if isinstance(parsed, dict):
|
||||
data = parsed
|
||||
else:
|
||||
self._log(
|
||||
f"验证码校验响应不是对象: {type(parsed).__name__},继续依据 Cookie 状态判断",
|
||||
"warning"
|
||||
)
|
||||
data = {}
|
||||
except Exception as parse_error:
|
||||
self._log(
|
||||
f"解析验证码校验响应失败: {parse_error},继续依据 Cookie 状态判断",
|
||||
"warning"
|
||||
)
|
||||
data = {}
|
||||
|
||||
continue_url = str(data.get("continue_url") or "").strip()
|
||||
if continue_url:
|
||||
self._last_continue_url = continue_url
|
||||
self._log(f"验证码校验返回 Continue URL: {continue_url[:60]}...")
|
||||
else:
|
||||
self._log("验证码校验成功,但响应中未包含 continue_url", "warning")
|
||||
|
||||
return OTPValidationResult(
|
||||
success=True,
|
||||
continue_url=continue_url,
|
||||
response_data=data,
|
||||
)
|
||||
except Exception as e:
|
||||
self._log(f"验证验证码失败: {e}", "error")
|
||||
return OTPValidationResult(success=False, error_message=str(e))
|
||||
|
||||
def _resolve_workspace_id(self) -> Optional[str]:
|
||||
"""按主闭环优先、降级补偿兜底的顺序解析 Workspace ID"""
|
||||
workspace_id = self._get_workspace_id(log_missing=False)
|
||||
if workspace_id:
|
||||
return workspace_id
|
||||
|
||||
if self._last_continue_url:
|
||||
self._log("主流程首次探测未获取到 Workspace,尝试复用最近的 continue_url 升级 Cookie...")
|
||||
workspace_id = self._try_upgrade_cookie_with_continue_url(
|
||||
self._last_continue_url,
|
||||
"主流程 Continue URL 补偿"
|
||||
)
|
||||
if workspace_id:
|
||||
return workspace_id
|
||||
|
||||
self._log("主流程 Workspace 探测仍失败,切换到降级登录补偿流程...", "warning")
|
||||
return self._fallback_to_login_flow()
|
||||
|
||||
def _fallback_to_login_flow(self) -> Optional[str]:
|
||||
"""
|
||||
[核心修复] 降级到登录流程以获取 Workspace ID
|
||||
参考 Issue #62 解决方案
|
||||
"""
|
||||
self._log("-" * 40)
|
||||
self._log("检测到 Cookie 缺失 Workspace,启动降级登录补偿流程...")
|
||||
|
||||
# 1. 重置会话
|
||||
if not self._reset_oauth_session():
|
||||
return None
|
||||
|
||||
# 2. 获取 Device ID
|
||||
did = self._get_device_id()
|
||||
if not did:
|
||||
self._log("降级登录未能获取 Device ID", "error")
|
||||
return None
|
||||
|
||||
# 3. Sentinel 检查
|
||||
sen_token = self._check_sentinel(did)
|
||||
|
||||
# 4. 提交登录表单
|
||||
login_res = self._submit_login_form(did, sen_token)
|
||||
if not login_res.success:
|
||||
self._log(f"降级登录表单失败: {login_res.error_message}", "error")
|
||||
return None
|
||||
|
||||
# 5. 发送无密码 OTP
|
||||
if not self._send_passwordless_otp():
|
||||
self._log("降级登录未能触发无密码 OTP", "error")
|
||||
return None
|
||||
|
||||
login_otp_sent_at = self._otp_sent_at
|
||||
if not login_otp_sent_at:
|
||||
login_otp_sent_at = time.time()
|
||||
self._otp_sent_at = login_otp_sent_at
|
||||
self._log("降级登录未记录到无密码 OTP 时间戳,使用当前时间兜底", "warning")
|
||||
|
||||
self._log(f"降级登录 OTP 时间线已校准: {login_otp_sent_at:.3f}")
|
||||
|
||||
# 6. 获取邮件验证码
|
||||
self._log("正在等待新的登录验证码...")
|
||||
code = self._get_verification_code(otp_sent_at=login_otp_sent_at)
|
||||
if not code:
|
||||
self._log("降级登录未获取到新的 OTP 验证码", "error")
|
||||
return None
|
||||
|
||||
# 7. 验证 OTP -> GET continue_url -> 解析 Workspace
|
||||
otp_result = self._validate_verification_code(code)
|
||||
if not otp_result.success:
|
||||
self._log(f"OTP 校验失败: {otp_result.error_message}", "error")
|
||||
return None
|
||||
|
||||
self._log_cookie_state("降级登录 OTP 校验后")
|
||||
|
||||
workspace_id = self._get_workspace_id(log_missing=False)
|
||||
if workspace_id:
|
||||
self._log("降级登录 OTP 校验后已直接获得 Workspace,无需继续访问 continue_url")
|
||||
return workspace_id
|
||||
|
||||
continue_url = str(otp_result.continue_url or "").strip()
|
||||
if not continue_url:
|
||||
self._log("OTP 校验成功但缺少 continue_url,无法执行 Cookie 升级", "error")
|
||||
return None
|
||||
|
||||
self._log("降级登录 OTP 校验成功,开始 GET continue_url 升级 Cookie")
|
||||
workspace_id = self._try_upgrade_cookie_with_continue_url(
|
||||
continue_url,
|
||||
"降级登录 Continue URL"
|
||||
)
|
||||
if not workspace_id:
|
||||
self._log("降级登录闭环完成,但仍未拿到 Workspace", "error")
|
||||
return workspace_id
|
||||
|
||||
def _get_workspace_id(self, log_missing: bool = True) -> Optional[str]:
|
||||
"""获取 Workspace ID"""
|
||||
try:
|
||||
self._log_cookie_state("解析 Workspace 前")
|
||||
auth_cookie = self.session.cookies.get("oai-client-auth-session")
|
||||
if not auth_cookie:
|
||||
if log_missing:
|
||||
self._log("未能获取到授权 Cookie", "error")
|
||||
return None
|
||||
|
||||
try:
|
||||
workspace_id, source = self._decode_workspace_id_from_auth_cookie(auth_cookie)
|
||||
if not workspace_id:
|
||||
if log_missing:
|
||||
self._log(f"授权 Cookie 中未解析到 workspace: {source}", "warning")
|
||||
return None
|
||||
|
||||
self._log(f"Workspace ID 解析成功: {workspace_id} (source={source})")
|
||||
return workspace_id
|
||||
|
||||
except Exception as e:
|
||||
if log_missing:
|
||||
self._log(f"解析授权 Cookie 异常: {e}", "warning")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
self._log(f"获取 Workspace ID 失败: {e}", "error")
|
||||
return None
|
||||
|
||||
def _select_workspace(self, workspace_id: str) -> Optional[str]:
|
||||
"""选择 Workspace"""
|
||||
try:
|
||||
select_body = f'{{"workspace_id":"{workspace_id}"}}'
|
||||
|
||||
response = self.session.post(
|
||||
OPENAI_API_ENDPOINTS["select_workspace"],
|
||||
headers={
|
||||
"referer": "https://auth.openai.com/sign-in-with-chatgpt/codex/consent",
|
||||
"content-type": "application/json",
|
||||
},
|
||||
data=select_body,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
self._log(f"选择 workspace 失败: {response.status_code}", "error")
|
||||
self._log(f"响应: {response.text[:200]}", "warning")
|
||||
return None
|
||||
|
||||
continue_url = str((response.json() or {}).get("continue_url") or "").strip()
|
||||
if not continue_url:
|
||||
self._log("workspace/select 响应里缺少 continue_url", "error")
|
||||
return None
|
||||
|
||||
self._log(f"Continue URL: {continue_url[:100]}...")
|
||||
return continue_url
|
||||
|
||||
except Exception as e:
|
||||
self._log(f"选择 Workspace 失败: {e}", "error")
|
||||
return None
|
||||
|
||||
def _follow_redirects(self, start_url: str) -> Optional[str]:
|
||||
"""跟随重定向链,寻找回调 URL"""
|
||||
try:
|
||||
current_url = start_url
|
||||
max_redirects = 6
|
||||
|
||||
for i in range(max_redirects):
|
||||
self._log(f"重定向 {i+1}/{max_redirects}: {current_url[:100]}...")
|
||||
|
||||
response = self.session.get(
|
||||
current_url,
|
||||
allow_redirects=False,
|
||||
timeout=15
|
||||
)
|
||||
|
||||
location = response.headers.get("Location") or ""
|
||||
|
||||
# 如果不是重定向状态码,停止
|
||||
if response.status_code not in [301, 302, 303, 307, 308]:
|
||||
self._log(f"非重定向状态码: {response.status_code}")
|
||||
break
|
||||
|
||||
if not location:
|
||||
self._log("重定向响应缺少 Location 头")
|
||||
break
|
||||
|
||||
# 构建下一个 URL
|
||||
import urllib.parse
|
||||
next_url = urllib.parse.urljoin(current_url, location)
|
||||
|
||||
# 检查是否包含回调参数
|
||||
if "code=" in next_url and "state=" in next_url:
|
||||
self._log(f"找到回调 URL: {next_url[:100]}...")
|
||||
return next_url
|
||||
|
||||
current_url = next_url
|
||||
|
||||
self._log("未能在重定向链中找到回调 URL", "error")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
self._log(f"跟随重定向失败: {e}", "error")
|
||||
return None
|
||||
|
||||
def _handle_oauth_callback(self, callback_url: str) -> Optional[Dict[str, Any]]:
|
||||
"""处理 OAuth 回调"""
|
||||
try:
|
||||
if not self.oauth_start:
|
||||
self._log("OAuth 流程未初始化", "error")
|
||||
return None
|
||||
|
||||
self._log("处理 OAuth 回调...")
|
||||
token_info = self.oauth_manager.handle_callback(
|
||||
callback_url=callback_url,
|
||||
expected_state=self.oauth_start.state,
|
||||
code_verifier=self.oauth_start.code_verifier
|
||||
)
|
||||
|
||||
self._log("OAuth 授权成功")
|
||||
return token_info
|
||||
|
||||
except Exception as e:
|
||||
self._log(f"处理 OAuth 回调失败: {e}", "error")
|
||||
return None
|
||||
|
||||
def run(self) -> RegistrationResult:
|
||||
"""
|
||||
执行完整的注册流程
|
||||
|
||||
支持已注册账号自动登录:
|
||||
- 如果检测到邮箱已注册,自动切换到登录流程
|
||||
- 已注册账号跳过:设置密码、发送验证码、创建用户账户
|
||||
- 共用步骤:获取验证码、验证验证码、Workspace 和 OAuth 回调
|
||||
|
||||
Returns:
|
||||
RegistrationResult: 注册结果
|
||||
"""
|
||||
result = RegistrationResult(success=False, logs=self.logs)
|
||||
|
||||
try:
|
||||
self._log("=" * 60)
|
||||
self._log("开始注册流程")
|
||||
self._log("=" * 60)
|
||||
|
||||
# 1. 检查 IP 地理位置
|
||||
self._log("1. 检查 IP 地理位置...")
|
||||
ip_ok, location = self._check_ip_location()
|
||||
if not ip_ok:
|
||||
result.error_message = f"IP 地理位置不支持: {location}"
|
||||
self._log(f"IP 检查失败: {location}", "error")
|
||||
return result
|
||||
|
||||
self._log(f"IP 位置: {location}")
|
||||
|
||||
# 2. 创建邮箱
|
||||
self._log("2. 创建邮箱...")
|
||||
if not self._create_email():
|
||||
result.error_message = "创建邮箱失败"
|
||||
return result
|
||||
|
||||
result.email = self.email
|
||||
|
||||
# 3. 初始化会话
|
||||
self._log("3. 初始化会话...")
|
||||
if not self._init_session():
|
||||
result.error_message = "初始化会话失败"
|
||||
return result
|
||||
|
||||
# 4. 开始 OAuth 流程
|
||||
self._log("4. 开始 OAuth 授权流程...")
|
||||
if not self._start_oauth():
|
||||
result.error_message = "开始 OAuth 流程失败"
|
||||
return result
|
||||
|
||||
# 5. 获取 Device ID
|
||||
self._log("5. 获取 Device ID...")
|
||||
did = self._get_device_id()
|
||||
if not did:
|
||||
result.error_message = "获取 Device ID 失败"
|
||||
return result
|
||||
|
||||
# 6. 检查 Sentinel 拦截
|
||||
self._log("6. 检查 Sentinel 拦截...")
|
||||
sen_token = self._check_sentinel(did)
|
||||
if sen_token:
|
||||
self._log("Sentinel 检查通过")
|
||||
else:
|
||||
self._log("Sentinel 检查失败或未启用", "warning")
|
||||
|
||||
# 7. 提交注册表单 + 解析响应判断账号状态
|
||||
self._log("7. 提交注册表单...")
|
||||
signup_result = self._submit_signup_form(did, sen_token)
|
||||
if not signup_result.success:
|
||||
result.error_message = f"提交注册表单失败: {signup_result.error_message}"
|
||||
return result
|
||||
|
||||
# 8. [已注册账号跳过] 注册密码
|
||||
if self._is_existing_account:
|
||||
self._log("8. [已注册账号] 跳过密码设置,OTP 已自动发送")
|
||||
else:
|
||||
self._log("8. 注册密码...")
|
||||
password_ok, password = self._register_password()
|
||||
if not password_ok:
|
||||
result.error_message = "注册密码失败"
|
||||
return result
|
||||
|
||||
# 9. [已注册账号跳过] 发送验证码
|
||||
if self._is_existing_account:
|
||||
self._log("9. [已注册账号] 跳过发送验证码,使用自动发送的 OTP")
|
||||
# 已注册账号的 OTP 在提交表单时已自动发送,记录时间戳
|
||||
self._otp_sent_at = time.time()
|
||||
else:
|
||||
self._log("9. 发送验证码...")
|
||||
if not self._send_verification_code():
|
||||
result.error_message = "发送验证码失败"
|
||||
return result
|
||||
|
||||
# 10. 获取验证码
|
||||
self._log("10. 等待验证码...")
|
||||
code = self._get_verification_code()
|
||||
if not code:
|
||||
result.error_message = "获取验证码失败"
|
||||
return result
|
||||
|
||||
# 11. 验证验证码
|
||||
self._log("11. 验证验证码...")
|
||||
otp_result = self._validate_verification_code(code)
|
||||
if not otp_result.success:
|
||||
result.error_message = f"验证验证码失败: {otp_result.error_message}"
|
||||
return result
|
||||
|
||||
if otp_result.continue_url:
|
||||
self._try_upgrade_cookie_with_continue_url(
|
||||
otp_result.continue_url,
|
||||
"主流程 OTP 校验"
|
||||
)
|
||||
|
||||
# 12. [已注册账号跳过] 创建用户账户
|
||||
if self._is_existing_account:
|
||||
self._log("12. [已注册账号] 跳过创建用户账户")
|
||||
else:
|
||||
self._log("12. 创建用户账户...")
|
||||
if not self._create_user_account():
|
||||
result.error_message = "创建用户账户失败"
|
||||
return result
|
||||
|
||||
# 13. 获取 Workspace ID
|
||||
self._log("13. 获取 Workspace ID...")
|
||||
workspace_id = self._resolve_workspace_id()
|
||||
|
||||
if not workspace_id:
|
||||
result.error_message = "获取 Workspace ID 失败 (含降级补偿)"
|
||||
return result
|
||||
|
||||
result.workspace_id = workspace_id
|
||||
|
||||
# 14. 选择 Workspace
|
||||
self._log("14. 选择 Workspace...")
|
||||
continue_url = self._select_workspace(workspace_id)
|
||||
if not continue_url:
|
||||
result.error_message = "选择 Workspace 失败"
|
||||
return result
|
||||
|
||||
# 15. 跟随重定向链
|
||||
self._log("15. 跟随重定向链...")
|
||||
callback_url = self._follow_redirects(continue_url)
|
||||
if not callback_url:
|
||||
result.error_message = "跟随重定向链失败"
|
||||
return result
|
||||
|
||||
# 16. 处理 OAuth 回调
|
||||
self._log("16. 处理 OAuth 回调...")
|
||||
token_info = self._handle_oauth_callback(callback_url)
|
||||
if not token_info:
|
||||
result.error_message = "处理 OAuth 回调失败"
|
||||
return result
|
||||
|
||||
# 提取账户信息
|
||||
result.account_id = token_info.get("account_id", "")
|
||||
result.access_token = token_info.get("access_token", "")
|
||||
result.refresh_token = token_info.get("refresh_token", "")
|
||||
result.id_token = token_info.get("id_token", "")
|
||||
result.password = self.password or "" # 保存密码(已注册账号为空)
|
||||
|
||||
# 设置来源标记
|
||||
result.source = "login" if self._is_existing_account else "register"
|
||||
|
||||
# 尝试获取 session_token 从 cookie
|
||||
session_cookie = self.session.cookies.get("__Secure-next-auth.session-token")
|
||||
if session_cookie:
|
||||
self.session_token = session_cookie
|
||||
result.session_token = session_cookie
|
||||
self._log(f"获取到 Session Token")
|
||||
|
||||
# 17. 完成
|
||||
self._log("=" * 60)
|
||||
if self._is_existing_account:
|
||||
self._log("登录成功! (已注册账号)")
|
||||
else:
|
||||
self._log("注册成功!")
|
||||
self._log(f"邮箱: {result.email}")
|
||||
self._log(f"Account ID: {result.account_id}")
|
||||
self._log(f"Workspace ID: {result.workspace_id}")
|
||||
self._log("=" * 60)
|
||||
|
||||
result.success = True
|
||||
result.metadata = {
|
||||
"email_service": self.email_service.service_type.value,
|
||||
"proxy_used": self.proxy_url,
|
||||
"registered_at": datetime.now().isoformat(),
|
||||
"is_existing_account": self._is_existing_account,
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
self._log(f"注册过程中发生未预期错误: {e}", "error")
|
||||
result.error_message = str(e)
|
||||
return result
|
||||
finally:
|
||||
try:
|
||||
self._close_http_client()
|
||||
except Exception as e:
|
||||
logger.warning(f"关闭注册引擎 HTTP 客户端失败: {e}")
|
||||
|
||||
def save_to_database(self, result: RegistrationResult) -> bool:
|
||||
"""
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
数据库 CRUD 操作
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Dict, Any, Union
|
||||
from typing import List, Optional, Dict, Any, Union, Iterable, Set
|
||||
from datetime import datetime, timedelta
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import and_, or_, desc, asc, func
|
||||
@@ -464,9 +464,13 @@ def get_proxies(
|
||||
return query.all()
|
||||
|
||||
|
||||
def get_enabled_proxies(db: Session) -> List[Proxy]:
|
||||
def get_enabled_proxies(db: Session, exclude_ids: Optional[Iterable[int]] = None) -> List[Proxy]:
|
||||
"""获取所有启用的代理"""
|
||||
return db.query(Proxy).filter(Proxy.enabled == True).all()
|
||||
query = db.query(Proxy).filter(Proxy.enabled == True)
|
||||
excluded: Set[int] = {int(proxy_id) for proxy_id in (exclude_ids or [])}
|
||||
if excluded:
|
||||
query = query.filter(~Proxy.id.in_(excluded))
|
||||
return query.all()
|
||||
|
||||
|
||||
def update_proxy(
|
||||
@@ -517,14 +521,18 @@ def update_proxy_last_used(db: Session, proxy_id: int) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def get_random_proxy(db: Session) -> Optional[Proxy]:
|
||||
def get_random_proxy(db: Session, exclude_ids: Optional[Iterable[int]] = None) -> Optional[Proxy]:
|
||||
"""随机获取一个启用的代理,优先返回 is_default=True 的代理"""
|
||||
import random
|
||||
excluded: Set[int] = {int(proxy_id) for proxy_id in (exclude_ids or [])}
|
||||
# 优先返回默认代理
|
||||
default_proxy = db.query(Proxy).filter(Proxy.enabled == True, Proxy.is_default == True).first()
|
||||
default_query = db.query(Proxy).filter(Proxy.enabled == True, Proxy.is_default == True)
|
||||
if excluded:
|
||||
default_query = default_query.filter(~Proxy.id.in_(excluded))
|
||||
default_proxy = default_query.first()
|
||||
if default_proxy:
|
||||
return default_proxy
|
||||
proxies = get_enabled_proxies(db)
|
||||
proxies = get_enabled_proxies(db, exclude_ids=excluded)
|
||||
if not proxies:
|
||||
return None
|
||||
return random.choice(proxies)
|
||||
|
||||
@@ -6,6 +6,7 @@ import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
import random
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Dict, Tuple, Any
|
||||
|
||||
@@ -32,7 +33,16 @@ batch_tasks: Dict[str, dict] = {}
|
||||
|
||||
# ============== Proxy Helper Functions ==============
|
||||
|
||||
def get_proxy_for_registration(db) -> Tuple[Optional[str], Optional[int]]:
|
||||
RETRYABLE_PROXY_ERROR_PATTERN = re.compile(
|
||||
r"(?:curl(?:[^0-9]{0,8})?(35|56)\b|curl:\s*\((35|56)\))",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def get_proxy_for_registration(
|
||||
db,
|
||||
exclude_proxy_ids: Optional[List[int]] = None,
|
||||
) -> Tuple[Optional[str], Optional[int]]:
|
||||
"""
|
||||
获取用于注册的代理
|
||||
|
||||
@@ -45,7 +55,7 @@ def get_proxy_for_registration(db) -> Tuple[Optional[str], Optional[int]]:
|
||||
Tuple[proxy_url, proxy_id]: 代理 URL 和代理 ID(如果来自代理列表)
|
||||
"""
|
||||
# 先尝试从代理列表中获取
|
||||
proxy = crud.get_random_proxy(db)
|
||||
proxy = crud.get_random_proxy(db, exclude_ids=exclude_proxy_ids)
|
||||
if proxy:
|
||||
return proxy.proxy_url, proxy.id
|
||||
|
||||
@@ -64,6 +74,27 @@ def update_proxy_usage(db, proxy_id: Optional[int]):
|
||||
crud.update_proxy_last_used(db, proxy_id)
|
||||
|
||||
|
||||
def is_retryable_proxy_error(error_message: Optional[str]) -> bool:
|
||||
"""判断是否属于可通过切换代理自愈的 curl 网络错误。"""
|
||||
message = str(error_message or "").strip()
|
||||
if not message:
|
||||
return False
|
||||
return RETRYABLE_PROXY_ERROR_PATTERN.search(message) is not None
|
||||
|
||||
|
||||
def disable_proxy_for_network_error(db, proxy_id: Optional[int], reason: str) -> bool:
|
||||
"""将当前数据库代理标记为失效,避免后续再次被选中。"""
|
||||
if not proxy_id:
|
||||
return False
|
||||
|
||||
proxy = crud.update_proxy(db, proxy_id, enabled=False)
|
||||
if not proxy:
|
||||
return False
|
||||
|
||||
logger.warning(f"代理 {proxy_id} 因网络错误已自动禁用: {reason}")
|
||||
return True
|
||||
|
||||
|
||||
# ============== Pydantic Models ==============
|
||||
|
||||
class RegistrationTaskCreate(BaseModel):
|
||||
@@ -246,51 +277,44 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy:
|
||||
logger.error(f"任务不存在: {task_uuid}")
|
||||
return
|
||||
|
||||
# 确定使用的代理
|
||||
# 如果前端传入了代理参数,使用传入的
|
||||
# 否则从代理列表或系统设置中获取
|
||||
actual_proxy_url = proxy
|
||||
proxy_id = None
|
||||
resolved_email_service_id = email_service_id or task.email_service_id
|
||||
|
||||
if not actual_proxy_url:
|
||||
actual_proxy_url, proxy_id = get_proxy_for_registration(db)
|
||||
if actual_proxy_url:
|
||||
logger.info(f"任务 {task_uuid} 使用代理: {actual_proxy_url[:50]}...")
|
||||
|
||||
# 更新任务的代理记录
|
||||
crud.update_registration_task(db, task_uuid, proxy=actual_proxy_url)
|
||||
|
||||
# 创建邮箱服务
|
||||
service_type = EmailServiceType(email_service_type)
|
||||
# 更新 TaskManager 状态
|
||||
task_manager.update_status(task_uuid, "running")
|
||||
settings = get_settings()
|
||||
log_callback = task_manager.create_log_callback(task_uuid, prefix=log_prefix, batch_id=batch_id)
|
||||
|
||||
# 优先使用数据库中配置的邮箱服务
|
||||
if email_service_id:
|
||||
from ...database.models import EmailService as EmailServiceModel
|
||||
db_service = db.query(EmailServiceModel).filter(
|
||||
EmailServiceModel.id == email_service_id,
|
||||
EmailServiceModel.enabled == True
|
||||
).first()
|
||||
def build_email_service(active_proxy_url: Optional[str]):
|
||||
requested_service_type = EmailServiceType(email_service_type)
|
||||
|
||||
if db_service:
|
||||
service_type = EmailServiceType(db_service.service_type)
|
||||
config = _normalize_email_service_config(service_type, db_service.config, actual_proxy_url)
|
||||
# 更新任务关联的邮箱服务
|
||||
crud.update_registration_task(db, task_uuid, email_service_id=db_service.id)
|
||||
logger.info(f"使用数据库邮箱服务: {db_service.name} (ID: {db_service.id}, 类型: {service_type.value})")
|
||||
else:
|
||||
raise ValueError(f"邮箱服务不存在或已禁用: {email_service_id}")
|
||||
else:
|
||||
# 使用默认配置或传入的配置
|
||||
if resolved_email_service_id:
|
||||
from ...database.models import EmailService as EmailServiceModel
|
||||
db_service = db.query(EmailServiceModel).filter(
|
||||
EmailServiceModel.id == resolved_email_service_id,
|
||||
EmailServiceModel.enabled == True
|
||||
).first()
|
||||
|
||||
if db_service:
|
||||
selected_service_type = EmailServiceType(db_service.service_type)
|
||||
config = _normalize_email_service_config(selected_service_type, db_service.config, active_proxy_url)
|
||||
crud.update_registration_task(db, task_uuid, email_service_id=db_service.id)
|
||||
logger.info(
|
||||
f"使用数据库邮箱服务: {db_service.name} "
|
||||
f"(ID: {db_service.id}, 类型: {selected_service_type.value})"
|
||||
)
|
||||
email_service = EmailServiceFactory.create(selected_service_type, config)
|
||||
return email_service, selected_service_type
|
||||
raise ValueError(f"邮箱服务不存在或已禁用: {resolved_email_service_id}")
|
||||
|
||||
service_type = requested_service_type
|
||||
if service_type == EmailServiceType.TEMPMAIL:
|
||||
config = {
|
||||
"base_url": settings.tempmail_base_url,
|
||||
"timeout": settings.tempmail_timeout,
|
||||
"max_retries": settings.tempmail_max_retries,
|
||||
"proxy_url": actual_proxy_url,
|
||||
"proxy_url": active_proxy_url,
|
||||
}
|
||||
elif service_type == EmailServiceType.MOE_MAIL:
|
||||
# 检查数据库中是否有可用的自定义域名服务
|
||||
from ...database.models import EmailService as EmailServiceModel
|
||||
db_service = db.query(EmailServiceModel).filter(
|
||||
EmailServiceModel.service_type == "moe_mail",
|
||||
@@ -298,21 +322,19 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy:
|
||||
).order_by(EmailServiceModel.priority.asc()).first()
|
||||
|
||||
if db_service and db_service.config:
|
||||
config = _normalize_email_service_config(service_type, db_service.config, actual_proxy_url)
|
||||
config = _normalize_email_service_config(service_type, db_service.config, active_proxy_url)
|
||||
crud.update_registration_task(db, task_uuid, email_service_id=db_service.id)
|
||||
logger.info(f"使用数据库自定义域名服务: {db_service.name}")
|
||||
elif settings.custom_domain_base_url and settings.custom_domain_api_key:
|
||||
config = {
|
||||
"base_url": settings.custom_domain_base_url,
|
||||
"api_key": settings.custom_domain_api_key.get_secret_value() if settings.custom_domain_api_key else "",
|
||||
"proxy_url": actual_proxy_url,
|
||||
"proxy_url": active_proxy_url,
|
||||
}
|
||||
else:
|
||||
raise ValueError("没有可用的自定义域名邮箱服务,请先在设置中配置")
|
||||
elif service_type == EmailServiceType.OUTLOOK:
|
||||
# 检查数据库中是否有可用的 Outlook 账户
|
||||
from ...database.models import EmailService as EmailServiceModel, Account
|
||||
# 获取所有启用的 Outlook 服务
|
||||
outlook_services = db.query(EmailServiceModel).filter(
|
||||
EmailServiceModel.service_type == "outlook",
|
||||
EmailServiceModel.enabled == True
|
||||
@@ -327,14 +349,12 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy:
|
||||
email = svc.config.get("email") if svc.config else None
|
||||
if not email:
|
||||
continue
|
||||
# 检查是否已在 accounts 表中注册
|
||||
existing = db.query(Account).filter(Account.email == email).first()
|
||||
if not existing:
|
||||
selected_service = svc
|
||||
logger.info(f"选择未注册的 Outlook 账户: {email}")
|
||||
break
|
||||
else:
|
||||
logger.info(f"跳过已注册的 Outlook 账户: {email}")
|
||||
logger.info(f"跳过已注册的 Outlook 账户: {email}")
|
||||
|
||||
if selected_service and selected_service.config:
|
||||
config = selected_service.config.copy()
|
||||
@@ -352,7 +372,7 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy:
|
||||
).order_by(EmailServiceModel.priority.asc()).first()
|
||||
|
||||
if db_service and db_service.config:
|
||||
config = _normalize_email_service_config(service_type, db_service.config, actual_proxy_url)
|
||||
config = _normalize_email_service_config(service_type, db_service.config, active_proxy_url)
|
||||
crud.update_registration_task(db, task_uuid, email_service_id=db_service.id)
|
||||
logger.info(f"使用数据库 DuckMail 服务: {db_service.name}")
|
||||
else:
|
||||
@@ -366,7 +386,7 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy:
|
||||
).order_by(EmailServiceModel.priority.asc()).first()
|
||||
|
||||
if db_service and db_service.config:
|
||||
config = _normalize_email_service_config(service_type, db_service.config, actual_proxy_url)
|
||||
config = _normalize_email_service_config(service_type, db_service.config, active_proxy_url)
|
||||
crud.update_registration_task(db, task_uuid, email_service_id=db_service.id)
|
||||
logger.info(f"使用数据库 Freemail 服务: {db_service.name}")
|
||||
else:
|
||||
@@ -380,7 +400,7 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy:
|
||||
).order_by(EmailServiceModel.priority.asc()).first()
|
||||
|
||||
if db_service and db_service.config:
|
||||
config = _normalize_email_service_config(service_type, db_service.config, actual_proxy_url)
|
||||
config = _normalize_email_service_config(service_type, db_service.config, active_proxy_url)
|
||||
crud.update_registration_task(db, task_uuid, email_service_id=db_service.id)
|
||||
logger.info(f"使用数据库 IMAP 邮箱服务: {db_service.name}")
|
||||
else:
|
||||
@@ -388,23 +408,56 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy:
|
||||
else:
|
||||
config = email_service_config or {}
|
||||
|
||||
email_service = EmailServiceFactory.create(service_type, config)
|
||||
email_service = EmailServiceFactory.create(service_type, config)
|
||||
return email_service, service_type
|
||||
|
||||
# 在 WebSocket 状态里附带邮箱服务类型,前端可同步更新任务卡片
|
||||
task_manager.update_status(task_uuid, "running", email_service=service_type.value)
|
||||
requested_proxy = proxy
|
||||
exhausted_proxy_ids = set()
|
||||
result = None
|
||||
active_service_type = EmailServiceType(email_service_type)
|
||||
|
||||
# 创建注册引擎 - 使用 TaskManager 的日志回调
|
||||
log_callback = task_manager.create_log_callback(task_uuid, prefix=log_prefix, batch_id=batch_id)
|
||||
while True:
|
||||
actual_proxy_url = requested_proxy
|
||||
proxy_id = None
|
||||
|
||||
engine = LoginEngine(
|
||||
email_service=email_service,
|
||||
proxy_url=actual_proxy_url,
|
||||
callback_logger=log_callback,
|
||||
task_uuid=task_uuid
|
||||
)
|
||||
if not actual_proxy_url:
|
||||
actual_proxy_url, proxy_id = get_proxy_for_registration(
|
||||
db,
|
||||
exclude_proxy_ids=list(exhausted_proxy_ids),
|
||||
)
|
||||
if actual_proxy_url:
|
||||
logger.info(f"任务 {task_uuid} 使用代理: {actual_proxy_url[:50]}...")
|
||||
|
||||
# 执行注册
|
||||
result = engine.run()
|
||||
crud.update_registration_task(db, task_uuid, proxy=actual_proxy_url)
|
||||
email_service, active_service_type = build_email_service(actual_proxy_url)
|
||||
task_manager.update_status(task_uuid, "running", email_service=active_service_type.value)
|
||||
engine = LoginEngine(
|
||||
email_service=email_service,
|
||||
proxy_url=actual_proxy_url,
|
||||
callback_logger=log_callback,
|
||||
task_uuid=task_uuid
|
||||
)
|
||||
|
||||
result = engine.run()
|
||||
if result.success:
|
||||
break
|
||||
|
||||
if is_retryable_proxy_error(result.error_message):
|
||||
log_callback(f"[代理] 检测到可重试网络错误: {result.error_message}")
|
||||
if proxy_id and disable_proxy_for_network_error(db, proxy_id, result.error_message):
|
||||
exhausted_proxy_ids.add(proxy_id)
|
||||
log_callback(f"[代理] 当前代理已标记失效并从代理池移除: {proxy_id}")
|
||||
|
||||
next_proxy_url, next_proxy_id = get_proxy_for_registration(
|
||||
db,
|
||||
exclude_proxy_ids=list(exhausted_proxy_ids),
|
||||
)
|
||||
if next_proxy_url and (next_proxy_url != actual_proxy_url or next_proxy_id != proxy_id):
|
||||
requested_proxy = None
|
||||
log_callback(f"[代理] 切换到新代理后重试注册: {next_proxy_url[:50]}...")
|
||||
continue
|
||||
|
||||
break
|
||||
|
||||
if result.success:
|
||||
# 更新代理使用时间
|
||||
@@ -506,7 +559,7 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy:
|
||||
completed_at=datetime.utcnow(),
|
||||
result={
|
||||
**result.to_dict(),
|
||||
"email_service": service_type.value,
|
||||
"email_service": active_service_type.value,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -515,7 +568,7 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy:
|
||||
task_uuid,
|
||||
"completed",
|
||||
email=result.email,
|
||||
email_service=service_type.value,
|
||||
email_service=active_service_type.value,
|
||||
)
|
||||
|
||||
logger.info(f"注册任务完成: {task_uuid}, 邮箱: {result.email}")
|
||||
@@ -533,7 +586,7 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy:
|
||||
task_uuid,
|
||||
"failed",
|
||||
error=result.error_message,
|
||||
email_service=service_type.value,
|
||||
email_service=active_service_type.value,
|
||||
)
|
||||
|
||||
logger.warning(f"注册任务失败: {task_uuid}, 原因: {result.error_message}")
|
||||
|
||||
439
tests/test_register_fallback_flow.py
Normal file
439
tests/test_register_fallback_flow.py
Normal file
@@ -0,0 +1,439 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from src.config.constants import EmailServiceType, OPENAI_API_ENDPOINTS, OPENAI_PAGE_TYPES
|
||||
from src.core import register
|
||||
from src.services.base import BaseEmailService
|
||||
|
||||
|
||||
class DummyEmailService(BaseEmailService):
|
||||
def __init__(self):
|
||||
super().__init__(EmailServiceType.TEMPMAIL, name="dummy")
|
||||
|
||||
def create_email(self, config=None):
|
||||
return {"email": "tester@example.com", "service_id": "svc-1"}
|
||||
|
||||
def get_verification_code(self, email, email_id=None, timeout=120, pattern=None, otp_sent_at=None):
|
||||
return "123456"
|
||||
|
||||
def list_emails(self, **kwargs):
|
||||
return []
|
||||
|
||||
def delete_email(self, email_id):
|
||||
return True
|
||||
|
||||
def check_health(self):
|
||||
return True
|
||||
|
||||
def refresh_session(self):
|
||||
return None
|
||||
|
||||
|
||||
class FakeResponse:
|
||||
def __init__(self, status_code=200, payload=None, text=""):
|
||||
self.status_code = status_code
|
||||
self._payload = payload if payload is not None else {}
|
||||
self.text = text
|
||||
|
||||
def json(self):
|
||||
return self._payload
|
||||
|
||||
|
||||
class BrokenJSONResponse(FakeResponse):
|
||||
def json(self):
|
||||
raise ValueError("bad json")
|
||||
|
||||
|
||||
class FakeSession:
|
||||
def __init__(self, post_handler=None, get_handler=None, cookies=None):
|
||||
self.post_handler = post_handler
|
||||
self.get_handler = get_handler
|
||||
self.cookies = cookies or {}
|
||||
self.post_calls = []
|
||||
self.get_calls = []
|
||||
|
||||
def post(self, url, **kwargs):
|
||||
self.post_calls.append({"url": url, "kwargs": kwargs})
|
||||
if self.post_handler is None:
|
||||
raise AssertionError("unexpected post call")
|
||||
return self.post_handler(url, **kwargs)
|
||||
|
||||
def get(self, url, **kwargs):
|
||||
self.get_calls.append({"url": url, "kwargs": kwargs})
|
||||
if self.get_handler is None:
|
||||
raise AssertionError("unexpected get call")
|
||||
return self.get_handler(url, **kwargs)
|
||||
|
||||
|
||||
class DummyHTTPClient:
|
||||
def __init__(self, proxy_url=None):
|
||||
self.proxy_url = proxy_url
|
||||
self.session = FakeSession()
|
||||
self.closed = False
|
||||
|
||||
def close(self):
|
||||
self.closed = True
|
||||
|
||||
def post(self, url, **kwargs):
|
||||
raise AssertionError("unexpected http client post")
|
||||
|
||||
|
||||
class DummyOAuthManager:
|
||||
def __init__(self, **kwargs):
|
||||
self.kwargs = kwargs
|
||||
|
||||
def start_oauth(self):
|
||||
return SimpleNamespace(
|
||||
auth_url="https://auth.example/start",
|
||||
state="state-1",
|
||||
code_verifier="verifier-1",
|
||||
redirect_uri="http://localhost/callback",
|
||||
)
|
||||
|
||||
|
||||
def make_engine(monkeypatch, email_service=None):
|
||||
monkeypatch.setattr(
|
||||
register,
|
||||
"get_settings",
|
||||
lambda: SimpleNamespace(
|
||||
openai_client_id="client-id",
|
||||
openai_auth_url="https://auth.example/authorize",
|
||||
openai_token_url="https://auth.example/token",
|
||||
openai_redirect_uri="http://localhost/callback",
|
||||
openai_scope="openid email profile offline_access",
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(register, "OpenAIHTTPClient", DummyHTTPClient)
|
||||
monkeypatch.setattr(register, "OAuthManager", DummyOAuthManager)
|
||||
|
||||
engine = register.RegistrationEngine(email_service or DummyEmailService())
|
||||
engine.email = "tester@example.com"
|
||||
engine.email_info = {"email": "tester@example.com", "service_id": "svc-1"}
|
||||
return engine
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"page_type",
|
||||
[
|
||||
"login_password",
|
||||
OPENAI_PAGE_TYPES["EMAIL_OTP_VERIFICATION"],
|
||||
"consent_required",
|
||||
"some_other_page",
|
||||
],
|
||||
)
|
||||
def test_submit_login_form_accepts_any_http_200_page(monkeypatch, page_type):
|
||||
engine = make_engine(monkeypatch)
|
||||
engine.session = FakeSession(
|
||||
post_handler=lambda url, **kwargs: FakeResponse(
|
||||
status_code=200,
|
||||
payload={"page": {"type": page_type}},
|
||||
)
|
||||
)
|
||||
|
||||
result = engine._submit_login_form("did-1", "sen-1")
|
||||
|
||||
assert result.success is True
|
||||
assert result.page_type == page_type
|
||||
assert result.error_message == ""
|
||||
|
||||
|
||||
def test_submit_login_form_accepts_http_200_even_when_json_is_invalid(monkeypatch):
|
||||
engine = make_engine(monkeypatch)
|
||||
engine.session = FakeSession(
|
||||
post_handler=lambda url, **kwargs: BrokenJSONResponse(status_code=200)
|
||||
)
|
||||
|
||||
result = engine._submit_login_form("did-1", "sen-1")
|
||||
|
||||
assert result.success is True
|
||||
assert result.page_type == ""
|
||||
assert result.response_data == {}
|
||||
assert result.error_message == ""
|
||||
|
||||
|
||||
def test_send_passwordless_otp_posts_empty_body(monkeypatch):
|
||||
engine = make_engine(monkeypatch)
|
||||
engine.session = FakeSession(
|
||||
post_handler=lambda url, **kwargs: FakeResponse(status_code=200)
|
||||
)
|
||||
|
||||
success = engine._send_passwordless_otp()
|
||||
|
||||
assert success is True
|
||||
assert len(engine.session.post_calls) == 1
|
||||
call = engine.session.post_calls[0]
|
||||
assert call["url"] == OPENAI_API_ENDPOINTS["send_passwordless_otp"]
|
||||
assert call["kwargs"]["data"] == ""
|
||||
assert engine._otp_sent_at is not None
|
||||
|
||||
|
||||
def test_send_passwordless_otp_does_not_update_timestamp_on_failure(monkeypatch):
|
||||
engine = make_engine(monkeypatch)
|
||||
engine._otp_sent_at = 1234.5
|
||||
engine.session = FakeSession(
|
||||
post_handler=lambda url, **kwargs: FakeResponse(status_code=500, text="server error")
|
||||
)
|
||||
|
||||
success = engine._send_passwordless_otp()
|
||||
|
||||
assert success is False
|
||||
assert engine._otp_sent_at == 1234.5
|
||||
|
||||
|
||||
def test_get_verification_code_passes_explicit_otp_timestamp(monkeypatch):
|
||||
captured = {}
|
||||
|
||||
class RecordingEmailService(DummyEmailService):
|
||||
def get_verification_code(self, email, email_id=None, timeout=120, pattern=None, otp_sent_at=None):
|
||||
captured["email"] = email
|
||||
captured["email_id"] = email_id
|
||||
captured["timeout"] = timeout
|
||||
captured["pattern"] = pattern
|
||||
captured["otp_sent_at"] = otp_sent_at
|
||||
return "654321"
|
||||
|
||||
engine = make_engine(monkeypatch, email_service=RecordingEmailService())
|
||||
|
||||
code = engine._get_verification_code(otp_sent_at=1234.5)
|
||||
|
||||
assert code == "654321"
|
||||
assert captured["email"] == "tester@example.com"
|
||||
assert captured["email_id"] == "svc-1"
|
||||
assert captured["timeout"] == 120
|
||||
assert captured["otp_sent_at"] == 1234.5
|
||||
|
||||
|
||||
def test_validate_verification_code_accepts_http_200_even_when_json_is_invalid(monkeypatch):
|
||||
engine = make_engine(monkeypatch)
|
||||
engine.session = FakeSession(
|
||||
post_handler=lambda url, **kwargs: BrokenJSONResponse(status_code=200)
|
||||
)
|
||||
|
||||
result = engine._validate_verification_code("123456")
|
||||
|
||||
assert result.success is True
|
||||
assert result.continue_url == ""
|
||||
assert result.response_data == {}
|
||||
|
||||
|
||||
def test_run_closes_http_client_on_early_failure(monkeypatch):
|
||||
engine = make_engine(monkeypatch)
|
||||
tracking_client = DummyHTTPClient()
|
||||
engine.http_client = tracking_client
|
||||
|
||||
monkeypatch.setattr(engine, "_check_ip_location", lambda: (False, None))
|
||||
|
||||
result = engine.run()
|
||||
|
||||
assert result.success is False
|
||||
assert tracking_client.closed is True
|
||||
assert engine.session is None
|
||||
|
||||
|
||||
def test_fallback_to_login_flow_forces_otp_and_continue_url(monkeypatch):
|
||||
engine = make_engine(monkeypatch)
|
||||
steps = []
|
||||
captured = {}
|
||||
|
||||
monkeypatch.setattr(engine, "_reset_oauth_session", lambda: steps.append("reset_session") or True)
|
||||
monkeypatch.setattr(engine, "_get_device_id", lambda: steps.append("get_device_id") or "did-1")
|
||||
monkeypatch.setattr(engine, "_check_sentinel", lambda did: steps.append("check_sentinel") or "sen-1")
|
||||
monkeypatch.setattr(
|
||||
engine,
|
||||
"_submit_login_form",
|
||||
lambda did, sen: steps.append("submit_login_form")
|
||||
or register.SignupFormResult(success=True, page_type="login_password"),
|
||||
)
|
||||
|
||||
def fake_send_passwordless_otp():
|
||||
steps.append("send_passwordless_otp")
|
||||
engine._otp_sent_at = 4567.89
|
||||
return True
|
||||
|
||||
def fake_get_verification_code(otp_sent_at=None):
|
||||
steps.append("get_verification_code")
|
||||
captured["otp_sent_at"] = otp_sent_at
|
||||
return "123456"
|
||||
|
||||
monkeypatch.setattr(engine, "_send_passwordless_otp", fake_send_passwordless_otp)
|
||||
monkeypatch.setattr(engine, "_get_verification_code", fake_get_verification_code)
|
||||
monkeypatch.setattr(
|
||||
engine,
|
||||
"_validate_verification_code",
|
||||
lambda code: steps.append("validate_verification_code")
|
||||
or register.OTPValidationResult(success=True, continue_url="https://auth.example/continue"),
|
||||
)
|
||||
|
||||
def fake_try_upgrade(continue_url, stage):
|
||||
steps.append("get_continue_url_and_parse_workspace")
|
||||
captured["continue_url"] = continue_url
|
||||
captured["stage"] = stage
|
||||
return "ws-123"
|
||||
|
||||
monkeypatch.setattr(engine, "_try_upgrade_cookie_with_continue_url", fake_try_upgrade)
|
||||
|
||||
workspace_id = engine._fallback_to_login_flow()
|
||||
|
||||
assert workspace_id == "ws-123"
|
||||
assert steps == [
|
||||
"reset_session",
|
||||
"get_device_id",
|
||||
"check_sentinel",
|
||||
"submit_login_form",
|
||||
"send_passwordless_otp",
|
||||
"get_verification_code",
|
||||
"validate_verification_code",
|
||||
"get_continue_url_and_parse_workspace",
|
||||
]
|
||||
assert captured["continue_url"] == "https://auth.example/continue"
|
||||
assert captured["stage"] == "降级登录 Continue URL"
|
||||
assert captured["otp_sent_at"] == 4567.89
|
||||
|
||||
|
||||
def test_fallback_to_login_flow_requires_continue_url(monkeypatch):
|
||||
engine = make_engine(monkeypatch)
|
||||
|
||||
monkeypatch.setattr(engine, "_reset_oauth_session", lambda: True)
|
||||
monkeypatch.setattr(engine, "_get_device_id", lambda: "did-1")
|
||||
monkeypatch.setattr(engine, "_check_sentinel", lambda did: "sen-1")
|
||||
monkeypatch.setattr(
|
||||
engine,
|
||||
"_submit_login_form",
|
||||
lambda did, sen: register.SignupFormResult(success=True, page_type="login_password"),
|
||||
)
|
||||
|
||||
def fake_send_passwordless_otp():
|
||||
engine._otp_sent_at = 9876.5
|
||||
return True
|
||||
|
||||
monkeypatch.setattr(engine, "_send_passwordless_otp", fake_send_passwordless_otp)
|
||||
monkeypatch.setattr(engine, "_get_verification_code", lambda otp_sent_at=None: "123456")
|
||||
monkeypatch.setattr(
|
||||
engine,
|
||||
"_validate_verification_code",
|
||||
lambda code: register.OTPValidationResult(success=True, continue_url=""),
|
||||
)
|
||||
|
||||
def fail_if_called(*args, **kwargs):
|
||||
raise AssertionError("continue_url 缺失时不应尝试升级 Cookie")
|
||||
|
||||
monkeypatch.setattr(engine, "_try_upgrade_cookie_with_continue_url", fail_if_called)
|
||||
|
||||
assert engine._fallback_to_login_flow() is None
|
||||
|
||||
|
||||
def test_fallback_to_login_flow_accepts_workspace_without_continue_url(monkeypatch):
|
||||
engine = make_engine(monkeypatch)
|
||||
steps = []
|
||||
|
||||
monkeypatch.setattr(engine, "_reset_oauth_session", lambda: steps.append("reset_session") or True)
|
||||
monkeypatch.setattr(engine, "_get_device_id", lambda: steps.append("get_device_id") or "did-1")
|
||||
monkeypatch.setattr(engine, "_check_sentinel", lambda did: steps.append("check_sentinel") or "sen-1")
|
||||
monkeypatch.setattr(
|
||||
engine,
|
||||
"_submit_login_form",
|
||||
lambda did, sen: steps.append("submit_login_form")
|
||||
or register.SignupFormResult(success=True, page_type="login_password"),
|
||||
)
|
||||
|
||||
def fake_send_passwordless_otp():
|
||||
steps.append("send_passwordless_otp")
|
||||
engine._otp_sent_at = 4567.89
|
||||
return True
|
||||
|
||||
monkeypatch.setattr(engine, "_send_passwordless_otp", fake_send_passwordless_otp)
|
||||
monkeypatch.setattr(
|
||||
engine,
|
||||
"_get_verification_code",
|
||||
lambda otp_sent_at=None: steps.append("get_verification_code") or "123456",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
engine,
|
||||
"_validate_verification_code",
|
||||
lambda code: steps.append("validate_verification_code")
|
||||
or register.OTPValidationResult(success=True, continue_url=""),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
engine,
|
||||
"_get_workspace_id",
|
||||
lambda log_missing=True: steps.append("get_workspace_id") or "ws-cookie",
|
||||
)
|
||||
|
||||
def fail_if_called(*args, **kwargs):
|
||||
raise AssertionError("已有 workspace 时不应继续访问 continue_url")
|
||||
|
||||
monkeypatch.setattr(engine, "_try_upgrade_cookie_with_continue_url", fail_if_called)
|
||||
|
||||
workspace_id = engine._fallback_to_login_flow()
|
||||
|
||||
assert workspace_id == "ws-cookie"
|
||||
assert steps == [
|
||||
"reset_session",
|
||||
"get_device_id",
|
||||
"check_sentinel",
|
||||
"submit_login_form",
|
||||
"send_passwordless_otp",
|
||||
"get_verification_code",
|
||||
"validate_verification_code",
|
||||
"get_workspace_id",
|
||||
]
|
||||
|
||||
|
||||
def test_get_verification_code_uses_provider_timeout_and_refreshes_once(monkeypatch):
|
||||
captured = {"calls": [], "refresh_count": 0}
|
||||
|
||||
class RefreshableOutlookService(DummyEmailService):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.service_type = EmailServiceType.OUTLOOK
|
||||
|
||||
def get_verification_code(self, email, email_id=None, timeout=120, pattern=None, otp_sent_at=None):
|
||||
captured["calls"].append(
|
||||
{
|
||||
"timeout": timeout,
|
||||
"otp_sent_at": otp_sent_at,
|
||||
"email": email,
|
||||
"email_id": email_id,
|
||||
}
|
||||
)
|
||||
if len(captured["calls"]) == 1:
|
||||
return None
|
||||
return "987654"
|
||||
|
||||
def refresh_session(self):
|
||||
captured["refresh_count"] += 1
|
||||
|
||||
engine = make_engine(monkeypatch, email_service=RefreshableOutlookService())
|
||||
|
||||
code = engine._get_verification_code(otp_sent_at=2468.0)
|
||||
|
||||
assert code == "987654"
|
||||
assert captured["refresh_count"] == 1
|
||||
assert len(captured["calls"]) == 2
|
||||
assert captured["calls"][0]["timeout"] == 180
|
||||
assert captured["calls"][1]["timeout"] == 180
|
||||
assert captured["calls"][0]["otp_sent_at"] == 2468.0
|
||||
|
||||
|
||||
def test_try_upgrade_cookie_with_continue_url_retries_with_second_probe(monkeypatch):
|
||||
engine = make_engine(monkeypatch)
|
||||
sleep_calls = []
|
||||
workspace_results = iter([None, None, None, None, None, "ws-delayed"])
|
||||
engine.session = FakeSession(
|
||||
get_handler=lambda url, **kwargs: FakeResponse(status_code=302),
|
||||
cookies={},
|
||||
)
|
||||
|
||||
monkeypatch.setattr(engine, "_log_cookie_state", lambda *args, **kwargs: None)
|
||||
monkeypatch.setattr(engine, "_get_workspace_id", lambda log_missing=False: next(workspace_results))
|
||||
monkeypatch.setattr(register.time, "sleep", lambda seconds: sleep_calls.append(seconds))
|
||||
|
||||
workspace_id = engine._try_upgrade_cookie_with_continue_url(
|
||||
"https://auth.example/continue",
|
||||
"降级登录 Continue URL",
|
||||
)
|
||||
|
||||
assert workspace_id == "ws-delayed"
|
||||
assert len(engine.session.get_calls) == 3
|
||||
assert sleep_calls == [1.0, 2.0, 4.0]
|
||||
102
tests/test_registration_proxy_failover.py
Normal file
102
tests/test_registration_proxy_failover.py
Normal file
@@ -0,0 +1,102 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
from src.database import crud
|
||||
from src.database.session import DatabaseSessionManager
|
||||
from src.web.routes import registration
|
||||
from src.core.register import RegistrationResult
|
||||
|
||||
|
||||
def test_run_sync_registration_task_disables_bad_proxy_and_retries(monkeypatch, tmp_path):
|
||||
manager = DatabaseSessionManager(f"sqlite:///{tmp_path}/test.db")
|
||||
manager.create_tables()
|
||||
manager.migrate_tables()
|
||||
|
||||
with manager.session_scope() as session:
|
||||
primary_proxy = crud.create_proxy(
|
||||
session,
|
||||
name="primary",
|
||||
type="http",
|
||||
host="127.0.0.1",
|
||||
port=8001,
|
||||
)
|
||||
crud.update_proxy(session, primary_proxy.id, is_default=True)
|
||||
backup_proxy = crud.create_proxy(
|
||||
session,
|
||||
name="backup",
|
||||
type="http",
|
||||
host="127.0.0.1",
|
||||
port=8002,
|
||||
)
|
||||
email_service = crud.create_email_service(
|
||||
session,
|
||||
service_type="tempmail",
|
||||
name="tempmail-db",
|
||||
config={"base_url": "https://mail.example/api"},
|
||||
)
|
||||
crud.create_registration_task(session, task_uuid="task-proxy-failover")
|
||||
primary_proxy_id = primary_proxy.id
|
||||
backup_proxy_id = backup_proxy.id
|
||||
email_service_id = email_service.id
|
||||
|
||||
monkeypatch.setattr(registration, "get_db", manager.session_scope)
|
||||
monkeypatch.setattr(
|
||||
registration,
|
||||
"EmailServiceFactory",
|
||||
SimpleNamespace(create=lambda service_type, config: SimpleNamespace(service_type=service_type, config=config)),
|
||||
)
|
||||
|
||||
attempted_proxies = []
|
||||
saved_results = []
|
||||
|
||||
class FakeRegistrationEngine:
|
||||
def __init__(self, email_service, proxy_url=None, callback_logger=None, task_uuid=None):
|
||||
self.proxy_url = proxy_url
|
||||
|
||||
def run(self):
|
||||
attempted_proxies.append(self.proxy_url)
|
||||
if self.proxy_url.endswith(":8001"):
|
||||
return RegistrationResult(
|
||||
success=False,
|
||||
email="proxy@example.com",
|
||||
error_message="OpenAI 请求失败: curl: (35) TLS handshake failed",
|
||||
)
|
||||
|
||||
return RegistrationResult(
|
||||
success=True,
|
||||
email="proxy@example.com",
|
||||
access_token="access-token",
|
||||
workspace_id="ws-123",
|
||||
)
|
||||
|
||||
def save_to_database(self, result):
|
||||
saved_results.append(result.email)
|
||||
return True
|
||||
|
||||
monkeypatch.setattr(registration, "RegistrationEngine", FakeRegistrationEngine)
|
||||
|
||||
registration._run_sync_registration_task(
|
||||
task_uuid="task-proxy-failover",
|
||||
email_service_type="tempmail",
|
||||
proxy=None,
|
||||
email_service_config=None,
|
||||
email_service_id=email_service_id,
|
||||
)
|
||||
|
||||
assert attempted_proxies == [
|
||||
"http://127.0.0.1:8001",
|
||||
"http://127.0.0.1:8002",
|
||||
]
|
||||
assert saved_results == ["proxy@example.com"]
|
||||
|
||||
with manager.session_scope() as session:
|
||||
disabled_primary = crud.get_proxy_by_id(session, primary_proxy_id)
|
||||
active_backup = crud.get_proxy_by_id(session, backup_proxy_id)
|
||||
task = crud.get_registration_task_by_uuid(session, "task-proxy-failover")
|
||||
|
||||
assert disabled_primary is not None
|
||||
assert disabled_primary.enabled is False
|
||||
assert active_backup is not None
|
||||
assert active_backup.enabled is True
|
||||
assert task is not None
|
||||
assert task.status == "completed"
|
||||
assert task.proxy == "http://127.0.0.1:8002"
|
||||
Reference in New Issue
Block a user