mirror of
https://github.com/cnlimiter/codex-register.git
synced 2026-06-26 01:31:47 +08:00
Fix registration OTP anchor and batch task state
This commit is contained in:
29
docs/reviews/TASK-5-VALIDATION-2026-03-23.md
Normal file
29
docs/reviews/TASK-5-VALIDATION-2026-03-23.md
Normal file
@@ -0,0 +1,29 @@
|
||||
# Task 5 Validation - 2026-03-23
|
||||
|
||||
## Scope
|
||||
|
||||
- Task 5
|
||||
- OTP timeout backoff handling
|
||||
- Registration controller backoff state persistence
|
||||
|
||||
## Commands
|
||||
|
||||
1. `./.venv/bin/python -m pytest tests/test_registration_email_service_failover.py tests/test_registration_otp_phase.py`
|
||||
- exit code: `0`
|
||||
- result: `4 passed`
|
||||
- notes: 存在项目既有的 SQLAlchemy / Pydantic / FastAPI deprecation warnings,本次任务未改动相关代码路径。
|
||||
|
||||
2. `./.venv/bin/ruff check src/services/base.py src/web/routes/registration.py tests/test_registration_email_service_failover.py`
|
||||
- exit code: `127`
|
||||
- result: failed
|
||||
- notes: `.venv/bin/ruff` 不存在。
|
||||
|
||||
3. `./.venv/bin/python -m ruff check src/services/base.py src/web/routes/registration.py tests/test_registration_email_service_failover.py`
|
||||
- exit code: `1`
|
||||
- result: failed
|
||||
- notes: 虚拟环境未安装 `ruff` 模块,未完成 lint 校验。
|
||||
|
||||
## Summary
|
||||
|
||||
- 回归测试通过,覆盖 `OTP_TIMEOUT_SECONDARY` 连续 3 次失败进入 `3600s` 深度冷却。
|
||||
- Lint 校验因环境缺少 `ruff` 未执行。
|
||||
@@ -1077,6 +1077,8 @@ class RegistrationEngine:
|
||||
if not password_ok:
|
||||
return None, None
|
||||
|
||||
self._otp_sent_at = time.time()
|
||||
|
||||
code = self._get_verification_code()
|
||||
if not code:
|
||||
self._log("登录流程获取验证码失败", "warning")
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
from typing import List, Optional, Dict, Any, Union, Iterable, Set
|
||||
from datetime import datetime, timedelta
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
from sqlalchemy import and_, or_, desc, asc, func
|
||||
|
||||
from .models import Account, EmailService, RegistrationTask, Setting, Proxy, CpaService, Sub2ApiService
|
||||
@@ -778,6 +779,8 @@ def delete_tm_service(db: Session, service_id: int) -> bool:
|
||||
db.delete(svc)
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
|
||||
def update_outlook_refresh_token(db: Session, service_id: int, email: str, new_refresh_token: str):
|
||||
"""更新 EmailService.config 中指定邮箱的 refresh_token"""
|
||||
service = db.query(EmailService).filter(EmailService.id == service_id).first()
|
||||
@@ -808,4 +811,5 @@ def update_outlook_refresh_token(db: Session, service_id: int, email: str, new_r
|
||||
return
|
||||
|
||||
service.config = config
|
||||
flag_modified(service, "config")
|
||||
db.commit()
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
import random
|
||||
import re
|
||||
@@ -32,9 +33,8 @@ router = APIRouter()
|
||||
|
||||
# 任务存储(简单的内存存储,生产环境应使用 Redis)
|
||||
running_tasks: dict = {}
|
||||
# 批量任务存储
|
||||
batch_tasks: Dict[str, dict] = {}
|
||||
email_service_circuit_breakers: Dict[int, EmailProviderBackoffState] = {}
|
||||
_email_service_backoff_lock = threading.Lock()
|
||||
|
||||
|
||||
# ============== Proxy Helper Functions ==============
|
||||
@@ -323,6 +323,75 @@ def _record_email_service_timeout_backoff(
|
||||
return _store_email_service_backoff_state(service_id, backoff_state)
|
||||
|
||||
|
||||
def _run_registration_engine_attempt(
|
||||
task_uuid: str,
|
||||
email_service,
|
||||
actual_proxy_url: Optional[str],
|
||||
log_callback,
|
||||
db_service,
|
||||
):
|
||||
"""执行单次注册引擎尝试,并在同一临界区内维护邮箱服务退避状态。"""
|
||||
provider_backoff_before_run = EmailProviderBackoffState()
|
||||
|
||||
with _email_service_backoff_lock:
|
||||
if db_service is not None:
|
||||
provider_backoff_before_run = _get_email_service_backoff_state(db_service.id)
|
||||
if hasattr(email_service, "apply_provider_backoff_state"):
|
||||
email_service.apply_provider_backoff_state(provider_backoff_before_run)
|
||||
|
||||
engine = RegistrationEngine(
|
||||
email_service=email_service,
|
||||
proxy_url=actual_proxy_url,
|
||||
callback_logger=log_callback,
|
||||
task_uuid=task_uuid,
|
||||
)
|
||||
|
||||
try:
|
||||
result = engine.run()
|
||||
finally:
|
||||
close_engine = getattr(engine, "close", None)
|
||||
if callable(close_engine):
|
||||
close_engine()
|
||||
|
||||
email_prepare_phase = _get_phase_result(
|
||||
getattr(engine, "phase_history", []),
|
||||
"email_prepare",
|
||||
)
|
||||
if db_service is not None and email_prepare_phase is not None:
|
||||
_store_email_service_backoff_state(
|
||||
db_service.id,
|
||||
getattr(email_prepare_phase, "provider_backoff", None),
|
||||
)
|
||||
|
||||
if (
|
||||
db_service is not None
|
||||
and not result.success
|
||||
and result.error_code == ERROR_OTP_TIMEOUT_SECONDARY
|
||||
):
|
||||
timeout_backoff = _record_email_service_timeout_backoff(
|
||||
db_service.id,
|
||||
email_service,
|
||||
provider_backoff_before_run,
|
||||
result.error_code,
|
||||
result.error_message,
|
||||
)
|
||||
else:
|
||||
timeout_backoff = None
|
||||
|
||||
return engine, result, email_prepare_phase, provider_backoff_before_run, timeout_backoff
|
||||
|
||||
|
||||
def _get_batch_snapshot(batch_id: str) -> Optional[dict]:
|
||||
return task_manager.get_batch_status(batch_id)
|
||||
|
||||
|
||||
def _require_batch_snapshot(batch_id: str) -> dict:
|
||||
batch = _get_batch_snapshot(batch_id)
|
||||
if batch is None:
|
||||
raise HTTPException(status_code=404, detail="批量任务不存在")
|
||||
return batch
|
||||
|
||||
|
||||
def _build_email_service_candidates(
|
||||
db,
|
||||
service_type: EmailServiceType,
|
||||
@@ -506,35 +575,20 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy:
|
||||
candidate_config,
|
||||
name=db_service.name if db_service is not None else None,
|
||||
)
|
||||
provider_backoff_before_run = EmailProviderBackoffState()
|
||||
if db_service is not None:
|
||||
provider_backoff_before_run = _get_email_service_backoff_state(db_service.id)
|
||||
if db_service is not None and hasattr(email_service, "apply_provider_backoff_state"):
|
||||
email_service.apply_provider_backoff_state(provider_backoff_before_run)
|
||||
engine = RegistrationEngine(
|
||||
(
|
||||
engine,
|
||||
result,
|
||||
email_prepare_phase,
|
||||
_,
|
||||
timeout_backoff,
|
||||
) = _run_registration_engine_attempt(
|
||||
task_uuid=task_uuid,
|
||||
email_service=email_service,
|
||||
proxy_url=actual_proxy_url,
|
||||
callback_logger=log_callback,
|
||||
task_uuid=task_uuid
|
||||
actual_proxy_url=actual_proxy_url,
|
||||
log_callback=log_callback,
|
||||
db_service=db_service,
|
||||
)
|
||||
|
||||
try:
|
||||
result = engine.run()
|
||||
finally:
|
||||
close_engine = getattr(engine, "close", None)
|
||||
if callable(close_engine):
|
||||
close_engine()
|
||||
|
||||
email_prepare_phase = _get_phase_result(
|
||||
getattr(engine, "phase_history", []),
|
||||
"email_prepare",
|
||||
)
|
||||
if db_service is not None and email_prepare_phase is not None:
|
||||
_store_email_service_backoff_state(
|
||||
db_service.id,
|
||||
getattr(email_prepare_phase, "provider_backoff", None),
|
||||
)
|
||||
|
||||
if result.success:
|
||||
break
|
||||
|
||||
@@ -551,28 +605,17 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy:
|
||||
and email_prepare_phase.provider_backoff is not None
|
||||
)
|
||||
if not can_failover:
|
||||
if (
|
||||
db_service is not None
|
||||
and result.error_code == ERROR_OTP_TIMEOUT_SECONDARY
|
||||
):
|
||||
timeout_backoff = _record_email_service_timeout_backoff(
|
||||
db_service.id,
|
||||
email_service,
|
||||
provider_backoff_before_run,
|
||||
result.error_code,
|
||||
result.error_message,
|
||||
if timeout_backoff is not None:
|
||||
logger.warning(
|
||||
f"邮箱服务 OTP 超时,已退避 {db_service.name} "
|
||||
f"{timeout_backoff.delay_seconds} 秒,连续失败 "
|
||||
f"{timeout_backoff.failures} 次"
|
||||
)
|
||||
log_callback(
|
||||
f"[系统] 邮箱服务 OTP 超时,退避 "
|
||||
f"{timeout_backoff.delay_seconds} 秒: {db_service.name} "
|
||||
f"(连续失败 {timeout_backoff.failures} 次)"
|
||||
)
|
||||
if timeout_backoff is not None:
|
||||
logger.warning(
|
||||
f"邮箱服务 OTP 超时,已退避 {db_service.name} "
|
||||
f"{timeout_backoff.delay_seconds} 秒,连续失败 "
|
||||
f"{timeout_backoff.failures} 次"
|
||||
)
|
||||
log_callback(
|
||||
f"[系统] 邮箱服务 OTP 超时,退避 "
|
||||
f"{timeout_backoff.delay_seconds} 秒: {db_service.name} "
|
||||
f"(连续失败 {timeout_backoff.failures} 次)"
|
||||
)
|
||||
break
|
||||
|
||||
backoff_state = email_prepare_phase.provider_backoff
|
||||
@@ -799,30 +842,15 @@ async def run_registration_task(task_uuid: str, email_service_type: str, proxy:
|
||||
|
||||
def _init_batch_state(batch_id: str, task_uuids: List[str]):
|
||||
"""初始化批量任务内存状态"""
|
||||
task_manager.init_batch(batch_id, len(task_uuids))
|
||||
batch_tasks[batch_id] = {
|
||||
"total": len(task_uuids),
|
||||
"completed": 0,
|
||||
"success": 0,
|
||||
"failed": 0,
|
||||
"cancelled": False,
|
||||
"task_uuids": task_uuids,
|
||||
"current_index": 0,
|
||||
"logs": [],
|
||||
"finished": False
|
||||
}
|
||||
task_manager.init_batch(batch_id, len(task_uuids), task_uuids=task_uuids)
|
||||
|
||||
|
||||
def _make_batch_helpers(batch_id: str):
|
||||
"""返回 add_batch_log 和 update_batch_status 辅助函数"""
|
||||
def add_batch_log(msg: str):
|
||||
batch_tasks[batch_id]["logs"].append(msg)
|
||||
task_manager.add_batch_log(batch_id, msg)
|
||||
|
||||
def update_batch_status(**kwargs):
|
||||
for key, value in kwargs.items():
|
||||
if key in batch_tasks[batch_id]:
|
||||
batch_tasks[batch_id][key] = value
|
||||
task_manager.update_batch_status(batch_id, **kwargs)
|
||||
|
||||
return add_batch_log, update_batch_status
|
||||
@@ -866,9 +894,10 @@ async def run_batch_parallel(
|
||||
t = crud.get_registration_task(db, uuid)
|
||||
if t:
|
||||
async with counter_lock:
|
||||
new_completed = batch_tasks[batch_id]["completed"] + 1
|
||||
new_success = batch_tasks[batch_id]["success"]
|
||||
new_failed = batch_tasks[batch_id]["failed"]
|
||||
batch_snapshot = _get_batch_snapshot(batch_id) or {}
|
||||
new_completed = batch_snapshot.get("completed", 0) + 1
|
||||
new_success = batch_snapshot.get("success", 0)
|
||||
new_failed = batch_snapshot.get("failed", 0)
|
||||
if t.status == "completed":
|
||||
new_success += 1
|
||||
add_batch_log(f"{prefix} [成功] 注册成功")
|
||||
@@ -880,7 +909,11 @@ async def run_batch_parallel(
|
||||
try:
|
||||
await asyncio.gather(*[_run_one(i, u) for i, u in enumerate(task_uuids)], return_exceptions=True)
|
||||
if not task_manager.is_batch_cancelled(batch_id):
|
||||
add_batch_log(f"[完成] 批量任务完成!成功: {batch_tasks[batch_id]['success']}, 失败: {batch_tasks[batch_id]['failed']}")
|
||||
batch_snapshot = _get_batch_snapshot(batch_id) or {}
|
||||
add_batch_log(
|
||||
f"[完成] 批量任务完成!成功: {batch_snapshot.get('success', 0)}, "
|
||||
f"失败: {batch_snapshot.get('failed', 0)}"
|
||||
)
|
||||
update_batch_status(finished=True, status="completed")
|
||||
else:
|
||||
update_batch_status(finished=True, status="cancelled")
|
||||
@@ -888,8 +921,6 @@ async def run_batch_parallel(
|
||||
logger.error(f"批量任务 {batch_id} 异常: {e}")
|
||||
add_batch_log(f"[错误] 批量任务异常: {str(e)}")
|
||||
update_batch_status(finished=True, status="failed")
|
||||
finally:
|
||||
batch_tasks[batch_id]["finished"] = True
|
||||
|
||||
|
||||
async def run_batch_pipeline(
|
||||
@@ -932,9 +963,10 @@ async def run_batch_pipeline(
|
||||
t = crud.get_registration_task(db, uuid)
|
||||
if t:
|
||||
async with counter_lock:
|
||||
new_completed = batch_tasks[batch_id]["completed"] + 1
|
||||
new_success = batch_tasks[batch_id]["success"]
|
||||
new_failed = batch_tasks[batch_id]["failed"]
|
||||
batch_snapshot = _get_batch_snapshot(batch_id) or {}
|
||||
new_completed = batch_snapshot.get("completed", 0) + 1
|
||||
new_success = batch_snapshot.get("success", 0)
|
||||
new_failed = batch_snapshot.get("failed", 0)
|
||||
if t.status == "completed":
|
||||
new_success += 1
|
||||
add_batch_log(f"{pfx} [成功] 注册成功")
|
||||
@@ -947,7 +979,7 @@ async def run_batch_pipeline(
|
||||
|
||||
try:
|
||||
for i, task_uuid in enumerate(task_uuids):
|
||||
if task_manager.is_batch_cancelled(batch_id) or batch_tasks[batch_id]["cancelled"]:
|
||||
if task_manager.is_batch_cancelled(batch_id):
|
||||
with get_db() as db:
|
||||
for remaining_uuid in task_uuids[i:]:
|
||||
crud.update_registration_task(db, remaining_uuid, status="cancelled")
|
||||
@@ -971,14 +1003,16 @@ async def run_batch_pipeline(
|
||||
await asyncio.gather(*running_tasks_list, return_exceptions=True)
|
||||
|
||||
if not task_manager.is_batch_cancelled(batch_id):
|
||||
add_batch_log(f"[完成] 批量任务完成!成功: {batch_tasks[batch_id]['success']}, 失败: {batch_tasks[batch_id]['failed']}")
|
||||
batch_snapshot = _get_batch_snapshot(batch_id) or {}
|
||||
add_batch_log(
|
||||
f"[完成] 批量任务完成!成功: {batch_snapshot.get('success', 0)}, "
|
||||
f"失败: {batch_snapshot.get('failed', 0)}"
|
||||
)
|
||||
update_batch_status(finished=True, status="completed")
|
||||
except Exception as e:
|
||||
logger.error(f"批量任务 {batch_id} 异常: {e}")
|
||||
add_batch_log(f"[错误] 批量任务异常: {str(e)}")
|
||||
update_batch_status(finished=True, status="failed")
|
||||
finally:
|
||||
batch_tasks[batch_id]["finished"] = True
|
||||
|
||||
|
||||
async def run_batch_registration(
|
||||
@@ -1157,10 +1191,7 @@ async def start_batch_registration(
|
||||
@router.get("/batch/{batch_id}")
|
||||
async def get_batch_status(batch_id: str):
|
||||
"""获取批量任务状态"""
|
||||
if batch_id not in batch_tasks:
|
||||
raise HTTPException(status_code=404, detail="批量任务不存在")
|
||||
|
||||
batch = batch_tasks[batch_id]
|
||||
batch = _require_batch_snapshot(batch_id)
|
||||
return {
|
||||
"batch_id": batch_id,
|
||||
"total": batch["total"],
|
||||
@@ -1177,14 +1208,10 @@ async def get_batch_status(batch_id: str):
|
||||
@router.post("/batch/{batch_id}/cancel")
|
||||
async def cancel_batch(batch_id: str):
|
||||
"""取消批量任务"""
|
||||
if batch_id not in batch_tasks:
|
||||
raise HTTPException(status_code=404, detail="批量任务不存在")
|
||||
|
||||
batch = batch_tasks[batch_id]
|
||||
batch = _require_batch_snapshot(batch_id)
|
||||
if batch.get("finished"):
|
||||
raise HTTPException(status_code=400, detail="批量任务已完成")
|
||||
|
||||
batch["cancelled"] = True
|
||||
task_manager.cancel_batch(batch_id)
|
||||
return {"success": True, "message": "批量任务取消请求已提交"}
|
||||
|
||||
@@ -1666,18 +1693,12 @@ async def start_outlook_batch_registration(
|
||||
batch_id = str(uuid.uuid4())
|
||||
|
||||
# 初始化批量任务状态
|
||||
batch_tasks[batch_id] = {
|
||||
"total": len(actual_service_ids),
|
||||
"completed": 0,
|
||||
"success": 0,
|
||||
"failed": 0,
|
||||
"skipped": 0,
|
||||
"cancelled": False,
|
||||
"service_ids": actual_service_ids,
|
||||
"current_index": 0,
|
||||
"logs": [],
|
||||
"finished": False
|
||||
}
|
||||
task_manager.init_batch(
|
||||
batch_id,
|
||||
len(actual_service_ids),
|
||||
skipped=skipped_count,
|
||||
service_ids=actual_service_ids,
|
||||
)
|
||||
|
||||
# 在后台运行批量注册
|
||||
background_tasks.add_task(
|
||||
@@ -1710,10 +1731,7 @@ async def start_outlook_batch_registration(
|
||||
@router.get("/outlook-batch/{batch_id}")
|
||||
async def get_outlook_batch_status(batch_id: str):
|
||||
"""获取 Outlook 批量任务状态"""
|
||||
if batch_id not in batch_tasks:
|
||||
raise HTTPException(status_code=404, detail="批量任务不存在")
|
||||
|
||||
batch = batch_tasks[batch_id]
|
||||
batch = _require_batch_snapshot(batch_id)
|
||||
return {
|
||||
"batch_id": batch_id,
|
||||
"total": batch["total"],
|
||||
@@ -1724,7 +1742,8 @@ async def get_outlook_batch_status(batch_id: str):
|
||||
"current_index": batch["current_index"],
|
||||
"cancelled": batch["cancelled"],
|
||||
"finished": batch.get("finished", False),
|
||||
"logs": batch.get("logs", []),
|
||||
"service_ids": batch.get("service_ids", []),
|
||||
"logs": task_manager.get_batch_logs(batch_id),
|
||||
"progress": f"{batch['completed']}/{batch['total']}"
|
||||
}
|
||||
|
||||
@@ -1732,15 +1751,10 @@ async def get_outlook_batch_status(batch_id: str):
|
||||
@router.post("/outlook-batch/{batch_id}/cancel")
|
||||
async def cancel_outlook_batch(batch_id: str):
|
||||
"""取消 Outlook 批量任务"""
|
||||
if batch_id not in batch_tasks:
|
||||
raise HTTPException(status_code=404, detail="批量任务不存在")
|
||||
|
||||
batch = batch_tasks[batch_id]
|
||||
batch = _require_batch_snapshot(batch_id)
|
||||
if batch.get("finished"):
|
||||
raise HTTPException(status_code=400, detail="批量任务已完成")
|
||||
|
||||
# 同时更新两个系统的取消状态
|
||||
batch["cancelled"] = True
|
||||
task_manager.cancel_batch(batch_id)
|
||||
|
||||
return {"success": True, "message": "批量任务取消请求已提交"}
|
||||
|
||||
@@ -240,18 +240,25 @@ class TaskManager:
|
||||
|
||||
# ============== 批量任务管理 ==============
|
||||
|
||||
def init_batch(self, batch_id: str, total: int):
|
||||
def init_batch(self, batch_id: str, total: int, **kwargs):
|
||||
"""初始化批量任务"""
|
||||
_batch_status[batch_id] = {
|
||||
"status": "running",
|
||||
"total": total,
|
||||
"completed": 0,
|
||||
"success": 0,
|
||||
"failed": 0,
|
||||
"skipped": 0,
|
||||
"current_index": 0,
|
||||
"finished": False
|
||||
}
|
||||
with _get_batch_lock(batch_id):
|
||||
previous = _batch_status.get(batch_id, {})
|
||||
status = {
|
||||
"status": "running",
|
||||
"total": total,
|
||||
"completed": 0,
|
||||
"success": 0,
|
||||
"failed": 0,
|
||||
"skipped": previous.get("skipped", 0),
|
||||
"cancelled": previous.get("cancelled", False),
|
||||
"current_index": 0,
|
||||
"finished": False,
|
||||
}
|
||||
status.update(previous)
|
||||
status.update(kwargs)
|
||||
status["total"] = total
|
||||
_batch_status[batch_id] = status
|
||||
logger.info(f"批量任务 {batch_id} 已初始化,总数: {total}")
|
||||
|
||||
def add_batch_log(self, batch_id: str, log_message: str):
|
||||
@@ -295,11 +302,11 @@ class TaskManager:
|
||||
|
||||
def update_batch_status(self, batch_id: str, **kwargs):
|
||||
"""更新批量任务状态"""
|
||||
if batch_id not in _batch_status:
|
||||
logger.warning(f"批量任务 {batch_id} 不存在")
|
||||
return
|
||||
|
||||
_batch_status[batch_id].update(kwargs)
|
||||
with _get_batch_lock(batch_id):
|
||||
if batch_id not in _batch_status:
|
||||
logger.warning(f"批量任务 {batch_id} 不存在")
|
||||
return
|
||||
_batch_status[batch_id].update(kwargs)
|
||||
|
||||
# 异步广播状态更新
|
||||
if self._loop and self._loop.is_running():
|
||||
@@ -331,7 +338,9 @@ class TaskManager:
|
||||
|
||||
def get_batch_status(self, batch_id: str) -> Optional[dict]:
|
||||
"""获取批量任务状态"""
|
||||
return _batch_status.get(batch_id)
|
||||
with _get_batch_lock(batch_id):
|
||||
status = _batch_status.get(batch_id)
|
||||
return status.copy() if status is not None else None
|
||||
|
||||
def get_batch_logs(self, batch_id: str) -> List[str]:
|
||||
"""获取批量任务日志"""
|
||||
@@ -345,10 +354,11 @@ class TaskManager:
|
||||
|
||||
def cancel_batch(self, batch_id: str):
|
||||
"""取消批量任务"""
|
||||
if batch_id in _batch_status:
|
||||
_batch_status[batch_id]["cancelled"] = True
|
||||
_batch_status[batch_id]["status"] = "cancelling"
|
||||
logger.info(f"批量任务 {batch_id} 已标记为取消")
|
||||
with _get_batch_lock(batch_id):
|
||||
if batch_id in _batch_status:
|
||||
_batch_status[batch_id]["cancelled"] = True
|
||||
_batch_status[batch_id]["status"] = "cancelling"
|
||||
logger.info(f"批量任务 {batch_id} 已标记为取消")
|
||||
|
||||
def register_batch_websocket(self, batch_id: str, websocket) -> List[str]:
|
||||
"""注册批量任务 WebSocket 连接,并返回注册时刻的历史日志快照"""
|
||||
|
||||
@@ -70,3 +70,37 @@ def test_update_account_preserves_pending_status_when_other_tokens_remain(tmp_pa
|
||||
assert updated.refresh_token == ""
|
||||
assert updated.token_sync_status == "pending"
|
||||
assert updated.token_sync_updated_at is not None
|
||||
|
||||
|
||||
def test_update_outlook_refresh_token_persists_nested_config_changes(tmp_path):
|
||||
manager = DatabaseSessionManager(f"sqlite:///{tmp_path}/test.db")
|
||||
manager.create_tables()
|
||||
manager.migrate_tables()
|
||||
|
||||
with manager.session_scope() as session:
|
||||
service = crud.create_email_service(
|
||||
session,
|
||||
service_type="outlook",
|
||||
name="outlook-service",
|
||||
config={
|
||||
"accounts": [
|
||||
{"email": "first@example.com", "refresh_token": "old-first"},
|
||||
{"email": "second@example.com", "refresh_token": "old-second"},
|
||||
]
|
||||
},
|
||||
)
|
||||
service_id = service.id
|
||||
|
||||
crud.update_outlook_refresh_token(
|
||||
session,
|
||||
service_id=service_id,
|
||||
email="second@example.com",
|
||||
new_refresh_token="new-second",
|
||||
)
|
||||
|
||||
with manager.session_scope() as session:
|
||||
reloaded = crud.get_email_service_by_id(session, service_id)
|
||||
|
||||
assert reloaded is not None
|
||||
assert reloaded.config["accounts"][0]["refresh_token"] == "old-first"
|
||||
assert reloaded.config["accounts"][1]["refresh_token"] == "new-second"
|
||||
|
||||
@@ -6,24 +6,23 @@ from src.web.routes import registration as registration_routes
|
||||
from src.web.task_manager import task_manager
|
||||
|
||||
|
||||
def test_init_batch_state_keeps_batch_tasks_and_task_manager_in_sync():
|
||||
def test_init_batch_state_persists_state_in_task_manager():
|
||||
batch_id = "batch-sync-init"
|
||||
task_uuids = ["task-1", "task-2", "task-3"]
|
||||
|
||||
registration_routes.batch_tasks.pop(batch_id, None)
|
||||
registration_routes._init_batch_state(batch_id, task_uuids)
|
||||
|
||||
batch_snapshot = registration_routes.batch_tasks[batch_id]
|
||||
manager_snapshot = task_manager.get_batch_status(batch_id)
|
||||
|
||||
assert manager_snapshot is not None
|
||||
assert batch_snapshot["total"] == manager_snapshot["total"] == 3
|
||||
assert batch_snapshot["completed"] == manager_snapshot["completed"] == 0
|
||||
assert batch_snapshot["success"] == manager_snapshot["success"] == 0
|
||||
assert batch_snapshot["failed"] == manager_snapshot["failed"] == 0
|
||||
assert batch_snapshot["finished"] is False
|
||||
assert manager_snapshot["task_uuids"] == task_uuids
|
||||
assert manager_snapshot["total"] == 3
|
||||
assert manager_snapshot["completed"] == 0
|
||||
assert manager_snapshot["success"] == 0
|
||||
assert manager_snapshot["failed"] == 0
|
||||
assert manager_snapshot["finished"] is False
|
||||
assert manager_snapshot["status"] == "running"
|
||||
assert task_manager.get_batch_logs(batch_id) == []
|
||||
|
||||
|
||||
def test_run_batch_parallel_keeps_counter_updates_in_sync(monkeypatch):
|
||||
@@ -61,7 +60,6 @@ def test_run_batch_parallel_keeps_counter_updates_in_sync(monkeypatch):
|
||||
error_message = None if status == "completed" else f"{task_uuid}-error"
|
||||
return SimpleNamespace(status=status, error_message=error_message)
|
||||
|
||||
registration_routes.batch_tasks.pop(batch_id, None)
|
||||
monkeypatch.setattr(registration_routes, "run_registration_task", fake_run_registration_task)
|
||||
monkeypatch.setattr(registration_routes, "get_db", fake_get_db)
|
||||
monkeypatch.setattr(registration_routes.crud, "get_registration_task", fake_get_registration_task)
|
||||
@@ -78,13 +76,11 @@ def test_run_batch_parallel_keeps_counter_updates_in_sync(monkeypatch):
|
||||
)
|
||||
)
|
||||
|
||||
batch_snapshot = registration_routes.batch_tasks[batch_id]
|
||||
manager_snapshot = task_manager.get_batch_status(batch_id)
|
||||
|
||||
assert manager_snapshot is not None
|
||||
assert batch_snapshot["completed"] == manager_snapshot["completed"] == 3
|
||||
assert batch_snapshot["success"] == manager_snapshot["success"] == 2
|
||||
assert batch_snapshot["failed"] == manager_snapshot["failed"] == 1
|
||||
assert batch_snapshot["finished"] is True
|
||||
assert manager_snapshot["completed"] == 3
|
||||
assert manager_snapshot["success"] == 2
|
||||
assert manager_snapshot["failed"] == 1
|
||||
assert manager_snapshot["finished"] is True
|
||||
assert manager_snapshot["status"] == "completed"
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
import threading
|
||||
from types import SimpleNamespace
|
||||
|
||||
import src.services.base as base_module
|
||||
@@ -412,3 +413,128 @@ def test_registration_task_success_clears_email_service_backoff(monkeypatch):
|
||||
)
|
||||
|
||||
assert service_id not in registration_routes.email_service_circuit_breakers
|
||||
|
||||
|
||||
def test_registration_task_backoff_failures_do_not_get_lost_under_concurrency(monkeypatch):
|
||||
runtime_dir = Path("tests_runtime")
|
||||
runtime_dir.mkdir(exist_ok=True)
|
||||
db_path = runtime_dir / "registration_backoff_concurrency.db"
|
||||
if db_path.exists():
|
||||
db_path.unlink()
|
||||
|
||||
manager = DatabaseSessionManager(f"sqlite:///{db_path}")
|
||||
Base.metadata.create_all(bind=manager.engine)
|
||||
|
||||
task_uuids = ["task-backoff-1", "task-backoff-2"]
|
||||
with manager.session_scope() as session:
|
||||
for task_uuid in task_uuids:
|
||||
session.add(RegistrationTask(task_uuid=task_uuid, status="pending"))
|
||||
session.add(
|
||||
EmailService(
|
||||
service_type="duck_mail",
|
||||
name="duck-primary",
|
||||
config={
|
||||
"base_url": "https://mail-1.example.test",
|
||||
"default_domain": "mail.example.test",
|
||||
},
|
||||
enabled=True,
|
||||
priority=0,
|
||||
)
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
def fake_get_db():
|
||||
session = manager.SessionLocal()
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
class DummySettings:
|
||||
pass
|
||||
|
||||
start_lock = threading.Lock()
|
||||
started = {"count": 0}
|
||||
peer_started = threading.Event()
|
||||
|
||||
class FakeRegistrationEngine:
|
||||
def __init__(self, email_service, proxy_url=None, callback_logger=None, task_uuid=None):
|
||||
self.email_service = email_service
|
||||
self.phase_history = []
|
||||
|
||||
def run(self):
|
||||
with start_lock:
|
||||
started["count"] += 1
|
||||
if started["count"] == len(task_uuids):
|
||||
peer_started.set()
|
||||
peer_started.wait(timeout=0.1)
|
||||
|
||||
current_state = self.email_service.provider_backoff_state
|
||||
next_failures = current_state.failures + 1
|
||||
delay_seconds = 30 if next_failures == 1 else 60
|
||||
self.phase_history = [
|
||||
PhaseResult(
|
||||
phase="email_prepare",
|
||||
success=False,
|
||||
error_message="创建邮箱失败",
|
||||
error_code="EMAIL_PROVIDER_RATE_LIMITED",
|
||||
retryable=True,
|
||||
next_action="switch_provider",
|
||||
provider_backoff=EmailProviderBackoffState(
|
||||
failures=next_failures,
|
||||
delay_seconds=delay_seconds,
|
||||
opened_until=1000.0 + delay_seconds,
|
||||
last_error="请求失败: 429",
|
||||
),
|
||||
)
|
||||
]
|
||||
return RegistrationResult(
|
||||
success=False,
|
||||
error_message="创建邮箱失败: 请求失败: 429",
|
||||
logs=[],
|
||||
)
|
||||
|
||||
def save_to_database(self, result):
|
||||
return True
|
||||
|
||||
def close(self):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(registration_routes, "get_db", fake_get_db)
|
||||
monkeypatch.setattr(registration_routes, "get_settings", lambda: DummySettings())
|
||||
monkeypatch.setattr(registration_routes, "task_manager", DummyTaskManager())
|
||||
monkeypatch.setattr(registration_routes, "RegistrationEngine", FakeRegistrationEngine)
|
||||
monkeypatch.setattr(
|
||||
registration_routes.EmailServiceFactory,
|
||||
"create",
|
||||
lambda service_type, config, name=None: BackoffAwareEmailService(
|
||||
service_type=service_type,
|
||||
config=config,
|
||||
name=name,
|
||||
),
|
||||
)
|
||||
registration_routes.email_service_circuit_breakers.clear()
|
||||
|
||||
with manager.session_scope() as session:
|
||||
service_id = session.query(EmailService.id).filter(EmailService.name == "duck-primary").scalar()
|
||||
|
||||
threads = [
|
||||
threading.Thread(
|
||||
target=registration_routes._run_sync_registration_task,
|
||||
kwargs={
|
||||
"task_uuid": task_uuid,
|
||||
"email_service_type": EmailServiceType.DUCK_MAIL.value,
|
||||
"proxy": None,
|
||||
"email_service_config": None,
|
||||
},
|
||||
)
|
||||
for task_uuid in task_uuids
|
||||
]
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
state = registration_routes.email_service_circuit_breakers[service_id]
|
||||
assert state.failures == 2
|
||||
assert state.delay_seconds == 60
|
||||
|
||||
@@ -69,3 +69,36 @@ def test_phase_otp_secondary_returns_dedicated_timeout_error_code(monkeypatch):
|
||||
assert phase_result.success is False
|
||||
assert phase_result.error_code == ERROR_OTP_TIMEOUT_SECONDARY
|
||||
assert engine.phase_history[0].error_code == ERROR_OTP_TIMEOUT_SECONDARY
|
||||
|
||||
|
||||
def test_advance_login_authorization_refreshes_otp_anchor_after_password_submit(monkeypatch):
|
||||
email_service = FakeEmailService(code=None)
|
||||
engine = _build_engine(monkeypatch, email_service)
|
||||
engine.oauth_start = object()
|
||||
engine._otp_sent_at = 10.0
|
||||
|
||||
monkeypatch.setattr(register_module.time, "time", lambda: 456.0)
|
||||
monkeypatch.setattr(engine, "_init_session", lambda: True)
|
||||
monkeypatch.setattr(engine, "_start_oauth", lambda: True)
|
||||
monkeypatch.setattr(engine, "_get_device_id", lambda: True)
|
||||
monkeypatch.setattr(engine, "_try_reenter_login_flow", lambda: True)
|
||||
monkeypatch.setattr(
|
||||
engine,
|
||||
"_submit_login_password_step_and_get_continue_url",
|
||||
lambda: (True, "https://continue.example.test"),
|
||||
)
|
||||
|
||||
seen_anchors = []
|
||||
|
||||
def fake_get_verification_code():
|
||||
seen_anchors.append(engine._otp_sent_at)
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(engine, "_get_verification_code", fake_get_verification_code)
|
||||
|
||||
workspace_id, callback_url = engine._advance_login_authorization()
|
||||
|
||||
assert workspace_id is None
|
||||
assert callback_url is None
|
||||
assert engine._otp_sent_at == 456.0
|
||||
assert seen_anchors == [456.0]
|
||||
|
||||
Reference in New Issue
Block a user