feat(email): implement cancellation handling in email services

This commit is contained in:
cnlimiter
2026-03-28 14:11:00 +08:00
parent 8b8ef7c6c0
commit 744b1f4c1b
15 changed files with 681 additions and 101 deletions

View File

@@ -31,6 +31,7 @@ from ..database import crud
from ..database.session import get_db
from ..services import BaseEmailService
from ..services.base import (
EmailServiceCancelledError,
EmailProviderBackoffState,
OTP_NO_OPENAI_SENDER_ERROR,
OTPNoOpenAISenderEmailServiceError,
@@ -43,6 +44,15 @@ PHASE_EMAIL_PREPARE = "email_prepare"
PHASE_OTP_SECONDARY = "otp_secondary"
ERROR_EMAIL_PROVIDER_RATE_LIMITED = "EMAIL_PROVIDER_RATE_LIMITED"
ERROR_OTP_TIMEOUT_SECONDARY = "OTP_TIMEOUT_SECONDARY"
ERROR_TASK_CANCELLED = "TASK_CANCELLED"
class TaskCancelledError(Exception):
"""注册任务被主动取消。"""
def __init__(self, message: str = "任务已取消"):
super().__init__(message)
self.error_code = ERROR_TASK_CANCELLED
@dataclass
@@ -185,6 +195,8 @@ class RegistrationEngine:
self._otp_sent_at: Optional[float] = None # OTP 发送时间戳
self._is_existing_account: bool = False # 是否为已注册账号(用于自动登录)
self.phase_history: list[PhaseResult] = []
self.check_cancelled: Optional[Callable[[], bool]] = None
self._cancel_logged = False
def _log(self, message: str, level: str = "info"):
"""记录日志"""
@@ -274,6 +286,35 @@ class RegistrationEngine:
self.phase_history.append(phase_result)
return phase_result
def _is_cancelled_requested(self) -> bool:
"""检查是否收到外部取消信号。"""
callback = self.check_cancelled
if not callable(callback):
return False
try:
return bool(callback())
except Exception as e:
logger.warning(f"检查任务取消状态失败: {e}")
return False
def _raise_if_cancelled(self, message: str = "任务已取消"):
"""在可协作阶段检查取消请求。"""
if not self._is_cancelled_requested():
return
if not self._cancel_logged:
self._log(message, "warning")
self._cancel_logged = True
raise TaskCancelledError(message)
def _sleep_with_cancel(self, seconds: float, chunk_seconds: float = 0.2):
"""可响应取消的短分片休眠。"""
remaining = max(0.0, float(seconds))
while remaining > 0:
self._raise_if_cancelled()
sleep_for = min(chunk_seconds, remaining)
time.sleep(sleep_for)
remaining -= sleep_for
def _get_phase_result(self, phase_name: str) -> Optional[PhaseResult]:
for phase_result in reversed(self.phase_history):
if phase_result.phase == phase_name:
@@ -427,7 +468,7 @@ class RegistrationEngine:
)
if attempt < max_attempts:
time.sleep(attempt)
self._sleep_with_cancel(attempt)
self.http_client.close()
self.session = self.http_client.session
@@ -656,6 +697,7 @@ class RegistrationEngine:
otp_phase: Optional[PhaseResult] = None
while True:
self._raise_if_cancelled("等待验证码重试时任务已取消")
code, otp_phase = self._phase_otp_secondary(
PhaseContext(otp_sent_at=self._otp_sent_at),
started_at=otp_phase_started_at,
@@ -698,6 +740,7 @@ class RegistrationEngine:
**emit_kwargs,
)
self._raise_if_cancelled("等待验证码重试时任务已取消")
if not resend_callback():
self._log("重新发送验证码失败,跳过本次重试", "warning")
@@ -712,76 +755,79 @@ class RegistrationEngine:
) -> Tuple[Optional[str], PhaseResult]:
"""等待二次验证码邮件并做超时归因。"""
try:
self._raise_if_cancelled("等待验证码时任务已取消")
self._log(f"正在等待邮箱 {self.email} 的验证码...")
email_id = self.email_info.get("service_id") if self.email_info else None
settings = get_settings()
budget = Budget(
timeout_seconds=get_settings().email_code_timeout,
timeout_seconds=settings.email_code_timeout,
started_at=started_at if started_at is not None else time.time(),
)
remaining_timeout = budget.remaining_seconds()
poll_interval = max(1, int(settings.email_code_poll_interval or 1))
if remaining_timeout <= 0:
while True:
self._raise_if_cancelled("等待验证码时任务已取消")
remaining_timeout = budget.remaining_seconds()
if remaining_timeout <= 0:
phase_result = self._record_phase_result(
PhaseResult(
phase=PHASE_OTP_SECONDARY,
success=False,
error_message="等待验证码超时",
error_code=ERROR_OTP_TIMEOUT_SECONDARY,
retryable=True,
next_action="await_email",
metadata={
"budget_started_at": budget.started_at,
"budget_timeout_seconds": budget.timeout_seconds,
"otp_sent_at": context.otp_sent_at,
},
)
)
self._log(phase_result.error_message, "error")
return None, phase_result
attempt_timeout = max(1, min(remaining_timeout, poll_interval))
code = self.email_service.get_verification_code(
email=self.email,
email_id=email_id,
timeout=attempt_timeout,
pattern=OTP_CODE_PATTERN,
otp_sent_at=context.otp_sent_at,
)
if code:
self._log(f"成功获取验证码: {code}")
phase_result = self._record_phase_result(
PhaseResult(
phase=PHASE_OTP_SECONDARY,
success=True,
metadata={
"budget_started_at": budget.started_at,
"budget_timeout_seconds": budget.timeout_seconds,
"otp_sent_at": context.otp_sent_at,
},
)
)
return code, phase_result
except Exception as e:
if isinstance(e, (TaskCancelledError, EmailServiceCancelledError)):
phase_result = self._record_phase_result(
PhaseResult(
phase=PHASE_OTP_SECONDARY,
success=False,
error_message="等待验证码超时",
error_code=ERROR_OTP_TIMEOUT_SECONDARY,
retryable=True,
next_action="await_email",
metadata={
"budget_started_at": budget.started_at,
"budget_timeout_seconds": budget.timeout_seconds,
"otp_sent_at": context.otp_sent_at,
},
error_message=str(e),
error_code=getattr(e, "error_code", ERROR_TASK_CANCELLED),
retryable=False,
next_action="cancelled",
metadata={"otp_sent_at": context.otp_sent_at},
)
)
self._log(phase_result.error_message, "error")
return None, phase_result
code = self.email_service.get_verification_code(
email=self.email,
email_id=email_id,
timeout=remaining_timeout,
pattern=OTP_CODE_PATTERN,
otp_sent_at=context.otp_sent_at,
)
if code:
self._log(f"成功获取验证码: {code}")
phase_result = self._record_phase_result(
PhaseResult(
phase=PHASE_OTP_SECONDARY,
success=True,
metadata={
"budget_started_at": budget.started_at,
"budget_timeout_seconds": budget.timeout_seconds,
"otp_sent_at": context.otp_sent_at,
},
)
)
return code, phase_result
phase_result = self._record_phase_result(
PhaseResult(
phase=PHASE_OTP_SECONDARY,
success=False,
error_message="等待验证码超时",
error_code=ERROR_OTP_TIMEOUT_SECONDARY,
retryable=True,
next_action="await_email",
metadata={
"budget_started_at": budget.started_at,
"budget_timeout_seconds": budget.timeout_seconds,
"otp_sent_at": context.otp_sent_at,
},
)
)
self._log(phase_result.error_message, "error")
return None, phase_result
except Exception as e:
if isinstance(e, OTPNoOpenAISenderEmailServiceError):
self._log(str(e), "warning")
phase_result = self._record_phase_result(
@@ -1404,6 +1450,8 @@ class RegistrationEngine:
non_openai_retry_status_template="重新触发登录验证码(非 OpenAI 发件人,第 {attempt} 次)",
)
if not code:
if otp_phase and otp_phase.error_code == ERROR_TASK_CANCELLED:
raise TaskCancelledError(otp_phase.error_message or "登录流程已取消")
self._log(
otp_phase.error_message if otp_phase and otp_phase.error_message else "登录流程获取验证码失败",
"warning",
@@ -1539,11 +1587,13 @@ class RegistrationEngine:
result = RegistrationResult(success=False, logs=self.logs)
try:
self._raise_if_cancelled()
self._log("=" * 60)
self._log("开始注册流程")
self._log("=" * 60)
# 1. 检查 IP 地理位置
self._raise_if_cancelled()
self._log("1. 检查 IP 地理位置...")
self._emit_status("ip_check", "检查 IP 地理位置", step_index=1)
ip_ok, location = self._check_ip_location()
@@ -1555,6 +1605,7 @@ class RegistrationEngine:
self._log(f"IP 位置: {location}")
# 2. 创建邮箱
self._raise_if_cancelled()
self._log("2. 创建邮箱...")
self._emit_status("email_prepare", "创建邮箱地址", step_index=2)
if not self._phase_email_prepare():
@@ -1570,6 +1621,7 @@ class RegistrationEngine:
result.email = self.email
# 3. 初始化会话
self._raise_if_cancelled()
self._log("3. 初始化会话...")
self._emit_status("session_init", "初始化 HTTP 会话", step_index=3)
if not self._init_session():
@@ -1577,6 +1629,7 @@ class RegistrationEngine:
return result
# 4. 开始 OAuth 流程
self._raise_if_cancelled()
self._log("4. 开始 OAuth 授权流程...")
self._emit_status("oauth_start", "开始 OAuth 授权流程", step_index=4)
if not self._start_oauth():
@@ -1584,6 +1637,7 @@ class RegistrationEngine:
return result
# 5. 获取 Device ID
self._raise_if_cancelled()
self._log("5. 获取 Device ID...")
self._emit_status("oauth_device_id", "获取 Device ID", step_index=5)
did = self._get_device_id()
@@ -1592,6 +1646,7 @@ class RegistrationEngine:
return result
# 6. 检查 Sentinel 拦截
self._raise_if_cancelled()
self._log("6. 检查 Sentinel 拦截...")
self._emit_status("sentinel", "检查 Sentinel 拦截", step_index=6)
sen_token = self._check_sentinel(did)
@@ -1601,6 +1656,7 @@ class RegistrationEngine:
self._log("Sentinel 检查失败或未启用", "warning")
# 7. 提交注册表单 + 解析响应判断账号状态
self._raise_if_cancelled()
self._log("7. 提交注册表单...")
self._emit_status("signup_submit", "提交注册表单", step_index=7)
signup_result = self._submit_signup_form(did, sen_token)
@@ -1612,6 +1668,7 @@ class RegistrationEngine:
if self._is_existing_account:
self._log("8. [已注册账号] 跳过密码设置OTP 已自动发送")
else:
self._raise_if_cancelled()
self._log("8. 注册密码...")
self._emit_status("signup_password", "提交注册密码", step_index=8)
password_ok, password = self._register_password()
@@ -1625,6 +1682,7 @@ class RegistrationEngine:
# 已注册账号的 OTP 在提交表单时已自动发送,记录时间戳
self._otp_sent_at = time.time()
else:
self._raise_if_cancelled()
self._log("9. 发送验证码...")
self._emit_status("otp_send", "发送验证码", step_index=9)
if not self._send_verification_code():
@@ -1632,6 +1690,7 @@ class RegistrationEngine:
return result
# 10. 获取验证码(支持重发重试)
self._raise_if_cancelled()
self._log("10. 等待验证码...")
self._emit_status("otp_secondary", "等待验证码邮件", step_index=10)
code, otp_phase = self._await_verification_code_with_resends(
@@ -1650,6 +1709,7 @@ class RegistrationEngine:
return result
# 11. 验证验证码
self._raise_if_cancelled()
self._log("11. 验证验证码...")
self._emit_status("otp_validate", "校验验证码", step_index=11)
if not self._validate_verification_code(code):
@@ -1660,6 +1720,7 @@ class RegistrationEngine:
if self._is_existing_account:
self._log("12. [已注册账号] 跳过创建用户账户")
else:
self._raise_if_cancelled()
self._log("12. 创建用户账户...")
self._emit_status("account_create", "创建 OpenAI 账户资料", step_index=12)
if not self._create_user_account():
@@ -1670,6 +1731,7 @@ class RegistrationEngine:
callback_url = None
if not self._is_existing_account:
self._raise_if_cancelled()
self._log(f"{next_step}. [新账号] 推进 Codex 授权流程...")
self._emit_status("oauth_reentry", "推进 Codex 授权流程", step_index=next_step)
workspace_id, callback_url = self._advance_login_authorization()
@@ -1679,6 +1741,7 @@ class RegistrationEngine:
if not result.workspace_id:
# 获取 Workspace ID
self._raise_if_cancelled()
self._log(f"{next_step}. 获取 Workspace ID...")
self._emit_status("workspace_extract", "从授权态提取 Workspace ID", step_index=next_step)
workspace_id = self._get_workspace_id()
@@ -1691,6 +1754,7 @@ class RegistrationEngine:
next_step += 1
# 选择 Workspace
self._raise_if_cancelled()
self._log(f"{next_step}. 选择 Workspace...")
self._emit_status("workspace_select", "选择 Workspace", step_index=next_step)
continue_url = self._select_workspace(result.workspace_id)
@@ -1701,6 +1765,7 @@ class RegistrationEngine:
next_step += 1
# 跟随重定向链
self._raise_if_cancelled()
self._log(f"{next_step}. 跟随重定向链...")
self._emit_status("redirect_chain", "跟随授权重定向链", step_index=next_step)
callback_url = self._follow_redirects(continue_url)
@@ -1711,6 +1776,7 @@ class RegistrationEngine:
next_step += 1
# 处理 OAuth 回调
self._raise_if_cancelled()
self._log(f"{next_step}. 处理 OAuth 回调...")
self._emit_status("oauth_callback", "处理 OAuth 回调", step_index=next_step)
token_info = self._handle_oauth_callback(callback_url)
@@ -1757,6 +1823,11 @@ class RegistrationEngine:
return result
except TaskCancelledError as e:
result.error_message = str(e)
result.error_code = getattr(e, "error_code", ERROR_TASK_CANCELLED)
return result
except Exception as e:
self._log(f"注册过程中发生未预期错误: {e}", "error")
result.error_message = str(e)

View File

@@ -9,7 +9,7 @@ import re
import time
from dataclasses import dataclass
from datetime import datetime
from typing import Optional, Dict, Any, List
from typing import Optional, Dict, Any, List, Callable
from enum import Enum
from ..config.constants import EmailServiceType, OPENAI_EMAIL_SENDERS, OTP_CODE_PATTERN, OTP_CODE_SEMANTIC_PATTERN
@@ -139,6 +139,10 @@ class OTPNoOpenAISenderEmailServiceError(EmailServiceError):
self.error_code = error_code
class EmailServiceCancelledError(EmailServiceError):
"""邮箱服务在轮询过程中收到取消信号。"""
class EmailServiceStatus(Enum):
"""邮箱服务状态"""
HEALTHY = "healthy"
@@ -168,6 +172,7 @@ class BaseEmailService(abc.ABC):
self._provider_backoff = reset_adaptive_backoff()
self._used_verification_codes: Dict[str, set] = {}
self._seen_verification_messages: Dict[str, set] = {}
self.check_cancelled: Optional[Callable[[], bool]] = None
_EMAIL_ADDRESS_PATTERN = re.compile(r"[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}")
@@ -190,6 +195,35 @@ class BaseEmailService(abc.ABC):
"""注入外部持久化的邮箱供应商退避状态"""
self._provider_backoff = state or reset_adaptive_backoff()
def set_check_cancelled(self, callback: Optional[Callable[[], bool]]) -> None:
"""注入外部取消检查回调。"""
self.check_cancelled = callback if callable(callback) else None
def _is_cancelled_requested(self) -> bool:
"""检查邮箱服务是否收到取消请求。"""
callback = self.check_cancelled
if not callable(callback):
return False
try:
return bool(callback())
except Exception as e:
logger.warning(f"检查邮箱服务取消状态失败: {e}")
return False
def _raise_if_cancelled(self, message: str = "任务已取消") -> None:
"""在轮询/等待阶段响应取消请求。"""
if self._is_cancelled_requested():
raise EmailServiceCancelledError(message)
def _sleep_with_cancel(self, seconds: float, chunk_seconds: float = 0.2) -> None:
"""可响应取消的短分片休眠。"""
remaining = max(0.0, float(seconds))
while remaining > 0:
self._raise_if_cancelled()
sleep_for = min(chunk_seconds, remaining)
time.sleep(sleep_for)
remaining -= sleep_for
@abc.abstractmethod
def create_email(self, config: Dict[str, Any] = None) -> Dict[str, Any]:
"""
@@ -520,6 +554,7 @@ class BaseEmailService(abc.ABC):
last_email_id = None
while time.time() - start_time < timeout:
self._raise_if_cancelled("等待邮件时任务已取消")
try:
emails = self.list_emails()
for email_info in emails:
@@ -562,7 +597,7 @@ class BaseEmailService(abc.ABC):
except Exception as e:
logger.warning(f"等待邮件时出错: {e}")
time.sleep(check_interval)
self._sleep_with_cancel(check_interval)
return None

View File

@@ -236,6 +236,7 @@ class CloudMailService(BaseEmailService):
seen_mail_ids: set = set()
while time.time() - start_time < timeout:
self._raise_if_cancelled("等待 Cloud Mail 验证码时任务已取消")
try:
token = self._get_public_token()
mails = self._make_request(
@@ -253,7 +254,7 @@ class CloudMailService(BaseEmailService):
mails = mails["list"]
if not isinstance(mails, list):
time.sleep(poll_interval)
self._sleep_with_cancel(poll_interval)
continue
if mails:
@@ -302,7 +303,7 @@ class CloudMailService(BaseEmailService):
raise
logger.debug(f"检查 Cloud Mail 邮件时出错: {e}")
time.sleep(poll_interval)
self._sleep_with_cancel(poll_interval)
logger.warning(f"等待 Cloud Mail 验证码超时: {email}")
return None

View File

@@ -263,6 +263,7 @@ class DuckMailService(BaseEmailService):
seen_message_ids = set()
while time.time() - start_time < timeout:
self._raise_if_cancelled("等待 DuckMail 验证码时任务已取消")
try:
response = self._make_request(
"GET",
@@ -324,7 +325,7 @@ class DuckMailService(BaseEmailService):
raise
logger.debug(f"DuckMail 轮询验证码失败: {e}")
time.sleep(poll_interval)
self._sleep_with_cancel(poll_interval)
return None

View File

@@ -214,10 +214,11 @@ class FreemailService(BaseEmailService):
seen_mail_ids: set = set()
while time.time() - start_time < timeout:
self._raise_if_cancelled("等待 Freemail 验证码时任务已取消")
try:
mails = self._make_request("GET", "/api/emails", params={"mailbox": email, "limit": 20})
if not isinstance(mails, list):
time.sleep(poll_interval)
self._sleep_with_cancel(poll_interval)
continue
ordered_mails = self._sort_items_by_message_time(
@@ -299,7 +300,7 @@ class FreemailService(BaseEmailService):
raise
logger.debug(f"检查 Freemail 邮件时出错: {e}")
time.sleep(poll_interval)
self._sleep_with_cancel(poll_interval)
logger.warning(f"等待 Freemail 验证码超时: {email}")
return None

View File

@@ -12,7 +12,7 @@ import logging
from email.header import decode_header
from typing import Any, Dict, Optional
from .base import BaseEmailService, EmailServiceError, OTPNoOpenAISenderEmailServiceError, get_email_code_settings
from .base import BaseEmailService, OTPNoOpenAISenderEmailServiceError, get_email_code_settings
from ..config.constants import (
EmailServiceType,
OTP_CODE_SEMANTIC_PATTERN,
@@ -124,11 +124,12 @@ class ImapMailService(BaseEmailService):
mail.select("INBOX")
while time.time() - start_time < timeout:
self._raise_if_cancelled("等待 IMAP 验证码时任务已取消")
try:
# 搜索所有未读邮件
status, data = mail.search(None, "UNSEEN")
if status != "OK" or not data or not data[0]:
time.sleep(poll_interval)
self._sleep_with_cancel(poll_interval)
continue
msg_ids = data[0].split()
@@ -181,13 +182,13 @@ class ImapMailService(BaseEmailService):
except Exception:
pass
time.sleep(poll_interval)
self._sleep_with_cancel(poll_interval)
except Exception as e:
if isinstance(e, OTPNoOpenAISenderEmailServiceError):
raise
logger.warning(f"IMAP 连接/轮询失败: {e}")
self.update_status(False, str(e))
self.update_status(False, e)
finally:
if mail:
try:

View File

@@ -308,13 +308,14 @@ class MeoMailEmailService(BaseEmailService):
seen_message_ids = set()
while time.time() - start_time < timeout:
self._raise_if_cancelled("等待自定义域名邮箱验证码时任务已取消")
try:
# 获取邮件列表
response = self._make_request("GET", f"/api/emails/{target_email_id}")
messages = response.get("messages", [])
if not isinstance(messages, list):
time.sleep(poll_interval)
self._sleep_with_cancel(poll_interval)
continue
ordered_messages = self._sort_items_by_message_time(
@@ -380,7 +381,7 @@ class MeoMailEmailService(BaseEmailService):
logger.debug(f"检查邮件时出错: {e}")
# 等待一段时间再检查
time.sleep(poll_interval)
self._sleep_with_cancel(poll_interval)
logger.warning(f"等待验证码超时: {email}")
return None

View File

@@ -337,6 +337,7 @@ class OutlookService(BaseEmailService):
poll_count = 0
while time.time() - start_time < actual_timeout:
self._raise_if_cancelled("等待 Outlook 验证码时任务已取消")
poll_count += 1
# 渐进式邮件检查:前 3 次只检查未读
@@ -387,7 +388,7 @@ class OutlookService(BaseEmailService):
logger.warning(f"[{email}] 检查出错: {e}")
# 等待下次轮询
time.sleep(poll_interval)
self._sleep_with_cancel(poll_interval)
elapsed = int(time.time() - start_time)
logger.warning(f"[{email}] 验证码超时 ({actual_timeout}s),共轮询 {poll_count}")

View File

@@ -309,6 +309,7 @@ class TempMailService(BaseEmailService):
# jwt = cached.get("jwt")
while time.time() - start_time < timeout:
self._raise_if_cancelled("等待 TempMail 验证码时任务已取消")
try:
# if jwt:
# response = self._make_request(
@@ -327,7 +328,7 @@ class TempMailService(BaseEmailService):
# /user_api/mails 和 /admin/mails 返回格式相同: {"results": [...], "total": N}
mails = response.get("results", [])
if not isinstance(mails, list):
time.sleep(poll_interval)
self._sleep_with_cancel(poll_interval)
continue
ordered_mails = self._sort_items_by_message_time(
@@ -389,7 +390,7 @@ class TempMailService(BaseEmailService):
raise
logger.debug(f"检查 TempMail 邮件时出错: {e}")
time.sleep(poll_interval)
self._sleep_with_cancel(poll_interval)
logger.warning(f"等待 TempMail 验证码超时: {email}")
return None

View File

@@ -217,6 +217,7 @@ class TempmailService(BaseEmailService):
seen_ids = set()
while time.time() - start_time < timeout:
self._raise_if_cancelled("等待 Tempmail 验证码时任务已取消")
try:
# 获取邮件列表
response = self.http_client.get(
@@ -226,7 +227,7 @@ class TempmailService(BaseEmailService):
)
if response.status_code != 200:
time.sleep(poll_interval)
self._sleep_with_cancel(poll_interval)
continue
data = response.json()
@@ -239,7 +240,7 @@ class TempmailService(BaseEmailService):
email_list = data.get("emails", []) if isinstance(data, dict) else []
if not isinstance(email_list, list):
time.sleep(poll_interval)
self._sleep_with_cancel(poll_interval)
continue
ordered_emails = self._sort_items_by_message_time(
@@ -302,7 +303,7 @@ class TempmailService(BaseEmailService):
logger.debug(f"检查邮件时出错: {e}")
# 等待一段时间再检查
time.sleep(poll_interval)
self._sleep_with_cancel(poll_interval)
logger.warning(f"等待验证码超时: {email}")
return None

View File

@@ -20,6 +20,7 @@ from ...database.session import get_db
from ...database.models import RegistrationTask, Proxy
from ...core.register import (
ERROR_OTP_TIMEOUT_SECONDARY,
ERROR_TASK_CANCELLED,
RegistrationEngine,
RegistrationResult,
)
@@ -395,6 +396,11 @@ def _run_registration_engine_attempt(
status_callback=status_callback,
task_uuid=task_uuid,
)
create_cancel_callback = getattr(task_manager, "create_check_cancelled_callback", None)
if callable(create_cancel_callback):
setattr(engine, "check_cancelled", create_cancel_callback(task_uuid))
else:
setattr(engine, "check_cancelled", lambda: False)
try:
result = engine.run()
@@ -435,6 +441,52 @@ def _get_batch_snapshot(batch_id: str) -> Optional[dict]:
return task_manager.get_batch_status(batch_id)
def _finalize_task_cancelled(
db,
task_uuid: str,
message: str = "任务已取消",
*,
email_service: Optional[str] = None,
):
"""将任务统一收敛为已取消状态。"""
crud.update_registration_task(
db,
task_uuid,
status="cancelled",
completed_at=datetime.utcnow(),
error_message=message,
)
status_kwargs = {"error": message, "cancelled": True}
if email_service:
status_kwargs["email_service"] = email_service
task_manager.update_status(task_uuid, "cancelled", **status_kwargs)
def _request_task_cancellation(db, task_uuid: str, message: str = "任务已取消"):
"""向运行中的任务发送取消信号,并持久化取消状态。"""
task_manager.cancel_task(task_uuid)
_finalize_task_cancelled(db, task_uuid, message)
def _request_batch_cancellation(batch_id: str, message: str = "批量任务取消请求已提交"):
"""向批量任务及其成员任务传播取消信号。"""
batch = _require_batch_snapshot(batch_id)
if batch.get("finished"):
raise HTTPException(status_code=400, detail="批量任务已完成")
task_manager.cancel_batch(batch_id)
task_manager.update_batch_status(batch_id, status="cancelled", cancelled=True)
with get_db() as db:
for task_uuid in batch.get("task_uuids", []):
task_manager.cancel_task(task_uuid)
task = crud.get_registration_task(db, task_uuid)
if task and task.status in ["pending", "running"]:
_finalize_task_cancelled(db, task_uuid, "任务已取消")
return {"success": True, "message": message}
def _require_batch_snapshot(batch_id: str) -> dict:
batch = _get_batch_snapshot(batch_id)
if batch is None:
@@ -560,8 +612,12 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy:
"""
with get_db() as db:
try:
def cancellation_requested() -> bool:
return task_manager.is_cancelled(task_uuid)
if task_manager.is_cancelled(task_uuid):
logger.info(f"任务 {task_uuid} 已取消,跳过执行")
_finalize_task_cancelled(db, task_uuid)
return
task = crud.update_registration_task(
@@ -583,6 +639,10 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy:
proxy_id = None
while True:
if cancellation_requested():
_finalize_task_cancelled(db, task_uuid, email_service=active_service_type.value)
return
actual_proxy_url = requested_proxy
proxy_id = None
if not actual_proxy_url:
@@ -605,6 +665,10 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy:
should_retry_with_new_proxy = False
for attempt_index, candidate in enumerate(service_candidates, start=1):
if cancellation_requested():
_finalize_task_cancelled(db, task_uuid, email_service=active_service_type.value)
return
selected_service_type = candidate["service_type"]
candidate_config = candidate["config"]
db_service = candidate.get("db_service")
@@ -630,6 +694,11 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy:
candidate_config,
name=db_service.name if db_service is not None else None,
)
create_cancel_callback = getattr(task_manager, "create_check_cancelled_callback", None)
if callable(create_cancel_callback):
set_cancel_callback = getattr(email_service, "set_check_cancelled", None)
if callable(set_cancel_callback):
set_cancel_callback(create_cancel_callback(task_uuid))
(
engine,
result,
@@ -645,6 +714,15 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy:
status_callback=status_callback,
)
if cancellation_requested() or result.error_code == ERROR_TASK_CANCELLED:
_finalize_task_cancelled(
db,
task_uuid,
result.error_message or "任务已取消",
email_service=active_service_type.value,
)
return
if result.success:
break
@@ -707,6 +785,10 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy:
break
if result.success:
if cancellation_requested():
_finalize_task_cancelled(db, task_uuid, "任务已取消", email_service=active_service_type.value)
return
# 更新代理使用时间
update_proxy_usage(db, proxy_id)
@@ -836,6 +918,10 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy:
except Exception as na_err:
log_callback(f"[NEWAPI] 上传异常: {na_err}")
if cancellation_requested():
_finalize_task_cancelled(db, task_uuid, "任务已取消", email_service=active_service_type.value)
return
# 更新任务状态
crud.update_registration_task(
db, task_uuid,
@@ -857,6 +943,15 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy:
logger.info(f"注册任务完成: {task_uuid}, 邮箱: {result.email}")
else:
if cancellation_requested() or result.error_code == ERROR_TASK_CANCELLED:
_finalize_task_cancelled(
db,
task_uuid,
result.error_message or "任务已取消",
email_service=active_service_type.value,
)
return
# 更新任务状态为失败
crud.update_registration_task(
db, task_uuid,
@@ -880,15 +975,18 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy:
try:
with get_db() as db:
crud.update_registration_task(
db, task_uuid,
status="failed",
completed_at=datetime.utcnow(),
error_message=str(e)
)
if task_manager.is_cancelled(task_uuid):
_finalize_task_cancelled(db, task_uuid)
else:
crud.update_registration_task(
db, task_uuid,
status="failed",
completed_at=datetime.utcnow(),
error_message=str(e)
)
# 更新 TaskManager 状态
task_manager.update_status(task_uuid, "failed", error=str(e))
# 更新 TaskManager 状态
task_manager.update_status(task_uuid, "failed", error=str(e))
except:
pass
@@ -904,6 +1002,11 @@ async def run_registration_task(task_uuid: str, email_service_type: str, proxy:
loop = asyncio.get_event_loop()
task_manager.set_loop(loop)
if task_manager.is_cancelled(task_uuid):
with get_db() as db:
_finalize_task_cancelled(db, task_uuid)
return
# 初始化 TaskManager 状态
task_manager.update_status(task_uuid, "pending", email_service=email_service_type)
task_manager.add_log(task_uuid, f"{log_prefix} [系统] 任务 {task_uuid[:8]} 已加入队列" if log_prefix else f"[系统] 任务 {task_uuid[:8]} 已加入队列")
@@ -1635,12 +1738,7 @@ async def get_batch_status(batch_id: str):
@router.post("/batch/{batch_id}/cancel")
async def cancel_batch(batch_id: str):
"""取消批量任务"""
batch = _require_batch_snapshot(batch_id)
if batch.get("finished"):
raise HTTPException(status_code=400, detail="批量任务已完成")
task_manager.cancel_batch(batch_id)
return {"success": True, "message": "批量任务取消请求已提交"}
return _request_batch_cancellation(batch_id)
@router.get("/tasks", response_model=TaskListResponse)
@@ -1703,7 +1801,7 @@ async def cancel_task(task_uuid: str):
if task.status not in ["pending", "running"]:
raise HTTPException(status_code=400, detail="任务已完成或已取消")
task = crud.update_registration_task(db, task_uuid, status="cancelled")
_request_task_cancellation(db, task_uuid)
return {"success": True, "message": "任务已取消"}
@@ -2207,10 +2305,4 @@ async def get_outlook_batch_status(batch_id: str):
@router.post("/outlook-batch/{batch_id}/cancel")
async def cancel_outlook_batch(batch_id: str):
"""取消 Outlook 批量任务"""
batch = _require_batch_snapshot(batch_id)
if batch.get("finished"):
raise HTTPException(status_code=400, detail="批量任务已完成")
task_manager.cancel_batch(batch_id)
return {"success": True, "message": "批量任务取消请求已提交"}
return _request_batch_cancellation(batch_id)

View File

@@ -86,7 +86,14 @@ async def task_websocket(websocket: WebSocket, task_uuid: str):
# 处理取消请求
elif data.get("type") == "cancel":
task_manager.cancel_task(task_uuid)
from . import registration as registration_routes
with get_db() as db:
task = crud.get_registration_task(db, task_uuid)
if task and task.status in ["pending", "running"]:
registration_routes._request_task_cancellation(db, task_uuid)
else:
task_manager.cancel_task(task_uuid)
await websocket.send_json({
"type": "status",
"task_uuid": task_uuid,
@@ -163,7 +170,9 @@ async def batch_websocket(websocket: WebSocket, batch_id: str):
# 处理取消请求
elif data.get("type") == "cancel":
task_manager.cancel_batch(batch_id)
from . import registration as registration_routes
registration_routes._request_batch_cancellation(batch_id)
await websocket.send_json({
"type": "status",
"batch_id": batch_id,

View File

@@ -1,9 +1,10 @@
const test = require('node:test');
const assert = require('node:assert/strict');
const fs = require('node:fs');
const path = require('node:path');
const vm = require('node:vm');
const APP_JS_PATH = '/Users/zhoukailian/.config/superpowers/worktrees/codex-manager/repro-batch-monitor/static/js/app.js';
const APP_JS_PATH = path.join(__dirname, '..', 'static', 'js', 'app.js');
function createElementStub() {
return {

View File

@@ -0,0 +1,363 @@
import asyncio
import threading
import time
from contextlib import contextmanager
from pathlib import Path
from types import SimpleNamespace
from src.core.register import ERROR_TASK_CANCELLED, RegistrationResult
from src.database.models import Base, RegistrationTask
from src.database.session import DatabaseSessionManager
from src.services import EmailServiceType
from src.services.base import BaseEmailService, EmailServiceCancelledError
from src.web.routes import registration as registration_routes
class FakeTaskManager:
def __init__(self):
self.cancelled = set()
self.status_updates = []
self.logs = {}
self.batch_status = {}
def is_cancelled(self, task_uuid):
return task_uuid in self.cancelled
def cancel_task(self, task_uuid):
self.cancelled.add(task_uuid)
def update_status(self, task_uuid, status, **kwargs):
self.status_updates.append((task_uuid, status, kwargs))
def create_log_callback(self, task_uuid, prefix="", batch_id=""):
def callback(message):
full_message = f"{prefix} {message}" if prefix else message
self.logs.setdefault(task_uuid, []).append(full_message)
return callback
def create_check_cancelled_callback(self, task_uuid):
return lambda: self.is_cancelled(task_uuid)
def get_batch_status(self, batch_id):
snapshot = self.batch_status.get(batch_id)
return dict(snapshot) if snapshot else None
def cancel_batch(self, batch_id):
snapshot = self.batch_status.setdefault(batch_id, {})
snapshot["cancelled"] = True
snapshot["status"] = "cancelling"
def update_batch_status(self, batch_id, **kwargs):
snapshot = self.batch_status.setdefault(batch_id, {})
snapshot.update(kwargs)
class DummySettings:
proxy_dynamic_enabled = False
proxy_dynamic_api_url = ""
email_code_timeout = 10
email_code_poll_interval = 1
email_code_resend_max_retries = 0
email_code_non_openai_sender_resend_max_retries = 0
openai_client_id = "client-id"
openai_auth_url = "https://auth.example.test"
openai_token_url = "https://token.example.test"
openai_redirect_uri = "https://callback.example.test"
openai_scope = "openid profile email"
def get_proxy_url(self):
return None
def _build_fake_get_db(manager):
@contextmanager
def fake_get_db():
with manager.session_scope() as session:
yield session
return fake_get_db
class FakeRegistrationEngine:
started_event = None
def __init__(self, email_service, proxy_url=None, callback_logger=None, status_callback=None, task_uuid=None):
self.email_service = email_service
self.phase_history = []
self.check_cancelled = None
self.callback_logger = callback_logger or (lambda _msg: None)
self.task_uuid = task_uuid
def run(self):
if self.started_event is not None:
self.started_event.set()
while True:
if callable(self.check_cancelled) and self.check_cancelled():
return RegistrationResult(
success=False,
error_message="任务已取消",
error_code=ERROR_TASK_CANCELLED,
logs=[],
)
time.sleep(0.01)
def save_to_database(self, result):
return False
def close(self):
return None
class FakePollingEmailService(BaseEmailService):
def __init__(self, started_event=None):
super().__init__(EmailServiceType.TEMPMAIL, "fake-polling-email")
self.started_event = started_event
def create_email(self, config=None):
return {"email": "poll@example.test", "service_id": "poll-service"}
def get_verification_code(self, email: str, email_id: str = None, timeout: int = 120, pattern: str = r"(?<!\d)(\d{6})(?!\d)", otp_sent_at=None):
if self.started_event is not None:
self.started_event.set()
while True:
self._raise_if_cancelled("任务已取消")
self._sleep_with_cancel(0.05)
def list_emails(self, **kwargs):
return []
def delete_email(self, email_id: str):
return True
def check_health(self):
return True
class FakeEmailPollingRegistrationEngine:
def __init__(self, email_service, proxy_url=None, callback_logger=None, status_callback=None, task_uuid=None):
self.email_service = email_service
self.phase_history = []
self.check_cancelled = None
self.callback_logger = callback_logger or (lambda _msg: None)
self.task_uuid = task_uuid
def run(self):
try:
self.email_service.get_verification_code("poll@example.test", timeout=60)
except EmailServiceCancelledError as exc:
return RegistrationResult(
success=False,
error_message=str(exc),
error_code=ERROR_TASK_CANCELLED,
logs=[],
)
return RegistrationResult(success=False, error_message="邮箱轮询未被取消", logs=[])
def save_to_database(self, result):
return False
def close(self):
return None
def test_cancel_task_route_marks_task_manager_and_db_cancelled(monkeypatch):
runtime_dir = Path("tests_runtime")
runtime_dir.mkdir(exist_ok=True)
db_path = runtime_dir / "registration_cancel_route.db"
if db_path.exists():
db_path.unlink()
manager = DatabaseSessionManager(f"sqlite:///{db_path}")
Base.metadata.create_all(bind=manager.engine)
task_uuid = "task-cancel-route"
with manager.session_scope() as session:
session.add(RegistrationTask(task_uuid=task_uuid, status="running"))
fake_task_manager = FakeTaskManager()
monkeypatch.setattr(registration_routes, "get_db", _build_fake_get_db(manager))
monkeypatch.setattr(registration_routes, "task_manager", fake_task_manager)
response = asyncio.run(registration_routes.cancel_task(task_uuid))
assert response == {"success": True, "message": "任务已取消"}
assert task_uuid in fake_task_manager.cancelled
with manager.session_scope() as session:
task = session.query(RegistrationTask).filter(RegistrationTask.task_uuid == task_uuid).first()
assert task.status == "cancelled"
assert task.error_message == "任务已取消"
def test_run_sync_registration_task_stops_after_cancel_request(monkeypatch):
runtime_dir = Path("tests_runtime")
runtime_dir.mkdir(exist_ok=True)
db_path = runtime_dir / "registration_cancel_runtime.db"
if db_path.exists():
db_path.unlink()
manager = DatabaseSessionManager(f"sqlite:///{db_path}")
Base.metadata.create_all(bind=manager.engine)
task_uuid = "task-cancel-runtime"
with manager.session_scope() as session:
session.add(RegistrationTask(task_uuid=task_uuid, status="pending"))
fake_task_manager = FakeTaskManager()
start_event = threading.Event()
FakeRegistrationEngine.started_event = start_event
monkeypatch.setattr(registration_routes, "get_db", _build_fake_get_db(manager))
monkeypatch.setattr(registration_routes, "task_manager", fake_task_manager)
monkeypatch.setattr(registration_routes, "get_settings", lambda: DummySettings())
monkeypatch.setattr(registration_routes, "RegistrationEngine", FakeRegistrationEngine)
monkeypatch.setattr(
registration_routes,
"_build_email_service_candidates",
lambda db, service_type, actual_proxy_url, email_service_id, email_service_config: [
{
"service_type": EmailServiceType.TEMPMAIL,
"config": {"proxy_url": actual_proxy_url},
"db_service": None,
}
],
)
monkeypatch.setattr(
registration_routes.EmailServiceFactory,
"create",
lambda service_type, config, name=None: SimpleNamespace(
service_type=service_type,
name=name or service_type.value,
config=config,
),
)
worker = threading.Thread(
target=registration_routes._run_sync_registration_task,
kwargs={
"task_uuid": task_uuid,
"email_service_type": EmailServiceType.TEMPMAIL.value,
"proxy": None,
"email_service_config": {},
},
)
worker.start()
assert start_event.wait(timeout=1.0), "registration engine did not start in time"
response = asyncio.run(registration_routes.cancel_task(task_uuid))
assert response == {"success": True, "message": "任务已取消"}
worker.join(timeout=2.0)
assert not worker.is_alive(), "registration worker should stop after cancellation"
with manager.session_scope() as session:
task = session.query(RegistrationTask).filter(RegistrationTask.task_uuid == task_uuid).first()
assert task.status == "cancelled"
assert task.error_message == "任务已取消"
statuses = [status for current_uuid, status, _kwargs in fake_task_manager.status_updates if current_uuid == task_uuid]
assert "cancelled" in statuses
def test_cancel_batch_propagates_to_member_tasks(monkeypatch):
runtime_dir = Path("tests_runtime")
runtime_dir.mkdir(exist_ok=True)
db_path = runtime_dir / "registration_cancel_batch.db"
if db_path.exists():
db_path.unlink()
manager = DatabaseSessionManager(f"sqlite:///{db_path}")
Base.metadata.create_all(bind=manager.engine)
task_uuids = ["batch-cancel-1", "batch-cancel-2"]
with manager.session_scope() as session:
session.add_all([
RegistrationTask(task_uuid=task_uuids[0], status="running"),
RegistrationTask(task_uuid=task_uuids[1], status="pending"),
])
fake_task_manager = FakeTaskManager()
fake_task_manager.batch_status["batch-1"] = {
"finished": False,
"cancelled": False,
"task_uuids": task_uuids,
}
monkeypatch.setattr(registration_routes, "get_db", _build_fake_get_db(manager))
monkeypatch.setattr(registration_routes, "task_manager", fake_task_manager)
response = asyncio.run(registration_routes.cancel_batch("batch-1"))
assert response["success"] is True
assert fake_task_manager.batch_status["batch-1"]["cancelled"] is True
assert fake_task_manager.cancelled == set(task_uuids)
with manager.session_scope() as session:
tasks = session.query(RegistrationTask).order_by(RegistrationTask.task_uuid.asc()).all()
assert [task.status for task in tasks] == ["cancelled", "cancelled"]
def test_run_sync_registration_task_stops_while_email_service_polling(monkeypatch):
runtime_dir = Path("tests_runtime")
runtime_dir.mkdir(exist_ok=True)
db_path = runtime_dir / "registration_cancel_email_polling.db"
if db_path.exists():
db_path.unlink()
manager = DatabaseSessionManager(f"sqlite:///{db_path}")
Base.metadata.create_all(bind=manager.engine)
task_uuid = "task-cancel-email-polling"
with manager.session_scope() as session:
session.add(RegistrationTask(task_uuid=task_uuid, status="pending"))
fake_task_manager = FakeTaskManager()
start_event = threading.Event()
monkeypatch.setattr(registration_routes, "get_db", _build_fake_get_db(manager))
monkeypatch.setattr(registration_routes, "task_manager", fake_task_manager)
monkeypatch.setattr(registration_routes, "get_settings", lambda: DummySettings())
monkeypatch.setattr(registration_routes, "RegistrationEngine", FakeEmailPollingRegistrationEngine)
monkeypatch.setattr(
registration_routes,
"_build_email_service_candidates",
lambda db, service_type, actual_proxy_url, email_service_id, email_service_config: [
{
"service_type": EmailServiceType.TEMPMAIL,
"config": {"proxy_url": actual_proxy_url},
"db_service": None,
}
],
)
monkeypatch.setattr(
registration_routes.EmailServiceFactory,
"create",
lambda service_type, config, name=None: FakePollingEmailService(start_event),
)
worker = threading.Thread(
target=registration_routes._run_sync_registration_task,
kwargs={
"task_uuid": task_uuid,
"email_service_type": EmailServiceType.TEMPMAIL.value,
"proxy": None,
"email_service_config": {},
},
)
worker.start()
assert start_event.wait(timeout=1.0), "email polling did not start in time"
response = asyncio.run(registration_routes.cancel_task(task_uuid))
assert response == {"success": True, "message": "任务已取消"}
worker.join(timeout=2.0)
assert not worker.is_alive(), "registration worker should stop while email service is polling"
with manager.session_scope() as session:
task = session.query(RegistrationTask).filter(RegistrationTask.task_uuid == task_uuid).first()
assert task.status == "cancelled"
assert task.error_message == "任务已取消"

View File

@@ -1,9 +1,10 @@
const test = require('node:test');
const assert = require('node:assert/strict');
const fs = require('node:fs');
const path = require('node:path');
const vm = require('node:vm');
const APP_JS_PATH = '/Users/zhoukailian/.config/superpowers/worktrees/codex-manager/repro-batch-monitor/static/js/app.js';
const APP_JS_PATH = path.join(__dirname, '..', 'static', 'js', 'app.js');
function createElementStub() {
return {