mirror of
https://github.com/cnlimiter/codex-register.git
synced 2026-05-06 20:02:51 +08:00
feat(email): implement cancellation handling in email services
This commit is contained in:
@@ -31,6 +31,7 @@ from ..database import crud
|
||||
from ..database.session import get_db
|
||||
from ..services import BaseEmailService
|
||||
from ..services.base import (
|
||||
EmailServiceCancelledError,
|
||||
EmailProviderBackoffState,
|
||||
OTP_NO_OPENAI_SENDER_ERROR,
|
||||
OTPNoOpenAISenderEmailServiceError,
|
||||
@@ -43,6 +44,15 @@ PHASE_EMAIL_PREPARE = "email_prepare"
|
||||
PHASE_OTP_SECONDARY = "otp_secondary"
|
||||
ERROR_EMAIL_PROVIDER_RATE_LIMITED = "EMAIL_PROVIDER_RATE_LIMITED"
|
||||
ERROR_OTP_TIMEOUT_SECONDARY = "OTP_TIMEOUT_SECONDARY"
|
||||
ERROR_TASK_CANCELLED = "TASK_CANCELLED"
|
||||
|
||||
|
||||
class TaskCancelledError(Exception):
|
||||
"""注册任务被主动取消。"""
|
||||
|
||||
def __init__(self, message: str = "任务已取消"):
|
||||
super().__init__(message)
|
||||
self.error_code = ERROR_TASK_CANCELLED
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -185,6 +195,8 @@ class RegistrationEngine:
|
||||
self._otp_sent_at: Optional[float] = None # OTP 发送时间戳
|
||||
self._is_existing_account: bool = False # 是否为已注册账号(用于自动登录)
|
||||
self.phase_history: list[PhaseResult] = []
|
||||
self.check_cancelled: Optional[Callable[[], bool]] = None
|
||||
self._cancel_logged = False
|
||||
|
||||
def _log(self, message: str, level: str = "info"):
|
||||
"""记录日志"""
|
||||
@@ -274,6 +286,35 @@ class RegistrationEngine:
|
||||
self.phase_history.append(phase_result)
|
||||
return phase_result
|
||||
|
||||
def _is_cancelled_requested(self) -> bool:
|
||||
"""检查是否收到外部取消信号。"""
|
||||
callback = self.check_cancelled
|
||||
if not callable(callback):
|
||||
return False
|
||||
try:
|
||||
return bool(callback())
|
||||
except Exception as e:
|
||||
logger.warning(f"检查任务取消状态失败: {e}")
|
||||
return False
|
||||
|
||||
def _raise_if_cancelled(self, message: str = "任务已取消"):
|
||||
"""在可协作阶段检查取消请求。"""
|
||||
if not self._is_cancelled_requested():
|
||||
return
|
||||
if not self._cancel_logged:
|
||||
self._log(message, "warning")
|
||||
self._cancel_logged = True
|
||||
raise TaskCancelledError(message)
|
||||
|
||||
def _sleep_with_cancel(self, seconds: float, chunk_seconds: float = 0.2):
|
||||
"""可响应取消的短分片休眠。"""
|
||||
remaining = max(0.0, float(seconds))
|
||||
while remaining > 0:
|
||||
self._raise_if_cancelled()
|
||||
sleep_for = min(chunk_seconds, remaining)
|
||||
time.sleep(sleep_for)
|
||||
remaining -= sleep_for
|
||||
|
||||
def _get_phase_result(self, phase_name: str) -> Optional[PhaseResult]:
|
||||
for phase_result in reversed(self.phase_history):
|
||||
if phase_result.phase == phase_name:
|
||||
@@ -427,7 +468,7 @@ class RegistrationEngine:
|
||||
)
|
||||
|
||||
if attempt < max_attempts:
|
||||
time.sleep(attempt)
|
||||
self._sleep_with_cancel(attempt)
|
||||
self.http_client.close()
|
||||
self.session = self.http_client.session
|
||||
|
||||
@@ -656,6 +697,7 @@ class RegistrationEngine:
|
||||
otp_phase: Optional[PhaseResult] = None
|
||||
|
||||
while True:
|
||||
self._raise_if_cancelled("等待验证码重试时任务已取消")
|
||||
code, otp_phase = self._phase_otp_secondary(
|
||||
PhaseContext(otp_sent_at=self._otp_sent_at),
|
||||
started_at=otp_phase_started_at,
|
||||
@@ -698,6 +740,7 @@ class RegistrationEngine:
|
||||
**emit_kwargs,
|
||||
)
|
||||
|
||||
self._raise_if_cancelled("等待验证码重试时任务已取消")
|
||||
if not resend_callback():
|
||||
self._log("重新发送验证码失败,跳过本次重试", "warning")
|
||||
|
||||
@@ -712,76 +755,79 @@ class RegistrationEngine:
|
||||
) -> Tuple[Optional[str], PhaseResult]:
|
||||
"""等待二次验证码邮件并做超时归因。"""
|
||||
try:
|
||||
self._raise_if_cancelled("等待验证码时任务已取消")
|
||||
self._log(f"正在等待邮箱 {self.email} 的验证码...")
|
||||
|
||||
email_id = self.email_info.get("service_id") if self.email_info else None
|
||||
settings = get_settings()
|
||||
budget = Budget(
|
||||
timeout_seconds=get_settings().email_code_timeout,
|
||||
timeout_seconds=settings.email_code_timeout,
|
||||
started_at=started_at if started_at is not None else time.time(),
|
||||
)
|
||||
remaining_timeout = budget.remaining_seconds()
|
||||
poll_interval = max(1, int(settings.email_code_poll_interval or 1))
|
||||
|
||||
if remaining_timeout <= 0:
|
||||
while True:
|
||||
self._raise_if_cancelled("等待验证码时任务已取消")
|
||||
remaining_timeout = budget.remaining_seconds()
|
||||
|
||||
if remaining_timeout <= 0:
|
||||
phase_result = self._record_phase_result(
|
||||
PhaseResult(
|
||||
phase=PHASE_OTP_SECONDARY,
|
||||
success=False,
|
||||
error_message="等待验证码超时",
|
||||
error_code=ERROR_OTP_TIMEOUT_SECONDARY,
|
||||
retryable=True,
|
||||
next_action="await_email",
|
||||
metadata={
|
||||
"budget_started_at": budget.started_at,
|
||||
"budget_timeout_seconds": budget.timeout_seconds,
|
||||
"otp_sent_at": context.otp_sent_at,
|
||||
},
|
||||
)
|
||||
)
|
||||
self._log(phase_result.error_message, "error")
|
||||
return None, phase_result
|
||||
|
||||
attempt_timeout = max(1, min(remaining_timeout, poll_interval))
|
||||
code = self.email_service.get_verification_code(
|
||||
email=self.email,
|
||||
email_id=email_id,
|
||||
timeout=attempt_timeout,
|
||||
pattern=OTP_CODE_PATTERN,
|
||||
otp_sent_at=context.otp_sent_at,
|
||||
)
|
||||
|
||||
if code:
|
||||
self._log(f"成功获取验证码: {code}")
|
||||
phase_result = self._record_phase_result(
|
||||
PhaseResult(
|
||||
phase=PHASE_OTP_SECONDARY,
|
||||
success=True,
|
||||
metadata={
|
||||
"budget_started_at": budget.started_at,
|
||||
"budget_timeout_seconds": budget.timeout_seconds,
|
||||
"otp_sent_at": context.otp_sent_at,
|
||||
},
|
||||
)
|
||||
)
|
||||
return code, phase_result
|
||||
|
||||
except Exception as e:
|
||||
if isinstance(e, (TaskCancelledError, EmailServiceCancelledError)):
|
||||
phase_result = self._record_phase_result(
|
||||
PhaseResult(
|
||||
phase=PHASE_OTP_SECONDARY,
|
||||
success=False,
|
||||
error_message="等待验证码超时",
|
||||
error_code=ERROR_OTP_TIMEOUT_SECONDARY,
|
||||
retryable=True,
|
||||
next_action="await_email",
|
||||
metadata={
|
||||
"budget_started_at": budget.started_at,
|
||||
"budget_timeout_seconds": budget.timeout_seconds,
|
||||
"otp_sent_at": context.otp_sent_at,
|
||||
},
|
||||
error_message=str(e),
|
||||
error_code=getattr(e, "error_code", ERROR_TASK_CANCELLED),
|
||||
retryable=False,
|
||||
next_action="cancelled",
|
||||
metadata={"otp_sent_at": context.otp_sent_at},
|
||||
)
|
||||
)
|
||||
self._log(phase_result.error_message, "error")
|
||||
return None, phase_result
|
||||
|
||||
code = self.email_service.get_verification_code(
|
||||
email=self.email,
|
||||
email_id=email_id,
|
||||
timeout=remaining_timeout,
|
||||
pattern=OTP_CODE_PATTERN,
|
||||
otp_sent_at=context.otp_sent_at,
|
||||
)
|
||||
|
||||
if code:
|
||||
self._log(f"成功获取验证码: {code}")
|
||||
phase_result = self._record_phase_result(
|
||||
PhaseResult(
|
||||
phase=PHASE_OTP_SECONDARY,
|
||||
success=True,
|
||||
metadata={
|
||||
"budget_started_at": budget.started_at,
|
||||
"budget_timeout_seconds": budget.timeout_seconds,
|
||||
"otp_sent_at": context.otp_sent_at,
|
||||
},
|
||||
)
|
||||
)
|
||||
return code, phase_result
|
||||
|
||||
phase_result = self._record_phase_result(
|
||||
PhaseResult(
|
||||
phase=PHASE_OTP_SECONDARY,
|
||||
success=False,
|
||||
error_message="等待验证码超时",
|
||||
error_code=ERROR_OTP_TIMEOUT_SECONDARY,
|
||||
retryable=True,
|
||||
next_action="await_email",
|
||||
metadata={
|
||||
"budget_started_at": budget.started_at,
|
||||
"budget_timeout_seconds": budget.timeout_seconds,
|
||||
"otp_sent_at": context.otp_sent_at,
|
||||
},
|
||||
)
|
||||
)
|
||||
self._log(phase_result.error_message, "error")
|
||||
return None, phase_result
|
||||
|
||||
except Exception as e:
|
||||
if isinstance(e, OTPNoOpenAISenderEmailServiceError):
|
||||
self._log(str(e), "warning")
|
||||
phase_result = self._record_phase_result(
|
||||
@@ -1404,6 +1450,8 @@ class RegistrationEngine:
|
||||
non_openai_retry_status_template="重新触发登录验证码(非 OpenAI 发件人,第 {attempt} 次)",
|
||||
)
|
||||
if not code:
|
||||
if otp_phase and otp_phase.error_code == ERROR_TASK_CANCELLED:
|
||||
raise TaskCancelledError(otp_phase.error_message or "登录流程已取消")
|
||||
self._log(
|
||||
otp_phase.error_message if otp_phase and otp_phase.error_message else "登录流程获取验证码失败",
|
||||
"warning",
|
||||
@@ -1539,11 +1587,13 @@ class RegistrationEngine:
|
||||
result = RegistrationResult(success=False, logs=self.logs)
|
||||
|
||||
try:
|
||||
self._raise_if_cancelled()
|
||||
self._log("=" * 60)
|
||||
self._log("开始注册流程")
|
||||
self._log("=" * 60)
|
||||
|
||||
# 1. 检查 IP 地理位置
|
||||
self._raise_if_cancelled()
|
||||
self._log("1. 检查 IP 地理位置...")
|
||||
self._emit_status("ip_check", "检查 IP 地理位置", step_index=1)
|
||||
ip_ok, location = self._check_ip_location()
|
||||
@@ -1555,6 +1605,7 @@ class RegistrationEngine:
|
||||
self._log(f"IP 位置: {location}")
|
||||
|
||||
# 2. 创建邮箱
|
||||
self._raise_if_cancelled()
|
||||
self._log("2. 创建邮箱...")
|
||||
self._emit_status("email_prepare", "创建邮箱地址", step_index=2)
|
||||
if not self._phase_email_prepare():
|
||||
@@ -1570,6 +1621,7 @@ class RegistrationEngine:
|
||||
result.email = self.email
|
||||
|
||||
# 3. 初始化会话
|
||||
self._raise_if_cancelled()
|
||||
self._log("3. 初始化会话...")
|
||||
self._emit_status("session_init", "初始化 HTTP 会话", step_index=3)
|
||||
if not self._init_session():
|
||||
@@ -1577,6 +1629,7 @@ class RegistrationEngine:
|
||||
return result
|
||||
|
||||
# 4. 开始 OAuth 流程
|
||||
self._raise_if_cancelled()
|
||||
self._log("4. 开始 OAuth 授权流程...")
|
||||
self._emit_status("oauth_start", "开始 OAuth 授权流程", step_index=4)
|
||||
if not self._start_oauth():
|
||||
@@ -1584,6 +1637,7 @@ class RegistrationEngine:
|
||||
return result
|
||||
|
||||
# 5. 获取 Device ID
|
||||
self._raise_if_cancelled()
|
||||
self._log("5. 获取 Device ID...")
|
||||
self._emit_status("oauth_device_id", "获取 Device ID", step_index=5)
|
||||
did = self._get_device_id()
|
||||
@@ -1592,6 +1646,7 @@ class RegistrationEngine:
|
||||
return result
|
||||
|
||||
# 6. 检查 Sentinel 拦截
|
||||
self._raise_if_cancelled()
|
||||
self._log("6. 检查 Sentinel 拦截...")
|
||||
self._emit_status("sentinel", "检查 Sentinel 拦截", step_index=6)
|
||||
sen_token = self._check_sentinel(did)
|
||||
@@ -1601,6 +1656,7 @@ class RegistrationEngine:
|
||||
self._log("Sentinel 检查失败或未启用", "warning")
|
||||
|
||||
# 7. 提交注册表单 + 解析响应判断账号状态
|
||||
self._raise_if_cancelled()
|
||||
self._log("7. 提交注册表单...")
|
||||
self._emit_status("signup_submit", "提交注册表单", step_index=7)
|
||||
signup_result = self._submit_signup_form(did, sen_token)
|
||||
@@ -1612,6 +1668,7 @@ class RegistrationEngine:
|
||||
if self._is_existing_account:
|
||||
self._log("8. [已注册账号] 跳过密码设置,OTP 已自动发送")
|
||||
else:
|
||||
self._raise_if_cancelled()
|
||||
self._log("8. 注册密码...")
|
||||
self._emit_status("signup_password", "提交注册密码", step_index=8)
|
||||
password_ok, password = self._register_password()
|
||||
@@ -1625,6 +1682,7 @@ class RegistrationEngine:
|
||||
# 已注册账号的 OTP 在提交表单时已自动发送,记录时间戳
|
||||
self._otp_sent_at = time.time()
|
||||
else:
|
||||
self._raise_if_cancelled()
|
||||
self._log("9. 发送验证码...")
|
||||
self._emit_status("otp_send", "发送验证码", step_index=9)
|
||||
if not self._send_verification_code():
|
||||
@@ -1632,6 +1690,7 @@ class RegistrationEngine:
|
||||
return result
|
||||
|
||||
# 10. 获取验证码(支持重发重试)
|
||||
self._raise_if_cancelled()
|
||||
self._log("10. 等待验证码...")
|
||||
self._emit_status("otp_secondary", "等待验证码邮件", step_index=10)
|
||||
code, otp_phase = self._await_verification_code_with_resends(
|
||||
@@ -1650,6 +1709,7 @@ class RegistrationEngine:
|
||||
return result
|
||||
|
||||
# 11. 验证验证码
|
||||
self._raise_if_cancelled()
|
||||
self._log("11. 验证验证码...")
|
||||
self._emit_status("otp_validate", "校验验证码", step_index=11)
|
||||
if not self._validate_verification_code(code):
|
||||
@@ -1660,6 +1720,7 @@ class RegistrationEngine:
|
||||
if self._is_existing_account:
|
||||
self._log("12. [已注册账号] 跳过创建用户账户")
|
||||
else:
|
||||
self._raise_if_cancelled()
|
||||
self._log("12. 创建用户账户...")
|
||||
self._emit_status("account_create", "创建 OpenAI 账户资料", step_index=12)
|
||||
if not self._create_user_account():
|
||||
@@ -1670,6 +1731,7 @@ class RegistrationEngine:
|
||||
callback_url = None
|
||||
|
||||
if not self._is_existing_account:
|
||||
self._raise_if_cancelled()
|
||||
self._log(f"{next_step}. [新账号] 推进 Codex 授权流程...")
|
||||
self._emit_status("oauth_reentry", "推进 Codex 授权流程", step_index=next_step)
|
||||
workspace_id, callback_url = self._advance_login_authorization()
|
||||
@@ -1679,6 +1741,7 @@ class RegistrationEngine:
|
||||
|
||||
if not result.workspace_id:
|
||||
# 获取 Workspace ID
|
||||
self._raise_if_cancelled()
|
||||
self._log(f"{next_step}. 获取 Workspace ID...")
|
||||
self._emit_status("workspace_extract", "从授权态提取 Workspace ID", step_index=next_step)
|
||||
workspace_id = self._get_workspace_id()
|
||||
@@ -1691,6 +1754,7 @@ class RegistrationEngine:
|
||||
next_step += 1
|
||||
|
||||
# 选择 Workspace
|
||||
self._raise_if_cancelled()
|
||||
self._log(f"{next_step}. 选择 Workspace...")
|
||||
self._emit_status("workspace_select", "选择 Workspace", step_index=next_step)
|
||||
continue_url = self._select_workspace(result.workspace_id)
|
||||
@@ -1701,6 +1765,7 @@ class RegistrationEngine:
|
||||
next_step += 1
|
||||
|
||||
# 跟随重定向链
|
||||
self._raise_if_cancelled()
|
||||
self._log(f"{next_step}. 跟随重定向链...")
|
||||
self._emit_status("redirect_chain", "跟随授权重定向链", step_index=next_step)
|
||||
callback_url = self._follow_redirects(continue_url)
|
||||
@@ -1711,6 +1776,7 @@ class RegistrationEngine:
|
||||
next_step += 1
|
||||
|
||||
# 处理 OAuth 回调
|
||||
self._raise_if_cancelled()
|
||||
self._log(f"{next_step}. 处理 OAuth 回调...")
|
||||
self._emit_status("oauth_callback", "处理 OAuth 回调", step_index=next_step)
|
||||
token_info = self._handle_oauth_callback(callback_url)
|
||||
@@ -1757,6 +1823,11 @@ class RegistrationEngine:
|
||||
|
||||
return result
|
||||
|
||||
except TaskCancelledError as e:
|
||||
result.error_message = str(e)
|
||||
result.error_code = getattr(e, "error_code", ERROR_TASK_CANCELLED)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
self._log(f"注册过程中发生未预期错误: {e}", "error")
|
||||
result.error_message = str(e)
|
||||
|
||||
@@ -9,7 +9,7 @@ import re
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any, List
|
||||
from typing import Optional, Dict, Any, List, Callable
|
||||
from enum import Enum
|
||||
|
||||
from ..config.constants import EmailServiceType, OPENAI_EMAIL_SENDERS, OTP_CODE_PATTERN, OTP_CODE_SEMANTIC_PATTERN
|
||||
@@ -139,6 +139,10 @@ class OTPNoOpenAISenderEmailServiceError(EmailServiceError):
|
||||
self.error_code = error_code
|
||||
|
||||
|
||||
class EmailServiceCancelledError(EmailServiceError):
|
||||
"""邮箱服务在轮询过程中收到取消信号。"""
|
||||
|
||||
|
||||
class EmailServiceStatus(Enum):
|
||||
"""邮箱服务状态"""
|
||||
HEALTHY = "healthy"
|
||||
@@ -168,6 +172,7 @@ class BaseEmailService(abc.ABC):
|
||||
self._provider_backoff = reset_adaptive_backoff()
|
||||
self._used_verification_codes: Dict[str, set] = {}
|
||||
self._seen_verification_messages: Dict[str, set] = {}
|
||||
self.check_cancelled: Optional[Callable[[], bool]] = None
|
||||
|
||||
_EMAIL_ADDRESS_PATTERN = re.compile(r"[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}")
|
||||
|
||||
@@ -190,6 +195,35 @@ class BaseEmailService(abc.ABC):
|
||||
"""注入外部持久化的邮箱供应商退避状态"""
|
||||
self._provider_backoff = state or reset_adaptive_backoff()
|
||||
|
||||
def set_check_cancelled(self, callback: Optional[Callable[[], bool]]) -> None:
|
||||
"""注入外部取消检查回调。"""
|
||||
self.check_cancelled = callback if callable(callback) else None
|
||||
|
||||
def _is_cancelled_requested(self) -> bool:
|
||||
"""检查邮箱服务是否收到取消请求。"""
|
||||
callback = self.check_cancelled
|
||||
if not callable(callback):
|
||||
return False
|
||||
try:
|
||||
return bool(callback())
|
||||
except Exception as e:
|
||||
logger.warning(f"检查邮箱服务取消状态失败: {e}")
|
||||
return False
|
||||
|
||||
def _raise_if_cancelled(self, message: str = "任务已取消") -> None:
|
||||
"""在轮询/等待阶段响应取消请求。"""
|
||||
if self._is_cancelled_requested():
|
||||
raise EmailServiceCancelledError(message)
|
||||
|
||||
def _sleep_with_cancel(self, seconds: float, chunk_seconds: float = 0.2) -> None:
|
||||
"""可响应取消的短分片休眠。"""
|
||||
remaining = max(0.0, float(seconds))
|
||||
while remaining > 0:
|
||||
self._raise_if_cancelled()
|
||||
sleep_for = min(chunk_seconds, remaining)
|
||||
time.sleep(sleep_for)
|
||||
remaining -= sleep_for
|
||||
|
||||
@abc.abstractmethod
|
||||
def create_email(self, config: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
@@ -520,6 +554,7 @@ class BaseEmailService(abc.ABC):
|
||||
last_email_id = None
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
self._raise_if_cancelled("等待邮件时任务已取消")
|
||||
try:
|
||||
emails = self.list_emails()
|
||||
for email_info in emails:
|
||||
@@ -562,7 +597,7 @@ class BaseEmailService(abc.ABC):
|
||||
except Exception as e:
|
||||
logger.warning(f"等待邮件时出错: {e}")
|
||||
|
||||
time.sleep(check_interval)
|
||||
self._sleep_with_cancel(check_interval)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@@ -236,6 +236,7 @@ class CloudMailService(BaseEmailService):
|
||||
seen_mail_ids: set = set()
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
self._raise_if_cancelled("等待 Cloud Mail 验证码时任务已取消")
|
||||
try:
|
||||
token = self._get_public_token()
|
||||
mails = self._make_request(
|
||||
@@ -253,7 +254,7 @@ class CloudMailService(BaseEmailService):
|
||||
mails = mails["list"]
|
||||
|
||||
if not isinstance(mails, list):
|
||||
time.sleep(poll_interval)
|
||||
self._sleep_with_cancel(poll_interval)
|
||||
continue
|
||||
|
||||
if mails:
|
||||
@@ -302,7 +303,7 @@ class CloudMailService(BaseEmailService):
|
||||
raise
|
||||
logger.debug(f"检查 Cloud Mail 邮件时出错: {e}")
|
||||
|
||||
time.sleep(poll_interval)
|
||||
self._sleep_with_cancel(poll_interval)
|
||||
|
||||
logger.warning(f"等待 Cloud Mail 验证码超时: {email}")
|
||||
return None
|
||||
|
||||
@@ -263,6 +263,7 @@ class DuckMailService(BaseEmailService):
|
||||
seen_message_ids = set()
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
self._raise_if_cancelled("等待 DuckMail 验证码时任务已取消")
|
||||
try:
|
||||
response = self._make_request(
|
||||
"GET",
|
||||
@@ -324,7 +325,7 @@ class DuckMailService(BaseEmailService):
|
||||
raise
|
||||
logger.debug(f"DuckMail 轮询验证码失败: {e}")
|
||||
|
||||
time.sleep(poll_interval)
|
||||
self._sleep_with_cancel(poll_interval)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@@ -214,10 +214,11 @@ class FreemailService(BaseEmailService):
|
||||
seen_mail_ids: set = set()
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
self._raise_if_cancelled("等待 Freemail 验证码时任务已取消")
|
||||
try:
|
||||
mails = self._make_request("GET", "/api/emails", params={"mailbox": email, "limit": 20})
|
||||
if not isinstance(mails, list):
|
||||
time.sleep(poll_interval)
|
||||
self._sleep_with_cancel(poll_interval)
|
||||
continue
|
||||
|
||||
ordered_mails = self._sort_items_by_message_time(
|
||||
@@ -299,7 +300,7 @@ class FreemailService(BaseEmailService):
|
||||
raise
|
||||
logger.debug(f"检查 Freemail 邮件时出错: {e}")
|
||||
|
||||
time.sleep(poll_interval)
|
||||
self._sleep_with_cancel(poll_interval)
|
||||
|
||||
logger.warning(f"等待 Freemail 验证码超时: {email}")
|
||||
return None
|
||||
|
||||
@@ -12,7 +12,7 @@ import logging
|
||||
from email.header import decode_header
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from .base import BaseEmailService, EmailServiceError, OTPNoOpenAISenderEmailServiceError, get_email_code_settings
|
||||
from .base import BaseEmailService, OTPNoOpenAISenderEmailServiceError, get_email_code_settings
|
||||
from ..config.constants import (
|
||||
EmailServiceType,
|
||||
OTP_CODE_SEMANTIC_PATTERN,
|
||||
@@ -124,11 +124,12 @@ class ImapMailService(BaseEmailService):
|
||||
mail.select("INBOX")
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
self._raise_if_cancelled("等待 IMAP 验证码时任务已取消")
|
||||
try:
|
||||
# 搜索所有未读邮件
|
||||
status, data = mail.search(None, "UNSEEN")
|
||||
if status != "OK" or not data or not data[0]:
|
||||
time.sleep(poll_interval)
|
||||
self._sleep_with_cancel(poll_interval)
|
||||
continue
|
||||
|
||||
msg_ids = data[0].split()
|
||||
@@ -181,13 +182,13 @@ class ImapMailService(BaseEmailService):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
time.sleep(poll_interval)
|
||||
self._sleep_with_cancel(poll_interval)
|
||||
|
||||
except Exception as e:
|
||||
if isinstance(e, OTPNoOpenAISenderEmailServiceError):
|
||||
raise
|
||||
logger.warning(f"IMAP 连接/轮询失败: {e}")
|
||||
self.update_status(False, str(e))
|
||||
self.update_status(False, e)
|
||||
finally:
|
||||
if mail:
|
||||
try:
|
||||
|
||||
@@ -308,13 +308,14 @@ class MeoMailEmailService(BaseEmailService):
|
||||
seen_message_ids = set()
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
self._raise_if_cancelled("等待自定义域名邮箱验证码时任务已取消")
|
||||
try:
|
||||
# 获取邮件列表
|
||||
response = self._make_request("GET", f"/api/emails/{target_email_id}")
|
||||
|
||||
messages = response.get("messages", [])
|
||||
if not isinstance(messages, list):
|
||||
time.sleep(poll_interval)
|
||||
self._sleep_with_cancel(poll_interval)
|
||||
continue
|
||||
|
||||
ordered_messages = self._sort_items_by_message_time(
|
||||
@@ -380,7 +381,7 @@ class MeoMailEmailService(BaseEmailService):
|
||||
logger.debug(f"检查邮件时出错: {e}")
|
||||
|
||||
# 等待一段时间再检查
|
||||
time.sleep(poll_interval)
|
||||
self._sleep_with_cancel(poll_interval)
|
||||
|
||||
logger.warning(f"等待验证码超时: {email}")
|
||||
return None
|
||||
|
||||
@@ -337,6 +337,7 @@ class OutlookService(BaseEmailService):
|
||||
poll_count = 0
|
||||
|
||||
while time.time() - start_time < actual_timeout:
|
||||
self._raise_if_cancelled("等待 Outlook 验证码时任务已取消")
|
||||
poll_count += 1
|
||||
|
||||
# 渐进式邮件检查:前 3 次只检查未读
|
||||
@@ -387,7 +388,7 @@ class OutlookService(BaseEmailService):
|
||||
logger.warning(f"[{email}] 检查出错: {e}")
|
||||
|
||||
# 等待下次轮询
|
||||
time.sleep(poll_interval)
|
||||
self._sleep_with_cancel(poll_interval)
|
||||
|
||||
elapsed = int(time.time() - start_time)
|
||||
logger.warning(f"[{email}] 验证码超时 ({actual_timeout}s),共轮询 {poll_count} 次")
|
||||
|
||||
@@ -309,6 +309,7 @@ class TempMailService(BaseEmailService):
|
||||
# jwt = cached.get("jwt")
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
self._raise_if_cancelled("等待 TempMail 验证码时任务已取消")
|
||||
try:
|
||||
# if jwt:
|
||||
# response = self._make_request(
|
||||
@@ -327,7 +328,7 @@ class TempMailService(BaseEmailService):
|
||||
# /user_api/mails 和 /admin/mails 返回格式相同: {"results": [...], "total": N}
|
||||
mails = response.get("results", [])
|
||||
if not isinstance(mails, list):
|
||||
time.sleep(poll_interval)
|
||||
self._sleep_with_cancel(poll_interval)
|
||||
continue
|
||||
|
||||
ordered_mails = self._sort_items_by_message_time(
|
||||
@@ -389,7 +390,7 @@ class TempMailService(BaseEmailService):
|
||||
raise
|
||||
logger.debug(f"检查 TempMail 邮件时出错: {e}")
|
||||
|
||||
time.sleep(poll_interval)
|
||||
self._sleep_with_cancel(poll_interval)
|
||||
|
||||
logger.warning(f"等待 TempMail 验证码超时: {email}")
|
||||
return None
|
||||
|
||||
@@ -217,6 +217,7 @@ class TempmailService(BaseEmailService):
|
||||
seen_ids = set()
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
self._raise_if_cancelled("等待 Tempmail 验证码时任务已取消")
|
||||
try:
|
||||
# 获取邮件列表
|
||||
response = self.http_client.get(
|
||||
@@ -226,7 +227,7 @@ class TempmailService(BaseEmailService):
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
time.sleep(poll_interval)
|
||||
self._sleep_with_cancel(poll_interval)
|
||||
continue
|
||||
|
||||
data = response.json()
|
||||
@@ -239,7 +240,7 @@ class TempmailService(BaseEmailService):
|
||||
email_list = data.get("emails", []) if isinstance(data, dict) else []
|
||||
|
||||
if not isinstance(email_list, list):
|
||||
time.sleep(poll_interval)
|
||||
self._sleep_with_cancel(poll_interval)
|
||||
continue
|
||||
|
||||
ordered_emails = self._sort_items_by_message_time(
|
||||
@@ -302,7 +303,7 @@ class TempmailService(BaseEmailService):
|
||||
logger.debug(f"检查邮件时出错: {e}")
|
||||
|
||||
# 等待一段时间再检查
|
||||
time.sleep(poll_interval)
|
||||
self._sleep_with_cancel(poll_interval)
|
||||
|
||||
logger.warning(f"等待验证码超时: {email}")
|
||||
return None
|
||||
|
||||
@@ -20,6 +20,7 @@ from ...database.session import get_db
|
||||
from ...database.models import RegistrationTask, Proxy
|
||||
from ...core.register import (
|
||||
ERROR_OTP_TIMEOUT_SECONDARY,
|
||||
ERROR_TASK_CANCELLED,
|
||||
RegistrationEngine,
|
||||
RegistrationResult,
|
||||
)
|
||||
@@ -395,6 +396,11 @@ def _run_registration_engine_attempt(
|
||||
status_callback=status_callback,
|
||||
task_uuid=task_uuid,
|
||||
)
|
||||
create_cancel_callback = getattr(task_manager, "create_check_cancelled_callback", None)
|
||||
if callable(create_cancel_callback):
|
||||
setattr(engine, "check_cancelled", create_cancel_callback(task_uuid))
|
||||
else:
|
||||
setattr(engine, "check_cancelled", lambda: False)
|
||||
|
||||
try:
|
||||
result = engine.run()
|
||||
@@ -435,6 +441,52 @@ def _get_batch_snapshot(batch_id: str) -> Optional[dict]:
|
||||
return task_manager.get_batch_status(batch_id)
|
||||
|
||||
|
||||
def _finalize_task_cancelled(
|
||||
db,
|
||||
task_uuid: str,
|
||||
message: str = "任务已取消",
|
||||
*,
|
||||
email_service: Optional[str] = None,
|
||||
):
|
||||
"""将任务统一收敛为已取消状态。"""
|
||||
crud.update_registration_task(
|
||||
db,
|
||||
task_uuid,
|
||||
status="cancelled",
|
||||
completed_at=datetime.utcnow(),
|
||||
error_message=message,
|
||||
)
|
||||
status_kwargs = {"error": message, "cancelled": True}
|
||||
if email_service:
|
||||
status_kwargs["email_service"] = email_service
|
||||
task_manager.update_status(task_uuid, "cancelled", **status_kwargs)
|
||||
|
||||
|
||||
def _request_task_cancellation(db, task_uuid: str, message: str = "任务已取消"):
|
||||
"""向运行中的任务发送取消信号,并持久化取消状态。"""
|
||||
task_manager.cancel_task(task_uuid)
|
||||
_finalize_task_cancelled(db, task_uuid, message)
|
||||
|
||||
|
||||
def _request_batch_cancellation(batch_id: str, message: str = "批量任务取消请求已提交"):
|
||||
"""向批量任务及其成员任务传播取消信号。"""
|
||||
batch = _require_batch_snapshot(batch_id)
|
||||
if batch.get("finished"):
|
||||
raise HTTPException(status_code=400, detail="批量任务已完成")
|
||||
|
||||
task_manager.cancel_batch(batch_id)
|
||||
task_manager.update_batch_status(batch_id, status="cancelled", cancelled=True)
|
||||
|
||||
with get_db() as db:
|
||||
for task_uuid in batch.get("task_uuids", []):
|
||||
task_manager.cancel_task(task_uuid)
|
||||
task = crud.get_registration_task(db, task_uuid)
|
||||
if task and task.status in ["pending", "running"]:
|
||||
_finalize_task_cancelled(db, task_uuid, "任务已取消")
|
||||
|
||||
return {"success": True, "message": message}
|
||||
|
||||
|
||||
def _require_batch_snapshot(batch_id: str) -> dict:
|
||||
batch = _get_batch_snapshot(batch_id)
|
||||
if batch is None:
|
||||
@@ -560,8 +612,12 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy:
|
||||
"""
|
||||
with get_db() as db:
|
||||
try:
|
||||
def cancellation_requested() -> bool:
|
||||
return task_manager.is_cancelled(task_uuid)
|
||||
|
||||
if task_manager.is_cancelled(task_uuid):
|
||||
logger.info(f"任务 {task_uuid} 已取消,跳过执行")
|
||||
_finalize_task_cancelled(db, task_uuid)
|
||||
return
|
||||
|
||||
task = crud.update_registration_task(
|
||||
@@ -583,6 +639,10 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy:
|
||||
proxy_id = None
|
||||
|
||||
while True:
|
||||
if cancellation_requested():
|
||||
_finalize_task_cancelled(db, task_uuid, email_service=active_service_type.value)
|
||||
return
|
||||
|
||||
actual_proxy_url = requested_proxy
|
||||
proxy_id = None
|
||||
if not actual_proxy_url:
|
||||
@@ -605,6 +665,10 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy:
|
||||
should_retry_with_new_proxy = False
|
||||
|
||||
for attempt_index, candidate in enumerate(service_candidates, start=1):
|
||||
if cancellation_requested():
|
||||
_finalize_task_cancelled(db, task_uuid, email_service=active_service_type.value)
|
||||
return
|
||||
|
||||
selected_service_type = candidate["service_type"]
|
||||
candidate_config = candidate["config"]
|
||||
db_service = candidate.get("db_service")
|
||||
@@ -630,6 +694,11 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy:
|
||||
candidate_config,
|
||||
name=db_service.name if db_service is not None else None,
|
||||
)
|
||||
create_cancel_callback = getattr(task_manager, "create_check_cancelled_callback", None)
|
||||
if callable(create_cancel_callback):
|
||||
set_cancel_callback = getattr(email_service, "set_check_cancelled", None)
|
||||
if callable(set_cancel_callback):
|
||||
set_cancel_callback(create_cancel_callback(task_uuid))
|
||||
(
|
||||
engine,
|
||||
result,
|
||||
@@ -645,6 +714,15 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy:
|
||||
status_callback=status_callback,
|
||||
)
|
||||
|
||||
if cancellation_requested() or result.error_code == ERROR_TASK_CANCELLED:
|
||||
_finalize_task_cancelled(
|
||||
db,
|
||||
task_uuid,
|
||||
result.error_message or "任务已取消",
|
||||
email_service=active_service_type.value,
|
||||
)
|
||||
return
|
||||
|
||||
if result.success:
|
||||
break
|
||||
|
||||
@@ -707,6 +785,10 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy:
|
||||
break
|
||||
|
||||
if result.success:
|
||||
if cancellation_requested():
|
||||
_finalize_task_cancelled(db, task_uuid, "任务已取消", email_service=active_service_type.value)
|
||||
return
|
||||
|
||||
# 更新代理使用时间
|
||||
update_proxy_usage(db, proxy_id)
|
||||
|
||||
@@ -836,6 +918,10 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy:
|
||||
except Exception as na_err:
|
||||
log_callback(f"[NEWAPI] 上传异常: {na_err}")
|
||||
|
||||
if cancellation_requested():
|
||||
_finalize_task_cancelled(db, task_uuid, "任务已取消", email_service=active_service_type.value)
|
||||
return
|
||||
|
||||
# 更新任务状态
|
||||
crud.update_registration_task(
|
||||
db, task_uuid,
|
||||
@@ -857,6 +943,15 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy:
|
||||
|
||||
logger.info(f"注册任务完成: {task_uuid}, 邮箱: {result.email}")
|
||||
else:
|
||||
if cancellation_requested() or result.error_code == ERROR_TASK_CANCELLED:
|
||||
_finalize_task_cancelled(
|
||||
db,
|
||||
task_uuid,
|
||||
result.error_message or "任务已取消",
|
||||
email_service=active_service_type.value,
|
||||
)
|
||||
return
|
||||
|
||||
# 更新任务状态为失败
|
||||
crud.update_registration_task(
|
||||
db, task_uuid,
|
||||
@@ -880,15 +975,18 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy:
|
||||
|
||||
try:
|
||||
with get_db() as db:
|
||||
crud.update_registration_task(
|
||||
db, task_uuid,
|
||||
status="failed",
|
||||
completed_at=datetime.utcnow(),
|
||||
error_message=str(e)
|
||||
)
|
||||
if task_manager.is_cancelled(task_uuid):
|
||||
_finalize_task_cancelled(db, task_uuid)
|
||||
else:
|
||||
crud.update_registration_task(
|
||||
db, task_uuid,
|
||||
status="failed",
|
||||
completed_at=datetime.utcnow(),
|
||||
error_message=str(e)
|
||||
)
|
||||
|
||||
# 更新 TaskManager 状态
|
||||
task_manager.update_status(task_uuid, "failed", error=str(e))
|
||||
# 更新 TaskManager 状态
|
||||
task_manager.update_status(task_uuid, "failed", error=str(e))
|
||||
except:
|
||||
pass
|
||||
|
||||
@@ -904,6 +1002,11 @@ async def run_registration_task(task_uuid: str, email_service_type: str, proxy:
|
||||
loop = asyncio.get_event_loop()
|
||||
task_manager.set_loop(loop)
|
||||
|
||||
if task_manager.is_cancelled(task_uuid):
|
||||
with get_db() as db:
|
||||
_finalize_task_cancelled(db, task_uuid)
|
||||
return
|
||||
|
||||
# 初始化 TaskManager 状态
|
||||
task_manager.update_status(task_uuid, "pending", email_service=email_service_type)
|
||||
task_manager.add_log(task_uuid, f"{log_prefix} [系统] 任务 {task_uuid[:8]} 已加入队列" if log_prefix else f"[系统] 任务 {task_uuid[:8]} 已加入队列")
|
||||
@@ -1635,12 +1738,7 @@ async def get_batch_status(batch_id: str):
|
||||
@router.post("/batch/{batch_id}/cancel")
|
||||
async def cancel_batch(batch_id: str):
|
||||
"""取消批量任务"""
|
||||
batch = _require_batch_snapshot(batch_id)
|
||||
if batch.get("finished"):
|
||||
raise HTTPException(status_code=400, detail="批量任务已完成")
|
||||
|
||||
task_manager.cancel_batch(batch_id)
|
||||
return {"success": True, "message": "批量任务取消请求已提交"}
|
||||
return _request_batch_cancellation(batch_id)
|
||||
|
||||
|
||||
@router.get("/tasks", response_model=TaskListResponse)
|
||||
@@ -1703,7 +1801,7 @@ async def cancel_task(task_uuid: str):
|
||||
if task.status not in ["pending", "running"]:
|
||||
raise HTTPException(status_code=400, detail="任务已完成或已取消")
|
||||
|
||||
task = crud.update_registration_task(db, task_uuid, status="cancelled")
|
||||
_request_task_cancellation(db, task_uuid)
|
||||
|
||||
return {"success": True, "message": "任务已取消"}
|
||||
|
||||
@@ -2207,10 +2305,4 @@ async def get_outlook_batch_status(batch_id: str):
|
||||
@router.post("/outlook-batch/{batch_id}/cancel")
|
||||
async def cancel_outlook_batch(batch_id: str):
|
||||
"""取消 Outlook 批量任务"""
|
||||
batch = _require_batch_snapshot(batch_id)
|
||||
if batch.get("finished"):
|
||||
raise HTTPException(status_code=400, detail="批量任务已完成")
|
||||
|
||||
task_manager.cancel_batch(batch_id)
|
||||
|
||||
return {"success": True, "message": "批量任务取消请求已提交"}
|
||||
return _request_batch_cancellation(batch_id)
|
||||
|
||||
@@ -86,7 +86,14 @@ async def task_websocket(websocket: WebSocket, task_uuid: str):
|
||||
|
||||
# 处理取消请求
|
||||
elif data.get("type") == "cancel":
|
||||
task_manager.cancel_task(task_uuid)
|
||||
from . import registration as registration_routes
|
||||
|
||||
with get_db() as db:
|
||||
task = crud.get_registration_task(db, task_uuid)
|
||||
if task and task.status in ["pending", "running"]:
|
||||
registration_routes._request_task_cancellation(db, task_uuid)
|
||||
else:
|
||||
task_manager.cancel_task(task_uuid)
|
||||
await websocket.send_json({
|
||||
"type": "status",
|
||||
"task_uuid": task_uuid,
|
||||
@@ -163,7 +170,9 @@ async def batch_websocket(websocket: WebSocket, batch_id: str):
|
||||
|
||||
# 处理取消请求
|
||||
elif data.get("type") == "cancel":
|
||||
task_manager.cancel_batch(batch_id)
|
||||
from . import registration as registration_routes
|
||||
|
||||
registration_routes._request_batch_cancellation(batch_id)
|
||||
await websocket.send_json({
|
||||
"type": "status",
|
||||
"batch_id": batch_id,
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
const test = require('node:test');
|
||||
const assert = require('node:assert/strict');
|
||||
const fs = require('node:fs');
|
||||
const path = require('node:path');
|
||||
const vm = require('node:vm');
|
||||
|
||||
const APP_JS_PATH = '/Users/zhoukailian/.config/superpowers/worktrees/codex-manager/repro-batch-monitor/static/js/app.js';
|
||||
const APP_JS_PATH = path.join(__dirname, '..', 'static', 'js', 'app.js');
|
||||
|
||||
function createElementStub() {
|
||||
return {
|
||||
|
||||
363
tests/test_registration_task_cancellation.py
Normal file
363
tests/test_registration_task_cancellation.py
Normal file
@@ -0,0 +1,363 @@
|
||||
import asyncio
|
||||
import threading
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
from src.core.register import ERROR_TASK_CANCELLED, RegistrationResult
|
||||
from src.database.models import Base, RegistrationTask
|
||||
from src.database.session import DatabaseSessionManager
|
||||
from src.services import EmailServiceType
|
||||
from src.services.base import BaseEmailService, EmailServiceCancelledError
|
||||
from src.web.routes import registration as registration_routes
|
||||
|
||||
|
||||
class FakeTaskManager:
|
||||
def __init__(self):
|
||||
self.cancelled = set()
|
||||
self.status_updates = []
|
||||
self.logs = {}
|
||||
self.batch_status = {}
|
||||
|
||||
def is_cancelled(self, task_uuid):
|
||||
return task_uuid in self.cancelled
|
||||
|
||||
def cancel_task(self, task_uuid):
|
||||
self.cancelled.add(task_uuid)
|
||||
|
||||
def update_status(self, task_uuid, status, **kwargs):
|
||||
self.status_updates.append((task_uuid, status, kwargs))
|
||||
|
||||
def create_log_callback(self, task_uuid, prefix="", batch_id=""):
|
||||
def callback(message):
|
||||
full_message = f"{prefix} {message}" if prefix else message
|
||||
self.logs.setdefault(task_uuid, []).append(full_message)
|
||||
return callback
|
||||
|
||||
def create_check_cancelled_callback(self, task_uuid):
|
||||
return lambda: self.is_cancelled(task_uuid)
|
||||
|
||||
def get_batch_status(self, batch_id):
|
||||
snapshot = self.batch_status.get(batch_id)
|
||||
return dict(snapshot) if snapshot else None
|
||||
|
||||
def cancel_batch(self, batch_id):
|
||||
snapshot = self.batch_status.setdefault(batch_id, {})
|
||||
snapshot["cancelled"] = True
|
||||
snapshot["status"] = "cancelling"
|
||||
|
||||
def update_batch_status(self, batch_id, **kwargs):
|
||||
snapshot = self.batch_status.setdefault(batch_id, {})
|
||||
snapshot.update(kwargs)
|
||||
|
||||
|
||||
class DummySettings:
|
||||
proxy_dynamic_enabled = False
|
||||
proxy_dynamic_api_url = ""
|
||||
email_code_timeout = 10
|
||||
email_code_poll_interval = 1
|
||||
email_code_resend_max_retries = 0
|
||||
email_code_non_openai_sender_resend_max_retries = 0
|
||||
openai_client_id = "client-id"
|
||||
openai_auth_url = "https://auth.example.test"
|
||||
openai_token_url = "https://token.example.test"
|
||||
openai_redirect_uri = "https://callback.example.test"
|
||||
openai_scope = "openid profile email"
|
||||
|
||||
def get_proxy_url(self):
|
||||
return None
|
||||
|
||||
|
||||
def _build_fake_get_db(manager):
|
||||
@contextmanager
|
||||
def fake_get_db():
|
||||
with manager.session_scope() as session:
|
||||
yield session
|
||||
|
||||
return fake_get_db
|
||||
|
||||
|
||||
class FakeRegistrationEngine:
|
||||
started_event = None
|
||||
|
||||
def __init__(self, email_service, proxy_url=None, callback_logger=None, status_callback=None, task_uuid=None):
|
||||
self.email_service = email_service
|
||||
self.phase_history = []
|
||||
self.check_cancelled = None
|
||||
self.callback_logger = callback_logger or (lambda _msg: None)
|
||||
self.task_uuid = task_uuid
|
||||
|
||||
def run(self):
|
||||
if self.started_event is not None:
|
||||
self.started_event.set()
|
||||
while True:
|
||||
if callable(self.check_cancelled) and self.check_cancelled():
|
||||
return RegistrationResult(
|
||||
success=False,
|
||||
error_message="任务已取消",
|
||||
error_code=ERROR_TASK_CANCELLED,
|
||||
logs=[],
|
||||
)
|
||||
time.sleep(0.01)
|
||||
|
||||
def save_to_database(self, result):
|
||||
return False
|
||||
|
||||
def close(self):
|
||||
return None
|
||||
|
||||
|
||||
class FakePollingEmailService(BaseEmailService):
|
||||
def __init__(self, started_event=None):
|
||||
super().__init__(EmailServiceType.TEMPMAIL, "fake-polling-email")
|
||||
self.started_event = started_event
|
||||
|
||||
def create_email(self, config=None):
|
||||
return {"email": "poll@example.test", "service_id": "poll-service"}
|
||||
|
||||
def get_verification_code(self, email: str, email_id: str = None, timeout: int = 120, pattern: str = r"(?<!\d)(\d{6})(?!\d)", otp_sent_at=None):
|
||||
if self.started_event is not None:
|
||||
self.started_event.set()
|
||||
while True:
|
||||
self._raise_if_cancelled("任务已取消")
|
||||
self._sleep_with_cancel(0.05)
|
||||
|
||||
def list_emails(self, **kwargs):
|
||||
return []
|
||||
|
||||
def delete_email(self, email_id: str):
|
||||
return True
|
||||
|
||||
def check_health(self):
|
||||
return True
|
||||
|
||||
|
||||
class FakeEmailPollingRegistrationEngine:
|
||||
def __init__(self, email_service, proxy_url=None, callback_logger=None, status_callback=None, task_uuid=None):
|
||||
self.email_service = email_service
|
||||
self.phase_history = []
|
||||
self.check_cancelled = None
|
||||
self.callback_logger = callback_logger or (lambda _msg: None)
|
||||
self.task_uuid = task_uuid
|
||||
|
||||
def run(self):
|
||||
try:
|
||||
self.email_service.get_verification_code("poll@example.test", timeout=60)
|
||||
except EmailServiceCancelledError as exc:
|
||||
return RegistrationResult(
|
||||
success=False,
|
||||
error_message=str(exc),
|
||||
error_code=ERROR_TASK_CANCELLED,
|
||||
logs=[],
|
||||
)
|
||||
|
||||
return RegistrationResult(success=False, error_message="邮箱轮询未被取消", logs=[])
|
||||
|
||||
def save_to_database(self, result):
|
||||
return False
|
||||
|
||||
def close(self):
|
||||
return None
|
||||
|
||||
|
||||
def test_cancel_task_route_marks_task_manager_and_db_cancelled(monkeypatch):
|
||||
runtime_dir = Path("tests_runtime")
|
||||
runtime_dir.mkdir(exist_ok=True)
|
||||
db_path = runtime_dir / "registration_cancel_route.db"
|
||||
if db_path.exists():
|
||||
db_path.unlink()
|
||||
|
||||
manager = DatabaseSessionManager(f"sqlite:///{db_path}")
|
||||
Base.metadata.create_all(bind=manager.engine)
|
||||
|
||||
task_uuid = "task-cancel-route"
|
||||
with manager.session_scope() as session:
|
||||
session.add(RegistrationTask(task_uuid=task_uuid, status="running"))
|
||||
|
||||
fake_task_manager = FakeTaskManager()
|
||||
monkeypatch.setattr(registration_routes, "get_db", _build_fake_get_db(manager))
|
||||
monkeypatch.setattr(registration_routes, "task_manager", fake_task_manager)
|
||||
|
||||
response = asyncio.run(registration_routes.cancel_task(task_uuid))
|
||||
|
||||
assert response == {"success": True, "message": "任务已取消"}
|
||||
assert task_uuid in fake_task_manager.cancelled
|
||||
|
||||
with manager.session_scope() as session:
|
||||
task = session.query(RegistrationTask).filter(RegistrationTask.task_uuid == task_uuid).first()
|
||||
assert task.status == "cancelled"
|
||||
assert task.error_message == "任务已取消"
|
||||
|
||||
|
||||
def test_run_sync_registration_task_stops_after_cancel_request(monkeypatch):
|
||||
runtime_dir = Path("tests_runtime")
|
||||
runtime_dir.mkdir(exist_ok=True)
|
||||
db_path = runtime_dir / "registration_cancel_runtime.db"
|
||||
if db_path.exists():
|
||||
db_path.unlink()
|
||||
|
||||
manager = DatabaseSessionManager(f"sqlite:///{db_path}")
|
||||
Base.metadata.create_all(bind=manager.engine)
|
||||
|
||||
task_uuid = "task-cancel-runtime"
|
||||
with manager.session_scope() as session:
|
||||
session.add(RegistrationTask(task_uuid=task_uuid, status="pending"))
|
||||
|
||||
fake_task_manager = FakeTaskManager()
|
||||
start_event = threading.Event()
|
||||
FakeRegistrationEngine.started_event = start_event
|
||||
|
||||
monkeypatch.setattr(registration_routes, "get_db", _build_fake_get_db(manager))
|
||||
monkeypatch.setattr(registration_routes, "task_manager", fake_task_manager)
|
||||
monkeypatch.setattr(registration_routes, "get_settings", lambda: DummySettings())
|
||||
monkeypatch.setattr(registration_routes, "RegistrationEngine", FakeRegistrationEngine)
|
||||
monkeypatch.setattr(
|
||||
registration_routes,
|
||||
"_build_email_service_candidates",
|
||||
lambda db, service_type, actual_proxy_url, email_service_id, email_service_config: [
|
||||
{
|
||||
"service_type": EmailServiceType.TEMPMAIL,
|
||||
"config": {"proxy_url": actual_proxy_url},
|
||||
"db_service": None,
|
||||
}
|
||||
],
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
registration_routes.EmailServiceFactory,
|
||||
"create",
|
||||
lambda service_type, config, name=None: SimpleNamespace(
|
||||
service_type=service_type,
|
||||
name=name or service_type.value,
|
||||
config=config,
|
||||
),
|
||||
)
|
||||
|
||||
worker = threading.Thread(
|
||||
target=registration_routes._run_sync_registration_task,
|
||||
kwargs={
|
||||
"task_uuid": task_uuid,
|
||||
"email_service_type": EmailServiceType.TEMPMAIL.value,
|
||||
"proxy": None,
|
||||
"email_service_config": {},
|
||||
},
|
||||
)
|
||||
worker.start()
|
||||
assert start_event.wait(timeout=1.0), "registration engine did not start in time"
|
||||
|
||||
response = asyncio.run(registration_routes.cancel_task(task_uuid))
|
||||
assert response == {"success": True, "message": "任务已取消"}
|
||||
|
||||
worker.join(timeout=2.0)
|
||||
assert not worker.is_alive(), "registration worker should stop after cancellation"
|
||||
|
||||
with manager.session_scope() as session:
|
||||
task = session.query(RegistrationTask).filter(RegistrationTask.task_uuid == task_uuid).first()
|
||||
assert task.status == "cancelled"
|
||||
assert task.error_message == "任务已取消"
|
||||
|
||||
statuses = [status for current_uuid, status, _kwargs in fake_task_manager.status_updates if current_uuid == task_uuid]
|
||||
assert "cancelled" in statuses
|
||||
|
||||
|
||||
def test_cancel_batch_propagates_to_member_tasks(monkeypatch):
|
||||
runtime_dir = Path("tests_runtime")
|
||||
runtime_dir.mkdir(exist_ok=True)
|
||||
db_path = runtime_dir / "registration_cancel_batch.db"
|
||||
if db_path.exists():
|
||||
db_path.unlink()
|
||||
|
||||
manager = DatabaseSessionManager(f"sqlite:///{db_path}")
|
||||
Base.metadata.create_all(bind=manager.engine)
|
||||
|
||||
task_uuids = ["batch-cancel-1", "batch-cancel-2"]
|
||||
with manager.session_scope() as session:
|
||||
session.add_all([
|
||||
RegistrationTask(task_uuid=task_uuids[0], status="running"),
|
||||
RegistrationTask(task_uuid=task_uuids[1], status="pending"),
|
||||
])
|
||||
|
||||
fake_task_manager = FakeTaskManager()
|
||||
fake_task_manager.batch_status["batch-1"] = {
|
||||
"finished": False,
|
||||
"cancelled": False,
|
||||
"task_uuids": task_uuids,
|
||||
}
|
||||
|
||||
monkeypatch.setattr(registration_routes, "get_db", _build_fake_get_db(manager))
|
||||
monkeypatch.setattr(registration_routes, "task_manager", fake_task_manager)
|
||||
|
||||
response = asyncio.run(registration_routes.cancel_batch("batch-1"))
|
||||
|
||||
assert response["success"] is True
|
||||
assert fake_task_manager.batch_status["batch-1"]["cancelled"] is True
|
||||
assert fake_task_manager.cancelled == set(task_uuids)
|
||||
|
||||
with manager.session_scope() as session:
|
||||
tasks = session.query(RegistrationTask).order_by(RegistrationTask.task_uuid.asc()).all()
|
||||
assert [task.status for task in tasks] == ["cancelled", "cancelled"]
|
||||
|
||||
|
||||
def test_run_sync_registration_task_stops_while_email_service_polling(monkeypatch):
|
||||
runtime_dir = Path("tests_runtime")
|
||||
runtime_dir.mkdir(exist_ok=True)
|
||||
db_path = runtime_dir / "registration_cancel_email_polling.db"
|
||||
if db_path.exists():
|
||||
db_path.unlink()
|
||||
|
||||
manager = DatabaseSessionManager(f"sqlite:///{db_path}")
|
||||
Base.metadata.create_all(bind=manager.engine)
|
||||
|
||||
task_uuid = "task-cancel-email-polling"
|
||||
with manager.session_scope() as session:
|
||||
session.add(RegistrationTask(task_uuid=task_uuid, status="pending"))
|
||||
|
||||
fake_task_manager = FakeTaskManager()
|
||||
start_event = threading.Event()
|
||||
|
||||
monkeypatch.setattr(registration_routes, "get_db", _build_fake_get_db(manager))
|
||||
monkeypatch.setattr(registration_routes, "task_manager", fake_task_manager)
|
||||
monkeypatch.setattr(registration_routes, "get_settings", lambda: DummySettings())
|
||||
monkeypatch.setattr(registration_routes, "RegistrationEngine", FakeEmailPollingRegistrationEngine)
|
||||
monkeypatch.setattr(
|
||||
registration_routes,
|
||||
"_build_email_service_candidates",
|
||||
lambda db, service_type, actual_proxy_url, email_service_id, email_service_config: [
|
||||
{
|
||||
"service_type": EmailServiceType.TEMPMAIL,
|
||||
"config": {"proxy_url": actual_proxy_url},
|
||||
"db_service": None,
|
||||
}
|
||||
],
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
registration_routes.EmailServiceFactory,
|
||||
"create",
|
||||
lambda service_type, config, name=None: FakePollingEmailService(start_event),
|
||||
)
|
||||
|
||||
worker = threading.Thread(
|
||||
target=registration_routes._run_sync_registration_task,
|
||||
kwargs={
|
||||
"task_uuid": task_uuid,
|
||||
"email_service_type": EmailServiceType.TEMPMAIL.value,
|
||||
"proxy": None,
|
||||
"email_service_config": {},
|
||||
},
|
||||
)
|
||||
worker.start()
|
||||
assert start_event.wait(timeout=1.0), "email polling did not start in time"
|
||||
|
||||
response = asyncio.run(registration_routes.cancel_task(task_uuid))
|
||||
assert response == {"success": True, "message": "任务已取消"}
|
||||
|
||||
worker.join(timeout=2.0)
|
||||
assert not worker.is_alive(), "registration worker should stop while email service is polling"
|
||||
|
||||
with manager.session_scope() as session:
|
||||
task = session.query(RegistrationTask).filter(RegistrationTask.task_uuid == task_uuid).first()
|
||||
assert task.status == "cancelled"
|
||||
assert task.error_message == "任务已取消"
|
||||
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
const test = require('node:test');
|
||||
const assert = require('node:assert/strict');
|
||||
const fs = require('node:fs');
|
||||
const path = require('node:path');
|
||||
const vm = require('node:vm');
|
||||
|
||||
const APP_JS_PATH = '/Users/zhoukailian/.config/superpowers/worktrees/codex-manager/repro-batch-monitor/static/js/app.js';
|
||||
const APP_JS_PATH = path.join(__dirname, '..', 'static', 'js', 'app.js');
|
||||
|
||||
function createElementStub() {
|
||||
return {
|
||||
|
||||
Reference in New Issue
Block a user