mirror of
https://github.com/cnlimiter/codex-register.git
synced 2026-05-19 23:30:42 +08:00
Merge branch 'master' into fix/worker-mail-otp-extraction
This commit is contained in:
278
tests/e2e/runtime_functionality_check.py
Normal file
278
tests/e2e/runtime_functionality_check.py
Normal file
@@ -0,0 +1,278 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import sqlite3
|
||||
import time
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import httpx
|
||||
import websockets
|
||||
|
||||
|
||||
STALE_ERROR = "服务启动时检测到未完成的历史任务,已标记失败,请重新发起。"
|
||||
|
||||
|
||||
def _write_json(path: Path, payload: Dict[str, Any]) -> None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
|
||||
|
||||
def _load_json(path: Path) -> Dict[str, Any]:
|
||||
return json.loads(path.read_text(encoding="utf-8"))
|
||||
|
||||
|
||||
def _connect_db(db_path: Path) -> sqlite3.Connection:
|
||||
return sqlite3.connect(db_path, timeout=5)
|
||||
|
||||
|
||||
def _fetchone_dict(conn: sqlite3.Connection, sql: str, params: tuple[Any, ...]) -> Dict[str, Any]:
|
||||
conn.row_factory = sqlite3.Row
|
||||
row = conn.execute(sql, params).fetchone()
|
||||
return dict(row) if row else {}
|
||||
|
||||
|
||||
def _assert(condition: bool, message: str) -> None:
|
||||
if not condition:
|
||||
raise AssertionError(message)
|
||||
|
||||
|
||||
def _health_check(client: httpx.Client, report: Dict[str, Any]) -> None:
|
||||
response = client.get("/api/registration/tasks", params={"page": 1, "page_size": 1})
|
||||
report["health"] = {"status_code": response.status_code, "body": response.json()}
|
||||
_assert(response.status_code == 200, "健康检查失败")
|
||||
|
||||
|
||||
async def _collect_task_websocket(ws_url: str, task_uuid: str) -> Dict[str, Any]:
|
||||
endpoint = f"{ws_url}/api/ws/task/{task_uuid}"
|
||||
messages: List[Dict[str, Any]] = []
|
||||
started_at = time.time()
|
||||
|
||||
async with websockets.connect(endpoint, open_timeout=10, close_timeout=5) as websocket:
|
||||
while time.time() - started_at < 30:
|
||||
raw_message = await asyncio.wait_for(websocket.recv(), timeout=10)
|
||||
payload = json.loads(raw_message)
|
||||
messages.append(payload)
|
||||
if payload.get("type") == "status" and payload.get("status") in {"completed", "failed"}:
|
||||
break
|
||||
|
||||
logs = [message for message in messages if message.get("type") == "log"]
|
||||
statuses = [message for message in messages if message.get("type") == "status"]
|
||||
return {
|
||||
"messages": messages,
|
||||
"log_count": len(logs),
|
||||
"status_count": len(statuses),
|
||||
"live_log_count": sum(1 for message in logs if "timestamp" in message),
|
||||
"final_status": statuses[-1]["status"] if statuses else None,
|
||||
}
|
||||
|
||||
|
||||
def _poll_task_completion(client: httpx.Client, task_uuid: str) -> Dict[str, Any]:
|
||||
deadline = time.time() + 20
|
||||
while time.time() < deadline:
|
||||
response = client.get(f"/api/registration/tasks/{task_uuid}")
|
||||
response.raise_for_status()
|
||||
payload = response.json()
|
||||
if payload["status"] in {"completed", "failed"}:
|
||||
return payload
|
||||
time.sleep(0.2)
|
||||
raise TimeoutError(f"任务未在预期时间内结束: {task_uuid}")
|
||||
|
||||
|
||||
def _validate_live_database(
|
||||
db_path: Path,
|
||||
task_uuid: str,
|
||||
batch_id: str,
|
||||
checks: Dict[str, Any],
|
||||
report: Dict[str, Any],
|
||||
) -> None:
|
||||
with _connect_db(db_path) as conn:
|
||||
seeded = _fetchone_dict(
|
||||
conn,
|
||||
"SELECT email, access_token, refresh_token, token_sync_status FROM accounts WHERE email = ?",
|
||||
(checks["seeded_account_email"],),
|
||||
)
|
||||
tokenless = _fetchone_dict(
|
||||
conn,
|
||||
"SELECT email, access_token, refresh_token, token_sync_status FROM accounts WHERE email = ?",
|
||||
(checks["tokenless_account_email"],),
|
||||
)
|
||||
partial = _fetchone_dict(
|
||||
conn,
|
||||
"SELECT email, access_token, refresh_token, token_sync_status FROM accounts WHERE email = ?",
|
||||
(checks["partial_account_email"],),
|
||||
)
|
||||
task_row = _fetchone_dict(
|
||||
conn,
|
||||
"SELECT task_uuid, status, logs, result FROM registration_tasks WHERE task_uuid = ?",
|
||||
(task_uuid,),
|
||||
)
|
||||
outlook_row = _fetchone_dict(
|
||||
conn,
|
||||
"SELECT config FROM email_services WHERE id = ?",
|
||||
(checks["outlook_service_id"],),
|
||||
)
|
||||
|
||||
_assert(seeded.get("token_sync_status") == "pending", "seeded 账号 token_sync_status 异常")
|
||||
_assert(tokenless.get("access_token") == "mock-access-token-updated", "tokenless 账号 access_token 未写入")
|
||||
_assert(tokenless.get("token_sync_status") == "pending", "tokenless 账号 token_sync_status 异常")
|
||||
_assert(partial.get("access_token") == "mock-access-token-partial", "partial 账号 access_token 丢失")
|
||||
_assert(partial.get("refresh_token") == "", "partial 账号 refresh_token 未清空")
|
||||
_assert(partial.get("token_sync_status") == "pending", "partial 账号 token_sync_status 异常")
|
||||
_assert(task_row.get("status") == "completed", "模拟任务数据库状态不是 completed")
|
||||
_assert(task_row.get("logs"), "模拟任务日志未落库")
|
||||
|
||||
task_result = json.loads(task_row["result"]) if task_row.get("result") else {}
|
||||
outlook_config = json.loads(outlook_row["config"]) if outlook_row.get("config") else {}
|
||||
second_account = next(
|
||||
account for account in outlook_config.get("accounts", [])
|
||||
if account.get("email") == checks["outlook_account_email"]
|
||||
)
|
||||
|
||||
batch_snapshot = task_result["hardening_checks"]["batch_counter"]["snapshot"]
|
||||
backoff_states = task_result["hardening_checks"]["otp_timeout_backoff"]["states"]
|
||||
|
||||
_assert(second_account["refresh_token"] == "new-second", "Outlook refresh_token 未更新")
|
||||
_assert(batch_snapshot["completed"] == 3, "批量 completed 计数异常")
|
||||
_assert(batch_snapshot["success"] == 2, "批量 success 计数异常")
|
||||
_assert(batch_snapshot["failed"] == 1, "批量 failed 计数异常")
|
||||
_assert(batch_snapshot["status"] == "completed", "批量状态异常")
|
||||
_assert(batch_snapshot["finished"] is True, "批量 finished 标记异常")
|
||||
_assert(backoff_states[-1]["delay_seconds"] == 3600, "OTP 深度冷却未生效")
|
||||
_assert(backoff_states[-1]["failures"] == 3, "OTP 连续失败次数异常")
|
||||
|
||||
report["database"] = {
|
||||
"task_uuid": task_uuid,
|
||||
"batch_id": batch_id,
|
||||
"seeded_account": seeded,
|
||||
"tokenless_account": tokenless,
|
||||
"partial_account": partial,
|
||||
"task_result": task_result,
|
||||
"outlook_second_account": second_account,
|
||||
}
|
||||
|
||||
|
||||
def run_live_mode(base_url: str, ws_url: str, db_path: Path, report_path: Path) -> None:
|
||||
report: Dict[str, Any] = {"mode": "live", "base_url": base_url, "db_path": str(db_path)}
|
||||
with httpx.Client(base_url=base_url, timeout=httpx.Timeout(10, read=30)) as client:
|
||||
_health_check(client, report)
|
||||
|
||||
create_response = client.post(
|
||||
"/api/registration/create",
|
||||
json={
|
||||
"email_service_type": "tempmail",
|
||||
"start_delay_ms": 600,
|
||||
"log_delay_ms": 150,
|
||||
},
|
||||
)
|
||||
create_response.raise_for_status()
|
||||
created = create_response.json()
|
||||
task_uuid = created["task"]["task_uuid"]
|
||||
batch_id = created["batch_id"]
|
||||
checks = created["checks"]
|
||||
report["create"] = created
|
||||
|
||||
ws_report = asyncio.run(_collect_task_websocket(ws_url, task_uuid))
|
||||
report["websocket"] = ws_report
|
||||
_assert(ws_report["final_status"] == "completed", "WebSocket 未收到 completed 状态")
|
||||
_assert(ws_report["log_count"] >= 4, "WebSocket 日志数量不足")
|
||||
_assert(ws_report["live_log_count"] >= 1, "未捕获到实时日志广播")
|
||||
|
||||
task_payload = _poll_task_completion(client, task_uuid)
|
||||
report["task"] = task_payload
|
||||
runtime_checks = {
|
||||
**checks,
|
||||
"outlook_service_id": task_payload["result"]["hardening_checks"]["outlook_refresh"]["service_id"],
|
||||
"backoff_service_id": task_payload["result"]["hardening_checks"]["otp_timeout_backoff"]["service_id"],
|
||||
}
|
||||
|
||||
batch_response = client.get(f"/api/registration/batch/{batch_id}")
|
||||
batch_response.raise_for_status()
|
||||
report["batch_api"] = batch_response.json()
|
||||
_assert(report["batch_api"]["completed"] == 3, "批量状态 API completed 异常")
|
||||
_assert(report["batch_api"]["success"] == 2, "批量状态 API success 异常")
|
||||
_assert(report["batch_api"]["failed"] == 1, "批量状态 API failed 异常")
|
||||
_assert(report["batch_api"]["finished"] is True, "批量状态 API finished 异常")
|
||||
|
||||
_validate_live_database(db_path, task_uuid, batch_id, runtime_checks, report)
|
||||
_write_json(report_path, report)
|
||||
print(json.dumps(report, ensure_ascii=False, indent=2))
|
||||
|
||||
|
||||
def run_prepare_recovery_mode(db_path: Path, state_path: Path) -> None:
|
||||
stale_task_uuid = f"stale-{uuid.uuid4()}"
|
||||
now = time.strftime("%Y-%m-%d %H:%M:%S")
|
||||
with _connect_db(db_path) as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO registration_tasks (task_uuid, status, logs, created_at, started_at)
|
||||
VALUES (?, 'running', '[00:00:00] stale task', ?, ?)
|
||||
""",
|
||||
(stale_task_uuid, now, now),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
payload = {
|
||||
"stale_task_uuid": stale_task_uuid,
|
||||
"db_path": str(db_path),
|
||||
"prepared_at": now,
|
||||
}
|
||||
_write_json(state_path, payload)
|
||||
print(json.dumps(payload, ensure_ascii=False, indent=2))
|
||||
|
||||
|
||||
def run_verify_recovery_mode(base_url: str, db_path: Path, state_path: Path, report_path: Path) -> None:
|
||||
state = _load_json(state_path)
|
||||
report: Dict[str, Any] = {
|
||||
"mode": "verify-recovery",
|
||||
"base_url": base_url,
|
||||
"db_path": str(db_path),
|
||||
"state": state,
|
||||
}
|
||||
|
||||
with httpx.Client(base_url=base_url, timeout=httpx.Timeout(10, read=30)) as client:
|
||||
_health_check(client, report)
|
||||
|
||||
with _connect_db(db_path) as conn:
|
||||
stale_task = _fetchone_dict(
|
||||
conn,
|
||||
"SELECT task_uuid, status, error_message, logs, completed_at FROM registration_tasks WHERE task_uuid = ?",
|
||||
(state["stale_task_uuid"],),
|
||||
)
|
||||
|
||||
_assert(stale_task.get("status") == "failed", "僵尸任务未在重启后标记为 failed")
|
||||
_assert(stale_task.get("error_message") == STALE_ERROR, "僵尸任务 error_message 不匹配")
|
||||
_assert(STALE_ERROR in (stale_task.get("logs") or ""), "僵尸任务日志未追加系统收敛说明")
|
||||
_assert(bool(stale_task.get("completed_at")), "僵尸任务 completed_at 缺失")
|
||||
|
||||
report["recovery"] = stale_task
|
||||
_write_json(report_path, report)
|
||||
print(json.dumps(report, ensure_ascii=False, indent=2))
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="真实服务功能可用性验证脚本")
|
||||
parser.add_argument("--mode", choices=["live", "prepare-recovery", "verify-recovery"], required=True)
|
||||
parser.add_argument("--base-url", default="http://127.0.0.1:15555")
|
||||
parser.add_argument("--ws-url", default="ws://127.0.0.1:15555")
|
||||
parser.add_argument("--db-path", required=True)
|
||||
parser.add_argument("--report-path", default="tests_runtime/runtime_functionality_report.json")
|
||||
parser.add_argument("--state-path", default="tests_runtime/runtime_recovery_state.json")
|
||||
args = parser.parse_args()
|
||||
|
||||
db_path = Path(args.db_path).resolve()
|
||||
report_path = Path(args.report_path).resolve()
|
||||
state_path = Path(args.state_path).resolve()
|
||||
|
||||
if args.mode == "live":
|
||||
run_live_mode(args.base_url, args.ws_url, db_path, report_path)
|
||||
return
|
||||
if args.mode == "prepare-recovery":
|
||||
run_prepare_recovery_mode(db_path, state_path)
|
||||
return
|
||||
run_verify_recovery_mode(args.base_url, db_path, state_path, report_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
106
tests/test_account_token_sync_status.py
Normal file
106
tests/test_account_token_sync_status.py
Normal file
@@ -0,0 +1,106 @@
|
||||
from src.database import crud
|
||||
from src.database.session import DatabaseSessionManager
|
||||
|
||||
|
||||
def test_create_account_marks_token_sync_pending_when_tokens_persist(tmp_path):
|
||||
manager = DatabaseSessionManager(f"sqlite:///{tmp_path}/test.db")
|
||||
manager.create_tables()
|
||||
manager.migrate_tables()
|
||||
|
||||
with manager.session_scope() as session:
|
||||
account = crud.create_account(
|
||||
session,
|
||||
email="sync@example.com",
|
||||
email_service="tempmail",
|
||||
access_token="access-token",
|
||||
refresh_token="refresh-token",
|
||||
)
|
||||
|
||||
assert account.token_sync_status == "pending"
|
||||
assert account.token_sync_updated_at is not None
|
||||
|
||||
|
||||
def test_update_account_marks_token_sync_pending_when_tokens_change(tmp_path):
|
||||
manager = DatabaseSessionManager(f"sqlite:///{tmp_path}/test.db")
|
||||
manager.create_tables()
|
||||
manager.migrate_tables()
|
||||
|
||||
with manager.session_scope() as session:
|
||||
account = crud.create_account(
|
||||
session,
|
||||
email="nosync@example.com",
|
||||
email_service="tempmail",
|
||||
)
|
||||
|
||||
assert account.token_sync_status == "not_ready"
|
||||
|
||||
updated = crud.update_account(
|
||||
session,
|
||||
account.id,
|
||||
access_token="new-access-token",
|
||||
)
|
||||
|
||||
assert updated is not None
|
||||
assert updated.token_sync_status == "pending"
|
||||
assert updated.token_sync_updated_at is not None
|
||||
|
||||
|
||||
def test_update_account_preserves_pending_status_when_other_tokens_remain(tmp_path):
|
||||
manager = DatabaseSessionManager(f"sqlite:///{tmp_path}/test.db")
|
||||
manager.create_tables()
|
||||
manager.migrate_tables()
|
||||
|
||||
with manager.session_scope() as session:
|
||||
account = crud.create_account(
|
||||
session,
|
||||
email="partial-sync@example.com",
|
||||
email_service="tempmail",
|
||||
access_token="access-token",
|
||||
refresh_token="refresh-token",
|
||||
)
|
||||
|
||||
updated = crud.update_account(
|
||||
session,
|
||||
account.id,
|
||||
refresh_token="",
|
||||
)
|
||||
|
||||
assert updated is not None
|
||||
assert updated.access_token == "access-token"
|
||||
assert updated.refresh_token == ""
|
||||
assert updated.token_sync_status == "pending"
|
||||
assert updated.token_sync_updated_at is not None
|
||||
|
||||
|
||||
def test_update_outlook_refresh_token_persists_nested_config_changes(tmp_path):
|
||||
manager = DatabaseSessionManager(f"sqlite:///{tmp_path}/test.db")
|
||||
manager.create_tables()
|
||||
manager.migrate_tables()
|
||||
|
||||
with manager.session_scope() as session:
|
||||
service = crud.create_email_service(
|
||||
session,
|
||||
service_type="outlook",
|
||||
name="outlook-service",
|
||||
config={
|
||||
"accounts": [
|
||||
{"email": "first@example.com", "refresh_token": "old-first"},
|
||||
{"email": "second@example.com", "refresh_token": "old-second"},
|
||||
]
|
||||
},
|
||||
)
|
||||
service_id = service.id
|
||||
|
||||
crud.update_outlook_refresh_token(
|
||||
session,
|
||||
service_id=service_id,
|
||||
email="second@example.com",
|
||||
new_refresh_token="new-second",
|
||||
)
|
||||
|
||||
with manager.session_scope() as session:
|
||||
reloaded = crud.get_email_service_by_id(session, service_id)
|
||||
|
||||
assert reloaded is not None
|
||||
assert reloaded.config["accounts"][0]["refresh_token"] == "old-first"
|
||||
assert reloaded.config["accounts"][1]["refresh_token"] == "new-second"
|
||||
86
tests/test_batch_task_manager.py
Normal file
86
tests/test_batch_task_manager.py
Normal file
@@ -0,0 +1,86 @@
|
||||
import asyncio
|
||||
from contextlib import contextmanager
|
||||
from types import SimpleNamespace
|
||||
|
||||
from src.web.routes import registration as registration_routes
|
||||
from src.web.task_manager import task_manager
|
||||
|
||||
|
||||
def test_init_batch_state_persists_state_in_task_manager():
|
||||
batch_id = "batch-sync-init"
|
||||
task_uuids = ["task-1", "task-2", "task-3"]
|
||||
|
||||
registration_routes._init_batch_state(batch_id, task_uuids)
|
||||
|
||||
manager_snapshot = task_manager.get_batch_status(batch_id)
|
||||
|
||||
assert manager_snapshot is not None
|
||||
assert manager_snapshot["task_uuids"] == task_uuids
|
||||
assert manager_snapshot["total"] == 3
|
||||
assert manager_snapshot["completed"] == 0
|
||||
assert manager_snapshot["success"] == 0
|
||||
assert manager_snapshot["failed"] == 0
|
||||
assert manager_snapshot["finished"] is False
|
||||
assert manager_snapshot["status"] == "running"
|
||||
assert task_manager.get_batch_logs(batch_id) == []
|
||||
|
||||
|
||||
def test_run_batch_parallel_keeps_counter_updates_in_sync(monkeypatch):
|
||||
batch_id = "batch-sync-parallel"
|
||||
task_uuids = ["task-ok-1", "task-fail-1", "task-ok-2"]
|
||||
task_statuses = {
|
||||
"task-ok-1": "completed",
|
||||
"task-fail-1": "failed",
|
||||
"task-ok-2": "completed",
|
||||
}
|
||||
|
||||
async def fake_run_registration_task(
|
||||
task_uuid,
|
||||
email_service_type,
|
||||
proxy,
|
||||
email_service_config,
|
||||
email_service_id,
|
||||
log_prefix="",
|
||||
batch_id="",
|
||||
auto_upload_cpa=False,
|
||||
cpa_service_ids=None,
|
||||
auto_upload_sub2api=False,
|
||||
sub2api_service_ids=None,
|
||||
auto_upload_tm=False,
|
||||
tm_service_ids=None,
|
||||
):
|
||||
assert task_uuid in task_statuses
|
||||
|
||||
@contextmanager
|
||||
def fake_get_db():
|
||||
yield object()
|
||||
|
||||
def fake_get_registration_task(db, task_uuid):
|
||||
status = task_statuses[task_uuid]
|
||||
error_message = None if status == "completed" else f"{task_uuid}-error"
|
||||
return SimpleNamespace(status=status, error_message=error_message)
|
||||
|
||||
monkeypatch.setattr(registration_routes, "run_registration_task", fake_run_registration_task)
|
||||
monkeypatch.setattr(registration_routes, "get_db", fake_get_db)
|
||||
monkeypatch.setattr(registration_routes.crud, "get_registration_task", fake_get_registration_task)
|
||||
|
||||
asyncio.run(
|
||||
registration_routes.run_batch_parallel(
|
||||
batch_id=batch_id,
|
||||
task_uuids=task_uuids,
|
||||
email_service_type="tempmail",
|
||||
proxy=None,
|
||||
email_service_config=None,
|
||||
email_service_id=None,
|
||||
concurrency=2,
|
||||
)
|
||||
)
|
||||
|
||||
manager_snapshot = task_manager.get_batch_status(batch_id)
|
||||
|
||||
assert manager_snapshot is not None
|
||||
assert manager_snapshot["completed"] == 3
|
||||
assert manager_snapshot["success"] == 2
|
||||
assert manager_snapshot["failed"] == 1
|
||||
assert manager_snapshot["finished"] is True
|
||||
assert manager_snapshot["status"] == "completed"
|
||||
133
tests/test_batch_websocket_fallback.cjs
Normal file
133
tests/test_batch_websocket_fallback.cjs
Normal file
@@ -0,0 +1,133 @@
|
||||
const test = require('node:test');
|
||||
const assert = require('node:assert/strict');
|
||||
const fs = require('node:fs');
|
||||
const vm = require('node:vm');
|
||||
|
||||
const APP_JS_PATH = '/Users/zhoukailian/.config/superpowers/worktrees/codex-manager/repro-batch-monitor/static/js/app.js';
|
||||
|
||||
function createElementStub() {
|
||||
return {
|
||||
style: {},
|
||||
dataset: {},
|
||||
value: '',
|
||||
checked: false,
|
||||
disabled: false,
|
||||
innerHTML: '',
|
||||
textContent: '',
|
||||
className: '',
|
||||
appendChild() {},
|
||||
addEventListener() {},
|
||||
removeEventListener() {},
|
||||
querySelector() {
|
||||
return createElementStub();
|
||||
},
|
||||
querySelectorAll() {
|
||||
return [];
|
||||
},
|
||||
closest() {
|
||||
return null;
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
function createSandbox() {
|
||||
const sandbox = {
|
||||
console,
|
||||
setTimeout,
|
||||
clearTimeout,
|
||||
setInterval: () => 1,
|
||||
clearInterval: () => {},
|
||||
Event: class Event {
|
||||
constructor(type) {
|
||||
this.type = type;
|
||||
}
|
||||
},
|
||||
document: {
|
||||
getElementById() {
|
||||
return createElementStub();
|
||||
},
|
||||
createElement() {
|
||||
return createElementStub();
|
||||
},
|
||||
addEventListener() {},
|
||||
querySelector() {
|
||||
return createElementStub();
|
||||
},
|
||||
querySelectorAll() {
|
||||
return [];
|
||||
},
|
||||
},
|
||||
sessionStorage: {
|
||||
getItem() {
|
||||
return null;
|
||||
},
|
||||
setItem() {},
|
||||
removeItem() {},
|
||||
},
|
||||
toast: {
|
||||
info() {},
|
||||
success() {},
|
||||
warning() {},
|
||||
error() {},
|
||||
},
|
||||
api: {
|
||||
get() {
|
||||
throw new Error('api.get should not be called in this test');
|
||||
},
|
||||
post() {
|
||||
throw new Error('api.post should not be called in this test');
|
||||
},
|
||||
},
|
||||
window: null,
|
||||
WebSocket: null,
|
||||
};
|
||||
|
||||
sandbox.window = sandbox;
|
||||
sandbox.window.location = { protocol: 'http:', host: '127.0.0.1:8003' };
|
||||
|
||||
vm.createContext(sandbox);
|
||||
vm.runInContext(fs.readFileSync(APP_JS_PATH, 'utf8'), sandbox, { filename: 'app.js' });
|
||||
|
||||
return sandbox;
|
||||
}
|
||||
|
||||
async function runFallback(mode) {
|
||||
const sandbox = createSandbox();
|
||||
|
||||
vm.runInContext(
|
||||
`
|
||||
var __calls = [];
|
||||
currentBatch = { batch_id: 'test-batch' };
|
||||
isOutlookBatchMode = ${mode === 'outlook' ? 'true' : 'false'};
|
||||
batchCompleted = false;
|
||||
batchFinalStatus = null;
|
||||
startOutlookBatchPolling = function(batchId) { __calls.push(['outlook', batchId]); };
|
||||
startBatchPolling = function(batchId) { __calls.push(['batch', batchId]); };
|
||||
WebSocket = function(url) {
|
||||
this.url = url;
|
||||
this.readyState = 0;
|
||||
setTimeout(() => {
|
||||
if (this.onerror) {
|
||||
this.onerror({ type: 'error' });
|
||||
}
|
||||
}, 0);
|
||||
};
|
||||
WebSocket.OPEN = 1;
|
||||
connectBatchWebSocket('test-batch');
|
||||
`,
|
||||
sandbox,
|
||||
);
|
||||
|
||||
await new Promise((resolve) => setTimeout(resolve, 20));
|
||||
return JSON.parse(vm.runInContext('JSON.stringify(__calls)', sandbox));
|
||||
}
|
||||
|
||||
test('normal batch websocket fallback uses standard batch polling', async () => {
|
||||
const calls = await runFallback('batch');
|
||||
assert.deepEqual(calls, [['batch', 'test-batch']]);
|
||||
});
|
||||
|
||||
test('outlook batch websocket fallback uses outlook polling', async () => {
|
||||
const calls = await runFallback('outlook');
|
||||
assert.deepEqual(calls, [['outlook', 'test-batch']]);
|
||||
});
|
||||
103
tests/test_email_service_backoff.py
Normal file
103
tests/test_email_service_backoff.py
Normal file
@@ -0,0 +1,103 @@
|
||||
from src.services.base import (
|
||||
BaseEmailService,
|
||||
EmailProviderBackoffState,
|
||||
EmailServiceType,
|
||||
OTPTimeoutEmailServiceError,
|
||||
RateLimitedEmailServiceError,
|
||||
apply_adaptive_backoff,
|
||||
calculate_adaptive_backoff_delay,
|
||||
)
|
||||
|
||||
|
||||
class DummyEmailService(BaseEmailService):
|
||||
def __init__(self):
|
||||
super().__init__(EmailServiceType.DUCK_MAIL, "dummy")
|
||||
|
||||
def create_email(self, config=None):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_verification_code(
|
||||
self,
|
||||
email,
|
||||
email_id=None,
|
||||
timeout=120,
|
||||
pattern=r"(?<!\d)(\d{6})(?!\d)",
|
||||
otp_sent_at=None,
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
def list_emails(self, **kwargs):
|
||||
return []
|
||||
|
||||
def delete_email(self, email_id: str) -> bool:
|
||||
return False
|
||||
|
||||
def check_health(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def test_calculate_adaptive_backoff_delay_uses_failure_count_progression():
|
||||
assert calculate_adaptive_backoff_delay(0) == 30
|
||||
assert calculate_adaptive_backoff_delay(1) == 30
|
||||
assert calculate_adaptive_backoff_delay(2) == 60
|
||||
assert calculate_adaptive_backoff_delay(3) == 120
|
||||
|
||||
|
||||
def test_apply_adaptive_backoff_tracks_timeout_failures_to_one_hour():
|
||||
state = EmailProviderBackoffState()
|
||||
|
||||
first = apply_adaptive_backoff(
|
||||
state,
|
||||
OTPTimeoutEmailServiceError("等待验证码超时", error_code="OTP_TIMEOUT_SECONDARY"),
|
||||
now=1000.0,
|
||||
)
|
||||
second = apply_adaptive_backoff(
|
||||
first,
|
||||
OTPTimeoutEmailServiceError("等待验证码超时", error_code="OTP_TIMEOUT_SECONDARY"),
|
||||
now=1031.0,
|
||||
)
|
||||
third = apply_adaptive_backoff(
|
||||
second,
|
||||
OTPTimeoutEmailServiceError("等待验证码超时", error_code="OTP_TIMEOUT_SECONDARY"),
|
||||
now=1092.0,
|
||||
)
|
||||
|
||||
assert first.failures == 1
|
||||
assert first.delay_seconds == 30
|
||||
assert first.opened_until == 1030.0
|
||||
|
||||
assert second.failures == 2
|
||||
assert second.delay_seconds == 60
|
||||
assert second.opened_until == 1091.0
|
||||
|
||||
assert third.failures == 3
|
||||
assert third.delay_seconds == 3600
|
||||
assert third.opened_until == 4692.0
|
||||
|
||||
|
||||
def test_apply_adaptive_backoff_keeps_normal_rate_limit_on_exponential_curve():
|
||||
state = EmailProviderBackoffState(failures=2, delay_seconds=60, opened_until=1060.0)
|
||||
|
||||
next_state = apply_adaptive_backoff(
|
||||
state,
|
||||
RateLimitedEmailServiceError("请求失败: 429", retry_after=7),
|
||||
now=1100.0,
|
||||
)
|
||||
|
||||
assert next_state.failures == 3
|
||||
assert next_state.delay_seconds == 120
|
||||
assert next_state.opened_until == 1220.0
|
||||
assert next_state.retry_after == 7
|
||||
|
||||
|
||||
def test_update_status_resets_provider_backoff_after_success():
|
||||
service = DummyEmailService()
|
||||
|
||||
service.update_status(False, RateLimitedEmailServiceError("请求失败: 429"))
|
||||
|
||||
assert service.provider_backoff_state.failures == 1
|
||||
assert service.provider_backoff_state.delay_seconds == 30
|
||||
|
||||
service.update_status(True)
|
||||
|
||||
assert service.provider_backoff_state == EmailProviderBackoffState()
|
||||
82
tests/test_register_protocol_baseline.py
Normal file
82
tests/test_register_protocol_baseline.py
Normal file
@@ -0,0 +1,82 @@
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
|
||||
import src.core.register as register_module
|
||||
from src.config.constants import OPENAI_PAGE_TYPES
|
||||
from src.core.register import RegistrationEngine
|
||||
from src.services import EmailServiceType
|
||||
|
||||
|
||||
class DummySettings:
|
||||
openai_client_id = "client-id"
|
||||
openai_auth_url = "https://auth.example.test"
|
||||
openai_token_url = "https://token.example.test"
|
||||
openai_redirect_uri = "https://callback.example.test"
|
||||
openai_scope = "openid profile email"
|
||||
|
||||
|
||||
class FakeResponse:
|
||||
def __init__(self, status_code=200, payload=None, text=""):
|
||||
self.status_code = status_code
|
||||
self._payload = payload or {}
|
||||
self.text = text
|
||||
|
||||
def json(self):
|
||||
return self._payload
|
||||
|
||||
|
||||
class FakeSession:
|
||||
def __init__(self, response):
|
||||
self.response = response
|
||||
self.calls = []
|
||||
|
||||
def post(self, url, **kwargs):
|
||||
self.calls.append({
|
||||
"url": url,
|
||||
**kwargs,
|
||||
})
|
||||
return self.response
|
||||
|
||||
|
||||
def _build_engine(monkeypatch):
|
||||
monkeypatch.setattr(register_module, "get_settings", lambda: DummySettings())
|
||||
email_service = SimpleNamespace(service_type=EmailServiceType.DUCK_MAIL)
|
||||
return RegistrationEngine(email_service=email_service)
|
||||
|
||||
|
||||
def test_submit_signup_form_uses_stable_protocol_body(monkeypatch):
|
||||
engine = _build_engine(monkeypatch)
|
||||
session = FakeSession(FakeResponse(
|
||||
status_code=200,
|
||||
payload={"page": {"type": OPENAI_PAGE_TYPES["PASSWORD_REGISTRATION"]}},
|
||||
))
|
||||
engine.session = session
|
||||
engine.email = "tester@example.com"
|
||||
|
||||
result = engine._submit_signup_form("did-1", None)
|
||||
|
||||
assert result.success is True
|
||||
assert result.is_existing_account is False
|
||||
assert (
|
||||
session.calls[0]["data"]
|
||||
== '{"username":{"value":"tester@example.com","kind":"email"},"screen_hint":"signup"}'
|
||||
)
|
||||
|
||||
|
||||
def test_register_password_uses_stable_protocol_body(monkeypatch):
|
||||
engine = _build_engine(monkeypatch)
|
||||
session = FakeSession(FakeResponse(status_code=200))
|
||||
engine.session = session
|
||||
engine.email = "tester@example.com"
|
||||
monkeypatch.setattr(engine, "_generate_password", lambda length=0: "Pass12345")
|
||||
|
||||
success, password = engine._register_password()
|
||||
|
||||
assert success is True
|
||||
assert password == "Pass12345"
|
||||
assert session.calls[0]["data"] == json.dumps(
|
||||
{
|
||||
"password": "Pass12345",
|
||||
"username": "tester@example.com",
|
||||
}
|
||||
)
|
||||
540
tests/test_registration_email_service_failover.py
Normal file
540
tests/test_registration_email_service_failover.py
Normal file
@@ -0,0 +1,540 @@
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
import threading
|
||||
from types import SimpleNamespace
|
||||
|
||||
import src.services.base as base_module
|
||||
from src.core.register import (
|
||||
ERROR_OTP_TIMEOUT_SECONDARY,
|
||||
PhaseResult,
|
||||
RegistrationResult,
|
||||
)
|
||||
from src.database.models import Base, EmailService, RegistrationTask
|
||||
from src.database.session import DatabaseSessionManager
|
||||
from src.services import EmailServiceType
|
||||
from src.services.base import BaseEmailService, EmailProviderBackoffState
|
||||
from src.web.routes import registration as registration_routes
|
||||
|
||||
|
||||
class DummyTaskManager:
|
||||
def __init__(self):
|
||||
self.status_updates = []
|
||||
self.logs = {}
|
||||
|
||||
def is_cancelled(self, task_uuid):
|
||||
return False
|
||||
|
||||
def update_status(self, task_uuid, status, email=None, error=None, **kwargs):
|
||||
self.status_updates.append((task_uuid, status, email, error, kwargs))
|
||||
|
||||
def create_log_callback(self, task_uuid, prefix="", batch_id=""):
|
||||
def callback(message):
|
||||
self.logs.setdefault(task_uuid, []).append(message)
|
||||
return callback
|
||||
|
||||
|
||||
class BackoffAwareEmailService(BaseEmailService):
|
||||
def __init__(self, service_type, config=None, name=None):
|
||||
super().__init__(service_type=service_type, name=name)
|
||||
self.config = config or {}
|
||||
|
||||
def create_email(self, config=None):
|
||||
return {"email": "tester@example.com", "service_id": "svc-1"}
|
||||
|
||||
def get_verification_code(self, **kwargs):
|
||||
return None
|
||||
|
||||
def list_emails(self, **kwargs):
|
||||
return []
|
||||
|
||||
def delete_email(self, email_id: str) -> bool:
|
||||
return True
|
||||
|
||||
def check_health(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def test_registration_task_fails_over_after_rate_limit(monkeypatch):
|
||||
runtime_dir = Path("tests_runtime")
|
||||
runtime_dir.mkdir(exist_ok=True)
|
||||
db_path = runtime_dir / "registration_failover.db"
|
||||
if db_path.exists():
|
||||
db_path.unlink()
|
||||
|
||||
manager = DatabaseSessionManager(f"sqlite:///{db_path}")
|
||||
Base.metadata.create_all(bind=manager.engine)
|
||||
|
||||
task_uuid = "task-rate-limit-failover"
|
||||
with manager.session_scope() as session:
|
||||
session.add(RegistrationTask(task_uuid=task_uuid, status="pending"))
|
||||
session.add_all([
|
||||
EmailService(
|
||||
service_type="duck_mail",
|
||||
name="duck-primary",
|
||||
config={
|
||||
"base_url": "https://mail-1.example.test",
|
||||
"default_domain": "mail.example.test",
|
||||
},
|
||||
enabled=True,
|
||||
priority=0,
|
||||
),
|
||||
EmailService(
|
||||
service_type="duck_mail",
|
||||
name="duck-secondary",
|
||||
config={
|
||||
"base_url": "https://mail-2.example.test",
|
||||
"default_domain": "mail.example.test",
|
||||
},
|
||||
enabled=True,
|
||||
priority=1,
|
||||
),
|
||||
])
|
||||
|
||||
@contextmanager
|
||||
def fake_get_db():
|
||||
session = manager.SessionLocal()
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
class DummySettings:
|
||||
pass
|
||||
|
||||
attempts = []
|
||||
|
||||
class FakeRegistrationEngine:
|
||||
def __init__(self, email_service, proxy_url=None, callback_logger=None, task_uuid=None):
|
||||
self.email_service = email_service
|
||||
self.phase_history = []
|
||||
|
||||
def run(self):
|
||||
attempts.append(self.email_service.name)
|
||||
if self.email_service.name == "duck-primary":
|
||||
self.phase_history = [
|
||||
PhaseResult(
|
||||
phase="email_prepare",
|
||||
success=False,
|
||||
error_message="创建邮箱失败",
|
||||
error_code="EMAIL_PROVIDER_RATE_LIMITED",
|
||||
retryable=True,
|
||||
next_action="switch_provider",
|
||||
provider_backoff=EmailProviderBackoffState(
|
||||
failures=1,
|
||||
delay_seconds=30,
|
||||
opened_until=1030.0,
|
||||
retry_after=7,
|
||||
last_error="请求失败: 429",
|
||||
),
|
||||
)
|
||||
]
|
||||
return RegistrationResult(
|
||||
success=False,
|
||||
error_message="创建邮箱失败: 请求失败: 429",
|
||||
logs=[],
|
||||
)
|
||||
self.phase_history = [
|
||||
PhaseResult(
|
||||
phase="email_prepare",
|
||||
success=True,
|
||||
provider_backoff=EmailProviderBackoffState(),
|
||||
)
|
||||
]
|
||||
return RegistrationResult(
|
||||
success=True,
|
||||
email="tester@example.com",
|
||||
password="Pass12345",
|
||||
account_id="acct-1",
|
||||
workspace_id="ws-1",
|
||||
access_token="access-token",
|
||||
refresh_token="refresh-token",
|
||||
id_token="id-token",
|
||||
logs=[],
|
||||
)
|
||||
|
||||
def save_to_database(self, result):
|
||||
return True
|
||||
|
||||
def close(self):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(registration_routes, "get_db", fake_get_db)
|
||||
monkeypatch.setattr(registration_routes, "get_settings", lambda: DummySettings())
|
||||
monkeypatch.setattr(registration_routes, "task_manager", DummyTaskManager())
|
||||
monkeypatch.setattr(registration_routes, "RegistrationEngine", FakeRegistrationEngine)
|
||||
monkeypatch.setattr(
|
||||
registration_routes.EmailServiceFactory,
|
||||
"create",
|
||||
lambda service_type, config, name=None: SimpleNamespace(
|
||||
service_type=service_type,
|
||||
name=name or service_type.value,
|
||||
config=config,
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(registration_routes, "update_proxy_usage", lambda db, proxy_id: None)
|
||||
registration_routes.email_service_circuit_breakers.clear()
|
||||
|
||||
registration_routes._run_sync_registration_task(
|
||||
task_uuid=task_uuid,
|
||||
email_service_type=EmailServiceType.DUCK_MAIL.value,
|
||||
proxy=None,
|
||||
email_service_config=None,
|
||||
)
|
||||
|
||||
with manager.session_scope() as session:
|
||||
task = session.query(RegistrationTask).filter(RegistrationTask.task_uuid == task_uuid).first()
|
||||
services = session.query(EmailService).order_by(EmailService.priority.asc()).all()
|
||||
task_status = task.status
|
||||
task_email_service_id = task.email_service_id
|
||||
primary_service_id = services[0].id
|
||||
secondary_service_id = services[1].id
|
||||
|
||||
assert attempts == ["duck-primary", "duck-secondary"]
|
||||
assert task_status == "completed"
|
||||
assert task_email_service_id == secondary_service_id
|
||||
assert registration_routes.email_service_circuit_breakers[primary_service_id].failures == 1
|
||||
assert registration_routes.email_service_circuit_breakers[primary_service_id].delay_seconds == 30
|
||||
|
||||
|
||||
def test_registration_task_enters_deep_cooldown_after_three_otp_timeouts(monkeypatch):
|
||||
runtime_dir = Path("tests_runtime")
|
||||
runtime_dir.mkdir(exist_ok=True)
|
||||
db_path = runtime_dir / "registration_otp_timeout_backoff.db"
|
||||
if db_path.exists():
|
||||
db_path.unlink()
|
||||
|
||||
manager = DatabaseSessionManager(f"sqlite:///{db_path}")
|
||||
Base.metadata.create_all(bind=manager.engine)
|
||||
|
||||
task_uuids = [
|
||||
"task-otp-timeout-1",
|
||||
"task-otp-timeout-2",
|
||||
"task-otp-timeout-3",
|
||||
]
|
||||
with manager.session_scope() as session:
|
||||
session.add_all([RegistrationTask(task_uuid=task_uuid, status="pending") for task_uuid in task_uuids])
|
||||
session.add(
|
||||
EmailService(
|
||||
service_type="duck_mail",
|
||||
name="duck-primary",
|
||||
config={
|
||||
"base_url": "https://mail-1.example.test",
|
||||
"default_domain": "mail.example.test",
|
||||
},
|
||||
enabled=True,
|
||||
priority=0,
|
||||
)
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
def fake_get_db():
|
||||
session = manager.SessionLocal()
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
class DummySettings:
|
||||
pass
|
||||
|
||||
current_time = {"value": 1000.0}
|
||||
|
||||
class FakeRegistrationEngine:
|
||||
def __init__(self, email_service, proxy_url=None, callback_logger=None, task_uuid=None):
|
||||
self.email_service = email_service
|
||||
self.phase_history = []
|
||||
|
||||
def run(self):
|
||||
self.phase_history = [
|
||||
PhaseResult(
|
||||
phase="email_prepare",
|
||||
success=True,
|
||||
provider_backoff=EmailProviderBackoffState(),
|
||||
)
|
||||
]
|
||||
return RegistrationResult(
|
||||
success=False,
|
||||
error_message="等待验证码超时",
|
||||
error_code=ERROR_OTP_TIMEOUT_SECONDARY,
|
||||
logs=[],
|
||||
)
|
||||
|
||||
def save_to_database(self, result):
|
||||
return True
|
||||
|
||||
def close(self):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(registration_routes, "get_db", fake_get_db)
|
||||
monkeypatch.setattr(registration_routes, "get_settings", lambda: DummySettings())
|
||||
monkeypatch.setattr(registration_routes, "task_manager", DummyTaskManager())
|
||||
monkeypatch.setattr(registration_routes, "RegistrationEngine", FakeRegistrationEngine)
|
||||
monkeypatch.setattr(
|
||||
registration_routes.EmailServiceFactory,
|
||||
"create",
|
||||
lambda service_type, config, name=None: BackoffAwareEmailService(
|
||||
service_type=service_type,
|
||||
config=config,
|
||||
name=name,
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(registration_routes, "update_proxy_usage", lambda db, proxy_id: None)
|
||||
monkeypatch.setattr(base_module.time, "time", lambda: current_time["value"])
|
||||
registration_routes.email_service_circuit_breakers.clear()
|
||||
|
||||
with manager.session_scope() as session:
|
||||
service_id = session.query(EmailService.id).filter(EmailService.name == "duck-primary").scalar()
|
||||
|
||||
expected_delays = [30, 60, 3600]
|
||||
for attempt_index, task_uuid in enumerate(task_uuids, start=1):
|
||||
registration_routes._run_sync_registration_task(
|
||||
task_uuid=task_uuid,
|
||||
email_service_type=EmailServiceType.DUCK_MAIL.value,
|
||||
proxy=None,
|
||||
email_service_config=None,
|
||||
)
|
||||
|
||||
with manager.session_scope() as session:
|
||||
task = session.query(RegistrationTask).filter(RegistrationTask.task_uuid == task_uuid).first()
|
||||
assert task.status == "failed"
|
||||
assert task.error_message == "等待验证码超时"
|
||||
|
||||
state = registration_routes.email_service_circuit_breakers[service_id]
|
||||
assert state.failures == attempt_index
|
||||
assert state.delay_seconds == expected_delays[attempt_index - 1]
|
||||
assert state.opened_until == current_time["value"] + expected_delays[attempt_index - 1]
|
||||
|
||||
if attempt_index < len(task_uuids):
|
||||
current_time["value"] = state.opened_until + 1
|
||||
|
||||
final_state = registration_routes.email_service_circuit_breakers[service_id]
|
||||
assert final_state.delay_seconds == 3600
|
||||
assert final_state.failures == 3
|
||||
|
||||
|
||||
def test_registration_task_success_clears_email_service_backoff(monkeypatch):
|
||||
runtime_dir = Path("tests_runtime")
|
||||
runtime_dir.mkdir(exist_ok=True)
|
||||
db_path = runtime_dir / "registration_success_clears_backoff.db"
|
||||
if db_path.exists():
|
||||
db_path.unlink()
|
||||
|
||||
manager = DatabaseSessionManager(f"sqlite:///{db_path}")
|
||||
Base.metadata.create_all(bind=manager.engine)
|
||||
|
||||
task_uuid = "task-success-clears-backoff"
|
||||
with manager.session_scope() as session:
|
||||
session.add(RegistrationTask(task_uuid=task_uuid, status="pending"))
|
||||
session.add(
|
||||
EmailService(
|
||||
service_type="duck_mail",
|
||||
name="duck-primary",
|
||||
config={
|
||||
"base_url": "https://mail-1.example.test",
|
||||
"default_domain": "mail.example.test",
|
||||
},
|
||||
enabled=True,
|
||||
priority=0,
|
||||
)
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
def fake_get_db():
|
||||
session = manager.SessionLocal()
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
class DummySettings:
|
||||
pass
|
||||
|
||||
class FakeRegistrationEngine:
|
||||
def __init__(self, email_service, proxy_url=None, callback_logger=None, task_uuid=None):
|
||||
self.email_service = email_service
|
||||
self.phase_history = [
|
||||
PhaseResult(
|
||||
phase="email_prepare",
|
||||
success=True,
|
||||
provider_backoff=EmailProviderBackoffState(),
|
||||
)
|
||||
]
|
||||
|
||||
def run(self):
|
||||
return RegistrationResult(
|
||||
success=True,
|
||||
email="tester@example.com",
|
||||
password="Pass12345",
|
||||
account_id="acct-1",
|
||||
workspace_id="ws-1",
|
||||
access_token="access-token",
|
||||
refresh_token="refresh-token",
|
||||
id_token="id-token",
|
||||
logs=[],
|
||||
)
|
||||
|
||||
def save_to_database(self, result):
|
||||
return True
|
||||
|
||||
def close(self):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(registration_routes, "get_db", fake_get_db)
|
||||
monkeypatch.setattr(registration_routes, "get_settings", lambda: DummySettings())
|
||||
monkeypatch.setattr(registration_routes, "task_manager", DummyTaskManager())
|
||||
monkeypatch.setattr(registration_routes, "RegistrationEngine", FakeRegistrationEngine)
|
||||
monkeypatch.setattr(
|
||||
registration_routes.EmailServiceFactory,
|
||||
"create",
|
||||
lambda service_type, config, name=None: BackoffAwareEmailService(
|
||||
service_type=service_type,
|
||||
config=config,
|
||||
name=name,
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(registration_routes, "update_proxy_usage", lambda db, proxy_id: None)
|
||||
registration_routes.email_service_circuit_breakers.clear()
|
||||
|
||||
with manager.session_scope() as session:
|
||||
service_id = session.query(EmailService.id).filter(EmailService.name == "duck-primary").scalar()
|
||||
|
||||
registration_routes.email_service_circuit_breakers[service_id] = EmailProviderBackoffState(
|
||||
failures=2,
|
||||
delay_seconds=60,
|
||||
opened_until=9999.0,
|
||||
last_error="等待验证码超时",
|
||||
)
|
||||
|
||||
registration_routes._run_sync_registration_task(
|
||||
task_uuid=task_uuid,
|
||||
email_service_type=EmailServiceType.DUCK_MAIL.value,
|
||||
proxy=None,
|
||||
email_service_config=None,
|
||||
)
|
||||
|
||||
assert service_id not in registration_routes.email_service_circuit_breakers
|
||||
|
||||
|
||||
def test_registration_task_backoff_failures_do_not_get_lost_under_concurrency(monkeypatch):
|
||||
runtime_dir = Path("tests_runtime")
|
||||
runtime_dir.mkdir(exist_ok=True)
|
||||
db_path = runtime_dir / "registration_backoff_concurrency.db"
|
||||
if db_path.exists():
|
||||
db_path.unlink()
|
||||
|
||||
manager = DatabaseSessionManager(f"sqlite:///{db_path}")
|
||||
Base.metadata.create_all(bind=manager.engine)
|
||||
|
||||
task_uuids = ["task-backoff-1", "task-backoff-2"]
|
||||
with manager.session_scope() as session:
|
||||
for task_uuid in task_uuids:
|
||||
session.add(RegistrationTask(task_uuid=task_uuid, status="pending"))
|
||||
session.add(
|
||||
EmailService(
|
||||
service_type="duck_mail",
|
||||
name="duck-primary",
|
||||
config={
|
||||
"base_url": "https://mail-1.example.test",
|
||||
"default_domain": "mail.example.test",
|
||||
},
|
||||
enabled=True,
|
||||
priority=0,
|
||||
)
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
def fake_get_db():
|
||||
session = manager.SessionLocal()
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
class DummySettings:
|
||||
pass
|
||||
|
||||
start_lock = threading.Lock()
|
||||
started = {"count": 0}
|
||||
peer_started = threading.Event()
|
||||
|
||||
class FakeRegistrationEngine:
|
||||
def __init__(self, email_service, proxy_url=None, callback_logger=None, task_uuid=None):
|
||||
self.email_service = email_service
|
||||
self.phase_history = []
|
||||
|
||||
def run(self):
|
||||
with start_lock:
|
||||
started["count"] += 1
|
||||
if started["count"] == len(task_uuids):
|
||||
peer_started.set()
|
||||
peer_started.wait(timeout=0.1)
|
||||
|
||||
current_state = self.email_service.provider_backoff_state
|
||||
next_failures = current_state.failures + 1
|
||||
delay_seconds = 30 if next_failures == 1 else 60
|
||||
self.phase_history = [
|
||||
PhaseResult(
|
||||
phase="email_prepare",
|
||||
success=False,
|
||||
error_message="创建邮箱失败",
|
||||
error_code="EMAIL_PROVIDER_RATE_LIMITED",
|
||||
retryable=True,
|
||||
next_action="switch_provider",
|
||||
provider_backoff=EmailProviderBackoffState(
|
||||
failures=next_failures,
|
||||
delay_seconds=delay_seconds,
|
||||
opened_until=1000.0 + delay_seconds,
|
||||
last_error="请求失败: 429",
|
||||
),
|
||||
)
|
||||
]
|
||||
return RegistrationResult(
|
||||
success=False,
|
||||
error_message="创建邮箱失败: 请求失败: 429",
|
||||
logs=[],
|
||||
)
|
||||
|
||||
def save_to_database(self, result):
|
||||
return True
|
||||
|
||||
def close(self):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(registration_routes, "get_db", fake_get_db)
|
||||
monkeypatch.setattr(registration_routes, "get_settings", lambda: DummySettings())
|
||||
monkeypatch.setattr(registration_routes, "task_manager", DummyTaskManager())
|
||||
monkeypatch.setattr(registration_routes, "RegistrationEngine", FakeRegistrationEngine)
|
||||
monkeypatch.setattr(
|
||||
registration_routes.EmailServiceFactory,
|
||||
"create",
|
||||
lambda service_type, config, name=None: BackoffAwareEmailService(
|
||||
service_type=service_type,
|
||||
config=config,
|
||||
name=name,
|
||||
),
|
||||
)
|
||||
registration_routes.email_service_circuit_breakers.clear()
|
||||
|
||||
with manager.session_scope() as session:
|
||||
service_id = session.query(EmailService.id).filter(EmailService.name == "duck-primary").scalar()
|
||||
|
||||
threads = [
|
||||
threading.Thread(
|
||||
target=registration_routes._run_sync_registration_task,
|
||||
kwargs={
|
||||
"task_uuid": task_uuid,
|
||||
"email_service_type": EmailServiceType.DUCK_MAIL.value,
|
||||
"proxy": None,
|
||||
"email_service_config": None,
|
||||
},
|
||||
)
|
||||
for task_uuid in task_uuids
|
||||
]
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
state = registration_routes.email_service_circuit_breakers[service_id]
|
||||
assert state.failures == 2
|
||||
assert state.delay_seconds == 60
|
||||
174
tests/test_registration_otp_phase.py
Normal file
174
tests/test_registration_otp_phase.py
Normal file
@@ -0,0 +1,174 @@
|
||||
import src.core.register as register_module
|
||||
from src.core.register import (
|
||||
ERROR_OTP_TIMEOUT_SECONDARY,
|
||||
PhaseContext,
|
||||
RegistrationEngine,
|
||||
)
|
||||
from src.services import EmailServiceType
|
||||
|
||||
|
||||
class DummySettings:
|
||||
openai_client_id = "client-id"
|
||||
openai_auth_url = "https://auth.example.test"
|
||||
openai_token_url = "https://token.example.test"
|
||||
openai_redirect_uri = "https://callback.example.test"
|
||||
openai_scope = "openid profile email"
|
||||
|
||||
|
||||
class FakeEmailService:
|
||||
def __init__(self, code):
|
||||
self.service_type = EmailServiceType.TEMPMAIL
|
||||
self.code = code
|
||||
self.calls = []
|
||||
|
||||
def get_verification_code(self, **kwargs):
|
||||
self.calls.append(kwargs)
|
||||
return self.code
|
||||
|
||||
|
||||
class FakeCookies:
|
||||
def __init__(self, values):
|
||||
self.values = values
|
||||
|
||||
def get(self, name):
|
||||
return self.values.get(name)
|
||||
|
||||
|
||||
class FakeSession:
|
||||
def __init__(self, cookies=None):
|
||||
self.cookies = FakeCookies(cookies or {})
|
||||
self.get_calls = []
|
||||
|
||||
def get(self, *args, **kwargs):
|
||||
self.get_calls.append((args, kwargs))
|
||||
raise AssertionError("unexpected network call")
|
||||
|
||||
|
||||
class FakeResponse:
|
||||
def __init__(self, *, url="", text="", json_payload=None):
|
||||
self.url = url
|
||||
self.text = text
|
||||
self._json_payload = json_payload
|
||||
|
||||
def json(self):
|
||||
if isinstance(self._json_payload, Exception):
|
||||
raise self._json_payload
|
||||
return self._json_payload
|
||||
|
||||
|
||||
def _build_engine(monkeypatch, email_service):
|
||||
monkeypatch.setattr(register_module, "get_settings", lambda: DummySettings())
|
||||
return RegistrationEngine(email_service=email_service)
|
||||
|
||||
|
||||
def test_phase_otp_secondary_uses_remaining_budget_from_start_timestamp(monkeypatch):
|
||||
email_service = FakeEmailService(code="654321")
|
||||
engine = _build_engine(monkeypatch, email_service)
|
||||
engine.email = "tester@example.com"
|
||||
engine.email_info = {"service_id": "svc-1"}
|
||||
|
||||
monkeypatch.setattr(register_module.time, "time", lambda: 120.0)
|
||||
|
||||
code, phase_result = engine._phase_otp_secondary(
|
||||
PhaseContext(otp_sent_at=77.0),
|
||||
started_at=100.0,
|
||||
)
|
||||
|
||||
assert code == "654321"
|
||||
assert phase_result.success is True
|
||||
assert email_service.calls[0]["timeout"] == 100
|
||||
assert email_service.calls[0]["otp_sent_at"] == 77.0
|
||||
assert email_service.calls[0]["email"] == "tester@example.com"
|
||||
assert email_service.calls[0]["email_id"] == "svc-1"
|
||||
|
||||
|
||||
def test_phase_otp_secondary_returns_dedicated_timeout_error_code(monkeypatch):
|
||||
email_service = FakeEmailService(code=None)
|
||||
engine = _build_engine(monkeypatch, email_service)
|
||||
engine.email = "tester@example.com"
|
||||
engine.email_info = {"service_id": "svc-1"}
|
||||
|
||||
monkeypatch.setattr(register_module.time, "time", lambda: 120.0)
|
||||
|
||||
code, phase_result = engine._phase_otp_secondary(
|
||||
PhaseContext(otp_sent_at=80.0),
|
||||
started_at=100.0,
|
||||
)
|
||||
|
||||
assert code is None
|
||||
assert phase_result.success is False
|
||||
assert phase_result.error_code == ERROR_OTP_TIMEOUT_SECONDARY
|
||||
assert engine.phase_history[0].error_code == ERROR_OTP_TIMEOUT_SECONDARY
|
||||
|
||||
|
||||
def test_advance_login_authorization_sets_otp_anchor_before_password_submit(monkeypatch):
|
||||
email_service = FakeEmailService(code=None)
|
||||
engine = _build_engine(monkeypatch, email_service)
|
||||
engine.oauth_start = object()
|
||||
engine._otp_sent_at = 10.0
|
||||
|
||||
monkeypatch.setattr(register_module.time, "time", lambda: 456.0)
|
||||
monkeypatch.setattr(engine, "_init_session", lambda: True)
|
||||
monkeypatch.setattr(engine, "_start_oauth", lambda: True)
|
||||
monkeypatch.setattr(engine, "_get_device_id", lambda: True)
|
||||
monkeypatch.setattr(engine, "_try_reenter_login_flow", lambda: True)
|
||||
|
||||
seen_anchors = []
|
||||
|
||||
def fake_submit_login_password_step():
|
||||
seen_anchors.append(engine._otp_sent_at)
|
||||
return True
|
||||
|
||||
monkeypatch.setattr(engine, "_submit_login_password_step", fake_submit_login_password_step)
|
||||
|
||||
def fake_get_verification_code():
|
||||
seen_anchors.append(engine._otp_sent_at)
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(engine, "_get_verification_code", fake_get_verification_code)
|
||||
|
||||
workspace_id, callback_url = engine._advance_login_authorization()
|
||||
|
||||
assert workspace_id is None
|
||||
assert callback_url is None
|
||||
assert engine._otp_sent_at == 456.0
|
||||
assert seen_anchors == [456.0, 456.0]
|
||||
|
||||
|
||||
def test_get_device_id_reuses_existing_cookie_without_extra_request(monkeypatch):
|
||||
email_service = FakeEmailService(code=None)
|
||||
engine = _build_engine(monkeypatch, email_service)
|
||||
engine.oauth_start = type("OAuthStart", (), {"auth_url": "https://auth.example.test/authorize"})()
|
||||
engine.session = FakeSession(cookies={"oai-did": "did-cached"})
|
||||
|
||||
assert engine._get_device_id() == "did-cached"
|
||||
assert engine.session.get_calls == []
|
||||
|
||||
|
||||
def test_extract_workspace_id_from_response_payload(monkeypatch):
|
||||
email_service = FakeEmailService(code=None)
|
||||
engine = _build_engine(monkeypatch, email_service)
|
||||
response = FakeResponse(
|
||||
url="https://auth.example.test/consent?workspace_id=ws-url",
|
||||
json_payload={
|
||||
"page": {
|
||||
"workspace": {
|
||||
"id": "ws-json",
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
assert engine._extract_workspace_id_from_response(response=response) == "ws-json"
|
||||
|
||||
|
||||
def test_extract_workspace_id_from_response_text_when_hidden_input_missing(monkeypatch):
|
||||
email_service = FakeEmailService(code=None)
|
||||
engine = _build_engine(monkeypatch, email_service)
|
||||
response = FakeResponse(
|
||||
url="https://auth.example.test/consent",
|
||||
text='<script>window.__NEXT_DATA__={"activeWorkspaceId":"ws-script"}</script>',
|
||||
json_payload=ValueError("not json"),
|
||||
)
|
||||
|
||||
assert engine._extract_workspace_id_from_response(response=response) == "ws-script"
|
||||
109
tests/test_registration_proxy_failover.py
Normal file
109
tests/test_registration_proxy_failover.py
Normal file
@@ -0,0 +1,109 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
from src.database import crud
|
||||
from src.database.session import DatabaseSessionManager
|
||||
from src.web.routes import registration
|
||||
from src.core.register import RegistrationResult
|
||||
|
||||
|
||||
def test_run_sync_registration_task_disables_bad_proxy_and_retries(monkeypatch, tmp_path):
|
||||
manager = DatabaseSessionManager(f"sqlite:///{tmp_path}/test.db")
|
||||
manager.create_tables()
|
||||
manager.migrate_tables()
|
||||
|
||||
with manager.session_scope() as session:
|
||||
primary_proxy = crud.create_proxy(
|
||||
session,
|
||||
name="primary",
|
||||
type="http",
|
||||
host="127.0.0.1",
|
||||
port=8001,
|
||||
)
|
||||
crud.update_proxy(session, primary_proxy.id, is_default=True)
|
||||
backup_proxy = crud.create_proxy(
|
||||
session,
|
||||
name="backup",
|
||||
type="http",
|
||||
host="127.0.0.1",
|
||||
port=8002,
|
||||
)
|
||||
email_service = crud.create_email_service(
|
||||
session,
|
||||
service_type="tempmail",
|
||||
name="tempmail-db",
|
||||
config={"base_url": "https://mail.example/api"},
|
||||
)
|
||||
crud.create_registration_task(session, task_uuid="task-proxy-failover")
|
||||
primary_proxy_id = primary_proxy.id
|
||||
backup_proxy_id = backup_proxy.id
|
||||
email_service_id = email_service.id
|
||||
|
||||
monkeypatch.setattr(registration, "get_db", manager.session_scope)
|
||||
monkeypatch.setattr(
|
||||
registration,
|
||||
"EmailServiceFactory",
|
||||
SimpleNamespace(
|
||||
create=lambda service_type, config, name=None: SimpleNamespace(
|
||||
service_type=service_type,
|
||||
config=config,
|
||||
name=name or service_type.value,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
attempted_proxies = []
|
||||
saved_results = []
|
||||
|
||||
class FakeRegistrationEngine:
|
||||
def __init__(self, email_service, proxy_url=None, callback_logger=None, task_uuid=None):
|
||||
self.proxy_url = proxy_url
|
||||
|
||||
def run(self):
|
||||
attempted_proxies.append(self.proxy_url)
|
||||
if self.proxy_url.endswith(":8001"):
|
||||
return RegistrationResult(
|
||||
success=False,
|
||||
email="proxy@example.com",
|
||||
error_message="OpenAI 请求失败: curl: (35) TLS handshake failed",
|
||||
)
|
||||
|
||||
return RegistrationResult(
|
||||
success=True,
|
||||
email="proxy@example.com",
|
||||
access_token="access-token",
|
||||
workspace_id="ws-123",
|
||||
)
|
||||
|
||||
def save_to_database(self, result):
|
||||
saved_results.append(result.email)
|
||||
return True
|
||||
|
||||
monkeypatch.setattr(registration, "RegistrationEngine", FakeRegistrationEngine)
|
||||
registration.email_service_circuit_breakers.clear()
|
||||
|
||||
registration._run_sync_registration_task(
|
||||
task_uuid="task-proxy-failover",
|
||||
email_service_type="tempmail",
|
||||
proxy=None,
|
||||
email_service_config=None,
|
||||
email_service_id=email_service_id,
|
||||
)
|
||||
|
||||
assert attempted_proxies == [
|
||||
"http://127.0.0.1:8001",
|
||||
"http://127.0.0.1:8002",
|
||||
]
|
||||
assert saved_results == ["proxy@example.com"]
|
||||
|
||||
with manager.session_scope() as session:
|
||||
disabled_primary = crud.get_proxy_by_id(session, primary_proxy_id)
|
||||
active_backup = crud.get_proxy_by_id(session, backup_proxy_id)
|
||||
task = crud.get_registration_task_by_uuid(session, "task-proxy-failover")
|
||||
|
||||
assert disabled_primary is not None
|
||||
assert disabled_primary.enabled is False
|
||||
assert active_backup is not None
|
||||
assert active_backup.enabled is True
|
||||
assert task is not None
|
||||
assert task.status == "completed"
|
||||
assert task.proxy == "http://127.0.0.1:8002"
|
||||
150
tests/test_single_task_websocket_status.cjs
Normal file
150
tests/test_single_task_websocket_status.cjs
Normal file
@@ -0,0 +1,150 @@
|
||||
const test = require('node:test');
|
||||
const assert = require('node:assert/strict');
|
||||
const fs = require('node:fs');
|
||||
const vm = require('node:vm');
|
||||
|
||||
const APP_JS_PATH = '/Users/zhoukailian/.config/superpowers/worktrees/codex-manager/repro-batch-monitor/static/js/app.js';
|
||||
|
||||
function createElementStub() {
|
||||
return {
|
||||
style: {},
|
||||
dataset: {},
|
||||
value: '',
|
||||
checked: false,
|
||||
disabled: false,
|
||||
innerHTML: '',
|
||||
textContent: '',
|
||||
className: '',
|
||||
appendChild() {},
|
||||
addEventListener() {},
|
||||
removeEventListener() {},
|
||||
querySelector() {
|
||||
return createElementStub();
|
||||
},
|
||||
querySelectorAll() {
|
||||
return [];
|
||||
},
|
||||
closest() {
|
||||
return null;
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
function createSandbox() {
|
||||
const elements = new Map();
|
||||
|
||||
const sandbox = {
|
||||
console,
|
||||
setTimeout,
|
||||
clearTimeout,
|
||||
setInterval: () => 1,
|
||||
clearInterval: () => {},
|
||||
Event: class Event {
|
||||
constructor(type) {
|
||||
this.type = type;
|
||||
}
|
||||
},
|
||||
document: {
|
||||
getElementById(id) {
|
||||
if (!elements.has(id)) {
|
||||
elements.set(id, createElementStub());
|
||||
}
|
||||
return elements.get(id);
|
||||
},
|
||||
createElement() {
|
||||
return createElementStub();
|
||||
},
|
||||
addEventListener() {},
|
||||
querySelector() {
|
||||
return createElementStub();
|
||||
},
|
||||
querySelectorAll() {
|
||||
return [];
|
||||
},
|
||||
},
|
||||
sessionStorage: {
|
||||
getItem() {
|
||||
return null;
|
||||
},
|
||||
setItem() {},
|
||||
removeItem() {},
|
||||
},
|
||||
toast: {
|
||||
info() {},
|
||||
success() {},
|
||||
warning() {},
|
||||
error() {},
|
||||
},
|
||||
api: {
|
||||
get() {
|
||||
throw new Error('api.get should not be called in this test');
|
||||
},
|
||||
post() {
|
||||
throw new Error('api.post should not be called in this test');
|
||||
},
|
||||
},
|
||||
loadRecentAccounts() {},
|
||||
getServiceTypeText(type) {
|
||||
return {
|
||||
tempmail: '临时邮箱',
|
||||
outlook: 'Outlook',
|
||||
}[type] || type;
|
||||
},
|
||||
window: null,
|
||||
WebSocket: null,
|
||||
};
|
||||
|
||||
sandbox.window = sandbox;
|
||||
sandbox.window.location = { protocol: 'http:', host: '127.0.0.1:8005' };
|
||||
|
||||
vm.createContext(sandbox);
|
||||
vm.runInContext(fs.readFileSync(APP_JS_PATH, 'utf8'), sandbox, { filename: 'app.js' });
|
||||
|
||||
return { sandbox, elements };
|
||||
}
|
||||
|
||||
test('single task websocket completion updates task info and resets buttons', () => {
|
||||
const { sandbox, elements } = createSandbox();
|
||||
|
||||
vm.runInContext(
|
||||
`
|
||||
var __lastWs = null;
|
||||
startLogPolling = function() {
|
||||
throw new Error('startLogPolling should not be called for completed status');
|
||||
};
|
||||
loadRecentAccounts = function() {};
|
||||
currentTask = { task_uuid: 'task-1' };
|
||||
taskCompleted = false;
|
||||
taskFinalStatus = null;
|
||||
elements.startBtn.disabled = true;
|
||||
elements.cancelBtn.disabled = false;
|
||||
elements.taskStatusRow.style.display = 'grid';
|
||||
WebSocket = function(url) {
|
||||
this.url = url;
|
||||
this.readyState = 0;
|
||||
__lastWs = this;
|
||||
};
|
||||
WebSocket.OPEN = 1;
|
||||
WebSocket.CLOSED = 3;
|
||||
WebSocket.prototype.close = function() {
|
||||
this.readyState = WebSocket.CLOSED;
|
||||
};
|
||||
connectWebSocket('task-1');
|
||||
__lastWs.onmessage({
|
||||
data: JSON.stringify({
|
||||
type: 'status',
|
||||
status: 'completed',
|
||||
email: 'demo@example.com',
|
||||
email_service: 'tempmail',
|
||||
}),
|
||||
});
|
||||
`,
|
||||
sandbox,
|
||||
);
|
||||
|
||||
assert.equal(elements.get('start-btn').disabled, false);
|
||||
assert.equal(elements.get('cancel-btn').disabled, true);
|
||||
assert.equal(elements.get('task-status').textContent, '已完成');
|
||||
assert.equal(elements.get('task-email').textContent, 'demo@example.com');
|
||||
assert.equal(elements.get('task-service').textContent, '临时邮箱');
|
||||
});
|
||||
@@ -15,6 +15,7 @@ def test_static_asset_version_is_non_empty_string():
|
||||
def test_email_services_template_uses_versioned_static_assets():
|
||||
template = Path("templates/email_services.html").read_text(encoding="utf-8")
|
||||
|
||||
assert '/static/favicon.svg?v={{ static_version }}' in template
|
||||
assert '/static/css/style.css?v={{ static_version }}' in template
|
||||
assert '/static/js/utils.js?v={{ static_version }}' in template
|
||||
assert '/static/js/email_services.js?v={{ static_version }}' in template
|
||||
@@ -23,6 +24,7 @@ def test_email_services_template_uses_versioned_static_assets():
|
||||
def test_index_template_uses_versioned_static_assets():
|
||||
template = Path("templates/index.html").read_text(encoding="utf-8")
|
||||
|
||||
assert '/static/favicon.svg?v={{ static_version }}' in template
|
||||
assert '/static/css/style.css?v={{ static_version }}' in template
|
||||
assert '/static/js/utils.js?v={{ static_version }}' in template
|
||||
assert '/static/js/app.js?v={{ static_version }}' in template
|
||||
|
||||
72
tests/test_task_manager_status_broadcast.py
Normal file
72
tests/test_task_manager_status_broadcast.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import asyncio
|
||||
|
||||
from src.web.routes.registration import _create_task_status_callback
|
||||
from src.web.task_manager import task_manager
|
||||
|
||||
|
||||
class FakeWebSocket:
|
||||
def __init__(self):
|
||||
self.messages = []
|
||||
|
||||
async def send_json(self, payload):
|
||||
self.messages.append(payload)
|
||||
|
||||
|
||||
def test_update_status_broadcasts_to_registered_websocket():
|
||||
async def run_test():
|
||||
task_uuid = "test-status-broadcast"
|
||||
websocket = FakeWebSocket()
|
||||
|
||||
task_manager.set_loop(asyncio.get_running_loop())
|
||||
task_manager.register_websocket(task_uuid, websocket)
|
||||
|
||||
try:
|
||||
task_manager.update_status(
|
||||
task_uuid,
|
||||
"completed",
|
||||
email="demo@example.com",
|
||||
email_service="tempmail",
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
assert websocket.messages, "expected a status message to be broadcast"
|
||||
assert websocket.messages[-1]["type"] == "status"
|
||||
assert websocket.messages[-1]["status"] == "completed"
|
||||
assert websocket.messages[-1]["email"] == "demo@example.com"
|
||||
assert websocket.messages[-1]["email_service"] == "tempmail"
|
||||
finally:
|
||||
task_manager.unregister_websocket(task_uuid, websocket)
|
||||
|
||||
asyncio.run(run_test())
|
||||
|
||||
|
||||
def test_task_status_callback_broadcasts_phase_fields():
|
||||
async def run_test():
|
||||
task_uuid = "test-status-phase"
|
||||
websocket = FakeWebSocket()
|
||||
|
||||
task_manager.set_loop(asyncio.get_running_loop())
|
||||
task_manager.register_websocket(task_uuid, websocket)
|
||||
|
||||
try:
|
||||
callback = _create_task_status_callback(task_uuid, "tempmail")
|
||||
callback({
|
||||
"phase": "redirect_chain",
|
||||
"phase_detail": "跟随重定向 1/6",
|
||||
"step_index": 14,
|
||||
})
|
||||
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
assert websocket.messages, "expected a status message to be broadcast"
|
||||
assert websocket.messages[-1]["type"] == "status"
|
||||
assert websocket.messages[-1]["status"] == "running"
|
||||
assert websocket.messages[-1]["email_service"] == "tempmail"
|
||||
assert websocket.messages[-1]["phase"] == "redirect_chain"
|
||||
assert websocket.messages[-1]["phase_detail"] == "跟随重定向 1/6"
|
||||
assert websocket.messages[-1]["step_index"] == 14
|
||||
finally:
|
||||
task_manager.unregister_websocket(task_uuid, websocket)
|
||||
|
||||
asyncio.run(run_test())
|
||||
143
tests/test_task_recovery.py
Normal file
143
tests/test_task_recovery.py
Normal file
@@ -0,0 +1,143 @@
|
||||
from contextlib import contextmanager
|
||||
import asyncio
|
||||
|
||||
from fastapi import WebSocketDisconnect
|
||||
|
||||
from src.database import crud
|
||||
from src.database.models import Base, RegistrationTask
|
||||
from src.database.session import DatabaseSessionManager
|
||||
from src.web.routes import websocket as websocket_routes
|
||||
from src.web.task_manager import TaskManager
|
||||
|
||||
|
||||
def test_fail_incomplete_registration_tasks_marks_pending_and_running_failed(tmp_path):
|
||||
db_path = tmp_path / "recovery.db"
|
||||
manager = DatabaseSessionManager(f"sqlite:///{db_path}")
|
||||
Base.metadata.create_all(bind=manager.engine)
|
||||
|
||||
with manager.session_scope() as session:
|
||||
session.add_all([
|
||||
RegistrationTask(task_uuid="task-pending", status="pending"),
|
||||
RegistrationTask(task_uuid="task-running", status="running", logs="[01:00:00] still running"),
|
||||
RegistrationTask(task_uuid="task-done", status="completed"),
|
||||
])
|
||||
|
||||
with manager.session_scope() as session:
|
||||
cleaned = crud.fail_incomplete_registration_tasks(
|
||||
session,
|
||||
"服务启动时检测到未完成的历史任务,已标记失败,请重新发起。"
|
||||
)
|
||||
|
||||
assert cleaned == ["task-pending", "task-running"]
|
||||
|
||||
with manager.session_scope() as session:
|
||||
pending_task = crud.get_registration_task_by_uuid(session, "task-pending")
|
||||
running_task = crud.get_registration_task_by_uuid(session, "task-running")
|
||||
done_task = crud.get_registration_task_by_uuid(session, "task-done")
|
||||
|
||||
assert pending_task.status == "failed"
|
||||
assert running_task.status == "failed"
|
||||
assert pending_task.error_message == "服务启动时检测到未完成的历史任务,已标记失败,请重新发起。"
|
||||
assert running_task.completed_at is not None
|
||||
assert "[系统] 服务启动时检测到未完成的历史任务,已标记失败,请重新发起。" in running_task.logs
|
||||
assert done_task.status == "completed"
|
||||
|
||||
|
||||
def test_restore_task_snapshot_loads_status_and_logs_from_database(monkeypatch, tmp_path):
|
||||
db_path = tmp_path / "websocket.db"
|
||||
manager = DatabaseSessionManager(f"sqlite:///{db_path}")
|
||||
Base.metadata.create_all(bind=manager.engine)
|
||||
|
||||
with manager.session_scope() as session:
|
||||
session.add(
|
||||
RegistrationTask(
|
||||
task_uuid="task-websocket",
|
||||
status="failed",
|
||||
logs="[01:00:00] step 1\n[01:00:01] step 2",
|
||||
result={"email": "tester@example.com"},
|
||||
error_message="boom"
|
||||
)
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
def fake_get_db():
|
||||
session = manager.SessionLocal()
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
monkeypatch.setattr(websocket_routes, "get_db", fake_get_db)
|
||||
|
||||
status, logs = websocket_routes._restore_task_snapshot("task-websocket")
|
||||
|
||||
assert status == {
|
||||
"status": "failed",
|
||||
"email": "tester@example.com",
|
||||
"error": "boom",
|
||||
}
|
||||
assert logs == ["[01:00:00] step 1", "[01:00:01] step 2"]
|
||||
|
||||
|
||||
def test_sync_task_state_prefers_longer_persisted_log_history():
|
||||
manager = TaskManager()
|
||||
task_uuid = "task-sync"
|
||||
|
||||
manager.sync_task_state(task_uuid, status={"status": "running"}, logs=["a", "b"])
|
||||
manager.sync_task_state(task_uuid, logs=["a"])
|
||||
|
||||
assert manager.get_status(task_uuid) == {"status": "running"}
|
||||
assert manager.get_logs(task_uuid) == ["a", "b"]
|
||||
|
||||
|
||||
def test_register_websocket_returns_snapshot_and_keeps_live_cursor():
|
||||
manager = TaskManager()
|
||||
task_uuid = "task-live"
|
||||
websocket = object()
|
||||
|
||||
manager.sync_task_state(task_uuid, status={"status": "running"}, logs=["log-1", "log-2"])
|
||||
|
||||
history_logs = manager.register_websocket(task_uuid, websocket)
|
||||
|
||||
assert history_logs == ["log-1", "log-2"]
|
||||
assert manager.get_unsent_logs(task_uuid, websocket) == []
|
||||
|
||||
manager.add_log(task_uuid, "log-3")
|
||||
|
||||
assert manager.get_unsent_logs(task_uuid, websocket) == ["log-3"]
|
||||
|
||||
|
||||
class _FakeWebSocket:
|
||||
def __init__(self):
|
||||
self.messages = []
|
||||
self.accepted = False
|
||||
|
||||
async def accept(self):
|
||||
self.accepted = True
|
||||
|
||||
async def send_json(self, payload):
|
||||
self.messages.append(payload)
|
||||
|
||||
async def receive_json(self):
|
||||
raise WebSocketDisconnect()
|
||||
|
||||
|
||||
def test_batch_websocket_replays_history_logs_from_registration_snapshot(monkeypatch):
|
||||
manager = TaskManager()
|
||||
batch_id = "batch-history"
|
||||
websocket = _FakeWebSocket()
|
||||
|
||||
manager.init_batch(batch_id, total=2)
|
||||
manager.add_batch_log(batch_id, "[01:00:00] first")
|
||||
manager.add_batch_log(batch_id, "[01:00:01] second")
|
||||
|
||||
monkeypatch.setattr(websocket_routes, "task_manager", manager)
|
||||
|
||||
asyncio.run(websocket_routes.batch_websocket(websocket, batch_id))
|
||||
|
||||
assert websocket.accepted is True
|
||||
assert websocket.messages[0]["type"] == "status"
|
||||
assert [msg["message"] for msg in websocket.messages[1:]] == [
|
||||
"[01:00:00] first",
|
||||
"[01:00:01] second",
|
||||
]
|
||||
120
tests/test_tempmail_service.py
Normal file
120
tests/test_tempmail_service.py
Normal file
@@ -0,0 +1,120 @@
|
||||
import src.services.tempmail as tempmail_module
|
||||
from src.services.tempmail import TempmailService
|
||||
|
||||
|
||||
class FakeResponse:
|
||||
def __init__(self, status_code=200, payload=None):
|
||||
self.status_code = status_code
|
||||
self._payload = payload or {}
|
||||
|
||||
def json(self):
|
||||
return self._payload
|
||||
|
||||
|
||||
class FakeHTTPClient:
|
||||
def __init__(self, responses):
|
||||
self.responses = list(responses)
|
||||
self.calls = []
|
||||
|
||||
def get(self, url, **kwargs):
|
||||
self.calls.append({"url": url, "kwargs": kwargs})
|
||||
if not self.responses:
|
||||
raise AssertionError(f"未准备响应: GET {url}")
|
||||
return self.responses.pop(0)
|
||||
|
||||
|
||||
def test_get_verification_code_ignores_messages_older_than_tolerance_window(monkeypatch):
|
||||
service = TempmailService({
|
||||
"base_url": "https://api.tempmail.test/v2",
|
||||
"timeout": 1,
|
||||
"max_retries": 1,
|
||||
})
|
||||
service._email_cache["tester@example.com"] = {
|
||||
"email": "tester@example.com",
|
||||
"token": "token-1",
|
||||
}
|
||||
service.http_client = FakeHTTPClient([
|
||||
FakeResponse(
|
||||
status_code=200,
|
||||
payload={
|
||||
"emails": [
|
||||
{
|
||||
"id": "old-mail",
|
||||
"from": "noreply@openai.com",
|
||||
"subject": "Old verification code",
|
||||
"body": "111111",
|
||||
"received_at": 1998,
|
||||
},
|
||||
{
|
||||
"id": "new-mail",
|
||||
"from": "noreply@openai.com",
|
||||
"subject": "New verification code",
|
||||
"body": "654321",
|
||||
"received_at": 2001,
|
||||
},
|
||||
]
|
||||
},
|
||||
)
|
||||
])
|
||||
monkeypatch.setattr(tempmail_module.time, "sleep", lambda _: None)
|
||||
|
||||
code = service.get_verification_code(
|
||||
email="tester@example.com",
|
||||
timeout=1,
|
||||
otp_sent_at=2000,
|
||||
)
|
||||
|
||||
assert code == "654321"
|
||||
assert service.http_client.calls == [
|
||||
{
|
||||
"url": "https://api.tempmail.test/v2/inbox",
|
||||
"kwargs": {
|
||||
"params": {"token": "token-1"},
|
||||
"headers": {"Accept": "application/json"},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def test_get_verification_code_allows_two_second_anchor_tolerance(monkeypatch):
|
||||
service = TempmailService({
|
||||
"base_url": "https://api.tempmail.test/v2",
|
||||
"timeout": 1,
|
||||
"max_retries": 1,
|
||||
})
|
||||
service._email_cache["tester@example.com"] = {
|
||||
"email": "tester@example.com",
|
||||
"token": "token-1",
|
||||
}
|
||||
service.http_client = FakeHTTPClient([
|
||||
FakeResponse(
|
||||
status_code=200,
|
||||
payload={
|
||||
"emails": [
|
||||
{
|
||||
"id": "too-old-mail",
|
||||
"from": "noreply@openai.com",
|
||||
"subject": "Too old verification code",
|
||||
"body": "111111",
|
||||
"received_at": 1998,
|
||||
},
|
||||
{
|
||||
"id": "tolerated-mail",
|
||||
"from": "noreply@openai.com",
|
||||
"subject": "Tolerated verification code",
|
||||
"body": "654321",
|
||||
"received_at": 1999,
|
||||
},
|
||||
]
|
||||
},
|
||||
)
|
||||
])
|
||||
monkeypatch.setattr(tempmail_module.time, "sleep", lambda _: None)
|
||||
|
||||
code = service.get_verification_code(
|
||||
email="tester@example.com",
|
||||
timeout=1,
|
||||
otp_sent_at=2000,
|
||||
)
|
||||
|
||||
assert code == "654321"
|
||||
142
tests/test_tempmail_timestamp_filter.py
Normal file
142
tests/test_tempmail_timestamp_filter.py
Normal file
@@ -0,0 +1,142 @@
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from src.services.tempmail import TempmailService
|
||||
|
||||
|
||||
class FakeResponse:
|
||||
def __init__(self, payload, status_code=200):
|
||||
self._payload = payload
|
||||
self.status_code = status_code
|
||||
|
||||
def json(self):
|
||||
return self._payload
|
||||
|
||||
|
||||
class FakeHTTPClient:
|
||||
def __init__(self, responses):
|
||||
self.responses = list(responses)
|
||||
self.calls = []
|
||||
|
||||
def get(self, url, **kwargs):
|
||||
self.calls.append({"url": url, "kwargs": kwargs})
|
||||
if not self.responses:
|
||||
raise AssertionError(f"未准备响应: GET {url}")
|
||||
return self.responses.pop(0)
|
||||
|
||||
|
||||
def _to_timestamp(value: str) -> float:
|
||||
return datetime.fromisoformat(value.replace("Z", "+00:00")).astimezone(timezone.utc).timestamp()
|
||||
|
||||
|
||||
def test_get_verification_code_ignores_messages_received_before_otp_sent_at():
|
||||
service = TempmailService({"base_url": "https://api.tempmail.test"})
|
||||
service._email_cache["tester@example.com"] = {"token": "token-1"}
|
||||
service.http_client = FakeHTTPClient([
|
||||
FakeResponse(
|
||||
{
|
||||
"emails": [
|
||||
{
|
||||
"id": "old-mail",
|
||||
"received_at": "2026-03-23T10:00:00Z",
|
||||
"from": "noreply@openai.com",
|
||||
"subject": "Old code",
|
||||
"body": "111111",
|
||||
},
|
||||
{
|
||||
"id": "new-mail",
|
||||
"received_at": "2026-03-23T10:00:05Z",
|
||||
"from": "noreply@openai.com",
|
||||
"subject": "New code",
|
||||
"body": "222222",
|
||||
},
|
||||
]
|
||||
}
|
||||
)
|
||||
])
|
||||
|
||||
code = service.get_verification_code(
|
||||
email="tester@example.com",
|
||||
timeout=1,
|
||||
otp_sent_at=_to_timestamp("2026-03-23T10:00:02Z"),
|
||||
)
|
||||
|
||||
assert code == "222222"
|
||||
|
||||
|
||||
def test_get_verification_code_uses_date_field_when_received_at_is_missing():
|
||||
service = TempmailService({"base_url": "https://api.tempmail.test"})
|
||||
service._email_cache["tester@example.com"] = {"token": "token-1"}
|
||||
service.http_client = FakeHTTPClient([
|
||||
FakeResponse(
|
||||
{
|
||||
"emails": [
|
||||
{
|
||||
"id": "legacy-mail",
|
||||
"date": "2026-03-23T10:00:06Z",
|
||||
"from": "noreply@openai.com",
|
||||
"subject": "Legacy code",
|
||||
"body": "333333",
|
||||
},
|
||||
{
|
||||
"id": "received-mail",
|
||||
"received_at": "2026-03-23T10:00:07Z",
|
||||
"from": "noreply@openai.com",
|
||||
"subject": "Received code",
|
||||
"body": "444444",
|
||||
},
|
||||
]
|
||||
}
|
||||
)
|
||||
])
|
||||
|
||||
code = service.get_verification_code(
|
||||
email="tester@example.com",
|
||||
timeout=1,
|
||||
otp_sent_at=_to_timestamp("2026-03-23T10:00:05Z"),
|
||||
)
|
||||
|
||||
assert code == "333333"
|
||||
|
||||
|
||||
def test_get_verification_code_accepts_tempmail_date_field_as_timestamp():
|
||||
service = TempmailService({"base_url": "https://api.tempmail.test"})
|
||||
service._email_cache["tester@example.com"] = {"token": "token-1"}
|
||||
service.http_client = FakeHTTPClient([
|
||||
FakeResponse(
|
||||
{
|
||||
"emails": [
|
||||
{
|
||||
"id": "old-mail",
|
||||
"date": "2026-03-23T10:00:02Z",
|
||||
"from": "noreply@openai.com",
|
||||
"subject": "Old code",
|
||||
"body": "111111",
|
||||
},
|
||||
{
|
||||
"id": "new-mail",
|
||||
"date": "2026-03-23T10:00:08Z",
|
||||
"from": "noreply@openai.com",
|
||||
"subject": "New code",
|
||||
"body": "222222",
|
||||
},
|
||||
]
|
||||
}
|
||||
)
|
||||
])
|
||||
|
||||
code = service.get_verification_code(
|
||||
email="tester@example.com",
|
||||
timeout=1,
|
||||
otp_sent_at=_to_timestamp("2026-03-23T10:00:05Z"),
|
||||
)
|
||||
|
||||
assert code == "222222"
|
||||
|
||||
|
||||
def test_parse_message_time_normalizes_timezone_offset():
|
||||
service = TempmailService({"base_url": "https://api.tempmail.test"})
|
||||
|
||||
utc_timestamp = service._parse_message_time("2026-03-23T10:00:07Z")
|
||||
offset_timestamp = service._parse_message_time("2026-03-23T18:00:07+08:00")
|
||||
|
||||
assert utc_timestamp == offset_timestamp
|
||||
Reference in New Issue
Block a user