Fix registration OTP anchor and batch task state

This commit is contained in:
Mison
2026-03-24 07:20:38 +08:00
parent b8b1eb72d1
commit 78f2d0accc
9 changed files with 393 additions and 145 deletions

View File

@@ -0,0 +1,29 @@
# Task 5 Validation - 2026-03-23
## Scope
- Task 5
- OTP timeout backoff handling
- Registration controller backoff state persistence
## Commands
1. `./.venv/bin/python -m pytest tests/test_registration_email_service_failover.py tests/test_registration_otp_phase.py`
- exit code: `0`
- result: `4 passed`
- notes: 存在项目既有的 SQLAlchemy / Pydantic / FastAPI deprecation warnings本次任务未改动相关代码路径。
2. `./.venv/bin/ruff check src/services/base.py src/web/routes/registration.py tests/test_registration_email_service_failover.py`
- exit code: `127`
- result: failed
- notes: `.venv/bin/ruff` 不存在。
3. `./.venv/bin/python -m ruff check src/services/base.py src/web/routes/registration.py tests/test_registration_email_service_failover.py`
- exit code: `1`
- result: failed
- notes: 虚拟环境未安装 `ruff` 模块,未完成 lint 校验。
## Summary
- 回归测试通过,覆盖 `OTP_TIMEOUT_SECONDARY` 连续 3 次失败进入 `3600s` 深度冷却。
- Lint 校验因环境缺少 `ruff` 未执行。

View File

@@ -1077,6 +1077,8 @@ class RegistrationEngine:
if not password_ok:
return None, None
self._otp_sent_at = time.time()
code = self._get_verification_code()
if not code:
self._log("登录流程获取验证码失败", "warning")

View File

@@ -5,6 +5,7 @@
from typing import List, Optional, Dict, Any, Union, Iterable, Set
from datetime import datetime, timedelta
from sqlalchemy.orm import Session
from sqlalchemy.orm.attributes import flag_modified
from sqlalchemy import and_, or_, desc, asc, func
from .models import Account, EmailService, RegistrationTask, Setting, Proxy, CpaService, Sub2ApiService
@@ -778,6 +779,8 @@ def delete_tm_service(db: Session, service_id: int) -> bool:
db.delete(svc)
db.commit()
return True
def update_outlook_refresh_token(db: Session, service_id: int, email: str, new_refresh_token: str):
"""更新 EmailService.config 中指定邮箱的 refresh_token"""
service = db.query(EmailService).filter(EmailService.id == service_id).first()
@@ -808,4 +811,5 @@ def update_outlook_refresh_token(db: Session, service_id: int, email: str, new_r
return
service.config = config
flag_modified(service, "config")
db.commit()

View File

@@ -4,6 +4,7 @@
import asyncio
import logging
import threading
import uuid
import random
import re
@@ -32,9 +33,8 @@ router = APIRouter()
# 任务存储(简单的内存存储,生产环境应使用 Redis
running_tasks: dict = {}
# 批量任务存储
batch_tasks: Dict[str, dict] = {}
email_service_circuit_breakers: Dict[int, EmailProviderBackoffState] = {}
_email_service_backoff_lock = threading.Lock()
# ============== Proxy Helper Functions ==============
@@ -323,6 +323,75 @@ def _record_email_service_timeout_backoff(
return _store_email_service_backoff_state(service_id, backoff_state)
def _run_registration_engine_attempt(
task_uuid: str,
email_service,
actual_proxy_url: Optional[str],
log_callback,
db_service,
):
"""执行单次注册引擎尝试,并在同一临界区内维护邮箱服务退避状态。"""
provider_backoff_before_run = EmailProviderBackoffState()
with _email_service_backoff_lock:
if db_service is not None:
provider_backoff_before_run = _get_email_service_backoff_state(db_service.id)
if hasattr(email_service, "apply_provider_backoff_state"):
email_service.apply_provider_backoff_state(provider_backoff_before_run)
engine = RegistrationEngine(
email_service=email_service,
proxy_url=actual_proxy_url,
callback_logger=log_callback,
task_uuid=task_uuid,
)
try:
result = engine.run()
finally:
close_engine = getattr(engine, "close", None)
if callable(close_engine):
close_engine()
email_prepare_phase = _get_phase_result(
getattr(engine, "phase_history", []),
"email_prepare",
)
if db_service is not None and email_prepare_phase is not None:
_store_email_service_backoff_state(
db_service.id,
getattr(email_prepare_phase, "provider_backoff", None),
)
if (
db_service is not None
and not result.success
and result.error_code == ERROR_OTP_TIMEOUT_SECONDARY
):
timeout_backoff = _record_email_service_timeout_backoff(
db_service.id,
email_service,
provider_backoff_before_run,
result.error_code,
result.error_message,
)
else:
timeout_backoff = None
return engine, result, email_prepare_phase, provider_backoff_before_run, timeout_backoff
def _get_batch_snapshot(batch_id: str) -> Optional[dict]:
return task_manager.get_batch_status(batch_id)
def _require_batch_snapshot(batch_id: str) -> dict:
batch = _get_batch_snapshot(batch_id)
if batch is None:
raise HTTPException(status_code=404, detail="批量任务不存在")
return batch
def _build_email_service_candidates(
db,
service_type: EmailServiceType,
@@ -506,35 +575,20 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy:
candidate_config,
name=db_service.name if db_service is not None else None,
)
provider_backoff_before_run = EmailProviderBackoffState()
if db_service is not None:
provider_backoff_before_run = _get_email_service_backoff_state(db_service.id)
if db_service is not None and hasattr(email_service, "apply_provider_backoff_state"):
email_service.apply_provider_backoff_state(provider_backoff_before_run)
engine = RegistrationEngine(
(
engine,
result,
email_prepare_phase,
_,
timeout_backoff,
) = _run_registration_engine_attempt(
task_uuid=task_uuid,
email_service=email_service,
proxy_url=actual_proxy_url,
callback_logger=log_callback,
task_uuid=task_uuid
actual_proxy_url=actual_proxy_url,
log_callback=log_callback,
db_service=db_service,
)
try:
result = engine.run()
finally:
close_engine = getattr(engine, "close", None)
if callable(close_engine):
close_engine()
email_prepare_phase = _get_phase_result(
getattr(engine, "phase_history", []),
"email_prepare",
)
if db_service is not None and email_prepare_phase is not None:
_store_email_service_backoff_state(
db_service.id,
getattr(email_prepare_phase, "provider_backoff", None),
)
if result.success:
break
@@ -551,28 +605,17 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy:
and email_prepare_phase.provider_backoff is not None
)
if not can_failover:
if (
db_service is not None
and result.error_code == ERROR_OTP_TIMEOUT_SECONDARY
):
timeout_backoff = _record_email_service_timeout_backoff(
db_service.id,
email_service,
provider_backoff_before_run,
result.error_code,
result.error_message,
if timeout_backoff is not None:
logger.warning(
f"邮箱服务 OTP 超时,已退避 {db_service.name} "
f"{timeout_backoff.delay_seconds} 秒,连续失败 "
f"{timeout_backoff.failures}"
)
log_callback(
f"[系统] 邮箱服务 OTP 超时,退避 "
f"{timeout_backoff.delay_seconds} 秒: {db_service.name} "
f"(连续失败 {timeout_backoff.failures} 次)"
)
if timeout_backoff is not None:
logger.warning(
f"邮箱服务 OTP 超时,已退避 {db_service.name} "
f"{timeout_backoff.delay_seconds} 秒,连续失败 "
f"{timeout_backoff.failures}"
)
log_callback(
f"[系统] 邮箱服务 OTP 超时,退避 "
f"{timeout_backoff.delay_seconds} 秒: {db_service.name} "
f"(连续失败 {timeout_backoff.failures} 次)"
)
break
backoff_state = email_prepare_phase.provider_backoff
@@ -799,30 +842,15 @@ async def run_registration_task(task_uuid: str, email_service_type: str, proxy:
def _init_batch_state(batch_id: str, task_uuids: List[str]):
"""初始化批量任务内存状态"""
task_manager.init_batch(batch_id, len(task_uuids))
batch_tasks[batch_id] = {
"total": len(task_uuids),
"completed": 0,
"success": 0,
"failed": 0,
"cancelled": False,
"task_uuids": task_uuids,
"current_index": 0,
"logs": [],
"finished": False
}
task_manager.init_batch(batch_id, len(task_uuids), task_uuids=task_uuids)
def _make_batch_helpers(batch_id: str):
"""返回 add_batch_log 和 update_batch_status 辅助函数"""
def add_batch_log(msg: str):
batch_tasks[batch_id]["logs"].append(msg)
task_manager.add_batch_log(batch_id, msg)
def update_batch_status(**kwargs):
for key, value in kwargs.items():
if key in batch_tasks[batch_id]:
batch_tasks[batch_id][key] = value
task_manager.update_batch_status(batch_id, **kwargs)
return add_batch_log, update_batch_status
@@ -866,9 +894,10 @@ async def run_batch_parallel(
t = crud.get_registration_task(db, uuid)
if t:
async with counter_lock:
new_completed = batch_tasks[batch_id]["completed"] + 1
new_success = batch_tasks[batch_id]["success"]
new_failed = batch_tasks[batch_id]["failed"]
batch_snapshot = _get_batch_snapshot(batch_id) or {}
new_completed = batch_snapshot.get("completed", 0) + 1
new_success = batch_snapshot.get("success", 0)
new_failed = batch_snapshot.get("failed", 0)
if t.status == "completed":
new_success += 1
add_batch_log(f"{prefix} [成功] 注册成功")
@@ -880,7 +909,11 @@ async def run_batch_parallel(
try:
await asyncio.gather(*[_run_one(i, u) for i, u in enumerate(task_uuids)], return_exceptions=True)
if not task_manager.is_batch_cancelled(batch_id):
add_batch_log(f"[完成] 批量任务完成!成功: {batch_tasks[batch_id]['success']}, 失败: {batch_tasks[batch_id]['failed']}")
batch_snapshot = _get_batch_snapshot(batch_id) or {}
add_batch_log(
f"[完成] 批量任务完成!成功: {batch_snapshot.get('success', 0)}, "
f"失败: {batch_snapshot.get('failed', 0)}"
)
update_batch_status(finished=True, status="completed")
else:
update_batch_status(finished=True, status="cancelled")
@@ -888,8 +921,6 @@ async def run_batch_parallel(
logger.error(f"批量任务 {batch_id} 异常: {e}")
add_batch_log(f"[错误] 批量任务异常: {str(e)}")
update_batch_status(finished=True, status="failed")
finally:
batch_tasks[batch_id]["finished"] = True
async def run_batch_pipeline(
@@ -932,9 +963,10 @@ async def run_batch_pipeline(
t = crud.get_registration_task(db, uuid)
if t:
async with counter_lock:
new_completed = batch_tasks[batch_id]["completed"] + 1
new_success = batch_tasks[batch_id]["success"]
new_failed = batch_tasks[batch_id]["failed"]
batch_snapshot = _get_batch_snapshot(batch_id) or {}
new_completed = batch_snapshot.get("completed", 0) + 1
new_success = batch_snapshot.get("success", 0)
new_failed = batch_snapshot.get("failed", 0)
if t.status == "completed":
new_success += 1
add_batch_log(f"{pfx} [成功] 注册成功")
@@ -947,7 +979,7 @@ async def run_batch_pipeline(
try:
for i, task_uuid in enumerate(task_uuids):
if task_manager.is_batch_cancelled(batch_id) or batch_tasks[batch_id]["cancelled"]:
if task_manager.is_batch_cancelled(batch_id):
with get_db() as db:
for remaining_uuid in task_uuids[i:]:
crud.update_registration_task(db, remaining_uuid, status="cancelled")
@@ -971,14 +1003,16 @@ async def run_batch_pipeline(
await asyncio.gather(*running_tasks_list, return_exceptions=True)
if not task_manager.is_batch_cancelled(batch_id):
add_batch_log(f"[完成] 批量任务完成!成功: {batch_tasks[batch_id]['success']}, 失败: {batch_tasks[batch_id]['failed']}")
batch_snapshot = _get_batch_snapshot(batch_id) or {}
add_batch_log(
f"[完成] 批量任务完成!成功: {batch_snapshot.get('success', 0)}, "
f"失败: {batch_snapshot.get('failed', 0)}"
)
update_batch_status(finished=True, status="completed")
except Exception as e:
logger.error(f"批量任务 {batch_id} 异常: {e}")
add_batch_log(f"[错误] 批量任务异常: {str(e)}")
update_batch_status(finished=True, status="failed")
finally:
batch_tasks[batch_id]["finished"] = True
async def run_batch_registration(
@@ -1157,10 +1191,7 @@ async def start_batch_registration(
@router.get("/batch/{batch_id}")
async def get_batch_status(batch_id: str):
"""获取批量任务状态"""
if batch_id not in batch_tasks:
raise HTTPException(status_code=404, detail="批量任务不存在")
batch = batch_tasks[batch_id]
batch = _require_batch_snapshot(batch_id)
return {
"batch_id": batch_id,
"total": batch["total"],
@@ -1177,14 +1208,10 @@ async def get_batch_status(batch_id: str):
@router.post("/batch/{batch_id}/cancel")
async def cancel_batch(batch_id: str):
"""取消批量任务"""
if batch_id not in batch_tasks:
raise HTTPException(status_code=404, detail="批量任务不存在")
batch = batch_tasks[batch_id]
batch = _require_batch_snapshot(batch_id)
if batch.get("finished"):
raise HTTPException(status_code=400, detail="批量任务已完成")
batch["cancelled"] = True
task_manager.cancel_batch(batch_id)
return {"success": True, "message": "批量任务取消请求已提交"}
@@ -1666,18 +1693,12 @@ async def start_outlook_batch_registration(
batch_id = str(uuid.uuid4())
# 初始化批量任务状态
batch_tasks[batch_id] = {
"total": len(actual_service_ids),
"completed": 0,
"success": 0,
"failed": 0,
"skipped": 0,
"cancelled": False,
"service_ids": actual_service_ids,
"current_index": 0,
"logs": [],
"finished": False
}
task_manager.init_batch(
batch_id,
len(actual_service_ids),
skipped=skipped_count,
service_ids=actual_service_ids,
)
# 在后台运行批量注册
background_tasks.add_task(
@@ -1710,10 +1731,7 @@ async def start_outlook_batch_registration(
@router.get("/outlook-batch/{batch_id}")
async def get_outlook_batch_status(batch_id: str):
"""获取 Outlook 批量任务状态"""
if batch_id not in batch_tasks:
raise HTTPException(status_code=404, detail="批量任务不存在")
batch = batch_tasks[batch_id]
batch = _require_batch_snapshot(batch_id)
return {
"batch_id": batch_id,
"total": batch["total"],
@@ -1724,7 +1742,8 @@ async def get_outlook_batch_status(batch_id: str):
"current_index": batch["current_index"],
"cancelled": batch["cancelled"],
"finished": batch.get("finished", False),
"logs": batch.get("logs", []),
"service_ids": batch.get("service_ids", []),
"logs": task_manager.get_batch_logs(batch_id),
"progress": f"{batch['completed']}/{batch['total']}"
}
@@ -1732,15 +1751,10 @@ async def get_outlook_batch_status(batch_id: str):
@router.post("/outlook-batch/{batch_id}/cancel")
async def cancel_outlook_batch(batch_id: str):
"""取消 Outlook 批量任务"""
if batch_id not in batch_tasks:
raise HTTPException(status_code=404, detail="批量任务不存在")
batch = batch_tasks[batch_id]
batch = _require_batch_snapshot(batch_id)
if batch.get("finished"):
raise HTTPException(status_code=400, detail="批量任务已完成")
# 同时更新两个系统的取消状态
batch["cancelled"] = True
task_manager.cancel_batch(batch_id)
return {"success": True, "message": "批量任务取消请求已提交"}

View File

@@ -240,18 +240,25 @@ class TaskManager:
# ============== 批量任务管理 ==============
def init_batch(self, batch_id: str, total: int):
def init_batch(self, batch_id: str, total: int, **kwargs):
"""初始化批量任务"""
_batch_status[batch_id] = {
"status": "running",
"total": total,
"completed": 0,
"success": 0,
"failed": 0,
"skipped": 0,
"current_index": 0,
"finished": False
}
with _get_batch_lock(batch_id):
previous = _batch_status.get(batch_id, {})
status = {
"status": "running",
"total": total,
"completed": 0,
"success": 0,
"failed": 0,
"skipped": previous.get("skipped", 0),
"cancelled": previous.get("cancelled", False),
"current_index": 0,
"finished": False,
}
status.update(previous)
status.update(kwargs)
status["total"] = total
_batch_status[batch_id] = status
logger.info(f"批量任务 {batch_id} 已初始化,总数: {total}")
def add_batch_log(self, batch_id: str, log_message: str):
@@ -295,11 +302,11 @@ class TaskManager:
def update_batch_status(self, batch_id: str, **kwargs):
"""更新批量任务状态"""
if batch_id not in _batch_status:
logger.warning(f"批量任务 {batch_id} 不存在")
return
_batch_status[batch_id].update(kwargs)
with _get_batch_lock(batch_id):
if batch_id not in _batch_status:
logger.warning(f"批量任务 {batch_id} 不存在")
return
_batch_status[batch_id].update(kwargs)
# 异步广播状态更新
if self._loop and self._loop.is_running():
@@ -331,7 +338,9 @@ class TaskManager:
def get_batch_status(self, batch_id: str) -> Optional[dict]:
"""获取批量任务状态"""
return _batch_status.get(batch_id)
with _get_batch_lock(batch_id):
status = _batch_status.get(batch_id)
return status.copy() if status is not None else None
def get_batch_logs(self, batch_id: str) -> List[str]:
"""获取批量任务日志"""
@@ -345,10 +354,11 @@ class TaskManager:
def cancel_batch(self, batch_id: str):
"""取消批量任务"""
if batch_id in _batch_status:
_batch_status[batch_id]["cancelled"] = True
_batch_status[batch_id]["status"] = "cancelling"
logger.info(f"批量任务 {batch_id} 已标记为取消")
with _get_batch_lock(batch_id):
if batch_id in _batch_status:
_batch_status[batch_id]["cancelled"] = True
_batch_status[batch_id]["status"] = "cancelling"
logger.info(f"批量任务 {batch_id} 已标记为取消")
def register_batch_websocket(self, batch_id: str, websocket) -> List[str]:
"""注册批量任务 WebSocket 连接,并返回注册时刻的历史日志快照"""

View File

@@ -70,3 +70,37 @@ def test_update_account_preserves_pending_status_when_other_tokens_remain(tmp_pa
assert updated.refresh_token == ""
assert updated.token_sync_status == "pending"
assert updated.token_sync_updated_at is not None
def test_update_outlook_refresh_token_persists_nested_config_changes(tmp_path):
manager = DatabaseSessionManager(f"sqlite:///{tmp_path}/test.db")
manager.create_tables()
manager.migrate_tables()
with manager.session_scope() as session:
service = crud.create_email_service(
session,
service_type="outlook",
name="outlook-service",
config={
"accounts": [
{"email": "first@example.com", "refresh_token": "old-first"},
{"email": "second@example.com", "refresh_token": "old-second"},
]
},
)
service_id = service.id
crud.update_outlook_refresh_token(
session,
service_id=service_id,
email="second@example.com",
new_refresh_token="new-second",
)
with manager.session_scope() as session:
reloaded = crud.get_email_service_by_id(session, service_id)
assert reloaded is not None
assert reloaded.config["accounts"][0]["refresh_token"] == "old-first"
assert reloaded.config["accounts"][1]["refresh_token"] == "new-second"

View File

@@ -6,24 +6,23 @@ from src.web.routes import registration as registration_routes
from src.web.task_manager import task_manager
def test_init_batch_state_keeps_batch_tasks_and_task_manager_in_sync():
def test_init_batch_state_persists_state_in_task_manager():
batch_id = "batch-sync-init"
task_uuids = ["task-1", "task-2", "task-3"]
registration_routes.batch_tasks.pop(batch_id, None)
registration_routes._init_batch_state(batch_id, task_uuids)
batch_snapshot = registration_routes.batch_tasks[batch_id]
manager_snapshot = task_manager.get_batch_status(batch_id)
assert manager_snapshot is not None
assert batch_snapshot["total"] == manager_snapshot["total"] == 3
assert batch_snapshot["completed"] == manager_snapshot["completed"] == 0
assert batch_snapshot["success"] == manager_snapshot["success"] == 0
assert batch_snapshot["failed"] == manager_snapshot["failed"] == 0
assert batch_snapshot["finished"] is False
assert manager_snapshot["task_uuids"] == task_uuids
assert manager_snapshot["total"] == 3
assert manager_snapshot["completed"] == 0
assert manager_snapshot["success"] == 0
assert manager_snapshot["failed"] == 0
assert manager_snapshot["finished"] is False
assert manager_snapshot["status"] == "running"
assert task_manager.get_batch_logs(batch_id) == []
def test_run_batch_parallel_keeps_counter_updates_in_sync(monkeypatch):
@@ -61,7 +60,6 @@ def test_run_batch_parallel_keeps_counter_updates_in_sync(monkeypatch):
error_message = None if status == "completed" else f"{task_uuid}-error"
return SimpleNamespace(status=status, error_message=error_message)
registration_routes.batch_tasks.pop(batch_id, None)
monkeypatch.setattr(registration_routes, "run_registration_task", fake_run_registration_task)
monkeypatch.setattr(registration_routes, "get_db", fake_get_db)
monkeypatch.setattr(registration_routes.crud, "get_registration_task", fake_get_registration_task)
@@ -78,13 +76,11 @@ def test_run_batch_parallel_keeps_counter_updates_in_sync(monkeypatch):
)
)
batch_snapshot = registration_routes.batch_tasks[batch_id]
manager_snapshot = task_manager.get_batch_status(batch_id)
assert manager_snapshot is not None
assert batch_snapshot["completed"] == manager_snapshot["completed"] == 3
assert batch_snapshot["success"] == manager_snapshot["success"] == 2
assert batch_snapshot["failed"] == manager_snapshot["failed"] == 1
assert batch_snapshot["finished"] is True
assert manager_snapshot["completed"] == 3
assert manager_snapshot["success"] == 2
assert manager_snapshot["failed"] == 1
assert manager_snapshot["finished"] is True
assert manager_snapshot["status"] == "completed"

View File

@@ -1,5 +1,6 @@
from contextlib import contextmanager
from pathlib import Path
import threading
from types import SimpleNamespace
import src.services.base as base_module
@@ -412,3 +413,128 @@ def test_registration_task_success_clears_email_service_backoff(monkeypatch):
)
assert service_id not in registration_routes.email_service_circuit_breakers
def test_registration_task_backoff_failures_do_not_get_lost_under_concurrency(monkeypatch):
runtime_dir = Path("tests_runtime")
runtime_dir.mkdir(exist_ok=True)
db_path = runtime_dir / "registration_backoff_concurrency.db"
if db_path.exists():
db_path.unlink()
manager = DatabaseSessionManager(f"sqlite:///{db_path}")
Base.metadata.create_all(bind=manager.engine)
task_uuids = ["task-backoff-1", "task-backoff-2"]
with manager.session_scope() as session:
for task_uuid in task_uuids:
session.add(RegistrationTask(task_uuid=task_uuid, status="pending"))
session.add(
EmailService(
service_type="duck_mail",
name="duck-primary",
config={
"base_url": "https://mail-1.example.test",
"default_domain": "mail.example.test",
},
enabled=True,
priority=0,
)
)
@contextmanager
def fake_get_db():
session = manager.SessionLocal()
try:
yield session
finally:
session.close()
class DummySettings:
pass
start_lock = threading.Lock()
started = {"count": 0}
peer_started = threading.Event()
class FakeRegistrationEngine:
def __init__(self, email_service, proxy_url=None, callback_logger=None, task_uuid=None):
self.email_service = email_service
self.phase_history = []
def run(self):
with start_lock:
started["count"] += 1
if started["count"] == len(task_uuids):
peer_started.set()
peer_started.wait(timeout=0.1)
current_state = self.email_service.provider_backoff_state
next_failures = current_state.failures + 1
delay_seconds = 30 if next_failures == 1 else 60
self.phase_history = [
PhaseResult(
phase="email_prepare",
success=False,
error_message="创建邮箱失败",
error_code="EMAIL_PROVIDER_RATE_LIMITED",
retryable=True,
next_action="switch_provider",
provider_backoff=EmailProviderBackoffState(
failures=next_failures,
delay_seconds=delay_seconds,
opened_until=1000.0 + delay_seconds,
last_error="请求失败: 429",
),
)
]
return RegistrationResult(
success=False,
error_message="创建邮箱失败: 请求失败: 429",
logs=[],
)
def save_to_database(self, result):
return True
def close(self):
return None
monkeypatch.setattr(registration_routes, "get_db", fake_get_db)
monkeypatch.setattr(registration_routes, "get_settings", lambda: DummySettings())
monkeypatch.setattr(registration_routes, "task_manager", DummyTaskManager())
monkeypatch.setattr(registration_routes, "RegistrationEngine", FakeRegistrationEngine)
monkeypatch.setattr(
registration_routes.EmailServiceFactory,
"create",
lambda service_type, config, name=None: BackoffAwareEmailService(
service_type=service_type,
config=config,
name=name,
),
)
registration_routes.email_service_circuit_breakers.clear()
with manager.session_scope() as session:
service_id = session.query(EmailService.id).filter(EmailService.name == "duck-primary").scalar()
threads = [
threading.Thread(
target=registration_routes._run_sync_registration_task,
kwargs={
"task_uuid": task_uuid,
"email_service_type": EmailServiceType.DUCK_MAIL.value,
"proxy": None,
"email_service_config": None,
},
)
for task_uuid in task_uuids
]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
state = registration_routes.email_service_circuit_breakers[service_id]
assert state.failures == 2
assert state.delay_seconds == 60

View File

@@ -69,3 +69,36 @@ def test_phase_otp_secondary_returns_dedicated_timeout_error_code(monkeypatch):
assert phase_result.success is False
assert phase_result.error_code == ERROR_OTP_TIMEOUT_SECONDARY
assert engine.phase_history[0].error_code == ERROR_OTP_TIMEOUT_SECONDARY
def test_advance_login_authorization_refreshes_otp_anchor_after_password_submit(monkeypatch):
email_service = FakeEmailService(code=None)
engine = _build_engine(monkeypatch, email_service)
engine.oauth_start = object()
engine._otp_sent_at = 10.0
monkeypatch.setattr(register_module.time, "time", lambda: 456.0)
monkeypatch.setattr(engine, "_init_session", lambda: True)
monkeypatch.setattr(engine, "_start_oauth", lambda: True)
monkeypatch.setattr(engine, "_get_device_id", lambda: True)
monkeypatch.setattr(engine, "_try_reenter_login_flow", lambda: True)
monkeypatch.setattr(
engine,
"_submit_login_password_step_and_get_continue_url",
lambda: (True, "https://continue.example.test"),
)
seen_anchors = []
def fake_get_verification_code():
seen_anchors.append(engine._otp_sent_at)
return None
monkeypatch.setattr(engine, "_get_verification_code", fake_get_verification_code)
workspace_id, callback_url = engine._advance_login_authorization()
assert workspace_id is None
assert callback_url is None
assert engine._otp_sent_at == 456.0
assert seen_anchors == [456.0]