mirror of
https://github.com/cnlimiter/codex-register.git
synced 2026-05-06 20:02:51 +08:00
feat(email): implement cancellation handling in email services
This commit is contained in:
@@ -31,6 +31,7 @@ from ..database import crud
|
|||||||
from ..database.session import get_db
|
from ..database.session import get_db
|
||||||
from ..services import BaseEmailService
|
from ..services import BaseEmailService
|
||||||
from ..services.base import (
|
from ..services.base import (
|
||||||
|
EmailServiceCancelledError,
|
||||||
EmailProviderBackoffState,
|
EmailProviderBackoffState,
|
||||||
OTP_NO_OPENAI_SENDER_ERROR,
|
OTP_NO_OPENAI_SENDER_ERROR,
|
||||||
OTPNoOpenAISenderEmailServiceError,
|
OTPNoOpenAISenderEmailServiceError,
|
||||||
@@ -43,6 +44,15 @@ PHASE_EMAIL_PREPARE = "email_prepare"
|
|||||||
PHASE_OTP_SECONDARY = "otp_secondary"
|
PHASE_OTP_SECONDARY = "otp_secondary"
|
||||||
ERROR_EMAIL_PROVIDER_RATE_LIMITED = "EMAIL_PROVIDER_RATE_LIMITED"
|
ERROR_EMAIL_PROVIDER_RATE_LIMITED = "EMAIL_PROVIDER_RATE_LIMITED"
|
||||||
ERROR_OTP_TIMEOUT_SECONDARY = "OTP_TIMEOUT_SECONDARY"
|
ERROR_OTP_TIMEOUT_SECONDARY = "OTP_TIMEOUT_SECONDARY"
|
||||||
|
ERROR_TASK_CANCELLED = "TASK_CANCELLED"
|
||||||
|
|
||||||
|
|
||||||
|
class TaskCancelledError(Exception):
|
||||||
|
"""注册任务被主动取消。"""
|
||||||
|
|
||||||
|
def __init__(self, message: str = "任务已取消"):
|
||||||
|
super().__init__(message)
|
||||||
|
self.error_code = ERROR_TASK_CANCELLED
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -185,6 +195,8 @@ class RegistrationEngine:
|
|||||||
self._otp_sent_at: Optional[float] = None # OTP 发送时间戳
|
self._otp_sent_at: Optional[float] = None # OTP 发送时间戳
|
||||||
self._is_existing_account: bool = False # 是否为已注册账号(用于自动登录)
|
self._is_existing_account: bool = False # 是否为已注册账号(用于自动登录)
|
||||||
self.phase_history: list[PhaseResult] = []
|
self.phase_history: list[PhaseResult] = []
|
||||||
|
self.check_cancelled: Optional[Callable[[], bool]] = None
|
||||||
|
self._cancel_logged = False
|
||||||
|
|
||||||
def _log(self, message: str, level: str = "info"):
|
def _log(self, message: str, level: str = "info"):
|
||||||
"""记录日志"""
|
"""记录日志"""
|
||||||
@@ -274,6 +286,35 @@ class RegistrationEngine:
|
|||||||
self.phase_history.append(phase_result)
|
self.phase_history.append(phase_result)
|
||||||
return phase_result
|
return phase_result
|
||||||
|
|
||||||
|
def _is_cancelled_requested(self) -> bool:
|
||||||
|
"""检查是否收到外部取消信号。"""
|
||||||
|
callback = self.check_cancelled
|
||||||
|
if not callable(callback):
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
return bool(callback())
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"检查任务取消状态失败: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _raise_if_cancelled(self, message: str = "任务已取消"):
|
||||||
|
"""在可协作阶段检查取消请求。"""
|
||||||
|
if not self._is_cancelled_requested():
|
||||||
|
return
|
||||||
|
if not self._cancel_logged:
|
||||||
|
self._log(message, "warning")
|
||||||
|
self._cancel_logged = True
|
||||||
|
raise TaskCancelledError(message)
|
||||||
|
|
||||||
|
def _sleep_with_cancel(self, seconds: float, chunk_seconds: float = 0.2):
|
||||||
|
"""可响应取消的短分片休眠。"""
|
||||||
|
remaining = max(0.0, float(seconds))
|
||||||
|
while remaining > 0:
|
||||||
|
self._raise_if_cancelled()
|
||||||
|
sleep_for = min(chunk_seconds, remaining)
|
||||||
|
time.sleep(sleep_for)
|
||||||
|
remaining -= sleep_for
|
||||||
|
|
||||||
def _get_phase_result(self, phase_name: str) -> Optional[PhaseResult]:
|
def _get_phase_result(self, phase_name: str) -> Optional[PhaseResult]:
|
||||||
for phase_result in reversed(self.phase_history):
|
for phase_result in reversed(self.phase_history):
|
||||||
if phase_result.phase == phase_name:
|
if phase_result.phase == phase_name:
|
||||||
@@ -427,7 +468,7 @@ class RegistrationEngine:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if attempt < max_attempts:
|
if attempt < max_attempts:
|
||||||
time.sleep(attempt)
|
self._sleep_with_cancel(attempt)
|
||||||
self.http_client.close()
|
self.http_client.close()
|
||||||
self.session = self.http_client.session
|
self.session = self.http_client.session
|
||||||
|
|
||||||
@@ -656,6 +697,7 @@ class RegistrationEngine:
|
|||||||
otp_phase: Optional[PhaseResult] = None
|
otp_phase: Optional[PhaseResult] = None
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
|
self._raise_if_cancelled("等待验证码重试时任务已取消")
|
||||||
code, otp_phase = self._phase_otp_secondary(
|
code, otp_phase = self._phase_otp_secondary(
|
||||||
PhaseContext(otp_sent_at=self._otp_sent_at),
|
PhaseContext(otp_sent_at=self._otp_sent_at),
|
||||||
started_at=otp_phase_started_at,
|
started_at=otp_phase_started_at,
|
||||||
@@ -698,6 +740,7 @@ class RegistrationEngine:
|
|||||||
**emit_kwargs,
|
**emit_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._raise_if_cancelled("等待验证码重试时任务已取消")
|
||||||
if not resend_callback():
|
if not resend_callback():
|
||||||
self._log("重新发送验证码失败,跳过本次重试", "warning")
|
self._log("重新发送验证码失败,跳过本次重试", "warning")
|
||||||
|
|
||||||
@@ -712,76 +755,79 @@ class RegistrationEngine:
|
|||||||
) -> Tuple[Optional[str], PhaseResult]:
|
) -> Tuple[Optional[str], PhaseResult]:
|
||||||
"""等待二次验证码邮件并做超时归因。"""
|
"""等待二次验证码邮件并做超时归因。"""
|
||||||
try:
|
try:
|
||||||
|
self._raise_if_cancelled("等待验证码时任务已取消")
|
||||||
self._log(f"正在等待邮箱 {self.email} 的验证码...")
|
self._log(f"正在等待邮箱 {self.email} 的验证码...")
|
||||||
|
|
||||||
email_id = self.email_info.get("service_id") if self.email_info else None
|
email_id = self.email_info.get("service_id") if self.email_info else None
|
||||||
|
settings = get_settings()
|
||||||
budget = Budget(
|
budget = Budget(
|
||||||
timeout_seconds=get_settings().email_code_timeout,
|
timeout_seconds=settings.email_code_timeout,
|
||||||
started_at=started_at if started_at is not None else time.time(),
|
started_at=started_at if started_at is not None else time.time(),
|
||||||
)
|
)
|
||||||
remaining_timeout = budget.remaining_seconds()
|
poll_interval = max(1, int(settings.email_code_poll_interval or 1))
|
||||||
|
|
||||||
if remaining_timeout <= 0:
|
while True:
|
||||||
|
self._raise_if_cancelled("等待验证码时任务已取消")
|
||||||
|
remaining_timeout = budget.remaining_seconds()
|
||||||
|
|
||||||
|
if remaining_timeout <= 0:
|
||||||
|
phase_result = self._record_phase_result(
|
||||||
|
PhaseResult(
|
||||||
|
phase=PHASE_OTP_SECONDARY,
|
||||||
|
success=False,
|
||||||
|
error_message="等待验证码超时",
|
||||||
|
error_code=ERROR_OTP_TIMEOUT_SECONDARY,
|
||||||
|
retryable=True,
|
||||||
|
next_action="await_email",
|
||||||
|
metadata={
|
||||||
|
"budget_started_at": budget.started_at,
|
||||||
|
"budget_timeout_seconds": budget.timeout_seconds,
|
||||||
|
"otp_sent_at": context.otp_sent_at,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self._log(phase_result.error_message, "error")
|
||||||
|
return None, phase_result
|
||||||
|
|
||||||
|
attempt_timeout = max(1, min(remaining_timeout, poll_interval))
|
||||||
|
code = self.email_service.get_verification_code(
|
||||||
|
email=self.email,
|
||||||
|
email_id=email_id,
|
||||||
|
timeout=attempt_timeout,
|
||||||
|
pattern=OTP_CODE_PATTERN,
|
||||||
|
otp_sent_at=context.otp_sent_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
if code:
|
||||||
|
self._log(f"成功获取验证码: {code}")
|
||||||
|
phase_result = self._record_phase_result(
|
||||||
|
PhaseResult(
|
||||||
|
phase=PHASE_OTP_SECONDARY,
|
||||||
|
success=True,
|
||||||
|
metadata={
|
||||||
|
"budget_started_at": budget.started_at,
|
||||||
|
"budget_timeout_seconds": budget.timeout_seconds,
|
||||||
|
"otp_sent_at": context.otp_sent_at,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return code, phase_result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if isinstance(e, (TaskCancelledError, EmailServiceCancelledError)):
|
||||||
phase_result = self._record_phase_result(
|
phase_result = self._record_phase_result(
|
||||||
PhaseResult(
|
PhaseResult(
|
||||||
phase=PHASE_OTP_SECONDARY,
|
phase=PHASE_OTP_SECONDARY,
|
||||||
success=False,
|
success=False,
|
||||||
error_message="等待验证码超时",
|
error_message=str(e),
|
||||||
error_code=ERROR_OTP_TIMEOUT_SECONDARY,
|
error_code=getattr(e, "error_code", ERROR_TASK_CANCELLED),
|
||||||
retryable=True,
|
retryable=False,
|
||||||
next_action="await_email",
|
next_action="cancelled",
|
||||||
metadata={
|
metadata={"otp_sent_at": context.otp_sent_at},
|
||||||
"budget_started_at": budget.started_at,
|
|
||||||
"budget_timeout_seconds": budget.timeout_seconds,
|
|
||||||
"otp_sent_at": context.otp_sent_at,
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self._log(phase_result.error_message, "error")
|
|
||||||
return None, phase_result
|
return None, phase_result
|
||||||
|
|
||||||
code = self.email_service.get_verification_code(
|
|
||||||
email=self.email,
|
|
||||||
email_id=email_id,
|
|
||||||
timeout=remaining_timeout,
|
|
||||||
pattern=OTP_CODE_PATTERN,
|
|
||||||
otp_sent_at=context.otp_sent_at,
|
|
||||||
)
|
|
||||||
|
|
||||||
if code:
|
|
||||||
self._log(f"成功获取验证码: {code}")
|
|
||||||
phase_result = self._record_phase_result(
|
|
||||||
PhaseResult(
|
|
||||||
phase=PHASE_OTP_SECONDARY,
|
|
||||||
success=True,
|
|
||||||
metadata={
|
|
||||||
"budget_started_at": budget.started_at,
|
|
||||||
"budget_timeout_seconds": budget.timeout_seconds,
|
|
||||||
"otp_sent_at": context.otp_sent_at,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return code, phase_result
|
|
||||||
|
|
||||||
phase_result = self._record_phase_result(
|
|
||||||
PhaseResult(
|
|
||||||
phase=PHASE_OTP_SECONDARY,
|
|
||||||
success=False,
|
|
||||||
error_message="等待验证码超时",
|
|
||||||
error_code=ERROR_OTP_TIMEOUT_SECONDARY,
|
|
||||||
retryable=True,
|
|
||||||
next_action="await_email",
|
|
||||||
metadata={
|
|
||||||
"budget_started_at": budget.started_at,
|
|
||||||
"budget_timeout_seconds": budget.timeout_seconds,
|
|
||||||
"otp_sent_at": context.otp_sent_at,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self._log(phase_result.error_message, "error")
|
|
||||||
return None, phase_result
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
if isinstance(e, OTPNoOpenAISenderEmailServiceError):
|
if isinstance(e, OTPNoOpenAISenderEmailServiceError):
|
||||||
self._log(str(e), "warning")
|
self._log(str(e), "warning")
|
||||||
phase_result = self._record_phase_result(
|
phase_result = self._record_phase_result(
|
||||||
@@ -1404,6 +1450,8 @@ class RegistrationEngine:
|
|||||||
non_openai_retry_status_template="重新触发登录验证码(非 OpenAI 发件人,第 {attempt} 次)",
|
non_openai_retry_status_template="重新触发登录验证码(非 OpenAI 发件人,第 {attempt} 次)",
|
||||||
)
|
)
|
||||||
if not code:
|
if not code:
|
||||||
|
if otp_phase and otp_phase.error_code == ERROR_TASK_CANCELLED:
|
||||||
|
raise TaskCancelledError(otp_phase.error_message or "登录流程已取消")
|
||||||
self._log(
|
self._log(
|
||||||
otp_phase.error_message if otp_phase and otp_phase.error_message else "登录流程获取验证码失败",
|
otp_phase.error_message if otp_phase and otp_phase.error_message else "登录流程获取验证码失败",
|
||||||
"warning",
|
"warning",
|
||||||
@@ -1539,11 +1587,13 @@ class RegistrationEngine:
|
|||||||
result = RegistrationResult(success=False, logs=self.logs)
|
result = RegistrationResult(success=False, logs=self.logs)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
self._raise_if_cancelled()
|
||||||
self._log("=" * 60)
|
self._log("=" * 60)
|
||||||
self._log("开始注册流程")
|
self._log("开始注册流程")
|
||||||
self._log("=" * 60)
|
self._log("=" * 60)
|
||||||
|
|
||||||
# 1. 检查 IP 地理位置
|
# 1. 检查 IP 地理位置
|
||||||
|
self._raise_if_cancelled()
|
||||||
self._log("1. 检查 IP 地理位置...")
|
self._log("1. 检查 IP 地理位置...")
|
||||||
self._emit_status("ip_check", "检查 IP 地理位置", step_index=1)
|
self._emit_status("ip_check", "检查 IP 地理位置", step_index=1)
|
||||||
ip_ok, location = self._check_ip_location()
|
ip_ok, location = self._check_ip_location()
|
||||||
@@ -1555,6 +1605,7 @@ class RegistrationEngine:
|
|||||||
self._log(f"IP 位置: {location}")
|
self._log(f"IP 位置: {location}")
|
||||||
|
|
||||||
# 2. 创建邮箱
|
# 2. 创建邮箱
|
||||||
|
self._raise_if_cancelled()
|
||||||
self._log("2. 创建邮箱...")
|
self._log("2. 创建邮箱...")
|
||||||
self._emit_status("email_prepare", "创建邮箱地址", step_index=2)
|
self._emit_status("email_prepare", "创建邮箱地址", step_index=2)
|
||||||
if not self._phase_email_prepare():
|
if not self._phase_email_prepare():
|
||||||
@@ -1570,6 +1621,7 @@ class RegistrationEngine:
|
|||||||
result.email = self.email
|
result.email = self.email
|
||||||
|
|
||||||
# 3. 初始化会话
|
# 3. 初始化会话
|
||||||
|
self._raise_if_cancelled()
|
||||||
self._log("3. 初始化会话...")
|
self._log("3. 初始化会话...")
|
||||||
self._emit_status("session_init", "初始化 HTTP 会话", step_index=3)
|
self._emit_status("session_init", "初始化 HTTP 会话", step_index=3)
|
||||||
if not self._init_session():
|
if not self._init_session():
|
||||||
@@ -1577,6 +1629,7 @@ class RegistrationEngine:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
# 4. 开始 OAuth 流程
|
# 4. 开始 OAuth 流程
|
||||||
|
self._raise_if_cancelled()
|
||||||
self._log("4. 开始 OAuth 授权流程...")
|
self._log("4. 开始 OAuth 授权流程...")
|
||||||
self._emit_status("oauth_start", "开始 OAuth 授权流程", step_index=4)
|
self._emit_status("oauth_start", "开始 OAuth 授权流程", step_index=4)
|
||||||
if not self._start_oauth():
|
if not self._start_oauth():
|
||||||
@@ -1584,6 +1637,7 @@ class RegistrationEngine:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
# 5. 获取 Device ID
|
# 5. 获取 Device ID
|
||||||
|
self._raise_if_cancelled()
|
||||||
self._log("5. 获取 Device ID...")
|
self._log("5. 获取 Device ID...")
|
||||||
self._emit_status("oauth_device_id", "获取 Device ID", step_index=5)
|
self._emit_status("oauth_device_id", "获取 Device ID", step_index=5)
|
||||||
did = self._get_device_id()
|
did = self._get_device_id()
|
||||||
@@ -1592,6 +1646,7 @@ class RegistrationEngine:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
# 6. 检查 Sentinel 拦截
|
# 6. 检查 Sentinel 拦截
|
||||||
|
self._raise_if_cancelled()
|
||||||
self._log("6. 检查 Sentinel 拦截...")
|
self._log("6. 检查 Sentinel 拦截...")
|
||||||
self._emit_status("sentinel", "检查 Sentinel 拦截", step_index=6)
|
self._emit_status("sentinel", "检查 Sentinel 拦截", step_index=6)
|
||||||
sen_token = self._check_sentinel(did)
|
sen_token = self._check_sentinel(did)
|
||||||
@@ -1601,6 +1656,7 @@ class RegistrationEngine:
|
|||||||
self._log("Sentinel 检查失败或未启用", "warning")
|
self._log("Sentinel 检查失败或未启用", "warning")
|
||||||
|
|
||||||
# 7. 提交注册表单 + 解析响应判断账号状态
|
# 7. 提交注册表单 + 解析响应判断账号状态
|
||||||
|
self._raise_if_cancelled()
|
||||||
self._log("7. 提交注册表单...")
|
self._log("7. 提交注册表单...")
|
||||||
self._emit_status("signup_submit", "提交注册表单", step_index=7)
|
self._emit_status("signup_submit", "提交注册表单", step_index=7)
|
||||||
signup_result = self._submit_signup_form(did, sen_token)
|
signup_result = self._submit_signup_form(did, sen_token)
|
||||||
@@ -1612,6 +1668,7 @@ class RegistrationEngine:
|
|||||||
if self._is_existing_account:
|
if self._is_existing_account:
|
||||||
self._log("8. [已注册账号] 跳过密码设置,OTP 已自动发送")
|
self._log("8. [已注册账号] 跳过密码设置,OTP 已自动发送")
|
||||||
else:
|
else:
|
||||||
|
self._raise_if_cancelled()
|
||||||
self._log("8. 注册密码...")
|
self._log("8. 注册密码...")
|
||||||
self._emit_status("signup_password", "提交注册密码", step_index=8)
|
self._emit_status("signup_password", "提交注册密码", step_index=8)
|
||||||
password_ok, password = self._register_password()
|
password_ok, password = self._register_password()
|
||||||
@@ -1625,6 +1682,7 @@ class RegistrationEngine:
|
|||||||
# 已注册账号的 OTP 在提交表单时已自动发送,记录时间戳
|
# 已注册账号的 OTP 在提交表单时已自动发送,记录时间戳
|
||||||
self._otp_sent_at = time.time()
|
self._otp_sent_at = time.time()
|
||||||
else:
|
else:
|
||||||
|
self._raise_if_cancelled()
|
||||||
self._log("9. 发送验证码...")
|
self._log("9. 发送验证码...")
|
||||||
self._emit_status("otp_send", "发送验证码", step_index=9)
|
self._emit_status("otp_send", "发送验证码", step_index=9)
|
||||||
if not self._send_verification_code():
|
if not self._send_verification_code():
|
||||||
@@ -1632,6 +1690,7 @@ class RegistrationEngine:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
# 10. 获取验证码(支持重发重试)
|
# 10. 获取验证码(支持重发重试)
|
||||||
|
self._raise_if_cancelled()
|
||||||
self._log("10. 等待验证码...")
|
self._log("10. 等待验证码...")
|
||||||
self._emit_status("otp_secondary", "等待验证码邮件", step_index=10)
|
self._emit_status("otp_secondary", "等待验证码邮件", step_index=10)
|
||||||
code, otp_phase = self._await_verification_code_with_resends(
|
code, otp_phase = self._await_verification_code_with_resends(
|
||||||
@@ -1650,6 +1709,7 @@ class RegistrationEngine:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
# 11. 验证验证码
|
# 11. 验证验证码
|
||||||
|
self._raise_if_cancelled()
|
||||||
self._log("11. 验证验证码...")
|
self._log("11. 验证验证码...")
|
||||||
self._emit_status("otp_validate", "校验验证码", step_index=11)
|
self._emit_status("otp_validate", "校验验证码", step_index=11)
|
||||||
if not self._validate_verification_code(code):
|
if not self._validate_verification_code(code):
|
||||||
@@ -1660,6 +1720,7 @@ class RegistrationEngine:
|
|||||||
if self._is_existing_account:
|
if self._is_existing_account:
|
||||||
self._log("12. [已注册账号] 跳过创建用户账户")
|
self._log("12. [已注册账号] 跳过创建用户账户")
|
||||||
else:
|
else:
|
||||||
|
self._raise_if_cancelled()
|
||||||
self._log("12. 创建用户账户...")
|
self._log("12. 创建用户账户...")
|
||||||
self._emit_status("account_create", "创建 OpenAI 账户资料", step_index=12)
|
self._emit_status("account_create", "创建 OpenAI 账户资料", step_index=12)
|
||||||
if not self._create_user_account():
|
if not self._create_user_account():
|
||||||
@@ -1670,6 +1731,7 @@ class RegistrationEngine:
|
|||||||
callback_url = None
|
callback_url = None
|
||||||
|
|
||||||
if not self._is_existing_account:
|
if not self._is_existing_account:
|
||||||
|
self._raise_if_cancelled()
|
||||||
self._log(f"{next_step}. [新账号] 推进 Codex 授权流程...")
|
self._log(f"{next_step}. [新账号] 推进 Codex 授权流程...")
|
||||||
self._emit_status("oauth_reentry", "推进 Codex 授权流程", step_index=next_step)
|
self._emit_status("oauth_reentry", "推进 Codex 授权流程", step_index=next_step)
|
||||||
workspace_id, callback_url = self._advance_login_authorization()
|
workspace_id, callback_url = self._advance_login_authorization()
|
||||||
@@ -1679,6 +1741,7 @@ class RegistrationEngine:
|
|||||||
|
|
||||||
if not result.workspace_id:
|
if not result.workspace_id:
|
||||||
# 获取 Workspace ID
|
# 获取 Workspace ID
|
||||||
|
self._raise_if_cancelled()
|
||||||
self._log(f"{next_step}. 获取 Workspace ID...")
|
self._log(f"{next_step}. 获取 Workspace ID...")
|
||||||
self._emit_status("workspace_extract", "从授权态提取 Workspace ID", step_index=next_step)
|
self._emit_status("workspace_extract", "从授权态提取 Workspace ID", step_index=next_step)
|
||||||
workspace_id = self._get_workspace_id()
|
workspace_id = self._get_workspace_id()
|
||||||
@@ -1691,6 +1754,7 @@ class RegistrationEngine:
|
|||||||
next_step += 1
|
next_step += 1
|
||||||
|
|
||||||
# 选择 Workspace
|
# 选择 Workspace
|
||||||
|
self._raise_if_cancelled()
|
||||||
self._log(f"{next_step}. 选择 Workspace...")
|
self._log(f"{next_step}. 选择 Workspace...")
|
||||||
self._emit_status("workspace_select", "选择 Workspace", step_index=next_step)
|
self._emit_status("workspace_select", "选择 Workspace", step_index=next_step)
|
||||||
continue_url = self._select_workspace(result.workspace_id)
|
continue_url = self._select_workspace(result.workspace_id)
|
||||||
@@ -1701,6 +1765,7 @@ class RegistrationEngine:
|
|||||||
next_step += 1
|
next_step += 1
|
||||||
|
|
||||||
# 跟随重定向链
|
# 跟随重定向链
|
||||||
|
self._raise_if_cancelled()
|
||||||
self._log(f"{next_step}. 跟随重定向链...")
|
self._log(f"{next_step}. 跟随重定向链...")
|
||||||
self._emit_status("redirect_chain", "跟随授权重定向链", step_index=next_step)
|
self._emit_status("redirect_chain", "跟随授权重定向链", step_index=next_step)
|
||||||
callback_url = self._follow_redirects(continue_url)
|
callback_url = self._follow_redirects(continue_url)
|
||||||
@@ -1711,6 +1776,7 @@ class RegistrationEngine:
|
|||||||
next_step += 1
|
next_step += 1
|
||||||
|
|
||||||
# 处理 OAuth 回调
|
# 处理 OAuth 回调
|
||||||
|
self._raise_if_cancelled()
|
||||||
self._log(f"{next_step}. 处理 OAuth 回调...")
|
self._log(f"{next_step}. 处理 OAuth 回调...")
|
||||||
self._emit_status("oauth_callback", "处理 OAuth 回调", step_index=next_step)
|
self._emit_status("oauth_callback", "处理 OAuth 回调", step_index=next_step)
|
||||||
token_info = self._handle_oauth_callback(callback_url)
|
token_info = self._handle_oauth_callback(callback_url)
|
||||||
@@ -1757,6 +1823,11 @@ class RegistrationEngine:
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
except TaskCancelledError as e:
|
||||||
|
result.error_message = str(e)
|
||||||
|
result.error_code = getattr(e, "error_code", ERROR_TASK_CANCELLED)
|
||||||
|
return result
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._log(f"注册过程中发生未预期错误: {e}", "error")
|
self._log(f"注册过程中发生未预期错误: {e}", "error")
|
||||||
result.error_message = str(e)
|
result.error_message = str(e)
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import re
|
|||||||
import time
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional, Dict, Any, List
|
from typing import Optional, Dict, Any, List, Callable
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
from ..config.constants import EmailServiceType, OPENAI_EMAIL_SENDERS, OTP_CODE_PATTERN, OTP_CODE_SEMANTIC_PATTERN
|
from ..config.constants import EmailServiceType, OPENAI_EMAIL_SENDERS, OTP_CODE_PATTERN, OTP_CODE_SEMANTIC_PATTERN
|
||||||
@@ -139,6 +139,10 @@ class OTPNoOpenAISenderEmailServiceError(EmailServiceError):
|
|||||||
self.error_code = error_code
|
self.error_code = error_code
|
||||||
|
|
||||||
|
|
||||||
|
class EmailServiceCancelledError(EmailServiceError):
|
||||||
|
"""邮箱服务在轮询过程中收到取消信号。"""
|
||||||
|
|
||||||
|
|
||||||
class EmailServiceStatus(Enum):
|
class EmailServiceStatus(Enum):
|
||||||
"""邮箱服务状态"""
|
"""邮箱服务状态"""
|
||||||
HEALTHY = "healthy"
|
HEALTHY = "healthy"
|
||||||
@@ -168,6 +172,7 @@ class BaseEmailService(abc.ABC):
|
|||||||
self._provider_backoff = reset_adaptive_backoff()
|
self._provider_backoff = reset_adaptive_backoff()
|
||||||
self._used_verification_codes: Dict[str, set] = {}
|
self._used_verification_codes: Dict[str, set] = {}
|
||||||
self._seen_verification_messages: Dict[str, set] = {}
|
self._seen_verification_messages: Dict[str, set] = {}
|
||||||
|
self.check_cancelled: Optional[Callable[[], bool]] = None
|
||||||
|
|
||||||
_EMAIL_ADDRESS_PATTERN = re.compile(r"[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}")
|
_EMAIL_ADDRESS_PATTERN = re.compile(r"[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}")
|
||||||
|
|
||||||
@@ -190,6 +195,35 @@ class BaseEmailService(abc.ABC):
|
|||||||
"""注入外部持久化的邮箱供应商退避状态"""
|
"""注入外部持久化的邮箱供应商退避状态"""
|
||||||
self._provider_backoff = state or reset_adaptive_backoff()
|
self._provider_backoff = state or reset_adaptive_backoff()
|
||||||
|
|
||||||
|
def set_check_cancelled(self, callback: Optional[Callable[[], bool]]) -> None:
|
||||||
|
"""注入外部取消检查回调。"""
|
||||||
|
self.check_cancelled = callback if callable(callback) else None
|
||||||
|
|
||||||
|
def _is_cancelled_requested(self) -> bool:
|
||||||
|
"""检查邮箱服务是否收到取消请求。"""
|
||||||
|
callback = self.check_cancelled
|
||||||
|
if not callable(callback):
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
return bool(callback())
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"检查邮箱服务取消状态失败: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _raise_if_cancelled(self, message: str = "任务已取消") -> None:
|
||||||
|
"""在轮询/等待阶段响应取消请求。"""
|
||||||
|
if self._is_cancelled_requested():
|
||||||
|
raise EmailServiceCancelledError(message)
|
||||||
|
|
||||||
|
def _sleep_with_cancel(self, seconds: float, chunk_seconds: float = 0.2) -> None:
|
||||||
|
"""可响应取消的短分片休眠。"""
|
||||||
|
remaining = max(0.0, float(seconds))
|
||||||
|
while remaining > 0:
|
||||||
|
self._raise_if_cancelled()
|
||||||
|
sleep_for = min(chunk_seconds, remaining)
|
||||||
|
time.sleep(sleep_for)
|
||||||
|
remaining -= sleep_for
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def create_email(self, config: Dict[str, Any] = None) -> Dict[str, Any]:
|
def create_email(self, config: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
@@ -520,6 +554,7 @@ class BaseEmailService(abc.ABC):
|
|||||||
last_email_id = None
|
last_email_id = None
|
||||||
|
|
||||||
while time.time() - start_time < timeout:
|
while time.time() - start_time < timeout:
|
||||||
|
self._raise_if_cancelled("等待邮件时任务已取消")
|
||||||
try:
|
try:
|
||||||
emails = self.list_emails()
|
emails = self.list_emails()
|
||||||
for email_info in emails:
|
for email_info in emails:
|
||||||
@@ -562,7 +597,7 @@ class BaseEmailService(abc.ABC):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"等待邮件时出错: {e}")
|
logger.warning(f"等待邮件时出错: {e}")
|
||||||
|
|
||||||
time.sleep(check_interval)
|
self._sleep_with_cancel(check_interval)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@@ -236,6 +236,7 @@ class CloudMailService(BaseEmailService):
|
|||||||
seen_mail_ids: set = set()
|
seen_mail_ids: set = set()
|
||||||
|
|
||||||
while time.time() - start_time < timeout:
|
while time.time() - start_time < timeout:
|
||||||
|
self._raise_if_cancelled("等待 Cloud Mail 验证码时任务已取消")
|
||||||
try:
|
try:
|
||||||
token = self._get_public_token()
|
token = self._get_public_token()
|
||||||
mails = self._make_request(
|
mails = self._make_request(
|
||||||
@@ -253,7 +254,7 @@ class CloudMailService(BaseEmailService):
|
|||||||
mails = mails["list"]
|
mails = mails["list"]
|
||||||
|
|
||||||
if not isinstance(mails, list):
|
if not isinstance(mails, list):
|
||||||
time.sleep(poll_interval)
|
self._sleep_with_cancel(poll_interval)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if mails:
|
if mails:
|
||||||
@@ -302,7 +303,7 @@ class CloudMailService(BaseEmailService):
|
|||||||
raise
|
raise
|
||||||
logger.debug(f"检查 Cloud Mail 邮件时出错: {e}")
|
logger.debug(f"检查 Cloud Mail 邮件时出错: {e}")
|
||||||
|
|
||||||
time.sleep(poll_interval)
|
self._sleep_with_cancel(poll_interval)
|
||||||
|
|
||||||
logger.warning(f"等待 Cloud Mail 验证码超时: {email}")
|
logger.warning(f"等待 Cloud Mail 验证码超时: {email}")
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -263,6 +263,7 @@ class DuckMailService(BaseEmailService):
|
|||||||
seen_message_ids = set()
|
seen_message_ids = set()
|
||||||
|
|
||||||
while time.time() - start_time < timeout:
|
while time.time() - start_time < timeout:
|
||||||
|
self._raise_if_cancelled("等待 DuckMail 验证码时任务已取消")
|
||||||
try:
|
try:
|
||||||
response = self._make_request(
|
response = self._make_request(
|
||||||
"GET",
|
"GET",
|
||||||
@@ -324,7 +325,7 @@ class DuckMailService(BaseEmailService):
|
|||||||
raise
|
raise
|
||||||
logger.debug(f"DuckMail 轮询验证码失败: {e}")
|
logger.debug(f"DuckMail 轮询验证码失败: {e}")
|
||||||
|
|
||||||
time.sleep(poll_interval)
|
self._sleep_with_cancel(poll_interval)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@@ -214,10 +214,11 @@ class FreemailService(BaseEmailService):
|
|||||||
seen_mail_ids: set = set()
|
seen_mail_ids: set = set()
|
||||||
|
|
||||||
while time.time() - start_time < timeout:
|
while time.time() - start_time < timeout:
|
||||||
|
self._raise_if_cancelled("等待 Freemail 验证码时任务已取消")
|
||||||
try:
|
try:
|
||||||
mails = self._make_request("GET", "/api/emails", params={"mailbox": email, "limit": 20})
|
mails = self._make_request("GET", "/api/emails", params={"mailbox": email, "limit": 20})
|
||||||
if not isinstance(mails, list):
|
if not isinstance(mails, list):
|
||||||
time.sleep(poll_interval)
|
self._sleep_with_cancel(poll_interval)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
ordered_mails = self._sort_items_by_message_time(
|
ordered_mails = self._sort_items_by_message_time(
|
||||||
@@ -299,7 +300,7 @@ class FreemailService(BaseEmailService):
|
|||||||
raise
|
raise
|
||||||
logger.debug(f"检查 Freemail 邮件时出错: {e}")
|
logger.debug(f"检查 Freemail 邮件时出错: {e}")
|
||||||
|
|
||||||
time.sleep(poll_interval)
|
self._sleep_with_cancel(poll_interval)
|
||||||
|
|
||||||
logger.warning(f"等待 Freemail 验证码超时: {email}")
|
logger.warning(f"等待 Freemail 验证码超时: {email}")
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ import logging
|
|||||||
from email.header import decode_header
|
from email.header import decode_header
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from .base import BaseEmailService, EmailServiceError, OTPNoOpenAISenderEmailServiceError, get_email_code_settings
|
from .base import BaseEmailService, OTPNoOpenAISenderEmailServiceError, get_email_code_settings
|
||||||
from ..config.constants import (
|
from ..config.constants import (
|
||||||
EmailServiceType,
|
EmailServiceType,
|
||||||
OTP_CODE_SEMANTIC_PATTERN,
|
OTP_CODE_SEMANTIC_PATTERN,
|
||||||
@@ -124,11 +124,12 @@ class ImapMailService(BaseEmailService):
|
|||||||
mail.select("INBOX")
|
mail.select("INBOX")
|
||||||
|
|
||||||
while time.time() - start_time < timeout:
|
while time.time() - start_time < timeout:
|
||||||
|
self._raise_if_cancelled("等待 IMAP 验证码时任务已取消")
|
||||||
try:
|
try:
|
||||||
# 搜索所有未读邮件
|
# 搜索所有未读邮件
|
||||||
status, data = mail.search(None, "UNSEEN")
|
status, data = mail.search(None, "UNSEEN")
|
||||||
if status != "OK" or not data or not data[0]:
|
if status != "OK" or not data or not data[0]:
|
||||||
time.sleep(poll_interval)
|
self._sleep_with_cancel(poll_interval)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
msg_ids = data[0].split()
|
msg_ids = data[0].split()
|
||||||
@@ -181,13 +182,13 @@ class ImapMailService(BaseEmailService):
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
time.sleep(poll_interval)
|
self._sleep_with_cancel(poll_interval)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if isinstance(e, OTPNoOpenAISenderEmailServiceError):
|
if isinstance(e, OTPNoOpenAISenderEmailServiceError):
|
||||||
raise
|
raise
|
||||||
logger.warning(f"IMAP 连接/轮询失败: {e}")
|
logger.warning(f"IMAP 连接/轮询失败: {e}")
|
||||||
self.update_status(False, str(e))
|
self.update_status(False, e)
|
||||||
finally:
|
finally:
|
||||||
if mail:
|
if mail:
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -308,13 +308,14 @@ class MeoMailEmailService(BaseEmailService):
|
|||||||
seen_message_ids = set()
|
seen_message_ids = set()
|
||||||
|
|
||||||
while time.time() - start_time < timeout:
|
while time.time() - start_time < timeout:
|
||||||
|
self._raise_if_cancelled("等待自定义域名邮箱验证码时任务已取消")
|
||||||
try:
|
try:
|
||||||
# 获取邮件列表
|
# 获取邮件列表
|
||||||
response = self._make_request("GET", f"/api/emails/{target_email_id}")
|
response = self._make_request("GET", f"/api/emails/{target_email_id}")
|
||||||
|
|
||||||
messages = response.get("messages", [])
|
messages = response.get("messages", [])
|
||||||
if not isinstance(messages, list):
|
if not isinstance(messages, list):
|
||||||
time.sleep(poll_interval)
|
self._sleep_with_cancel(poll_interval)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
ordered_messages = self._sort_items_by_message_time(
|
ordered_messages = self._sort_items_by_message_time(
|
||||||
@@ -380,7 +381,7 @@ class MeoMailEmailService(BaseEmailService):
|
|||||||
logger.debug(f"检查邮件时出错: {e}")
|
logger.debug(f"检查邮件时出错: {e}")
|
||||||
|
|
||||||
# 等待一段时间再检查
|
# 等待一段时间再检查
|
||||||
time.sleep(poll_interval)
|
self._sleep_with_cancel(poll_interval)
|
||||||
|
|
||||||
logger.warning(f"等待验证码超时: {email}")
|
logger.warning(f"等待验证码超时: {email}")
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -337,6 +337,7 @@ class OutlookService(BaseEmailService):
|
|||||||
poll_count = 0
|
poll_count = 0
|
||||||
|
|
||||||
while time.time() - start_time < actual_timeout:
|
while time.time() - start_time < actual_timeout:
|
||||||
|
self._raise_if_cancelled("等待 Outlook 验证码时任务已取消")
|
||||||
poll_count += 1
|
poll_count += 1
|
||||||
|
|
||||||
# 渐进式邮件检查:前 3 次只检查未读
|
# 渐进式邮件检查:前 3 次只检查未读
|
||||||
@@ -387,7 +388,7 @@ class OutlookService(BaseEmailService):
|
|||||||
logger.warning(f"[{email}] 检查出错: {e}")
|
logger.warning(f"[{email}] 检查出错: {e}")
|
||||||
|
|
||||||
# 等待下次轮询
|
# 等待下次轮询
|
||||||
time.sleep(poll_interval)
|
self._sleep_with_cancel(poll_interval)
|
||||||
|
|
||||||
elapsed = int(time.time() - start_time)
|
elapsed = int(time.time() - start_time)
|
||||||
logger.warning(f"[{email}] 验证码超时 ({actual_timeout}s),共轮询 {poll_count} 次")
|
logger.warning(f"[{email}] 验证码超时 ({actual_timeout}s),共轮询 {poll_count} 次")
|
||||||
|
|||||||
@@ -309,6 +309,7 @@ class TempMailService(BaseEmailService):
|
|||||||
# jwt = cached.get("jwt")
|
# jwt = cached.get("jwt")
|
||||||
|
|
||||||
while time.time() - start_time < timeout:
|
while time.time() - start_time < timeout:
|
||||||
|
self._raise_if_cancelled("等待 TempMail 验证码时任务已取消")
|
||||||
try:
|
try:
|
||||||
# if jwt:
|
# if jwt:
|
||||||
# response = self._make_request(
|
# response = self._make_request(
|
||||||
@@ -327,7 +328,7 @@ class TempMailService(BaseEmailService):
|
|||||||
# /user_api/mails 和 /admin/mails 返回格式相同: {"results": [...], "total": N}
|
# /user_api/mails 和 /admin/mails 返回格式相同: {"results": [...], "total": N}
|
||||||
mails = response.get("results", [])
|
mails = response.get("results", [])
|
||||||
if not isinstance(mails, list):
|
if not isinstance(mails, list):
|
||||||
time.sleep(poll_interval)
|
self._sleep_with_cancel(poll_interval)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
ordered_mails = self._sort_items_by_message_time(
|
ordered_mails = self._sort_items_by_message_time(
|
||||||
@@ -389,7 +390,7 @@ class TempMailService(BaseEmailService):
|
|||||||
raise
|
raise
|
||||||
logger.debug(f"检查 TempMail 邮件时出错: {e}")
|
logger.debug(f"检查 TempMail 邮件时出错: {e}")
|
||||||
|
|
||||||
time.sleep(poll_interval)
|
self._sleep_with_cancel(poll_interval)
|
||||||
|
|
||||||
logger.warning(f"等待 TempMail 验证码超时: {email}")
|
logger.warning(f"等待 TempMail 验证码超时: {email}")
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -217,6 +217,7 @@ class TempmailService(BaseEmailService):
|
|||||||
seen_ids = set()
|
seen_ids = set()
|
||||||
|
|
||||||
while time.time() - start_time < timeout:
|
while time.time() - start_time < timeout:
|
||||||
|
self._raise_if_cancelled("等待 Tempmail 验证码时任务已取消")
|
||||||
try:
|
try:
|
||||||
# 获取邮件列表
|
# 获取邮件列表
|
||||||
response = self.http_client.get(
|
response = self.http_client.get(
|
||||||
@@ -226,7 +227,7 @@ class TempmailService(BaseEmailService):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
time.sleep(poll_interval)
|
self._sleep_with_cancel(poll_interval)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
data = response.json()
|
data = response.json()
|
||||||
@@ -239,7 +240,7 @@ class TempmailService(BaseEmailService):
|
|||||||
email_list = data.get("emails", []) if isinstance(data, dict) else []
|
email_list = data.get("emails", []) if isinstance(data, dict) else []
|
||||||
|
|
||||||
if not isinstance(email_list, list):
|
if not isinstance(email_list, list):
|
||||||
time.sleep(poll_interval)
|
self._sleep_with_cancel(poll_interval)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
ordered_emails = self._sort_items_by_message_time(
|
ordered_emails = self._sort_items_by_message_time(
|
||||||
@@ -302,7 +303,7 @@ class TempmailService(BaseEmailService):
|
|||||||
logger.debug(f"检查邮件时出错: {e}")
|
logger.debug(f"检查邮件时出错: {e}")
|
||||||
|
|
||||||
# 等待一段时间再检查
|
# 等待一段时间再检查
|
||||||
time.sleep(poll_interval)
|
self._sleep_with_cancel(poll_interval)
|
||||||
|
|
||||||
logger.warning(f"等待验证码超时: {email}")
|
logger.warning(f"等待验证码超时: {email}")
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from ...database.session import get_db
|
|||||||
from ...database.models import RegistrationTask, Proxy
|
from ...database.models import RegistrationTask, Proxy
|
||||||
from ...core.register import (
|
from ...core.register import (
|
||||||
ERROR_OTP_TIMEOUT_SECONDARY,
|
ERROR_OTP_TIMEOUT_SECONDARY,
|
||||||
|
ERROR_TASK_CANCELLED,
|
||||||
RegistrationEngine,
|
RegistrationEngine,
|
||||||
RegistrationResult,
|
RegistrationResult,
|
||||||
)
|
)
|
||||||
@@ -395,6 +396,11 @@ def _run_registration_engine_attempt(
|
|||||||
status_callback=status_callback,
|
status_callback=status_callback,
|
||||||
task_uuid=task_uuid,
|
task_uuid=task_uuid,
|
||||||
)
|
)
|
||||||
|
create_cancel_callback = getattr(task_manager, "create_check_cancelled_callback", None)
|
||||||
|
if callable(create_cancel_callback):
|
||||||
|
setattr(engine, "check_cancelled", create_cancel_callback(task_uuid))
|
||||||
|
else:
|
||||||
|
setattr(engine, "check_cancelled", lambda: False)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = engine.run()
|
result = engine.run()
|
||||||
@@ -435,6 +441,52 @@ def _get_batch_snapshot(batch_id: str) -> Optional[dict]:
|
|||||||
return task_manager.get_batch_status(batch_id)
|
return task_manager.get_batch_status(batch_id)
|
||||||
|
|
||||||
|
|
||||||
|
def _finalize_task_cancelled(
|
||||||
|
db,
|
||||||
|
task_uuid: str,
|
||||||
|
message: str = "任务已取消",
|
||||||
|
*,
|
||||||
|
email_service: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""将任务统一收敛为已取消状态。"""
|
||||||
|
crud.update_registration_task(
|
||||||
|
db,
|
||||||
|
task_uuid,
|
||||||
|
status="cancelled",
|
||||||
|
completed_at=datetime.utcnow(),
|
||||||
|
error_message=message,
|
||||||
|
)
|
||||||
|
status_kwargs = {"error": message, "cancelled": True}
|
||||||
|
if email_service:
|
||||||
|
status_kwargs["email_service"] = email_service
|
||||||
|
task_manager.update_status(task_uuid, "cancelled", **status_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _request_task_cancellation(db, task_uuid: str, message: str = "任务已取消"):
|
||||||
|
"""向运行中的任务发送取消信号,并持久化取消状态。"""
|
||||||
|
task_manager.cancel_task(task_uuid)
|
||||||
|
_finalize_task_cancelled(db, task_uuid, message)
|
||||||
|
|
||||||
|
|
||||||
|
def _request_batch_cancellation(batch_id: str, message: str = "批量任务取消请求已提交"):
|
||||||
|
"""向批量任务及其成员任务传播取消信号。"""
|
||||||
|
batch = _require_batch_snapshot(batch_id)
|
||||||
|
if batch.get("finished"):
|
||||||
|
raise HTTPException(status_code=400, detail="批量任务已完成")
|
||||||
|
|
||||||
|
task_manager.cancel_batch(batch_id)
|
||||||
|
task_manager.update_batch_status(batch_id, status="cancelled", cancelled=True)
|
||||||
|
|
||||||
|
with get_db() as db:
|
||||||
|
for task_uuid in batch.get("task_uuids", []):
|
||||||
|
task_manager.cancel_task(task_uuid)
|
||||||
|
task = crud.get_registration_task(db, task_uuid)
|
||||||
|
if task and task.status in ["pending", "running"]:
|
||||||
|
_finalize_task_cancelled(db, task_uuid, "任务已取消")
|
||||||
|
|
||||||
|
return {"success": True, "message": message}
|
||||||
|
|
||||||
|
|
||||||
def _require_batch_snapshot(batch_id: str) -> dict:
|
def _require_batch_snapshot(batch_id: str) -> dict:
|
||||||
batch = _get_batch_snapshot(batch_id)
|
batch = _get_batch_snapshot(batch_id)
|
||||||
if batch is None:
|
if batch is None:
|
||||||
@@ -560,8 +612,12 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy:
|
|||||||
"""
|
"""
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
try:
|
try:
|
||||||
|
def cancellation_requested() -> bool:
|
||||||
|
return task_manager.is_cancelled(task_uuid)
|
||||||
|
|
||||||
if task_manager.is_cancelled(task_uuid):
|
if task_manager.is_cancelled(task_uuid):
|
||||||
logger.info(f"任务 {task_uuid} 已取消,跳过执行")
|
logger.info(f"任务 {task_uuid} 已取消,跳过执行")
|
||||||
|
_finalize_task_cancelled(db, task_uuid)
|
||||||
return
|
return
|
||||||
|
|
||||||
task = crud.update_registration_task(
|
task = crud.update_registration_task(
|
||||||
@@ -583,6 +639,10 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy:
|
|||||||
proxy_id = None
|
proxy_id = None
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
|
if cancellation_requested():
|
||||||
|
_finalize_task_cancelled(db, task_uuid, email_service=active_service_type.value)
|
||||||
|
return
|
||||||
|
|
||||||
actual_proxy_url = requested_proxy
|
actual_proxy_url = requested_proxy
|
||||||
proxy_id = None
|
proxy_id = None
|
||||||
if not actual_proxy_url:
|
if not actual_proxy_url:
|
||||||
@@ -605,6 +665,10 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy:
|
|||||||
should_retry_with_new_proxy = False
|
should_retry_with_new_proxy = False
|
||||||
|
|
||||||
for attempt_index, candidate in enumerate(service_candidates, start=1):
|
for attempt_index, candidate in enumerate(service_candidates, start=1):
|
||||||
|
if cancellation_requested():
|
||||||
|
_finalize_task_cancelled(db, task_uuid, email_service=active_service_type.value)
|
||||||
|
return
|
||||||
|
|
||||||
selected_service_type = candidate["service_type"]
|
selected_service_type = candidate["service_type"]
|
||||||
candidate_config = candidate["config"]
|
candidate_config = candidate["config"]
|
||||||
db_service = candidate.get("db_service")
|
db_service = candidate.get("db_service")
|
||||||
@@ -630,6 +694,11 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy:
|
|||||||
candidate_config,
|
candidate_config,
|
||||||
name=db_service.name if db_service is not None else None,
|
name=db_service.name if db_service is not None else None,
|
||||||
)
|
)
|
||||||
|
create_cancel_callback = getattr(task_manager, "create_check_cancelled_callback", None)
|
||||||
|
if callable(create_cancel_callback):
|
||||||
|
set_cancel_callback = getattr(email_service, "set_check_cancelled", None)
|
||||||
|
if callable(set_cancel_callback):
|
||||||
|
set_cancel_callback(create_cancel_callback(task_uuid))
|
||||||
(
|
(
|
||||||
engine,
|
engine,
|
||||||
result,
|
result,
|
||||||
@@ -645,6 +714,15 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy:
|
|||||||
status_callback=status_callback,
|
status_callback=status_callback,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if cancellation_requested() or result.error_code == ERROR_TASK_CANCELLED:
|
||||||
|
_finalize_task_cancelled(
|
||||||
|
db,
|
||||||
|
task_uuid,
|
||||||
|
result.error_message or "任务已取消",
|
||||||
|
email_service=active_service_type.value,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
if result.success:
|
if result.success:
|
||||||
break
|
break
|
||||||
|
|
||||||
@@ -707,6 +785,10 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy:
|
|||||||
break
|
break
|
||||||
|
|
||||||
if result.success:
|
if result.success:
|
||||||
|
if cancellation_requested():
|
||||||
|
_finalize_task_cancelled(db, task_uuid, "任务已取消", email_service=active_service_type.value)
|
||||||
|
return
|
||||||
|
|
||||||
# 更新代理使用时间
|
# 更新代理使用时间
|
||||||
update_proxy_usage(db, proxy_id)
|
update_proxy_usage(db, proxy_id)
|
||||||
|
|
||||||
@@ -836,6 +918,10 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy:
|
|||||||
except Exception as na_err:
|
except Exception as na_err:
|
||||||
log_callback(f"[NEWAPI] 上传异常: {na_err}")
|
log_callback(f"[NEWAPI] 上传异常: {na_err}")
|
||||||
|
|
||||||
|
if cancellation_requested():
|
||||||
|
_finalize_task_cancelled(db, task_uuid, "任务已取消", email_service=active_service_type.value)
|
||||||
|
return
|
||||||
|
|
||||||
# 更新任务状态
|
# 更新任务状态
|
||||||
crud.update_registration_task(
|
crud.update_registration_task(
|
||||||
db, task_uuid,
|
db, task_uuid,
|
||||||
@@ -857,6 +943,15 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy:
|
|||||||
|
|
||||||
logger.info(f"注册任务完成: {task_uuid}, 邮箱: {result.email}")
|
logger.info(f"注册任务完成: {task_uuid}, 邮箱: {result.email}")
|
||||||
else:
|
else:
|
||||||
|
if cancellation_requested() or result.error_code == ERROR_TASK_CANCELLED:
|
||||||
|
_finalize_task_cancelled(
|
||||||
|
db,
|
||||||
|
task_uuid,
|
||||||
|
result.error_message or "任务已取消",
|
||||||
|
email_service=active_service_type.value,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
# 更新任务状态为失败
|
# 更新任务状态为失败
|
||||||
crud.update_registration_task(
|
crud.update_registration_task(
|
||||||
db, task_uuid,
|
db, task_uuid,
|
||||||
@@ -880,15 +975,18 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
crud.update_registration_task(
|
if task_manager.is_cancelled(task_uuid):
|
||||||
db, task_uuid,
|
_finalize_task_cancelled(db, task_uuid)
|
||||||
status="failed",
|
else:
|
||||||
completed_at=datetime.utcnow(),
|
crud.update_registration_task(
|
||||||
error_message=str(e)
|
db, task_uuid,
|
||||||
)
|
status="failed",
|
||||||
|
completed_at=datetime.utcnow(),
|
||||||
|
error_message=str(e)
|
||||||
|
)
|
||||||
|
|
||||||
# 更新 TaskManager 状态
|
# 更新 TaskManager 状态
|
||||||
task_manager.update_status(task_uuid, "failed", error=str(e))
|
task_manager.update_status(task_uuid, "failed", error=str(e))
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -904,6 +1002,11 @@ async def run_registration_task(task_uuid: str, email_service_type: str, proxy:
|
|||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
task_manager.set_loop(loop)
|
task_manager.set_loop(loop)
|
||||||
|
|
||||||
|
if task_manager.is_cancelled(task_uuid):
|
||||||
|
with get_db() as db:
|
||||||
|
_finalize_task_cancelled(db, task_uuid)
|
||||||
|
return
|
||||||
|
|
||||||
# 初始化 TaskManager 状态
|
# 初始化 TaskManager 状态
|
||||||
task_manager.update_status(task_uuid, "pending", email_service=email_service_type)
|
task_manager.update_status(task_uuid, "pending", email_service=email_service_type)
|
||||||
task_manager.add_log(task_uuid, f"{log_prefix} [系统] 任务 {task_uuid[:8]} 已加入队列" if log_prefix else f"[系统] 任务 {task_uuid[:8]} 已加入队列")
|
task_manager.add_log(task_uuid, f"{log_prefix} [系统] 任务 {task_uuid[:8]} 已加入队列" if log_prefix else f"[系统] 任务 {task_uuid[:8]} 已加入队列")
|
||||||
@@ -1635,12 +1738,7 @@ async def get_batch_status(batch_id: str):
|
|||||||
@router.post("/batch/{batch_id}/cancel")
|
@router.post("/batch/{batch_id}/cancel")
|
||||||
async def cancel_batch(batch_id: str):
|
async def cancel_batch(batch_id: str):
|
||||||
"""取消批量任务"""
|
"""取消批量任务"""
|
||||||
batch = _require_batch_snapshot(batch_id)
|
return _request_batch_cancellation(batch_id)
|
||||||
if batch.get("finished"):
|
|
||||||
raise HTTPException(status_code=400, detail="批量任务已完成")
|
|
||||||
|
|
||||||
task_manager.cancel_batch(batch_id)
|
|
||||||
return {"success": True, "message": "批量任务取消请求已提交"}
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/tasks", response_model=TaskListResponse)
|
@router.get("/tasks", response_model=TaskListResponse)
|
||||||
@@ -1703,7 +1801,7 @@ async def cancel_task(task_uuid: str):
|
|||||||
if task.status not in ["pending", "running"]:
|
if task.status not in ["pending", "running"]:
|
||||||
raise HTTPException(status_code=400, detail="任务已完成或已取消")
|
raise HTTPException(status_code=400, detail="任务已完成或已取消")
|
||||||
|
|
||||||
task = crud.update_registration_task(db, task_uuid, status="cancelled")
|
_request_task_cancellation(db, task_uuid)
|
||||||
|
|
||||||
return {"success": True, "message": "任务已取消"}
|
return {"success": True, "message": "任务已取消"}
|
||||||
|
|
||||||
@@ -2207,10 +2305,4 @@ async def get_outlook_batch_status(batch_id: str):
|
|||||||
@router.post("/outlook-batch/{batch_id}/cancel")
|
@router.post("/outlook-batch/{batch_id}/cancel")
|
||||||
async def cancel_outlook_batch(batch_id: str):
|
async def cancel_outlook_batch(batch_id: str):
|
||||||
"""取消 Outlook 批量任务"""
|
"""取消 Outlook 批量任务"""
|
||||||
batch = _require_batch_snapshot(batch_id)
|
return _request_batch_cancellation(batch_id)
|
||||||
if batch.get("finished"):
|
|
||||||
raise HTTPException(status_code=400, detail="批量任务已完成")
|
|
||||||
|
|
||||||
task_manager.cancel_batch(batch_id)
|
|
||||||
|
|
||||||
return {"success": True, "message": "批量任务取消请求已提交"}
|
|
||||||
|
|||||||
@@ -86,7 +86,14 @@ async def task_websocket(websocket: WebSocket, task_uuid: str):
|
|||||||
|
|
||||||
# 处理取消请求
|
# 处理取消请求
|
||||||
elif data.get("type") == "cancel":
|
elif data.get("type") == "cancel":
|
||||||
task_manager.cancel_task(task_uuid)
|
from . import registration as registration_routes
|
||||||
|
|
||||||
|
with get_db() as db:
|
||||||
|
task = crud.get_registration_task(db, task_uuid)
|
||||||
|
if task and task.status in ["pending", "running"]:
|
||||||
|
registration_routes._request_task_cancellation(db, task_uuid)
|
||||||
|
else:
|
||||||
|
task_manager.cancel_task(task_uuid)
|
||||||
await websocket.send_json({
|
await websocket.send_json({
|
||||||
"type": "status",
|
"type": "status",
|
||||||
"task_uuid": task_uuid,
|
"task_uuid": task_uuid,
|
||||||
@@ -163,7 +170,9 @@ async def batch_websocket(websocket: WebSocket, batch_id: str):
|
|||||||
|
|
||||||
# 处理取消请求
|
# 处理取消请求
|
||||||
elif data.get("type") == "cancel":
|
elif data.get("type") == "cancel":
|
||||||
task_manager.cancel_batch(batch_id)
|
from . import registration as registration_routes
|
||||||
|
|
||||||
|
registration_routes._request_batch_cancellation(batch_id)
|
||||||
await websocket.send_json({
|
await websocket.send_json({
|
||||||
"type": "status",
|
"type": "status",
|
||||||
"batch_id": batch_id,
|
"batch_id": batch_id,
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
const test = require('node:test');
|
const test = require('node:test');
|
||||||
const assert = require('node:assert/strict');
|
const assert = require('node:assert/strict');
|
||||||
const fs = require('node:fs');
|
const fs = require('node:fs');
|
||||||
|
const path = require('node:path');
|
||||||
const vm = require('node:vm');
|
const vm = require('node:vm');
|
||||||
|
|
||||||
const APP_JS_PATH = '/Users/zhoukailian/.config/superpowers/worktrees/codex-manager/repro-batch-monitor/static/js/app.js';
|
const APP_JS_PATH = path.join(__dirname, '..', 'static', 'js', 'app.js');
|
||||||
|
|
||||||
function createElementStub() {
|
function createElementStub() {
|
||||||
return {
|
return {
|
||||||
|
|||||||
363
tests/test_registration_task_cancellation.py
Normal file
363
tests/test_registration_task_cancellation.py
Normal file
@@ -0,0 +1,363 @@
|
|||||||
|
import asyncio
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from pathlib import Path
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
from src.core.register import ERROR_TASK_CANCELLED, RegistrationResult
|
||||||
|
from src.database.models import Base, RegistrationTask
|
||||||
|
from src.database.session import DatabaseSessionManager
|
||||||
|
from src.services import EmailServiceType
|
||||||
|
from src.services.base import BaseEmailService, EmailServiceCancelledError
|
||||||
|
from src.web.routes import registration as registration_routes
|
||||||
|
|
||||||
|
|
||||||
|
class FakeTaskManager:
|
||||||
|
def __init__(self):
|
||||||
|
self.cancelled = set()
|
||||||
|
self.status_updates = []
|
||||||
|
self.logs = {}
|
||||||
|
self.batch_status = {}
|
||||||
|
|
||||||
|
def is_cancelled(self, task_uuid):
|
||||||
|
return task_uuid in self.cancelled
|
||||||
|
|
||||||
|
def cancel_task(self, task_uuid):
|
||||||
|
self.cancelled.add(task_uuid)
|
||||||
|
|
||||||
|
def update_status(self, task_uuid, status, **kwargs):
|
||||||
|
self.status_updates.append((task_uuid, status, kwargs))
|
||||||
|
|
||||||
|
def create_log_callback(self, task_uuid, prefix="", batch_id=""):
|
||||||
|
def callback(message):
|
||||||
|
full_message = f"{prefix} {message}" if prefix else message
|
||||||
|
self.logs.setdefault(task_uuid, []).append(full_message)
|
||||||
|
return callback
|
||||||
|
|
||||||
|
def create_check_cancelled_callback(self, task_uuid):
|
||||||
|
return lambda: self.is_cancelled(task_uuid)
|
||||||
|
|
||||||
|
def get_batch_status(self, batch_id):
|
||||||
|
snapshot = self.batch_status.get(batch_id)
|
||||||
|
return dict(snapshot) if snapshot else None
|
||||||
|
|
||||||
|
def cancel_batch(self, batch_id):
|
||||||
|
snapshot = self.batch_status.setdefault(batch_id, {})
|
||||||
|
snapshot["cancelled"] = True
|
||||||
|
snapshot["status"] = "cancelling"
|
||||||
|
|
||||||
|
def update_batch_status(self, batch_id, **kwargs):
|
||||||
|
snapshot = self.batch_status.setdefault(batch_id, {})
|
||||||
|
snapshot.update(kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class DummySettings:
|
||||||
|
proxy_dynamic_enabled = False
|
||||||
|
proxy_dynamic_api_url = ""
|
||||||
|
email_code_timeout = 10
|
||||||
|
email_code_poll_interval = 1
|
||||||
|
email_code_resend_max_retries = 0
|
||||||
|
email_code_non_openai_sender_resend_max_retries = 0
|
||||||
|
openai_client_id = "client-id"
|
||||||
|
openai_auth_url = "https://auth.example.test"
|
||||||
|
openai_token_url = "https://token.example.test"
|
||||||
|
openai_redirect_uri = "https://callback.example.test"
|
||||||
|
openai_scope = "openid profile email"
|
||||||
|
|
||||||
|
def get_proxy_url(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _build_fake_get_db(manager):
|
||||||
|
@contextmanager
|
||||||
|
def fake_get_db():
|
||||||
|
with manager.session_scope() as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
return fake_get_db
|
||||||
|
|
||||||
|
|
||||||
|
class FakeRegistrationEngine:
|
||||||
|
started_event = None
|
||||||
|
|
||||||
|
def __init__(self, email_service, proxy_url=None, callback_logger=None, status_callback=None, task_uuid=None):
|
||||||
|
self.email_service = email_service
|
||||||
|
self.phase_history = []
|
||||||
|
self.check_cancelled = None
|
||||||
|
self.callback_logger = callback_logger or (lambda _msg: None)
|
||||||
|
self.task_uuid = task_uuid
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
if self.started_event is not None:
|
||||||
|
self.started_event.set()
|
||||||
|
while True:
|
||||||
|
if callable(self.check_cancelled) and self.check_cancelled():
|
||||||
|
return RegistrationResult(
|
||||||
|
success=False,
|
||||||
|
error_message="任务已取消",
|
||||||
|
error_code=ERROR_TASK_CANCELLED,
|
||||||
|
logs=[],
|
||||||
|
)
|
||||||
|
time.sleep(0.01)
|
||||||
|
|
||||||
|
def save_to_database(self, result):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class FakePollingEmailService(BaseEmailService):
|
||||||
|
def __init__(self, started_event=None):
|
||||||
|
super().__init__(EmailServiceType.TEMPMAIL, "fake-polling-email")
|
||||||
|
self.started_event = started_event
|
||||||
|
|
||||||
|
def create_email(self, config=None):
|
||||||
|
return {"email": "poll@example.test", "service_id": "poll-service"}
|
||||||
|
|
||||||
|
def get_verification_code(self, email: str, email_id: str = None, timeout: int = 120, pattern: str = r"(?<!\d)(\d{6})(?!\d)", otp_sent_at=None):
|
||||||
|
if self.started_event is not None:
|
||||||
|
self.started_event.set()
|
||||||
|
while True:
|
||||||
|
self._raise_if_cancelled("任务已取消")
|
||||||
|
self._sleep_with_cancel(0.05)
|
||||||
|
|
||||||
|
def list_emails(self, **kwargs):
|
||||||
|
return []
|
||||||
|
|
||||||
|
def delete_email(self, email_id: str):
|
||||||
|
return True
|
||||||
|
|
||||||
|
def check_health(self):
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class FakeEmailPollingRegistrationEngine:
|
||||||
|
def __init__(self, email_service, proxy_url=None, callback_logger=None, status_callback=None, task_uuid=None):
|
||||||
|
self.email_service = email_service
|
||||||
|
self.phase_history = []
|
||||||
|
self.check_cancelled = None
|
||||||
|
self.callback_logger = callback_logger or (lambda _msg: None)
|
||||||
|
self.task_uuid = task_uuid
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
try:
|
||||||
|
self.email_service.get_verification_code("poll@example.test", timeout=60)
|
||||||
|
except EmailServiceCancelledError as exc:
|
||||||
|
return RegistrationResult(
|
||||||
|
success=False,
|
||||||
|
error_message=str(exc),
|
||||||
|
error_code=ERROR_TASK_CANCELLED,
|
||||||
|
logs=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
return RegistrationResult(success=False, error_message="邮箱轮询未被取消", logs=[])
|
||||||
|
|
||||||
|
def save_to_database(self, result):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def test_cancel_task_route_marks_task_manager_and_db_cancelled(monkeypatch):
|
||||||
|
runtime_dir = Path("tests_runtime")
|
||||||
|
runtime_dir.mkdir(exist_ok=True)
|
||||||
|
db_path = runtime_dir / "registration_cancel_route.db"
|
||||||
|
if db_path.exists():
|
||||||
|
db_path.unlink()
|
||||||
|
|
||||||
|
manager = DatabaseSessionManager(f"sqlite:///{db_path}")
|
||||||
|
Base.metadata.create_all(bind=manager.engine)
|
||||||
|
|
||||||
|
task_uuid = "task-cancel-route"
|
||||||
|
with manager.session_scope() as session:
|
||||||
|
session.add(RegistrationTask(task_uuid=task_uuid, status="running"))
|
||||||
|
|
||||||
|
fake_task_manager = FakeTaskManager()
|
||||||
|
monkeypatch.setattr(registration_routes, "get_db", _build_fake_get_db(manager))
|
||||||
|
monkeypatch.setattr(registration_routes, "task_manager", fake_task_manager)
|
||||||
|
|
||||||
|
response = asyncio.run(registration_routes.cancel_task(task_uuid))
|
||||||
|
|
||||||
|
assert response == {"success": True, "message": "任务已取消"}
|
||||||
|
assert task_uuid in fake_task_manager.cancelled
|
||||||
|
|
||||||
|
with manager.session_scope() as session:
|
||||||
|
task = session.query(RegistrationTask).filter(RegistrationTask.task_uuid == task_uuid).first()
|
||||||
|
assert task.status == "cancelled"
|
||||||
|
assert task.error_message == "任务已取消"
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_sync_registration_task_stops_after_cancel_request(monkeypatch):
|
||||||
|
runtime_dir = Path("tests_runtime")
|
||||||
|
runtime_dir.mkdir(exist_ok=True)
|
||||||
|
db_path = runtime_dir / "registration_cancel_runtime.db"
|
||||||
|
if db_path.exists():
|
||||||
|
db_path.unlink()
|
||||||
|
|
||||||
|
manager = DatabaseSessionManager(f"sqlite:///{db_path}")
|
||||||
|
Base.metadata.create_all(bind=manager.engine)
|
||||||
|
|
||||||
|
task_uuid = "task-cancel-runtime"
|
||||||
|
with manager.session_scope() as session:
|
||||||
|
session.add(RegistrationTask(task_uuid=task_uuid, status="pending"))
|
||||||
|
|
||||||
|
fake_task_manager = FakeTaskManager()
|
||||||
|
start_event = threading.Event()
|
||||||
|
FakeRegistrationEngine.started_event = start_event
|
||||||
|
|
||||||
|
monkeypatch.setattr(registration_routes, "get_db", _build_fake_get_db(manager))
|
||||||
|
monkeypatch.setattr(registration_routes, "task_manager", fake_task_manager)
|
||||||
|
monkeypatch.setattr(registration_routes, "get_settings", lambda: DummySettings())
|
||||||
|
monkeypatch.setattr(registration_routes, "RegistrationEngine", FakeRegistrationEngine)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
registration_routes,
|
||||||
|
"_build_email_service_candidates",
|
||||||
|
lambda db, service_type, actual_proxy_url, email_service_id, email_service_config: [
|
||||||
|
{
|
||||||
|
"service_type": EmailServiceType.TEMPMAIL,
|
||||||
|
"config": {"proxy_url": actual_proxy_url},
|
||||||
|
"db_service": None,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
registration_routes.EmailServiceFactory,
|
||||||
|
"create",
|
||||||
|
lambda service_type, config, name=None: SimpleNamespace(
|
||||||
|
service_type=service_type,
|
||||||
|
name=name or service_type.value,
|
||||||
|
config=config,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
worker = threading.Thread(
|
||||||
|
target=registration_routes._run_sync_registration_task,
|
||||||
|
kwargs={
|
||||||
|
"task_uuid": task_uuid,
|
||||||
|
"email_service_type": EmailServiceType.TEMPMAIL.value,
|
||||||
|
"proxy": None,
|
||||||
|
"email_service_config": {},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
worker.start()
|
||||||
|
assert start_event.wait(timeout=1.0), "registration engine did not start in time"
|
||||||
|
|
||||||
|
response = asyncio.run(registration_routes.cancel_task(task_uuid))
|
||||||
|
assert response == {"success": True, "message": "任务已取消"}
|
||||||
|
|
||||||
|
worker.join(timeout=2.0)
|
||||||
|
assert not worker.is_alive(), "registration worker should stop after cancellation"
|
||||||
|
|
||||||
|
with manager.session_scope() as session:
|
||||||
|
task = session.query(RegistrationTask).filter(RegistrationTask.task_uuid == task_uuid).first()
|
||||||
|
assert task.status == "cancelled"
|
||||||
|
assert task.error_message == "任务已取消"
|
||||||
|
|
||||||
|
statuses = [status for current_uuid, status, _kwargs in fake_task_manager.status_updates if current_uuid == task_uuid]
|
||||||
|
assert "cancelled" in statuses
|
||||||
|
|
||||||
|
|
||||||
|
def test_cancel_batch_propagates_to_member_tasks(monkeypatch):
|
||||||
|
runtime_dir = Path("tests_runtime")
|
||||||
|
runtime_dir.mkdir(exist_ok=True)
|
||||||
|
db_path = runtime_dir / "registration_cancel_batch.db"
|
||||||
|
if db_path.exists():
|
||||||
|
db_path.unlink()
|
||||||
|
|
||||||
|
manager = DatabaseSessionManager(f"sqlite:///{db_path}")
|
||||||
|
Base.metadata.create_all(bind=manager.engine)
|
||||||
|
|
||||||
|
task_uuids = ["batch-cancel-1", "batch-cancel-2"]
|
||||||
|
with manager.session_scope() as session:
|
||||||
|
session.add_all([
|
||||||
|
RegistrationTask(task_uuid=task_uuids[0], status="running"),
|
||||||
|
RegistrationTask(task_uuid=task_uuids[1], status="pending"),
|
||||||
|
])
|
||||||
|
|
||||||
|
fake_task_manager = FakeTaskManager()
|
||||||
|
fake_task_manager.batch_status["batch-1"] = {
|
||||||
|
"finished": False,
|
||||||
|
"cancelled": False,
|
||||||
|
"task_uuids": task_uuids,
|
||||||
|
}
|
||||||
|
|
||||||
|
monkeypatch.setattr(registration_routes, "get_db", _build_fake_get_db(manager))
|
||||||
|
monkeypatch.setattr(registration_routes, "task_manager", fake_task_manager)
|
||||||
|
|
||||||
|
response = asyncio.run(registration_routes.cancel_batch("batch-1"))
|
||||||
|
|
||||||
|
assert response["success"] is True
|
||||||
|
assert fake_task_manager.batch_status["batch-1"]["cancelled"] is True
|
||||||
|
assert fake_task_manager.cancelled == set(task_uuids)
|
||||||
|
|
||||||
|
with manager.session_scope() as session:
|
||||||
|
tasks = session.query(RegistrationTask).order_by(RegistrationTask.task_uuid.asc()).all()
|
||||||
|
assert [task.status for task in tasks] == ["cancelled", "cancelled"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_sync_registration_task_stops_while_email_service_polling(monkeypatch):
|
||||||
|
runtime_dir = Path("tests_runtime")
|
||||||
|
runtime_dir.mkdir(exist_ok=True)
|
||||||
|
db_path = runtime_dir / "registration_cancel_email_polling.db"
|
||||||
|
if db_path.exists():
|
||||||
|
db_path.unlink()
|
||||||
|
|
||||||
|
manager = DatabaseSessionManager(f"sqlite:///{db_path}")
|
||||||
|
Base.metadata.create_all(bind=manager.engine)
|
||||||
|
|
||||||
|
task_uuid = "task-cancel-email-polling"
|
||||||
|
with manager.session_scope() as session:
|
||||||
|
session.add(RegistrationTask(task_uuid=task_uuid, status="pending"))
|
||||||
|
|
||||||
|
fake_task_manager = FakeTaskManager()
|
||||||
|
start_event = threading.Event()
|
||||||
|
|
||||||
|
monkeypatch.setattr(registration_routes, "get_db", _build_fake_get_db(manager))
|
||||||
|
monkeypatch.setattr(registration_routes, "task_manager", fake_task_manager)
|
||||||
|
monkeypatch.setattr(registration_routes, "get_settings", lambda: DummySettings())
|
||||||
|
monkeypatch.setattr(registration_routes, "RegistrationEngine", FakeEmailPollingRegistrationEngine)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
registration_routes,
|
||||||
|
"_build_email_service_candidates",
|
||||||
|
lambda db, service_type, actual_proxy_url, email_service_id, email_service_config: [
|
||||||
|
{
|
||||||
|
"service_type": EmailServiceType.TEMPMAIL,
|
||||||
|
"config": {"proxy_url": actual_proxy_url},
|
||||||
|
"db_service": None,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
registration_routes.EmailServiceFactory,
|
||||||
|
"create",
|
||||||
|
lambda service_type, config, name=None: FakePollingEmailService(start_event),
|
||||||
|
)
|
||||||
|
|
||||||
|
worker = threading.Thread(
|
||||||
|
target=registration_routes._run_sync_registration_task,
|
||||||
|
kwargs={
|
||||||
|
"task_uuid": task_uuid,
|
||||||
|
"email_service_type": EmailServiceType.TEMPMAIL.value,
|
||||||
|
"proxy": None,
|
||||||
|
"email_service_config": {},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
worker.start()
|
||||||
|
assert start_event.wait(timeout=1.0), "email polling did not start in time"
|
||||||
|
|
||||||
|
response = asyncio.run(registration_routes.cancel_task(task_uuid))
|
||||||
|
assert response == {"success": True, "message": "任务已取消"}
|
||||||
|
|
||||||
|
worker.join(timeout=2.0)
|
||||||
|
assert not worker.is_alive(), "registration worker should stop while email service is polling"
|
||||||
|
|
||||||
|
with manager.session_scope() as session:
|
||||||
|
task = session.query(RegistrationTask).filter(RegistrationTask.task_uuid == task_uuid).first()
|
||||||
|
assert task.status == "cancelled"
|
||||||
|
assert task.error_message == "任务已取消"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -1,9 +1,10 @@
|
|||||||
const test = require('node:test');
|
const test = require('node:test');
|
||||||
const assert = require('node:assert/strict');
|
const assert = require('node:assert/strict');
|
||||||
const fs = require('node:fs');
|
const fs = require('node:fs');
|
||||||
|
const path = require('node:path');
|
||||||
const vm = require('node:vm');
|
const vm = require('node:vm');
|
||||||
|
|
||||||
const APP_JS_PATH = '/Users/zhoukailian/.config/superpowers/worktrees/codex-manager/repro-batch-monitor/static/js/app.js';
|
const APP_JS_PATH = path.join(__dirname, '..', 'static', 'js', 'app.js');
|
||||||
|
|
||||||
function createElementStub() {
|
function createElementStub() {
|
||||||
return {
|
return {
|
||||||
|
|||||||
Reference in New Issue
Block a user