diff --git a/docs/reviews/TASK-5-VALIDATION-2026-03-23.md b/docs/reviews/TASK-5-VALIDATION-2026-03-23.md new file mode 100644 index 0000000..061aac9 --- /dev/null +++ b/docs/reviews/TASK-5-VALIDATION-2026-03-23.md @@ -0,0 +1,29 @@ +# Task 5 Validation - 2026-03-23 + +## Scope + +- Task 5 +- OTP timeout backoff handling +- Registration controller backoff state persistence + +## Commands + +1. `./.venv/bin/python -m pytest tests/test_registration_email_service_failover.py tests/test_registration_otp_phase.py` + - exit code: `0` + - result: `4 passed` + - notes: 存在项目既有的 SQLAlchemy / Pydantic / FastAPI deprecation warnings,本次任务未改动相关代码路径。 + +2. `./.venv/bin/ruff check src/services/base.py src/web/routes/registration.py tests/test_registration_email_service_failover.py` + - exit code: `127` + - result: failed + - notes: `.venv/bin/ruff` 不存在。 + +3. `./.venv/bin/python -m ruff check src/services/base.py src/web/routes/registration.py tests/test_registration_email_service_failover.py` + - exit code: `1` + - result: failed + - notes: 虚拟环境未安装 `ruff` 模块,未完成 lint 校验。 + +## Summary + +- 回归测试通过,覆盖 `OTP_TIMEOUT_SECONDARY` 连续 3 次失败进入 `3600s` 深度冷却。 +- Lint 校验因环境缺少 `ruff` 未执行。 diff --git a/src/core/register.py b/src/core/register.py index 3f336bc..81a5a43 100644 --- a/src/core/register.py +++ b/src/core/register.py @@ -1077,6 +1077,8 @@ class RegistrationEngine: if not password_ok: return None, None + self._otp_sent_at = time.time() + code = self._get_verification_code() if not code: self._log("登录流程获取验证码失败", "warning") diff --git a/src/database/crud.py b/src/database/crud.py index f854e5e..e7731f9 100644 --- a/src/database/crud.py +++ b/src/database/crud.py @@ -5,6 +5,7 @@ from typing import List, Optional, Dict, Any, Union, Iterable, Set from datetime import datetime, timedelta from sqlalchemy.orm import Session +from sqlalchemy.orm.attributes import flag_modified from sqlalchemy import and_, or_, desc, asc, func from .models import Account, EmailService, RegistrationTask, Setting, Proxy, CpaService, Sub2ApiService @@ -778,6 +779,8 @@ def delete_tm_service(db: Session, service_id: int) -> bool: db.delete(svc) db.commit() return True + + def update_outlook_refresh_token(db: Session, service_id: int, email: str, new_refresh_token: str): """更新 EmailService.config 中指定邮箱的 refresh_token""" service = db.query(EmailService).filter(EmailService.id == service_id).first() @@ -808,4 +811,5 @@ def update_outlook_refresh_token(db: Session, service_id: int, email: str, new_r return service.config = config + flag_modified(service, "config") db.commit() diff --git a/src/web/routes/registration.py b/src/web/routes/registration.py index 522a9d5..7b4ea98 100644 --- a/src/web/routes/registration.py +++ b/src/web/routes/registration.py @@ -4,6 +4,7 @@ import asyncio import logging +import threading import uuid import random import re @@ -32,9 +33,8 @@ router = APIRouter() # 任务存储(简单的内存存储,生产环境应使用 Redis) running_tasks: dict = {} -# 批量任务存储 -batch_tasks: Dict[str, dict] = {} email_service_circuit_breakers: Dict[int, EmailProviderBackoffState] = {} +_email_service_backoff_lock = threading.Lock() # ============== Proxy Helper Functions ============== @@ -323,6 +323,75 @@ def _record_email_service_timeout_backoff( return _store_email_service_backoff_state(service_id, backoff_state) +def _run_registration_engine_attempt( + task_uuid: str, + email_service, + actual_proxy_url: Optional[str], + log_callback, + db_service, +): + """执行单次注册引擎尝试,并在同一临界区内维护邮箱服务退避状态。""" + provider_backoff_before_run = EmailProviderBackoffState() + + with _email_service_backoff_lock: + if db_service is not None: + provider_backoff_before_run = _get_email_service_backoff_state(db_service.id) + if hasattr(email_service, "apply_provider_backoff_state"): + email_service.apply_provider_backoff_state(provider_backoff_before_run) + + engine = RegistrationEngine( + email_service=email_service, + proxy_url=actual_proxy_url, + callback_logger=log_callback, + task_uuid=task_uuid, + ) + + try: + result = engine.run() + finally: + close_engine = getattr(engine, "close", None) + if callable(close_engine): + close_engine() + + email_prepare_phase = _get_phase_result( + getattr(engine, "phase_history", []), + "email_prepare", + ) + if db_service is not None and email_prepare_phase is not None: + _store_email_service_backoff_state( + db_service.id, + getattr(email_prepare_phase, "provider_backoff", None), + ) + + if ( + db_service is not None + and not result.success + and result.error_code == ERROR_OTP_TIMEOUT_SECONDARY + ): + timeout_backoff = _record_email_service_timeout_backoff( + db_service.id, + email_service, + provider_backoff_before_run, + result.error_code, + result.error_message, + ) + else: + timeout_backoff = None + + return engine, result, email_prepare_phase, provider_backoff_before_run, timeout_backoff + + +def _get_batch_snapshot(batch_id: str) -> Optional[dict]: + return task_manager.get_batch_status(batch_id) + + +def _require_batch_snapshot(batch_id: str) -> dict: + batch = _get_batch_snapshot(batch_id) + if batch is None: + raise HTTPException(status_code=404, detail="批量任务不存在") + return batch + + def _build_email_service_candidates( db, service_type: EmailServiceType, @@ -506,35 +575,20 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy: candidate_config, name=db_service.name if db_service is not None else None, ) - provider_backoff_before_run = EmailProviderBackoffState() - if db_service is not None: - provider_backoff_before_run = _get_email_service_backoff_state(db_service.id) - if db_service is not None and hasattr(email_service, "apply_provider_backoff_state"): - email_service.apply_provider_backoff_state(provider_backoff_before_run) - engine = RegistrationEngine( + ( + engine, + result, + email_prepare_phase, + _, + timeout_backoff, + ) = _run_registration_engine_attempt( + task_uuid=task_uuid, email_service=email_service, - proxy_url=actual_proxy_url, - callback_logger=log_callback, - task_uuid=task_uuid + actual_proxy_url=actual_proxy_url, + log_callback=log_callback, + db_service=db_service, ) - try: - result = engine.run() - finally: - close_engine = getattr(engine, "close", None) - if callable(close_engine): - close_engine() - - email_prepare_phase = _get_phase_result( - getattr(engine, "phase_history", []), - "email_prepare", - ) - if db_service is not None and email_prepare_phase is not None: - _store_email_service_backoff_state( - db_service.id, - getattr(email_prepare_phase, "provider_backoff", None), - ) - if result.success: break @@ -551,28 +605,17 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy: and email_prepare_phase.provider_backoff is not None ) if not can_failover: - if ( - db_service is not None - and result.error_code == ERROR_OTP_TIMEOUT_SECONDARY - ): - timeout_backoff = _record_email_service_timeout_backoff( - db_service.id, - email_service, - provider_backoff_before_run, - result.error_code, - result.error_message, + if timeout_backoff is not None: + logger.warning( + f"邮箱服务 OTP 超时,已退避 {db_service.name} " + f"{timeout_backoff.delay_seconds} 秒,连续失败 " + f"{timeout_backoff.failures} 次" + ) + log_callback( + f"[系统] 邮箱服务 OTP 超时,退避 " + f"{timeout_backoff.delay_seconds} 秒: {db_service.name} " + f"(连续失败 {timeout_backoff.failures} 次)" ) - if timeout_backoff is not None: - logger.warning( - f"邮箱服务 OTP 超时,已退避 {db_service.name} " - f"{timeout_backoff.delay_seconds} 秒,连续失败 " - f"{timeout_backoff.failures} 次" - ) - log_callback( - f"[系统] 邮箱服务 OTP 超时,退避 " - f"{timeout_backoff.delay_seconds} 秒: {db_service.name} " - f"(连续失败 {timeout_backoff.failures} 次)" - ) break backoff_state = email_prepare_phase.provider_backoff @@ -799,30 +842,15 @@ async def run_registration_task(task_uuid: str, email_service_type: str, proxy: def _init_batch_state(batch_id: str, task_uuids: List[str]): """初始化批量任务内存状态""" - task_manager.init_batch(batch_id, len(task_uuids)) - batch_tasks[batch_id] = { - "total": len(task_uuids), - "completed": 0, - "success": 0, - "failed": 0, - "cancelled": False, - "task_uuids": task_uuids, - "current_index": 0, - "logs": [], - "finished": False - } + task_manager.init_batch(batch_id, len(task_uuids), task_uuids=task_uuids) def _make_batch_helpers(batch_id: str): """返回 add_batch_log 和 update_batch_status 辅助函数""" def add_batch_log(msg: str): - batch_tasks[batch_id]["logs"].append(msg) task_manager.add_batch_log(batch_id, msg) def update_batch_status(**kwargs): - for key, value in kwargs.items(): - if key in batch_tasks[batch_id]: - batch_tasks[batch_id][key] = value task_manager.update_batch_status(batch_id, **kwargs) return add_batch_log, update_batch_status @@ -866,9 +894,10 @@ async def run_batch_parallel( t = crud.get_registration_task(db, uuid) if t: async with counter_lock: - new_completed = batch_tasks[batch_id]["completed"] + 1 - new_success = batch_tasks[batch_id]["success"] - new_failed = batch_tasks[batch_id]["failed"] + batch_snapshot = _get_batch_snapshot(batch_id) or {} + new_completed = batch_snapshot.get("completed", 0) + 1 + new_success = batch_snapshot.get("success", 0) + new_failed = batch_snapshot.get("failed", 0) if t.status == "completed": new_success += 1 add_batch_log(f"{prefix} [成功] 注册成功") @@ -880,7 +909,11 @@ async def run_batch_parallel( try: await asyncio.gather(*[_run_one(i, u) for i, u in enumerate(task_uuids)], return_exceptions=True) if not task_manager.is_batch_cancelled(batch_id): - add_batch_log(f"[完成] 批量任务完成!成功: {batch_tasks[batch_id]['success']}, 失败: {batch_tasks[batch_id]['failed']}") + batch_snapshot = _get_batch_snapshot(batch_id) or {} + add_batch_log( + f"[完成] 批量任务完成!成功: {batch_snapshot.get('success', 0)}, " + f"失败: {batch_snapshot.get('failed', 0)}" + ) update_batch_status(finished=True, status="completed") else: update_batch_status(finished=True, status="cancelled") @@ -888,8 +921,6 @@ async def run_batch_parallel( logger.error(f"批量任务 {batch_id} 异常: {e}") add_batch_log(f"[错误] 批量任务异常: {str(e)}") update_batch_status(finished=True, status="failed") - finally: - batch_tasks[batch_id]["finished"] = True async def run_batch_pipeline( @@ -932,9 +963,10 @@ async def run_batch_pipeline( t = crud.get_registration_task(db, uuid) if t: async with counter_lock: - new_completed = batch_tasks[batch_id]["completed"] + 1 - new_success = batch_tasks[batch_id]["success"] - new_failed = batch_tasks[batch_id]["failed"] + batch_snapshot = _get_batch_snapshot(batch_id) or {} + new_completed = batch_snapshot.get("completed", 0) + 1 + new_success = batch_snapshot.get("success", 0) + new_failed = batch_snapshot.get("failed", 0) if t.status == "completed": new_success += 1 add_batch_log(f"{pfx} [成功] 注册成功") @@ -947,7 +979,7 @@ async def run_batch_pipeline( try: for i, task_uuid in enumerate(task_uuids): - if task_manager.is_batch_cancelled(batch_id) or batch_tasks[batch_id]["cancelled"]: + if task_manager.is_batch_cancelled(batch_id): with get_db() as db: for remaining_uuid in task_uuids[i:]: crud.update_registration_task(db, remaining_uuid, status="cancelled") @@ -971,14 +1003,16 @@ async def run_batch_pipeline( await asyncio.gather(*running_tasks_list, return_exceptions=True) if not task_manager.is_batch_cancelled(batch_id): - add_batch_log(f"[完成] 批量任务完成!成功: {batch_tasks[batch_id]['success']}, 失败: {batch_tasks[batch_id]['failed']}") + batch_snapshot = _get_batch_snapshot(batch_id) or {} + add_batch_log( + f"[完成] 批量任务完成!成功: {batch_snapshot.get('success', 0)}, " + f"失败: {batch_snapshot.get('failed', 0)}" + ) update_batch_status(finished=True, status="completed") except Exception as e: logger.error(f"批量任务 {batch_id} 异常: {e}") add_batch_log(f"[错误] 批量任务异常: {str(e)}") update_batch_status(finished=True, status="failed") - finally: - batch_tasks[batch_id]["finished"] = True async def run_batch_registration( @@ -1157,10 +1191,7 @@ async def start_batch_registration( @router.get("/batch/{batch_id}") async def get_batch_status(batch_id: str): """获取批量任务状态""" - if batch_id not in batch_tasks: - raise HTTPException(status_code=404, detail="批量任务不存在") - - batch = batch_tasks[batch_id] + batch = _require_batch_snapshot(batch_id) return { "batch_id": batch_id, "total": batch["total"], @@ -1177,14 +1208,10 @@ async def get_batch_status(batch_id: str): @router.post("/batch/{batch_id}/cancel") async def cancel_batch(batch_id: str): """取消批量任务""" - if batch_id not in batch_tasks: - raise HTTPException(status_code=404, detail="批量任务不存在") - - batch = batch_tasks[batch_id] + batch = _require_batch_snapshot(batch_id) if batch.get("finished"): raise HTTPException(status_code=400, detail="批量任务已完成") - batch["cancelled"] = True task_manager.cancel_batch(batch_id) return {"success": True, "message": "批量任务取消请求已提交"} @@ -1666,18 +1693,12 @@ async def start_outlook_batch_registration( batch_id = str(uuid.uuid4()) # 初始化批量任务状态 - batch_tasks[batch_id] = { - "total": len(actual_service_ids), - "completed": 0, - "success": 0, - "failed": 0, - "skipped": 0, - "cancelled": False, - "service_ids": actual_service_ids, - "current_index": 0, - "logs": [], - "finished": False - } + task_manager.init_batch( + batch_id, + len(actual_service_ids), + skipped=skipped_count, + service_ids=actual_service_ids, + ) # 在后台运行批量注册 background_tasks.add_task( @@ -1710,10 +1731,7 @@ async def start_outlook_batch_registration( @router.get("/outlook-batch/{batch_id}") async def get_outlook_batch_status(batch_id: str): """获取 Outlook 批量任务状态""" - if batch_id not in batch_tasks: - raise HTTPException(status_code=404, detail="批量任务不存在") - - batch = batch_tasks[batch_id] + batch = _require_batch_snapshot(batch_id) return { "batch_id": batch_id, "total": batch["total"], @@ -1724,7 +1742,8 @@ async def get_outlook_batch_status(batch_id: str): "current_index": batch["current_index"], "cancelled": batch["cancelled"], "finished": batch.get("finished", False), - "logs": batch.get("logs", []), + "service_ids": batch.get("service_ids", []), + "logs": task_manager.get_batch_logs(batch_id), "progress": f"{batch['completed']}/{batch['total']}" } @@ -1732,15 +1751,10 @@ async def get_outlook_batch_status(batch_id: str): @router.post("/outlook-batch/{batch_id}/cancel") async def cancel_outlook_batch(batch_id: str): """取消 Outlook 批量任务""" - if batch_id not in batch_tasks: - raise HTTPException(status_code=404, detail="批量任务不存在") - - batch = batch_tasks[batch_id] + batch = _require_batch_snapshot(batch_id) if batch.get("finished"): raise HTTPException(status_code=400, detail="批量任务已完成") - # 同时更新两个系统的取消状态 - batch["cancelled"] = True task_manager.cancel_batch(batch_id) return {"success": True, "message": "批量任务取消请求已提交"} diff --git a/src/web/task_manager.py b/src/web/task_manager.py index 44c03bf..2317943 100644 --- a/src/web/task_manager.py +++ b/src/web/task_manager.py @@ -240,18 +240,25 @@ class TaskManager: # ============== 批量任务管理 ============== - def init_batch(self, batch_id: str, total: int): + def init_batch(self, batch_id: str, total: int, **kwargs): """初始化批量任务""" - _batch_status[batch_id] = { - "status": "running", - "total": total, - "completed": 0, - "success": 0, - "failed": 0, - "skipped": 0, - "current_index": 0, - "finished": False - } + with _get_batch_lock(batch_id): + previous = _batch_status.get(batch_id, {}) + status = { + "status": "running", + "total": total, + "completed": 0, + "success": 0, + "failed": 0, + "skipped": previous.get("skipped", 0), + "cancelled": previous.get("cancelled", False), + "current_index": 0, + "finished": False, + } + status.update(previous) + status.update(kwargs) + status["total"] = total + _batch_status[batch_id] = status logger.info(f"批量任务 {batch_id} 已初始化,总数: {total}") def add_batch_log(self, batch_id: str, log_message: str): @@ -295,11 +302,11 @@ class TaskManager: def update_batch_status(self, batch_id: str, **kwargs): """更新批量任务状态""" - if batch_id not in _batch_status: - logger.warning(f"批量任务 {batch_id} 不存在") - return - - _batch_status[batch_id].update(kwargs) + with _get_batch_lock(batch_id): + if batch_id not in _batch_status: + logger.warning(f"批量任务 {batch_id} 不存在") + return + _batch_status[batch_id].update(kwargs) # 异步广播状态更新 if self._loop and self._loop.is_running(): @@ -331,7 +338,9 @@ class TaskManager: def get_batch_status(self, batch_id: str) -> Optional[dict]: """获取批量任务状态""" - return _batch_status.get(batch_id) + with _get_batch_lock(batch_id): + status = _batch_status.get(batch_id) + return status.copy() if status is not None else None def get_batch_logs(self, batch_id: str) -> List[str]: """获取批量任务日志""" @@ -345,10 +354,11 @@ class TaskManager: def cancel_batch(self, batch_id: str): """取消批量任务""" - if batch_id in _batch_status: - _batch_status[batch_id]["cancelled"] = True - _batch_status[batch_id]["status"] = "cancelling" - logger.info(f"批量任务 {batch_id} 已标记为取消") + with _get_batch_lock(batch_id): + if batch_id in _batch_status: + _batch_status[batch_id]["cancelled"] = True + _batch_status[batch_id]["status"] = "cancelling" + logger.info(f"批量任务 {batch_id} 已标记为取消") def register_batch_websocket(self, batch_id: str, websocket) -> List[str]: """注册批量任务 WebSocket 连接,并返回注册时刻的历史日志快照""" diff --git a/tests/test_account_token_sync_status.py b/tests/test_account_token_sync_status.py index 9a13635..cbba650 100644 --- a/tests/test_account_token_sync_status.py +++ b/tests/test_account_token_sync_status.py @@ -70,3 +70,37 @@ def test_update_account_preserves_pending_status_when_other_tokens_remain(tmp_pa assert updated.refresh_token == "" assert updated.token_sync_status == "pending" assert updated.token_sync_updated_at is not None + + +def test_update_outlook_refresh_token_persists_nested_config_changes(tmp_path): + manager = DatabaseSessionManager(f"sqlite:///{tmp_path}/test.db") + manager.create_tables() + manager.migrate_tables() + + with manager.session_scope() as session: + service = crud.create_email_service( + session, + service_type="outlook", + name="outlook-service", + config={ + "accounts": [ + {"email": "first@example.com", "refresh_token": "old-first"}, + {"email": "second@example.com", "refresh_token": "old-second"}, + ] + }, + ) + service_id = service.id + + crud.update_outlook_refresh_token( + session, + service_id=service_id, + email="second@example.com", + new_refresh_token="new-second", + ) + + with manager.session_scope() as session: + reloaded = crud.get_email_service_by_id(session, service_id) + + assert reloaded is not None + assert reloaded.config["accounts"][0]["refresh_token"] == "old-first" + assert reloaded.config["accounts"][1]["refresh_token"] == "new-second" diff --git a/tests/test_batch_task_manager.py b/tests/test_batch_task_manager.py index 1a2fef6..f2b6070 100644 --- a/tests/test_batch_task_manager.py +++ b/tests/test_batch_task_manager.py @@ -6,24 +6,23 @@ from src.web.routes import registration as registration_routes from src.web.task_manager import task_manager -def test_init_batch_state_keeps_batch_tasks_and_task_manager_in_sync(): +def test_init_batch_state_persists_state_in_task_manager(): batch_id = "batch-sync-init" task_uuids = ["task-1", "task-2", "task-3"] - registration_routes.batch_tasks.pop(batch_id, None) registration_routes._init_batch_state(batch_id, task_uuids) - batch_snapshot = registration_routes.batch_tasks[batch_id] manager_snapshot = task_manager.get_batch_status(batch_id) assert manager_snapshot is not None - assert batch_snapshot["total"] == manager_snapshot["total"] == 3 - assert batch_snapshot["completed"] == manager_snapshot["completed"] == 0 - assert batch_snapshot["success"] == manager_snapshot["success"] == 0 - assert batch_snapshot["failed"] == manager_snapshot["failed"] == 0 - assert batch_snapshot["finished"] is False + assert manager_snapshot["task_uuids"] == task_uuids + assert manager_snapshot["total"] == 3 + assert manager_snapshot["completed"] == 0 + assert manager_snapshot["success"] == 0 + assert manager_snapshot["failed"] == 0 assert manager_snapshot["finished"] is False assert manager_snapshot["status"] == "running" + assert task_manager.get_batch_logs(batch_id) == [] def test_run_batch_parallel_keeps_counter_updates_in_sync(monkeypatch): @@ -61,7 +60,6 @@ def test_run_batch_parallel_keeps_counter_updates_in_sync(monkeypatch): error_message = None if status == "completed" else f"{task_uuid}-error" return SimpleNamespace(status=status, error_message=error_message) - registration_routes.batch_tasks.pop(batch_id, None) monkeypatch.setattr(registration_routes, "run_registration_task", fake_run_registration_task) monkeypatch.setattr(registration_routes, "get_db", fake_get_db) monkeypatch.setattr(registration_routes.crud, "get_registration_task", fake_get_registration_task) @@ -78,13 +76,11 @@ def test_run_batch_parallel_keeps_counter_updates_in_sync(monkeypatch): ) ) - batch_snapshot = registration_routes.batch_tasks[batch_id] manager_snapshot = task_manager.get_batch_status(batch_id) assert manager_snapshot is not None - assert batch_snapshot["completed"] == manager_snapshot["completed"] == 3 - assert batch_snapshot["success"] == manager_snapshot["success"] == 2 - assert batch_snapshot["failed"] == manager_snapshot["failed"] == 1 - assert batch_snapshot["finished"] is True + assert manager_snapshot["completed"] == 3 + assert manager_snapshot["success"] == 2 + assert manager_snapshot["failed"] == 1 assert manager_snapshot["finished"] is True assert manager_snapshot["status"] == "completed" diff --git a/tests/test_registration_email_service_failover.py b/tests/test_registration_email_service_failover.py index bb3805a..020b63f 100644 --- a/tests/test_registration_email_service_failover.py +++ b/tests/test_registration_email_service_failover.py @@ -1,5 +1,6 @@ from contextlib import contextmanager from pathlib import Path +import threading from types import SimpleNamespace import src.services.base as base_module @@ -412,3 +413,128 @@ def test_registration_task_success_clears_email_service_backoff(monkeypatch): ) assert service_id not in registration_routes.email_service_circuit_breakers + + +def test_registration_task_backoff_failures_do_not_get_lost_under_concurrency(monkeypatch): + runtime_dir = Path("tests_runtime") + runtime_dir.mkdir(exist_ok=True) + db_path = runtime_dir / "registration_backoff_concurrency.db" + if db_path.exists(): + db_path.unlink() + + manager = DatabaseSessionManager(f"sqlite:///{db_path}") + Base.metadata.create_all(bind=manager.engine) + + task_uuids = ["task-backoff-1", "task-backoff-2"] + with manager.session_scope() as session: + for task_uuid in task_uuids: + session.add(RegistrationTask(task_uuid=task_uuid, status="pending")) + session.add( + EmailService( + service_type="duck_mail", + name="duck-primary", + config={ + "base_url": "https://mail-1.example.test", + "default_domain": "mail.example.test", + }, + enabled=True, + priority=0, + ) + ) + + @contextmanager + def fake_get_db(): + session = manager.SessionLocal() + try: + yield session + finally: + session.close() + + class DummySettings: + pass + + start_lock = threading.Lock() + started = {"count": 0} + peer_started = threading.Event() + + class FakeRegistrationEngine: + def __init__(self, email_service, proxy_url=None, callback_logger=None, task_uuid=None): + self.email_service = email_service + self.phase_history = [] + + def run(self): + with start_lock: + started["count"] += 1 + if started["count"] == len(task_uuids): + peer_started.set() + peer_started.wait(timeout=0.1) + + current_state = self.email_service.provider_backoff_state + next_failures = current_state.failures + 1 + delay_seconds = 30 if next_failures == 1 else 60 + self.phase_history = [ + PhaseResult( + phase="email_prepare", + success=False, + error_message="创建邮箱失败", + error_code="EMAIL_PROVIDER_RATE_LIMITED", + retryable=True, + next_action="switch_provider", + provider_backoff=EmailProviderBackoffState( + failures=next_failures, + delay_seconds=delay_seconds, + opened_until=1000.0 + delay_seconds, + last_error="请求失败: 429", + ), + ) + ] + return RegistrationResult( + success=False, + error_message="创建邮箱失败: 请求失败: 429", + logs=[], + ) + + def save_to_database(self, result): + return True + + def close(self): + return None + + monkeypatch.setattr(registration_routes, "get_db", fake_get_db) + monkeypatch.setattr(registration_routes, "get_settings", lambda: DummySettings()) + monkeypatch.setattr(registration_routes, "task_manager", DummyTaskManager()) + monkeypatch.setattr(registration_routes, "RegistrationEngine", FakeRegistrationEngine) + monkeypatch.setattr( + registration_routes.EmailServiceFactory, + "create", + lambda service_type, config, name=None: BackoffAwareEmailService( + service_type=service_type, + config=config, + name=name, + ), + ) + registration_routes.email_service_circuit_breakers.clear() + + with manager.session_scope() as session: + service_id = session.query(EmailService.id).filter(EmailService.name == "duck-primary").scalar() + + threads = [ + threading.Thread( + target=registration_routes._run_sync_registration_task, + kwargs={ + "task_uuid": task_uuid, + "email_service_type": EmailServiceType.DUCK_MAIL.value, + "proxy": None, + "email_service_config": None, + }, + ) + for task_uuid in task_uuids + ] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + state = registration_routes.email_service_circuit_breakers[service_id] + assert state.failures == 2 + assert state.delay_seconds == 60 diff --git a/tests/test_registration_otp_phase.py b/tests/test_registration_otp_phase.py index 9a434c8..d5ea0b9 100644 --- a/tests/test_registration_otp_phase.py +++ b/tests/test_registration_otp_phase.py @@ -69,3 +69,36 @@ def test_phase_otp_secondary_returns_dedicated_timeout_error_code(monkeypatch): assert phase_result.success is False assert phase_result.error_code == ERROR_OTP_TIMEOUT_SECONDARY assert engine.phase_history[0].error_code == ERROR_OTP_TIMEOUT_SECONDARY + + +def test_advance_login_authorization_refreshes_otp_anchor_after_password_submit(monkeypatch): + email_service = FakeEmailService(code=None) + engine = _build_engine(monkeypatch, email_service) + engine.oauth_start = object() + engine._otp_sent_at = 10.0 + + monkeypatch.setattr(register_module.time, "time", lambda: 456.0) + monkeypatch.setattr(engine, "_init_session", lambda: True) + monkeypatch.setattr(engine, "_start_oauth", lambda: True) + monkeypatch.setattr(engine, "_get_device_id", lambda: True) + monkeypatch.setattr(engine, "_try_reenter_login_flow", lambda: True) + monkeypatch.setattr( + engine, + "_submit_login_password_step_and_get_continue_url", + lambda: (True, "https://continue.example.test"), + ) + + seen_anchors = [] + + def fake_get_verification_code(): + seen_anchors.append(engine._otp_sent_at) + return None + + monkeypatch.setattr(engine, "_get_verification_code", fake_get_verification_code) + + workspace_id, callback_url = engine._advance_login_authorization() + + assert workspace_id is None + assert callback_url is None + assert engine._otp_sent_at == 456.0 + assert seen_anchors == [456.0]