From 744b1f4c1b49ac92c8af87fceb67968f9a723567 Mon Sep 17 00:00:00 2001 From: cnlimiter Date: Sat, 28 Mar 2026 14:11:00 +0800 Subject: [PATCH] feat(email): implement cancellation handling in email services --- src/core/register.py | 183 +++++++--- src/services/base.py | 39 +- src/services/cloud_mail.py | 5 +- src/services/duck_mail.py | 3 +- src/services/freemail.py | 5 +- src/services/imap_mail.py | 9 +- src/services/moe_mail.py | 5 +- src/services/outlook/service.py | 3 +- src/services/temp_mail.py | 5 +- src/services/tempmail.py | 7 +- src/web/routes/registration.py | 136 +++++-- src/web/routes/websocket.py | 13 +- tests/test_batch_websocket_fallback.cjs | 3 +- tests/test_registration_task_cancellation.py | 363 +++++++++++++++++++ tests/test_single_task_websocket_status.cjs | 3 +- 15 files changed, 681 insertions(+), 101 deletions(-) create mode 100644 tests/test_registration_task_cancellation.py diff --git a/src/core/register.py b/src/core/register.py index cd8b27e..69e3f34 100644 --- a/src/core/register.py +++ b/src/core/register.py @@ -31,6 +31,7 @@ from ..database import crud from ..database.session import get_db from ..services import BaseEmailService from ..services.base import ( + EmailServiceCancelledError, EmailProviderBackoffState, OTP_NO_OPENAI_SENDER_ERROR, OTPNoOpenAISenderEmailServiceError, @@ -43,6 +44,15 @@ PHASE_EMAIL_PREPARE = "email_prepare" PHASE_OTP_SECONDARY = "otp_secondary" ERROR_EMAIL_PROVIDER_RATE_LIMITED = "EMAIL_PROVIDER_RATE_LIMITED" ERROR_OTP_TIMEOUT_SECONDARY = "OTP_TIMEOUT_SECONDARY" +ERROR_TASK_CANCELLED = "TASK_CANCELLED" + + +class TaskCancelledError(Exception): + """注册任务被主动取消。""" + + def __init__(self, message: str = "任务已取消"): + super().__init__(message) + self.error_code = ERROR_TASK_CANCELLED @dataclass @@ -185,6 +195,8 @@ class RegistrationEngine: self._otp_sent_at: Optional[float] = None # OTP 发送时间戳 self._is_existing_account: bool = False # 是否为已注册账号(用于自动登录) self.phase_history: list[PhaseResult] = [] + self.check_cancelled: Optional[Callable[[], bool]] = None + self._cancel_logged = False def _log(self, message: str, level: str = "info"): """记录日志""" @@ -274,6 +286,35 @@ class RegistrationEngine: self.phase_history.append(phase_result) return phase_result + def _is_cancelled_requested(self) -> bool: + """检查是否收到外部取消信号。""" + callback = self.check_cancelled + if not callable(callback): + return False + try: + return bool(callback()) + except Exception as e: + logger.warning(f"检查任务取消状态失败: {e}") + return False + + def _raise_if_cancelled(self, message: str = "任务已取消"): + """在可协作阶段检查取消请求。""" + if not self._is_cancelled_requested(): + return + if not self._cancel_logged: + self._log(message, "warning") + self._cancel_logged = True + raise TaskCancelledError(message) + + def _sleep_with_cancel(self, seconds: float, chunk_seconds: float = 0.2): + """可响应取消的短分片休眠。""" + remaining = max(0.0, float(seconds)) + while remaining > 0: + self._raise_if_cancelled() + sleep_for = min(chunk_seconds, remaining) + time.sleep(sleep_for) + remaining -= sleep_for + def _get_phase_result(self, phase_name: str) -> Optional[PhaseResult]: for phase_result in reversed(self.phase_history): if phase_result.phase == phase_name: @@ -427,7 +468,7 @@ class RegistrationEngine: ) if attempt < max_attempts: - time.sleep(attempt) + self._sleep_with_cancel(attempt) self.http_client.close() self.session = self.http_client.session @@ -656,6 +697,7 @@ class RegistrationEngine: otp_phase: Optional[PhaseResult] = None while True: + self._raise_if_cancelled("等待验证码重试时任务已取消") code, otp_phase = self._phase_otp_secondary( PhaseContext(otp_sent_at=self._otp_sent_at), started_at=otp_phase_started_at, @@ -698,6 +740,7 @@ class RegistrationEngine: **emit_kwargs, ) + self._raise_if_cancelled("等待验证码重试时任务已取消") if not resend_callback(): self._log("重新发送验证码失败,跳过本次重试", "warning") @@ -712,76 +755,79 @@ class RegistrationEngine: ) -> Tuple[Optional[str], PhaseResult]: """等待二次验证码邮件并做超时归因。""" try: + self._raise_if_cancelled("等待验证码时任务已取消") self._log(f"正在等待邮箱 {self.email} 的验证码...") email_id = self.email_info.get("service_id") if self.email_info else None + settings = get_settings() budget = Budget( - timeout_seconds=get_settings().email_code_timeout, + timeout_seconds=settings.email_code_timeout, started_at=started_at if started_at is not None else time.time(), ) - remaining_timeout = budget.remaining_seconds() + poll_interval = max(1, int(settings.email_code_poll_interval or 1)) - if remaining_timeout <= 0: + while True: + self._raise_if_cancelled("等待验证码时任务已取消") + remaining_timeout = budget.remaining_seconds() + + if remaining_timeout <= 0: + phase_result = self._record_phase_result( + PhaseResult( + phase=PHASE_OTP_SECONDARY, + success=False, + error_message="等待验证码超时", + error_code=ERROR_OTP_TIMEOUT_SECONDARY, + retryable=True, + next_action="await_email", + metadata={ + "budget_started_at": budget.started_at, + "budget_timeout_seconds": budget.timeout_seconds, + "otp_sent_at": context.otp_sent_at, + }, + ) + ) + self._log(phase_result.error_message, "error") + return None, phase_result + + attempt_timeout = max(1, min(remaining_timeout, poll_interval)) + code = self.email_service.get_verification_code( + email=self.email, + email_id=email_id, + timeout=attempt_timeout, + pattern=OTP_CODE_PATTERN, + otp_sent_at=context.otp_sent_at, + ) + + if code: + self._log(f"成功获取验证码: {code}") + phase_result = self._record_phase_result( + PhaseResult( + phase=PHASE_OTP_SECONDARY, + success=True, + metadata={ + "budget_started_at": budget.started_at, + "budget_timeout_seconds": budget.timeout_seconds, + "otp_sent_at": context.otp_sent_at, + }, + ) + ) + return code, phase_result + + except Exception as e: + if isinstance(e, (TaskCancelledError, EmailServiceCancelledError)): phase_result = self._record_phase_result( PhaseResult( phase=PHASE_OTP_SECONDARY, success=False, - error_message="等待验证码超时", - error_code=ERROR_OTP_TIMEOUT_SECONDARY, - retryable=True, - next_action="await_email", - metadata={ - "budget_started_at": budget.started_at, - "budget_timeout_seconds": budget.timeout_seconds, - "otp_sent_at": context.otp_sent_at, - }, + error_message=str(e), + error_code=getattr(e, "error_code", ERROR_TASK_CANCELLED), + retryable=False, + next_action="cancelled", + metadata={"otp_sent_at": context.otp_sent_at}, ) ) - self._log(phase_result.error_message, "error") return None, phase_result - code = self.email_service.get_verification_code( - email=self.email, - email_id=email_id, - timeout=remaining_timeout, - pattern=OTP_CODE_PATTERN, - otp_sent_at=context.otp_sent_at, - ) - - if code: - self._log(f"成功获取验证码: {code}") - phase_result = self._record_phase_result( - PhaseResult( - phase=PHASE_OTP_SECONDARY, - success=True, - metadata={ - "budget_started_at": budget.started_at, - "budget_timeout_seconds": budget.timeout_seconds, - "otp_sent_at": context.otp_sent_at, - }, - ) - ) - return code, phase_result - - phase_result = self._record_phase_result( - PhaseResult( - phase=PHASE_OTP_SECONDARY, - success=False, - error_message="等待验证码超时", - error_code=ERROR_OTP_TIMEOUT_SECONDARY, - retryable=True, - next_action="await_email", - metadata={ - "budget_started_at": budget.started_at, - "budget_timeout_seconds": budget.timeout_seconds, - "otp_sent_at": context.otp_sent_at, - }, - ) - ) - self._log(phase_result.error_message, "error") - return None, phase_result - - except Exception as e: if isinstance(e, OTPNoOpenAISenderEmailServiceError): self._log(str(e), "warning") phase_result = self._record_phase_result( @@ -1404,6 +1450,8 @@ class RegistrationEngine: non_openai_retry_status_template="重新触发登录验证码(非 OpenAI 发件人,第 {attempt} 次)", ) if not code: + if otp_phase and otp_phase.error_code == ERROR_TASK_CANCELLED: + raise TaskCancelledError(otp_phase.error_message or "登录流程已取消") self._log( otp_phase.error_message if otp_phase and otp_phase.error_message else "登录流程获取验证码失败", "warning", @@ -1539,11 +1587,13 @@ class RegistrationEngine: result = RegistrationResult(success=False, logs=self.logs) try: + self._raise_if_cancelled() self._log("=" * 60) self._log("开始注册流程") self._log("=" * 60) # 1. 检查 IP 地理位置 + self._raise_if_cancelled() self._log("1. 检查 IP 地理位置...") self._emit_status("ip_check", "检查 IP 地理位置", step_index=1) ip_ok, location = self._check_ip_location() @@ -1555,6 +1605,7 @@ class RegistrationEngine: self._log(f"IP 位置: {location}") # 2. 创建邮箱 + self._raise_if_cancelled() self._log("2. 创建邮箱...") self._emit_status("email_prepare", "创建邮箱地址", step_index=2) if not self._phase_email_prepare(): @@ -1570,6 +1621,7 @@ class RegistrationEngine: result.email = self.email # 3. 初始化会话 + self._raise_if_cancelled() self._log("3. 初始化会话...") self._emit_status("session_init", "初始化 HTTP 会话", step_index=3) if not self._init_session(): @@ -1577,6 +1629,7 @@ class RegistrationEngine: return result # 4. 开始 OAuth 流程 + self._raise_if_cancelled() self._log("4. 开始 OAuth 授权流程...") self._emit_status("oauth_start", "开始 OAuth 授权流程", step_index=4) if not self._start_oauth(): @@ -1584,6 +1637,7 @@ class RegistrationEngine: return result # 5. 获取 Device ID + self._raise_if_cancelled() self._log("5. 获取 Device ID...") self._emit_status("oauth_device_id", "获取 Device ID", step_index=5) did = self._get_device_id() @@ -1592,6 +1646,7 @@ class RegistrationEngine: return result # 6. 检查 Sentinel 拦截 + self._raise_if_cancelled() self._log("6. 检查 Sentinel 拦截...") self._emit_status("sentinel", "检查 Sentinel 拦截", step_index=6) sen_token = self._check_sentinel(did) @@ -1601,6 +1656,7 @@ class RegistrationEngine: self._log("Sentinel 检查失败或未启用", "warning") # 7. 提交注册表单 + 解析响应判断账号状态 + self._raise_if_cancelled() self._log("7. 提交注册表单...") self._emit_status("signup_submit", "提交注册表单", step_index=7) signup_result = self._submit_signup_form(did, sen_token) @@ -1612,6 +1668,7 @@ class RegistrationEngine: if self._is_existing_account: self._log("8. [已注册账号] 跳过密码设置,OTP 已自动发送") else: + self._raise_if_cancelled() self._log("8. 注册密码...") self._emit_status("signup_password", "提交注册密码", step_index=8) password_ok, password = self._register_password() @@ -1625,6 +1682,7 @@ class RegistrationEngine: # 已注册账号的 OTP 在提交表单时已自动发送,记录时间戳 self._otp_sent_at = time.time() else: + self._raise_if_cancelled() self._log("9. 发送验证码...") self._emit_status("otp_send", "发送验证码", step_index=9) if not self._send_verification_code(): @@ -1632,6 +1690,7 @@ class RegistrationEngine: return result # 10. 获取验证码(支持重发重试) + self._raise_if_cancelled() self._log("10. 等待验证码...") self._emit_status("otp_secondary", "等待验证码邮件", step_index=10) code, otp_phase = self._await_verification_code_with_resends( @@ -1650,6 +1709,7 @@ class RegistrationEngine: return result # 11. 验证验证码 + self._raise_if_cancelled() self._log("11. 验证验证码...") self._emit_status("otp_validate", "校验验证码", step_index=11) if not self._validate_verification_code(code): @@ -1660,6 +1720,7 @@ class RegistrationEngine: if self._is_existing_account: self._log("12. [已注册账号] 跳过创建用户账户") else: + self._raise_if_cancelled() self._log("12. 创建用户账户...") self._emit_status("account_create", "创建 OpenAI 账户资料", step_index=12) if not self._create_user_account(): @@ -1670,6 +1731,7 @@ class RegistrationEngine: callback_url = None if not self._is_existing_account: + self._raise_if_cancelled() self._log(f"{next_step}. [新账号] 推进 Codex 授权流程...") self._emit_status("oauth_reentry", "推进 Codex 授权流程", step_index=next_step) workspace_id, callback_url = self._advance_login_authorization() @@ -1679,6 +1741,7 @@ class RegistrationEngine: if not result.workspace_id: # 获取 Workspace ID + self._raise_if_cancelled() self._log(f"{next_step}. 获取 Workspace ID...") self._emit_status("workspace_extract", "从授权态提取 Workspace ID", step_index=next_step) workspace_id = self._get_workspace_id() @@ -1691,6 +1754,7 @@ class RegistrationEngine: next_step += 1 # 选择 Workspace + self._raise_if_cancelled() self._log(f"{next_step}. 选择 Workspace...") self._emit_status("workspace_select", "选择 Workspace", step_index=next_step) continue_url = self._select_workspace(result.workspace_id) @@ -1701,6 +1765,7 @@ class RegistrationEngine: next_step += 1 # 跟随重定向链 + self._raise_if_cancelled() self._log(f"{next_step}. 跟随重定向链...") self._emit_status("redirect_chain", "跟随授权重定向链", step_index=next_step) callback_url = self._follow_redirects(continue_url) @@ -1711,6 +1776,7 @@ class RegistrationEngine: next_step += 1 # 处理 OAuth 回调 + self._raise_if_cancelled() self._log(f"{next_step}. 处理 OAuth 回调...") self._emit_status("oauth_callback", "处理 OAuth 回调", step_index=next_step) token_info = self._handle_oauth_callback(callback_url) @@ -1757,6 +1823,11 @@ class RegistrationEngine: return result + except TaskCancelledError as e: + result.error_message = str(e) + result.error_code = getattr(e, "error_code", ERROR_TASK_CANCELLED) + return result + except Exception as e: self._log(f"注册过程中发生未预期错误: {e}", "error") result.error_message = str(e) diff --git a/src/services/base.py b/src/services/base.py index ca58866..bba62a0 100644 --- a/src/services/base.py +++ b/src/services/base.py @@ -9,7 +9,7 @@ import re import time from dataclasses import dataclass from datetime import datetime -from typing import Optional, Dict, Any, List +from typing import Optional, Dict, Any, List, Callable from enum import Enum from ..config.constants import EmailServiceType, OPENAI_EMAIL_SENDERS, OTP_CODE_PATTERN, OTP_CODE_SEMANTIC_PATTERN @@ -139,6 +139,10 @@ class OTPNoOpenAISenderEmailServiceError(EmailServiceError): self.error_code = error_code +class EmailServiceCancelledError(EmailServiceError): + """邮箱服务在轮询过程中收到取消信号。""" + + class EmailServiceStatus(Enum): """邮箱服务状态""" HEALTHY = "healthy" @@ -168,6 +172,7 @@ class BaseEmailService(abc.ABC): self._provider_backoff = reset_adaptive_backoff() self._used_verification_codes: Dict[str, set] = {} self._seen_verification_messages: Dict[str, set] = {} + self.check_cancelled: Optional[Callable[[], bool]] = None _EMAIL_ADDRESS_PATTERN = re.compile(r"[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}") @@ -190,6 +195,35 @@ class BaseEmailService(abc.ABC): """注入外部持久化的邮箱供应商退避状态""" self._provider_backoff = state or reset_adaptive_backoff() + def set_check_cancelled(self, callback: Optional[Callable[[], bool]]) -> None: + """注入外部取消检查回调。""" + self.check_cancelled = callback if callable(callback) else None + + def _is_cancelled_requested(self) -> bool: + """检查邮箱服务是否收到取消请求。""" + callback = self.check_cancelled + if not callable(callback): + return False + try: + return bool(callback()) + except Exception as e: + logger.warning(f"检查邮箱服务取消状态失败: {e}") + return False + + def _raise_if_cancelled(self, message: str = "任务已取消") -> None: + """在轮询/等待阶段响应取消请求。""" + if self._is_cancelled_requested(): + raise EmailServiceCancelledError(message) + + def _sleep_with_cancel(self, seconds: float, chunk_seconds: float = 0.2) -> None: + """可响应取消的短分片休眠。""" + remaining = max(0.0, float(seconds)) + while remaining > 0: + self._raise_if_cancelled() + sleep_for = min(chunk_seconds, remaining) + time.sleep(sleep_for) + remaining -= sleep_for + @abc.abstractmethod def create_email(self, config: Dict[str, Any] = None) -> Dict[str, Any]: """ @@ -520,6 +554,7 @@ class BaseEmailService(abc.ABC): last_email_id = None while time.time() - start_time < timeout: + self._raise_if_cancelled("等待邮件时任务已取消") try: emails = self.list_emails() for email_info in emails: @@ -562,7 +597,7 @@ class BaseEmailService(abc.ABC): except Exception as e: logger.warning(f"等待邮件时出错: {e}") - time.sleep(check_interval) + self._sleep_with_cancel(check_interval) return None diff --git a/src/services/cloud_mail.py b/src/services/cloud_mail.py index 9e595b2..2badc68 100644 --- a/src/services/cloud_mail.py +++ b/src/services/cloud_mail.py @@ -236,6 +236,7 @@ class CloudMailService(BaseEmailService): seen_mail_ids: set = set() while time.time() - start_time < timeout: + self._raise_if_cancelled("等待 Cloud Mail 验证码时任务已取消") try: token = self._get_public_token() mails = self._make_request( @@ -253,7 +254,7 @@ class CloudMailService(BaseEmailService): mails = mails["list"] if not isinstance(mails, list): - time.sleep(poll_interval) + self._sleep_with_cancel(poll_interval) continue if mails: @@ -302,7 +303,7 @@ class CloudMailService(BaseEmailService): raise logger.debug(f"检查 Cloud Mail 邮件时出错: {e}") - time.sleep(poll_interval) + self._sleep_with_cancel(poll_interval) logger.warning(f"等待 Cloud Mail 验证码超时: {email}") return None diff --git a/src/services/duck_mail.py b/src/services/duck_mail.py index a28e853..93a3e3d 100644 --- a/src/services/duck_mail.py +++ b/src/services/duck_mail.py @@ -263,6 +263,7 @@ class DuckMailService(BaseEmailService): seen_message_ids = set() while time.time() - start_time < timeout: + self._raise_if_cancelled("等待 DuckMail 验证码时任务已取消") try: response = self._make_request( "GET", @@ -324,7 +325,7 @@ class DuckMailService(BaseEmailService): raise logger.debug(f"DuckMail 轮询验证码失败: {e}") - time.sleep(poll_interval) + self._sleep_with_cancel(poll_interval) return None diff --git a/src/services/freemail.py b/src/services/freemail.py index 5d7e204..f3de90a 100644 --- a/src/services/freemail.py +++ b/src/services/freemail.py @@ -214,10 +214,11 @@ class FreemailService(BaseEmailService): seen_mail_ids: set = set() while time.time() - start_time < timeout: + self._raise_if_cancelled("等待 Freemail 验证码时任务已取消") try: mails = self._make_request("GET", "/api/emails", params={"mailbox": email, "limit": 20}) if not isinstance(mails, list): - time.sleep(poll_interval) + self._sleep_with_cancel(poll_interval) continue ordered_mails = self._sort_items_by_message_time( @@ -299,7 +300,7 @@ class FreemailService(BaseEmailService): raise logger.debug(f"检查 Freemail 邮件时出错: {e}") - time.sleep(poll_interval) + self._sleep_with_cancel(poll_interval) logger.warning(f"等待 Freemail 验证码超时: {email}") return None diff --git a/src/services/imap_mail.py b/src/services/imap_mail.py index c65b3da..c38a5d3 100644 --- a/src/services/imap_mail.py +++ b/src/services/imap_mail.py @@ -12,7 +12,7 @@ import logging from email.header import decode_header from typing import Any, Dict, Optional -from .base import BaseEmailService, EmailServiceError, OTPNoOpenAISenderEmailServiceError, get_email_code_settings +from .base import BaseEmailService, OTPNoOpenAISenderEmailServiceError, get_email_code_settings from ..config.constants import ( EmailServiceType, OTP_CODE_SEMANTIC_PATTERN, @@ -124,11 +124,12 @@ class ImapMailService(BaseEmailService): mail.select("INBOX") while time.time() - start_time < timeout: + self._raise_if_cancelled("等待 IMAP 验证码时任务已取消") try: # 搜索所有未读邮件 status, data = mail.search(None, "UNSEEN") if status != "OK" or not data or not data[0]: - time.sleep(poll_interval) + self._sleep_with_cancel(poll_interval) continue msg_ids = data[0].split() @@ -181,13 +182,13 @@ class ImapMailService(BaseEmailService): except Exception: pass - time.sleep(poll_interval) + self._sleep_with_cancel(poll_interval) except Exception as e: if isinstance(e, OTPNoOpenAISenderEmailServiceError): raise logger.warning(f"IMAP 连接/轮询失败: {e}") - self.update_status(False, str(e)) + self.update_status(False, e) finally: if mail: try: diff --git a/src/services/moe_mail.py b/src/services/moe_mail.py index e4eaa97..2ce6735 100644 --- a/src/services/moe_mail.py +++ b/src/services/moe_mail.py @@ -308,13 +308,14 @@ class MeoMailEmailService(BaseEmailService): seen_message_ids = set() while time.time() - start_time < timeout: + self._raise_if_cancelled("等待自定义域名邮箱验证码时任务已取消") try: # 获取邮件列表 response = self._make_request("GET", f"/api/emails/{target_email_id}") messages = response.get("messages", []) if not isinstance(messages, list): - time.sleep(poll_interval) + self._sleep_with_cancel(poll_interval) continue ordered_messages = self._sort_items_by_message_time( @@ -380,7 +381,7 @@ class MeoMailEmailService(BaseEmailService): logger.debug(f"检查邮件时出错: {e}") # 等待一段时间再检查 - time.sleep(poll_interval) + self._sleep_with_cancel(poll_interval) logger.warning(f"等待验证码超时: {email}") return None diff --git a/src/services/outlook/service.py b/src/services/outlook/service.py index f2141e4..4aa2d28 100644 --- a/src/services/outlook/service.py +++ b/src/services/outlook/service.py @@ -337,6 +337,7 @@ class OutlookService(BaseEmailService): poll_count = 0 while time.time() - start_time < actual_timeout: + self._raise_if_cancelled("等待 Outlook 验证码时任务已取消") poll_count += 1 # 渐进式邮件检查:前 3 次只检查未读 @@ -387,7 +388,7 @@ class OutlookService(BaseEmailService): logger.warning(f"[{email}] 检查出错: {e}") # 等待下次轮询 - time.sleep(poll_interval) + self._sleep_with_cancel(poll_interval) elapsed = int(time.time() - start_time) logger.warning(f"[{email}] 验证码超时 ({actual_timeout}s),共轮询 {poll_count} 次") diff --git a/src/services/temp_mail.py b/src/services/temp_mail.py index e8660cc..09bd31c 100644 --- a/src/services/temp_mail.py +++ b/src/services/temp_mail.py @@ -309,6 +309,7 @@ class TempMailService(BaseEmailService): # jwt = cached.get("jwt") while time.time() - start_time < timeout: + self._raise_if_cancelled("等待 TempMail 验证码时任务已取消") try: # if jwt: # response = self._make_request( @@ -327,7 +328,7 @@ class TempMailService(BaseEmailService): # /user_api/mails 和 /admin/mails 返回格式相同: {"results": [...], "total": N} mails = response.get("results", []) if not isinstance(mails, list): - time.sleep(poll_interval) + self._sleep_with_cancel(poll_interval) continue ordered_mails = self._sort_items_by_message_time( @@ -389,7 +390,7 @@ class TempMailService(BaseEmailService): raise logger.debug(f"检查 TempMail 邮件时出错: {e}") - time.sleep(poll_interval) + self._sleep_with_cancel(poll_interval) logger.warning(f"等待 TempMail 验证码超时: {email}") return None diff --git a/src/services/tempmail.py b/src/services/tempmail.py index ee5884c..49390a0 100644 --- a/src/services/tempmail.py +++ b/src/services/tempmail.py @@ -217,6 +217,7 @@ class TempmailService(BaseEmailService): seen_ids = set() while time.time() - start_time < timeout: + self._raise_if_cancelled("等待 Tempmail 验证码时任务已取消") try: # 获取邮件列表 response = self.http_client.get( @@ -226,7 +227,7 @@ class TempmailService(BaseEmailService): ) if response.status_code != 200: - time.sleep(poll_interval) + self._sleep_with_cancel(poll_interval) continue data = response.json() @@ -239,7 +240,7 @@ class TempmailService(BaseEmailService): email_list = data.get("emails", []) if isinstance(data, dict) else [] if not isinstance(email_list, list): - time.sleep(poll_interval) + self._sleep_with_cancel(poll_interval) continue ordered_emails = self._sort_items_by_message_time( @@ -302,7 +303,7 @@ class TempmailService(BaseEmailService): logger.debug(f"检查邮件时出错: {e}") # 等待一段时间再检查 - time.sleep(poll_interval) + self._sleep_with_cancel(poll_interval) logger.warning(f"等待验证码超时: {email}") return None diff --git a/src/web/routes/registration.py b/src/web/routes/registration.py index 4d6136b..5a8069d 100644 --- a/src/web/routes/registration.py +++ b/src/web/routes/registration.py @@ -20,6 +20,7 @@ from ...database.session import get_db from ...database.models import RegistrationTask, Proxy from ...core.register import ( ERROR_OTP_TIMEOUT_SECONDARY, + ERROR_TASK_CANCELLED, RegistrationEngine, RegistrationResult, ) @@ -395,6 +396,11 @@ def _run_registration_engine_attempt( status_callback=status_callback, task_uuid=task_uuid, ) + create_cancel_callback = getattr(task_manager, "create_check_cancelled_callback", None) + if callable(create_cancel_callback): + setattr(engine, "check_cancelled", create_cancel_callback(task_uuid)) + else: + setattr(engine, "check_cancelled", lambda: False) try: result = engine.run() @@ -435,6 +441,52 @@ def _get_batch_snapshot(batch_id: str) -> Optional[dict]: return task_manager.get_batch_status(batch_id) +def _finalize_task_cancelled( + db, + task_uuid: str, + message: str = "任务已取消", + *, + email_service: Optional[str] = None, +): + """将任务统一收敛为已取消状态。""" + crud.update_registration_task( + db, + task_uuid, + status="cancelled", + completed_at=datetime.utcnow(), + error_message=message, + ) + status_kwargs = {"error": message, "cancelled": True} + if email_service: + status_kwargs["email_service"] = email_service + task_manager.update_status(task_uuid, "cancelled", **status_kwargs) + + +def _request_task_cancellation(db, task_uuid: str, message: str = "任务已取消"): + """向运行中的任务发送取消信号,并持久化取消状态。""" + task_manager.cancel_task(task_uuid) + _finalize_task_cancelled(db, task_uuid, message) + + +def _request_batch_cancellation(batch_id: str, message: str = "批量任务取消请求已提交"): + """向批量任务及其成员任务传播取消信号。""" + batch = _require_batch_snapshot(batch_id) + if batch.get("finished"): + raise HTTPException(status_code=400, detail="批量任务已完成") + + task_manager.cancel_batch(batch_id) + task_manager.update_batch_status(batch_id, status="cancelled", cancelled=True) + + with get_db() as db: + for task_uuid in batch.get("task_uuids", []): + task_manager.cancel_task(task_uuid) + task = crud.get_registration_task(db, task_uuid) + if task and task.status in ["pending", "running"]: + _finalize_task_cancelled(db, task_uuid, "任务已取消") + + return {"success": True, "message": message} + + def _require_batch_snapshot(batch_id: str) -> dict: batch = _get_batch_snapshot(batch_id) if batch is None: @@ -560,8 +612,12 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy: """ with get_db() as db: try: + def cancellation_requested() -> bool: + return task_manager.is_cancelled(task_uuid) + if task_manager.is_cancelled(task_uuid): logger.info(f"任务 {task_uuid} 已取消,跳过执行") + _finalize_task_cancelled(db, task_uuid) return task = crud.update_registration_task( @@ -583,6 +639,10 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy: proxy_id = None while True: + if cancellation_requested(): + _finalize_task_cancelled(db, task_uuid, email_service=active_service_type.value) + return + actual_proxy_url = requested_proxy proxy_id = None if not actual_proxy_url: @@ -605,6 +665,10 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy: should_retry_with_new_proxy = False for attempt_index, candidate in enumerate(service_candidates, start=1): + if cancellation_requested(): + _finalize_task_cancelled(db, task_uuid, email_service=active_service_type.value) + return + selected_service_type = candidate["service_type"] candidate_config = candidate["config"] db_service = candidate.get("db_service") @@ -630,6 +694,11 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy: candidate_config, name=db_service.name if db_service is not None else None, ) + create_cancel_callback = getattr(task_manager, "create_check_cancelled_callback", None) + if callable(create_cancel_callback): + set_cancel_callback = getattr(email_service, "set_check_cancelled", None) + if callable(set_cancel_callback): + set_cancel_callback(create_cancel_callback(task_uuid)) ( engine, result, @@ -645,6 +714,15 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy: status_callback=status_callback, ) + if cancellation_requested() or result.error_code == ERROR_TASK_CANCELLED: + _finalize_task_cancelled( + db, + task_uuid, + result.error_message or "任务已取消", + email_service=active_service_type.value, + ) + return + if result.success: break @@ -707,6 +785,10 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy: break if result.success: + if cancellation_requested(): + _finalize_task_cancelled(db, task_uuid, "任务已取消", email_service=active_service_type.value) + return + # 更新代理使用时间 update_proxy_usage(db, proxy_id) @@ -836,6 +918,10 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy: except Exception as na_err: log_callback(f"[NEWAPI] 上传异常: {na_err}") + if cancellation_requested(): + _finalize_task_cancelled(db, task_uuid, "任务已取消", email_service=active_service_type.value) + return + # 更新任务状态 crud.update_registration_task( db, task_uuid, @@ -857,6 +943,15 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy: logger.info(f"注册任务完成: {task_uuid}, 邮箱: {result.email}") else: + if cancellation_requested() or result.error_code == ERROR_TASK_CANCELLED: + _finalize_task_cancelled( + db, + task_uuid, + result.error_message or "任务已取消", + email_service=active_service_type.value, + ) + return + # 更新任务状态为失败 crud.update_registration_task( db, task_uuid, @@ -880,15 +975,18 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy: try: with get_db() as db: - crud.update_registration_task( - db, task_uuid, - status="failed", - completed_at=datetime.utcnow(), - error_message=str(e) - ) + if task_manager.is_cancelled(task_uuid): + _finalize_task_cancelled(db, task_uuid) + else: + crud.update_registration_task( + db, task_uuid, + status="failed", + completed_at=datetime.utcnow(), + error_message=str(e) + ) - # 更新 TaskManager 状态 - task_manager.update_status(task_uuid, "failed", error=str(e)) + # 更新 TaskManager 状态 + task_manager.update_status(task_uuid, "failed", error=str(e)) except: pass @@ -904,6 +1002,11 @@ async def run_registration_task(task_uuid: str, email_service_type: str, proxy: loop = asyncio.get_event_loop() task_manager.set_loop(loop) + if task_manager.is_cancelled(task_uuid): + with get_db() as db: + _finalize_task_cancelled(db, task_uuid) + return + # 初始化 TaskManager 状态 task_manager.update_status(task_uuid, "pending", email_service=email_service_type) task_manager.add_log(task_uuid, f"{log_prefix} [系统] 任务 {task_uuid[:8]} 已加入队列" if log_prefix else f"[系统] 任务 {task_uuid[:8]} 已加入队列") @@ -1635,12 +1738,7 @@ async def get_batch_status(batch_id: str): @router.post("/batch/{batch_id}/cancel") async def cancel_batch(batch_id: str): """取消批量任务""" - batch = _require_batch_snapshot(batch_id) - if batch.get("finished"): - raise HTTPException(status_code=400, detail="批量任务已完成") - - task_manager.cancel_batch(batch_id) - return {"success": True, "message": "批量任务取消请求已提交"} + return _request_batch_cancellation(batch_id) @router.get("/tasks", response_model=TaskListResponse) @@ -1703,7 +1801,7 @@ async def cancel_task(task_uuid: str): if task.status not in ["pending", "running"]: raise HTTPException(status_code=400, detail="任务已完成或已取消") - task = crud.update_registration_task(db, task_uuid, status="cancelled") + _request_task_cancellation(db, task_uuid) return {"success": True, "message": "任务已取消"} @@ -2207,10 +2305,4 @@ async def get_outlook_batch_status(batch_id: str): @router.post("/outlook-batch/{batch_id}/cancel") async def cancel_outlook_batch(batch_id: str): """取消 Outlook 批量任务""" - batch = _require_batch_snapshot(batch_id) - if batch.get("finished"): - raise HTTPException(status_code=400, detail="批量任务已完成") - - task_manager.cancel_batch(batch_id) - - return {"success": True, "message": "批量任务取消请求已提交"} + return _request_batch_cancellation(batch_id) diff --git a/src/web/routes/websocket.py b/src/web/routes/websocket.py index caa0c0a..9938878 100644 --- a/src/web/routes/websocket.py +++ b/src/web/routes/websocket.py @@ -86,7 +86,14 @@ async def task_websocket(websocket: WebSocket, task_uuid: str): # 处理取消请求 elif data.get("type") == "cancel": - task_manager.cancel_task(task_uuid) + from . import registration as registration_routes + + with get_db() as db: + task = crud.get_registration_task(db, task_uuid) + if task and task.status in ["pending", "running"]: + registration_routes._request_task_cancellation(db, task_uuid) + else: + task_manager.cancel_task(task_uuid) await websocket.send_json({ "type": "status", "task_uuid": task_uuid, @@ -163,7 +170,9 @@ async def batch_websocket(websocket: WebSocket, batch_id: str): # 处理取消请求 elif data.get("type") == "cancel": - task_manager.cancel_batch(batch_id) + from . import registration as registration_routes + + registration_routes._request_batch_cancellation(batch_id) await websocket.send_json({ "type": "status", "batch_id": batch_id, diff --git a/tests/test_batch_websocket_fallback.cjs b/tests/test_batch_websocket_fallback.cjs index f6bd068..221777c 100644 --- a/tests/test_batch_websocket_fallback.cjs +++ b/tests/test_batch_websocket_fallback.cjs @@ -1,9 +1,10 @@ const test = require('node:test'); const assert = require('node:assert/strict'); const fs = require('node:fs'); +const path = require('node:path'); const vm = require('node:vm'); -const APP_JS_PATH = '/Users/zhoukailian/.config/superpowers/worktrees/codex-manager/repro-batch-monitor/static/js/app.js'; +const APP_JS_PATH = path.join(__dirname, '..', 'static', 'js', 'app.js'); function createElementStub() { return { diff --git a/tests/test_registration_task_cancellation.py b/tests/test_registration_task_cancellation.py new file mode 100644 index 0000000..88601e5 --- /dev/null +++ b/tests/test_registration_task_cancellation.py @@ -0,0 +1,363 @@ +import asyncio +import threading +import time +from contextlib import contextmanager +from pathlib import Path +from types import SimpleNamespace + +from src.core.register import ERROR_TASK_CANCELLED, RegistrationResult +from src.database.models import Base, RegistrationTask +from src.database.session import DatabaseSessionManager +from src.services import EmailServiceType +from src.services.base import BaseEmailService, EmailServiceCancelledError +from src.web.routes import registration as registration_routes + + +class FakeTaskManager: + def __init__(self): + self.cancelled = set() + self.status_updates = [] + self.logs = {} + self.batch_status = {} + + def is_cancelled(self, task_uuid): + return task_uuid in self.cancelled + + def cancel_task(self, task_uuid): + self.cancelled.add(task_uuid) + + def update_status(self, task_uuid, status, **kwargs): + self.status_updates.append((task_uuid, status, kwargs)) + + def create_log_callback(self, task_uuid, prefix="", batch_id=""): + def callback(message): + full_message = f"{prefix} {message}" if prefix else message + self.logs.setdefault(task_uuid, []).append(full_message) + return callback + + def create_check_cancelled_callback(self, task_uuid): + return lambda: self.is_cancelled(task_uuid) + + def get_batch_status(self, batch_id): + snapshot = self.batch_status.get(batch_id) + return dict(snapshot) if snapshot else None + + def cancel_batch(self, batch_id): + snapshot = self.batch_status.setdefault(batch_id, {}) + snapshot["cancelled"] = True + snapshot["status"] = "cancelling" + + def update_batch_status(self, batch_id, **kwargs): + snapshot = self.batch_status.setdefault(batch_id, {}) + snapshot.update(kwargs) + + +class DummySettings: + proxy_dynamic_enabled = False + proxy_dynamic_api_url = "" + email_code_timeout = 10 + email_code_poll_interval = 1 + email_code_resend_max_retries = 0 + email_code_non_openai_sender_resend_max_retries = 0 + openai_client_id = "client-id" + openai_auth_url = "https://auth.example.test" + openai_token_url = "https://token.example.test" + openai_redirect_uri = "https://callback.example.test" + openai_scope = "openid profile email" + + def get_proxy_url(self): + return None + + +def _build_fake_get_db(manager): + @contextmanager + def fake_get_db(): + with manager.session_scope() as session: + yield session + + return fake_get_db + + +class FakeRegistrationEngine: + started_event = None + + def __init__(self, email_service, proxy_url=None, callback_logger=None, status_callback=None, task_uuid=None): + self.email_service = email_service + self.phase_history = [] + self.check_cancelled = None + self.callback_logger = callback_logger or (lambda _msg: None) + self.task_uuid = task_uuid + + def run(self): + if self.started_event is not None: + self.started_event.set() + while True: + if callable(self.check_cancelled) and self.check_cancelled(): + return RegistrationResult( + success=False, + error_message="任务已取消", + error_code=ERROR_TASK_CANCELLED, + logs=[], + ) + time.sleep(0.01) + + def save_to_database(self, result): + return False + + def close(self): + return None + + +class FakePollingEmailService(BaseEmailService): + def __init__(self, started_event=None): + super().__init__(EmailServiceType.TEMPMAIL, "fake-polling-email") + self.started_event = started_event + + def create_email(self, config=None): + return {"email": "poll@example.test", "service_id": "poll-service"} + + def get_verification_code(self, email: str, email_id: str = None, timeout: int = 120, pattern: str = r"(?