mirror of
https://github.com/cnlimiter/codex-register.git
synced 2026-06-29 19:21:34 +08:00
fix: restore protocol baseline, resolve 403/400 registration errors, and fully remove deprecated playwright dependency
This commit is contained in:
@@ -1,21 +1,90 @@
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import asyncio
|
||||
from contextlib import contextmanager
|
||||
from types import SimpleNamespace
|
||||
|
||||
from src.web.routes import registration as registration_routes
|
||||
from src.web.task_manager import task_manager
|
||||
|
||||
|
||||
def test_record_batch_task_result_is_atomic_under_threads():
|
||||
batch_id = "batch-atomic-test"
|
||||
task_manager.init_batch(batch_id, 100)
|
||||
def test_init_batch_state_keeps_batch_tasks_and_task_manager_in_sync():
|
||||
batch_id = "batch-sync-init"
|
||||
task_uuids = ["task-1", "task-2", "task-3"]
|
||||
|
||||
statuses = ["completed"] * 60 + ["failed"] * 40
|
||||
registration_routes.batch_tasks.pop(batch_id, None)
|
||||
registration_routes._init_batch_state(batch_id, task_uuids)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=16) as executor:
|
||||
list(executor.map(lambda status: task_manager.record_batch_task_result(batch_id, status), statuses))
|
||||
batch_snapshot = registration_routes.batch_tasks[batch_id]
|
||||
manager_snapshot = task_manager.get_batch_status(batch_id)
|
||||
|
||||
snapshot = task_manager.get_batch_status(batch_id)
|
||||
assert manager_snapshot is not None
|
||||
assert batch_snapshot["total"] == manager_snapshot["total"] == 3
|
||||
assert batch_snapshot["completed"] == manager_snapshot["completed"] == 0
|
||||
assert batch_snapshot["success"] == manager_snapshot["success"] == 0
|
||||
assert batch_snapshot["failed"] == manager_snapshot["failed"] == 0
|
||||
assert batch_snapshot["finished"] is False
|
||||
assert manager_snapshot["finished"] is False
|
||||
assert manager_snapshot["status"] == "running"
|
||||
|
||||
assert snapshot is not None
|
||||
assert snapshot["completed"] == 100
|
||||
assert snapshot["success"] == 60
|
||||
assert snapshot["failed"] == 40
|
||||
assert snapshot["skipped"] == 0
|
||||
|
||||
def test_run_batch_parallel_keeps_counter_updates_in_sync(monkeypatch):
|
||||
batch_id = "batch-sync-parallel"
|
||||
task_uuids = ["task-ok-1", "task-fail-1", "task-ok-2"]
|
||||
task_statuses = {
|
||||
"task-ok-1": "completed",
|
||||
"task-fail-1": "failed",
|
||||
"task-ok-2": "completed",
|
||||
}
|
||||
|
||||
async def fake_run_registration_task(
|
||||
task_uuid,
|
||||
email_service_type,
|
||||
proxy,
|
||||
email_service_config,
|
||||
email_service_id,
|
||||
log_prefix="",
|
||||
batch_id="",
|
||||
auto_upload_cpa=False,
|
||||
cpa_service_ids=None,
|
||||
auto_upload_sub2api=False,
|
||||
sub2api_service_ids=None,
|
||||
auto_upload_tm=False,
|
||||
tm_service_ids=None,
|
||||
):
|
||||
assert task_uuid in task_statuses
|
||||
|
||||
@contextmanager
|
||||
def fake_get_db():
|
||||
yield object()
|
||||
|
||||
def fake_get_registration_task(db, task_uuid):
|
||||
status = task_statuses[task_uuid]
|
||||
error_message = None if status == "completed" else f"{task_uuid}-error"
|
||||
return SimpleNamespace(status=status, error_message=error_message)
|
||||
|
||||
registration_routes.batch_tasks.pop(batch_id, None)
|
||||
monkeypatch.setattr(registration_routes, "run_registration_task", fake_run_registration_task)
|
||||
monkeypatch.setattr(registration_routes, "get_db", fake_get_db)
|
||||
monkeypatch.setattr(registration_routes.crud, "get_registration_task", fake_get_registration_task)
|
||||
|
||||
asyncio.run(
|
||||
registration_routes.run_batch_parallel(
|
||||
batch_id=batch_id,
|
||||
task_uuids=task_uuids,
|
||||
email_service_type="tempmail",
|
||||
proxy=None,
|
||||
email_service_config=None,
|
||||
email_service_id=None,
|
||||
concurrency=2,
|
||||
)
|
||||
)
|
||||
|
||||
batch_snapshot = registration_routes.batch_tasks[batch_id]
|
||||
manager_snapshot = task_manager.get_batch_status(batch_id)
|
||||
|
||||
assert manager_snapshot is not None
|
||||
assert batch_snapshot["completed"] == manager_snapshot["completed"] == 3
|
||||
assert batch_snapshot["success"] == manager_snapshot["success"] == 2
|
||||
assert batch_snapshot["failed"] == manager_snapshot["failed"] == 1
|
||||
assert batch_snapshot["finished"] is True
|
||||
assert manager_snapshot["finished"] is True
|
||||
assert manager_snapshot["status"] == "completed"
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Time : 2026/3/21 14:48
|
||||
from src.core.utils import base64_payload_decode, base64_decode
|
||||
if __name__ == '__main__':
|
||||
print(base64_payload_decode("eyJzZXNzaW9uX2lkIjoiYXV0aHNlc3NfcUE5eFByY3RaZmtHWXJnSlJGdUpxRXBPIiwiY291bnRyeV9jb2RlX2hpbnQiOiJVUyIsImF1dGhfc2Vzc2lvbl9sb2dnaW5nX2lkIjoiMTk0ZDg5OGQtM2Q0ZC00MzU5LWI1NTQtYmJjMjc1YTJlYjU1IiwicHJvbW8iOiIiLCJzaWdudXBfc291cmNlIjoiIiwib3BlbmFpX2NsaWVudF9pZCI6ImFwcF9FTW9hbUVFWjczZjBDa1hhWHA3aHJhbm4iLCJhcHBfbmFtZV9lbnVtIjoib2FpY2xpIiwiYWFzX2VuYWJsZWQiOmZhbHNlLCJvcmlnaW5hbF9zY3JlZW5faGludCI6ImxvZ2luIiwicGFzc3dvcmRsZXNzX2Rpc2FibGVkIjpmYWxzZSwicGFzc3dvcmRsZXNzX290cF9mcm9tX3Bhc3N3b3JkX3JlZGlyZWN0IjpmYWxzZSwiZW1haWxfdmVyaWZpY2F0aW9uX21vZGUiOiJwYXNzd29yZGxlc3NfbG9naW4iLCJlbWFpbCI6Imxob2xsYW5kNTcwQGdzb2xleWZveWxlLm9yZy51ayIsImVtYWlsX3ZlcmlmaWVkIjp0cnVlLCJuYW1lIjoiZXJyIiwid29ya3NwYWNlcyI6W3siaWQiOiI4NjhmZGNmYi1kNjI3LTRhZTItYTQ4Mi1jMTQxMjA1MGZhYTYiLCJuYW1lIjpudWxsLCJraW5kIjoicGVyc29uYWwiLCJwcm9maWxlX3BpY3R1cmVfYWx0X3RleHQiOiJlcnIifV19"))
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
61
tests/test_email_service_backoff.py
Normal file
61
tests/test_email_service_backoff.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from src.services.base import (
|
||||
EmailProviderBackoffState,
|
||||
OTPTimeoutEmailServiceError,
|
||||
RateLimitedEmailServiceError,
|
||||
apply_adaptive_backoff,
|
||||
calculate_adaptive_backoff_delay,
|
||||
)
|
||||
|
||||
|
||||
def test_calculate_adaptive_backoff_delay_uses_failure_count_progression():
|
||||
assert calculate_adaptive_backoff_delay(0) == 30
|
||||
assert calculate_adaptive_backoff_delay(1) == 30
|
||||
assert calculate_adaptive_backoff_delay(2) == 60
|
||||
assert calculate_adaptive_backoff_delay(3) == 120
|
||||
|
||||
|
||||
def test_apply_adaptive_backoff_tracks_timeout_failures_to_one_hour():
|
||||
state = EmailProviderBackoffState()
|
||||
|
||||
first = apply_adaptive_backoff(
|
||||
state,
|
||||
OTPTimeoutEmailServiceError("等待验证码超时", error_code="OTP_TIMEOUT_SECONDARY"),
|
||||
now=1000.0,
|
||||
)
|
||||
second = apply_adaptive_backoff(
|
||||
first,
|
||||
OTPTimeoutEmailServiceError("等待验证码超时", error_code="OTP_TIMEOUT_SECONDARY"),
|
||||
now=1031.0,
|
||||
)
|
||||
third = apply_adaptive_backoff(
|
||||
second,
|
||||
OTPTimeoutEmailServiceError("等待验证码超时", error_code="OTP_TIMEOUT_SECONDARY"),
|
||||
now=1092.0,
|
||||
)
|
||||
|
||||
assert first.failures == 1
|
||||
assert first.delay_seconds == 30
|
||||
assert first.opened_until == 1030.0
|
||||
|
||||
assert second.failures == 2
|
||||
assert second.delay_seconds == 60
|
||||
assert second.opened_until == 1091.0
|
||||
|
||||
assert third.failures == 3
|
||||
assert third.delay_seconds == 3600
|
||||
assert third.opened_until == 4692.0
|
||||
|
||||
|
||||
def test_apply_adaptive_backoff_keeps_normal_rate_limit_on_exponential_curve():
|
||||
state = EmailProviderBackoffState(failures=2, delay_seconds=60, opened_until=1060.0)
|
||||
|
||||
next_state = apply_adaptive_backoff(
|
||||
state,
|
||||
RateLimitedEmailServiceError("请求失败: 429", retry_after=7),
|
||||
now=1100.0,
|
||||
)
|
||||
|
||||
assert next_state.failures == 3
|
||||
assert next_state.delay_seconds == 120
|
||||
assert next_state.opened_until == 1220.0
|
||||
assert next_state.retry_after == 7
|
||||
@@ -1,57 +0,0 @@
|
||||
import base64
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
|
||||
from src.core.login import LoginEngine
|
||||
|
||||
|
||||
def _build_auth_cookie(workspace_id: str) -> str:
|
||||
payload = base64.urlsafe_b64encode(
|
||||
json.dumps({"workspaces": [{"id": workspace_id}]}).encode("utf-8")
|
||||
).decode("ascii").rstrip("=")
|
||||
return f"{payload}.signature"
|
||||
|
||||
|
||||
def test_get_workspace_id_retries_with_exponential_backoff(monkeypatch):
|
||||
engine = LoginEngine.__new__(LoginEngine)
|
||||
engine.logs = []
|
||||
engine._log = lambda message, level="info": engine.logs.append((level, message))
|
||||
|
||||
auth_cookie = _build_auth_cookie("ws-123")
|
||||
cookies = SimpleNamespace()
|
||||
calls = {"count": 0}
|
||||
|
||||
def fake_get(name):
|
||||
assert name == "oai-client-auth-session"
|
||||
calls["count"] += 1
|
||||
if calls["count"] < 4:
|
||||
return None
|
||||
return auth_cookie
|
||||
|
||||
cookies.get = fake_get
|
||||
engine.session = SimpleNamespace(cookies=cookies)
|
||||
|
||||
sleeps = []
|
||||
monkeypatch.setattr("src.core.login.time.sleep", lambda seconds: sleeps.append(seconds))
|
||||
|
||||
workspace_id = engine._get_workspace_id()
|
||||
|
||||
assert workspace_id == "ws-123"
|
||||
assert calls["count"] == 4
|
||||
assert sleeps == [1, 2, 4]
|
||||
|
||||
|
||||
def test_run_always_closes_resources_on_early_return():
|
||||
engine = LoginEngine.__new__(LoginEngine)
|
||||
engine.logs = []
|
||||
engine._log = lambda message, level="info": None
|
||||
engine.close_called = False
|
||||
engine.close = lambda: setattr(engine, "close_called", True)
|
||||
|
||||
engine._check_ip_location = lambda: (False, "blocked")
|
||||
|
||||
result = engine.run()
|
||||
|
||||
assert result.success is False
|
||||
assert result.error_message == "IP 地理位置不支持: blocked"
|
||||
assert engine.close_called is True
|
||||
@@ -1,439 +0,0 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from src.config.constants import EmailServiceType, OPENAI_API_ENDPOINTS, OPENAI_PAGE_TYPES
|
||||
from src.core import register
|
||||
from src.services.base import BaseEmailService
|
||||
|
||||
|
||||
class DummyEmailService(BaseEmailService):
|
||||
def __init__(self):
|
||||
super().__init__(EmailServiceType.TEMPMAIL, name="dummy")
|
||||
|
||||
def create_email(self, config=None):
|
||||
return {"email": "tester@example.com", "service_id": "svc-1"}
|
||||
|
||||
def get_verification_code(self, email, email_id=None, timeout=120, pattern=None, otp_sent_at=None):
|
||||
return "123456"
|
||||
|
||||
def list_emails(self, **kwargs):
|
||||
return []
|
||||
|
||||
def delete_email(self, email_id):
|
||||
return True
|
||||
|
||||
def check_health(self):
|
||||
return True
|
||||
|
||||
def refresh_session(self):
|
||||
return None
|
||||
|
||||
|
||||
class FakeResponse:
|
||||
def __init__(self, status_code=200, payload=None, text=""):
|
||||
self.status_code = status_code
|
||||
self._payload = payload if payload is not None else {}
|
||||
self.text = text
|
||||
|
||||
def json(self):
|
||||
return self._payload
|
||||
|
||||
|
||||
class BrokenJSONResponse(FakeResponse):
|
||||
def json(self):
|
||||
raise ValueError("bad json")
|
||||
|
||||
|
||||
class FakeSession:
|
||||
def __init__(self, post_handler=None, get_handler=None, cookies=None):
|
||||
self.post_handler = post_handler
|
||||
self.get_handler = get_handler
|
||||
self.cookies = cookies or {}
|
||||
self.post_calls = []
|
||||
self.get_calls = []
|
||||
|
||||
def post(self, url, **kwargs):
|
||||
self.post_calls.append({"url": url, "kwargs": kwargs})
|
||||
if self.post_handler is None:
|
||||
raise AssertionError("unexpected post call")
|
||||
return self.post_handler(url, **kwargs)
|
||||
|
||||
def get(self, url, **kwargs):
|
||||
self.get_calls.append({"url": url, "kwargs": kwargs})
|
||||
if self.get_handler is None:
|
||||
raise AssertionError("unexpected get call")
|
||||
return self.get_handler(url, **kwargs)
|
||||
|
||||
|
||||
class DummyHTTPClient:
|
||||
def __init__(self, proxy_url=None):
|
||||
self.proxy_url = proxy_url
|
||||
self.session = FakeSession()
|
||||
self.closed = False
|
||||
|
||||
def close(self):
|
||||
self.closed = True
|
||||
|
||||
def post(self, url, **kwargs):
|
||||
raise AssertionError("unexpected http client post")
|
||||
|
||||
|
||||
class DummyOAuthManager:
|
||||
def __init__(self, **kwargs):
|
||||
self.kwargs = kwargs
|
||||
|
||||
def start_oauth(self):
|
||||
return SimpleNamespace(
|
||||
auth_url="https://auth.example/start",
|
||||
state="state-1",
|
||||
code_verifier="verifier-1",
|
||||
redirect_uri="http://localhost/callback",
|
||||
)
|
||||
|
||||
|
||||
def make_engine(monkeypatch, email_service=None):
|
||||
monkeypatch.setattr(
|
||||
register,
|
||||
"get_settings",
|
||||
lambda: SimpleNamespace(
|
||||
openai_client_id="client-id",
|
||||
openai_auth_url="https://auth.example/authorize",
|
||||
openai_token_url="https://auth.example/token",
|
||||
openai_redirect_uri="http://localhost/callback",
|
||||
openai_scope="openid email profile offline_access",
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(register, "OpenAIHTTPClient", DummyHTTPClient)
|
||||
monkeypatch.setattr(register, "OAuthManager", DummyOAuthManager)
|
||||
|
||||
engine = register.RegistrationEngine(email_service or DummyEmailService())
|
||||
engine.email = "tester@example.com"
|
||||
engine.email_info = {"email": "tester@example.com", "service_id": "svc-1"}
|
||||
return engine
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"page_type",
|
||||
[
|
||||
"login_password",
|
||||
OPENAI_PAGE_TYPES["EMAIL_OTP_VERIFICATION"],
|
||||
"consent_required",
|
||||
"some_other_page",
|
||||
],
|
||||
)
|
||||
def test_submit_login_form_accepts_any_http_200_page(monkeypatch, page_type):
|
||||
engine = make_engine(monkeypatch)
|
||||
engine.session = FakeSession(
|
||||
post_handler=lambda url, **kwargs: FakeResponse(
|
||||
status_code=200,
|
||||
payload={"page": {"type": page_type}},
|
||||
)
|
||||
)
|
||||
|
||||
result = engine._submit_login_form("did-1", "sen-1")
|
||||
|
||||
assert result.success is True
|
||||
assert result.page_type == page_type
|
||||
assert result.error_message == ""
|
||||
|
||||
|
||||
def test_submit_login_form_accepts_http_200_even_when_json_is_invalid(monkeypatch):
|
||||
engine = make_engine(monkeypatch)
|
||||
engine.session = FakeSession(
|
||||
post_handler=lambda url, **kwargs: BrokenJSONResponse(status_code=200)
|
||||
)
|
||||
|
||||
result = engine._submit_login_form("did-1", "sen-1")
|
||||
|
||||
assert result.success is True
|
||||
assert result.page_type == ""
|
||||
assert result.response_data == {}
|
||||
assert result.error_message == ""
|
||||
|
||||
|
||||
def test_send_passwordless_otp_posts_empty_body(monkeypatch):
|
||||
engine = make_engine(monkeypatch)
|
||||
engine.session = FakeSession(
|
||||
post_handler=lambda url, **kwargs: FakeResponse(status_code=200)
|
||||
)
|
||||
|
||||
success = engine._send_passwordless_otp()
|
||||
|
||||
assert success is True
|
||||
assert len(engine.session.post_calls) == 1
|
||||
call = engine.session.post_calls[0]
|
||||
assert call["url"] == OPENAI_API_ENDPOINTS["send_passwordless_otp"]
|
||||
assert call["kwargs"]["data"] == ""
|
||||
assert engine._otp_sent_at is not None
|
||||
|
||||
|
||||
def test_send_passwordless_otp_does_not_update_timestamp_on_failure(monkeypatch):
|
||||
engine = make_engine(monkeypatch)
|
||||
engine._otp_sent_at = 1234.5
|
||||
engine.session = FakeSession(
|
||||
post_handler=lambda url, **kwargs: FakeResponse(status_code=500, text="server error")
|
||||
)
|
||||
|
||||
success = engine._send_passwordless_otp()
|
||||
|
||||
assert success is False
|
||||
assert engine._otp_sent_at == 1234.5
|
||||
|
||||
|
||||
def test_get_verification_code_passes_explicit_otp_timestamp(monkeypatch):
|
||||
captured = {}
|
||||
|
||||
class RecordingEmailService(DummyEmailService):
|
||||
def get_verification_code(self, email, email_id=None, timeout=120, pattern=None, otp_sent_at=None):
|
||||
captured["email"] = email
|
||||
captured["email_id"] = email_id
|
||||
captured["timeout"] = timeout
|
||||
captured["pattern"] = pattern
|
||||
captured["otp_sent_at"] = otp_sent_at
|
||||
return "654321"
|
||||
|
||||
engine = make_engine(monkeypatch, email_service=RecordingEmailService())
|
||||
|
||||
code = engine._get_verification_code(otp_sent_at=1234.5)
|
||||
|
||||
assert code == "654321"
|
||||
assert captured["email"] == "tester@example.com"
|
||||
assert captured["email_id"] == "svc-1"
|
||||
assert captured["timeout"] == 120
|
||||
assert captured["otp_sent_at"] == 1234.5
|
||||
|
||||
|
||||
def test_validate_verification_code_accepts_http_200_even_when_json_is_invalid(monkeypatch):
|
||||
engine = make_engine(monkeypatch)
|
||||
engine.session = FakeSession(
|
||||
post_handler=lambda url, **kwargs: BrokenJSONResponse(status_code=200)
|
||||
)
|
||||
|
||||
result = engine._validate_verification_code("123456")
|
||||
|
||||
assert result.success is True
|
||||
assert result.continue_url == ""
|
||||
assert result.response_data == {}
|
||||
|
||||
|
||||
def test_run_closes_http_client_on_early_failure(monkeypatch):
|
||||
engine = make_engine(monkeypatch)
|
||||
tracking_client = DummyHTTPClient()
|
||||
engine.http_client = tracking_client
|
||||
|
||||
monkeypatch.setattr(engine, "_check_ip_location", lambda: (False, None))
|
||||
|
||||
result = engine.run()
|
||||
|
||||
assert result.success is False
|
||||
assert tracking_client.closed is True
|
||||
assert engine.session is None
|
||||
|
||||
|
||||
def test_fallback_to_login_flow_forces_otp_and_continue_url(monkeypatch):
|
||||
engine = make_engine(monkeypatch)
|
||||
steps = []
|
||||
captured = {}
|
||||
|
||||
monkeypatch.setattr(engine, "_reset_oauth_session", lambda: steps.append("reset_session") or True)
|
||||
monkeypatch.setattr(engine, "_get_device_id", lambda: steps.append("get_device_id") or "did-1")
|
||||
monkeypatch.setattr(engine, "_check_sentinel", lambda did: steps.append("check_sentinel") or "sen-1")
|
||||
monkeypatch.setattr(
|
||||
engine,
|
||||
"_submit_login_form",
|
||||
lambda did, sen: steps.append("submit_login_form")
|
||||
or register.SignupFormResult(success=True, page_type="login_password"),
|
||||
)
|
||||
|
||||
def fake_send_passwordless_otp():
|
||||
steps.append("send_passwordless_otp")
|
||||
engine._otp_sent_at = 4567.89
|
||||
return True
|
||||
|
||||
def fake_get_verification_code(otp_sent_at=None):
|
||||
steps.append("get_verification_code")
|
||||
captured["otp_sent_at"] = otp_sent_at
|
||||
return "123456"
|
||||
|
||||
monkeypatch.setattr(engine, "_send_passwordless_otp", fake_send_passwordless_otp)
|
||||
monkeypatch.setattr(engine, "_get_verification_code", fake_get_verification_code)
|
||||
monkeypatch.setattr(
|
||||
engine,
|
||||
"_validate_verification_code",
|
||||
lambda code: steps.append("validate_verification_code")
|
||||
or register.OTPValidationResult(success=True, continue_url="https://auth.example/continue"),
|
||||
)
|
||||
|
||||
def fake_try_upgrade(continue_url, stage):
|
||||
steps.append("get_continue_url_and_parse_workspace")
|
||||
captured["continue_url"] = continue_url
|
||||
captured["stage"] = stage
|
||||
return "ws-123"
|
||||
|
||||
monkeypatch.setattr(engine, "_try_upgrade_cookie_with_continue_url", fake_try_upgrade)
|
||||
|
||||
workspace_id = engine._fallback_to_login_flow()
|
||||
|
||||
assert workspace_id == "ws-123"
|
||||
assert steps == [
|
||||
"reset_session",
|
||||
"get_device_id",
|
||||
"check_sentinel",
|
||||
"submit_login_form",
|
||||
"send_passwordless_otp",
|
||||
"get_verification_code",
|
||||
"validate_verification_code",
|
||||
"get_continue_url_and_parse_workspace",
|
||||
]
|
||||
assert captured["continue_url"] == "https://auth.example/continue"
|
||||
assert captured["stage"] == "降级登录 Continue URL"
|
||||
assert captured["otp_sent_at"] == 4567.89
|
||||
|
||||
|
||||
def test_fallback_to_login_flow_requires_continue_url(monkeypatch):
|
||||
engine = make_engine(monkeypatch)
|
||||
|
||||
monkeypatch.setattr(engine, "_reset_oauth_session", lambda: True)
|
||||
monkeypatch.setattr(engine, "_get_device_id", lambda: "did-1")
|
||||
monkeypatch.setattr(engine, "_check_sentinel", lambda did: "sen-1")
|
||||
monkeypatch.setattr(
|
||||
engine,
|
||||
"_submit_login_form",
|
||||
lambda did, sen: register.SignupFormResult(success=True, page_type="login_password"),
|
||||
)
|
||||
|
||||
def fake_send_passwordless_otp():
|
||||
engine._otp_sent_at = 9876.5
|
||||
return True
|
||||
|
||||
monkeypatch.setattr(engine, "_send_passwordless_otp", fake_send_passwordless_otp)
|
||||
monkeypatch.setattr(engine, "_get_verification_code", lambda otp_sent_at=None: "123456")
|
||||
monkeypatch.setattr(
|
||||
engine,
|
||||
"_validate_verification_code",
|
||||
lambda code: register.OTPValidationResult(success=True, continue_url=""),
|
||||
)
|
||||
|
||||
def fail_if_called(*args, **kwargs):
|
||||
raise AssertionError("continue_url 缺失时不应尝试升级 Cookie")
|
||||
|
||||
monkeypatch.setattr(engine, "_try_upgrade_cookie_with_continue_url", fail_if_called)
|
||||
|
||||
assert engine._fallback_to_login_flow() is None
|
||||
|
||||
|
||||
def test_fallback_to_login_flow_accepts_workspace_without_continue_url(monkeypatch):
|
||||
engine = make_engine(monkeypatch)
|
||||
steps = []
|
||||
|
||||
monkeypatch.setattr(engine, "_reset_oauth_session", lambda: steps.append("reset_session") or True)
|
||||
monkeypatch.setattr(engine, "_get_device_id", lambda: steps.append("get_device_id") or "did-1")
|
||||
monkeypatch.setattr(engine, "_check_sentinel", lambda did: steps.append("check_sentinel") or "sen-1")
|
||||
monkeypatch.setattr(
|
||||
engine,
|
||||
"_submit_login_form",
|
||||
lambda did, sen: steps.append("submit_login_form")
|
||||
or register.SignupFormResult(success=True, page_type="login_password"),
|
||||
)
|
||||
|
||||
def fake_send_passwordless_otp():
|
||||
steps.append("send_passwordless_otp")
|
||||
engine._otp_sent_at = 4567.89
|
||||
return True
|
||||
|
||||
monkeypatch.setattr(engine, "_send_passwordless_otp", fake_send_passwordless_otp)
|
||||
monkeypatch.setattr(
|
||||
engine,
|
||||
"_get_verification_code",
|
||||
lambda otp_sent_at=None: steps.append("get_verification_code") or "123456",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
engine,
|
||||
"_validate_verification_code",
|
||||
lambda code: steps.append("validate_verification_code")
|
||||
or register.OTPValidationResult(success=True, continue_url=""),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
engine,
|
||||
"_get_workspace_id",
|
||||
lambda log_missing=True: steps.append("get_workspace_id") or "ws-cookie",
|
||||
)
|
||||
|
||||
def fail_if_called(*args, **kwargs):
|
||||
raise AssertionError("已有 workspace 时不应继续访问 continue_url")
|
||||
|
||||
monkeypatch.setattr(engine, "_try_upgrade_cookie_with_continue_url", fail_if_called)
|
||||
|
||||
workspace_id = engine._fallback_to_login_flow()
|
||||
|
||||
assert workspace_id == "ws-cookie"
|
||||
assert steps == [
|
||||
"reset_session",
|
||||
"get_device_id",
|
||||
"check_sentinel",
|
||||
"submit_login_form",
|
||||
"send_passwordless_otp",
|
||||
"get_verification_code",
|
||||
"validate_verification_code",
|
||||
"get_workspace_id",
|
||||
]
|
||||
|
||||
|
||||
def test_get_verification_code_uses_provider_timeout_and_refreshes_once(monkeypatch):
|
||||
captured = {"calls": [], "refresh_count": 0}
|
||||
|
||||
class RefreshableOutlookService(DummyEmailService):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.service_type = EmailServiceType.OUTLOOK
|
||||
|
||||
def get_verification_code(self, email, email_id=None, timeout=120, pattern=None, otp_sent_at=None):
|
||||
captured["calls"].append(
|
||||
{
|
||||
"timeout": timeout,
|
||||
"otp_sent_at": otp_sent_at,
|
||||
"email": email,
|
||||
"email_id": email_id,
|
||||
}
|
||||
)
|
||||
if len(captured["calls"]) == 1:
|
||||
return None
|
||||
return "987654"
|
||||
|
||||
def refresh_session(self):
|
||||
captured["refresh_count"] += 1
|
||||
|
||||
engine = make_engine(monkeypatch, email_service=RefreshableOutlookService())
|
||||
|
||||
code = engine._get_verification_code(otp_sent_at=2468.0)
|
||||
|
||||
assert code == "987654"
|
||||
assert captured["refresh_count"] == 1
|
||||
assert len(captured["calls"]) == 2
|
||||
assert captured["calls"][0]["timeout"] == 180
|
||||
assert captured["calls"][1]["timeout"] == 180
|
||||
assert captured["calls"][0]["otp_sent_at"] == 2468.0
|
||||
|
||||
|
||||
def test_try_upgrade_cookie_with_continue_url_retries_with_second_probe(monkeypatch):
|
||||
engine = make_engine(monkeypatch)
|
||||
sleep_calls = []
|
||||
workspace_results = iter([None, None, None, None, None, "ws-delayed"])
|
||||
engine.session = FakeSession(
|
||||
get_handler=lambda url, **kwargs: FakeResponse(status_code=302),
|
||||
cookies={},
|
||||
)
|
||||
|
||||
monkeypatch.setattr(engine, "_log_cookie_state", lambda *args, **kwargs: None)
|
||||
monkeypatch.setattr(engine, "_get_workspace_id", lambda log_missing=False: next(workspace_results))
|
||||
monkeypatch.setattr(register.time, "sleep", lambda seconds: sleep_calls.append(seconds))
|
||||
|
||||
workspace_id = engine._try_upgrade_cookie_with_continue_url(
|
||||
"https://auth.example/continue",
|
||||
"降级登录 Continue URL",
|
||||
)
|
||||
|
||||
assert workspace_id == "ws-delayed"
|
||||
assert len(engine.session.get_calls) == 3
|
||||
assert sleep_calls == [1.0, 2.0, 4.0]
|
||||
82
tests/test_register_protocol_baseline.py
Normal file
82
tests/test_register_protocol_baseline.py
Normal file
@@ -0,0 +1,82 @@
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
|
||||
import src.core.register as register_module
|
||||
from src.config.constants import OPENAI_PAGE_TYPES
|
||||
from src.core.register import RegistrationEngine
|
||||
from src.services import EmailServiceType
|
||||
|
||||
|
||||
class DummySettings:
|
||||
openai_client_id = "client-id"
|
||||
openai_auth_url = "https://auth.example.test"
|
||||
openai_token_url = "https://token.example.test"
|
||||
openai_redirect_uri = "https://callback.example.test"
|
||||
openai_scope = "openid profile email"
|
||||
|
||||
|
||||
class FakeResponse:
|
||||
def __init__(self, status_code=200, payload=None, text=""):
|
||||
self.status_code = status_code
|
||||
self._payload = payload or {}
|
||||
self.text = text
|
||||
|
||||
def json(self):
|
||||
return self._payload
|
||||
|
||||
|
||||
class FakeSession:
|
||||
def __init__(self, response):
|
||||
self.response = response
|
||||
self.calls = []
|
||||
|
||||
def post(self, url, **kwargs):
|
||||
self.calls.append({
|
||||
"url": url,
|
||||
**kwargs,
|
||||
})
|
||||
return self.response
|
||||
|
||||
|
||||
def _build_engine(monkeypatch):
|
||||
monkeypatch.setattr(register_module, "get_settings", lambda: DummySettings())
|
||||
email_service = SimpleNamespace(service_type=EmailServiceType.DUCK_MAIL)
|
||||
return RegistrationEngine(email_service=email_service)
|
||||
|
||||
|
||||
def test_submit_signup_form_uses_stable_protocol_body(monkeypatch):
|
||||
engine = _build_engine(monkeypatch)
|
||||
session = FakeSession(FakeResponse(
|
||||
status_code=200,
|
||||
payload={"page": {"type": OPENAI_PAGE_TYPES["PASSWORD_REGISTRATION"]}},
|
||||
))
|
||||
engine.session = session
|
||||
engine.email = "tester@example.com"
|
||||
|
||||
result = engine._submit_signup_form("did-1", None)
|
||||
|
||||
assert result.success is True
|
||||
assert result.is_existing_account is False
|
||||
assert (
|
||||
session.calls[0]["data"]
|
||||
== '{"username":{"value":"tester@example.com","kind":"email"},"screen_hint":"signup"}'
|
||||
)
|
||||
|
||||
|
||||
def test_register_password_uses_stable_protocol_body(monkeypatch):
|
||||
engine = _build_engine(monkeypatch)
|
||||
session = FakeSession(FakeResponse(status_code=200))
|
||||
engine.session = session
|
||||
engine.email = "tester@example.com"
|
||||
monkeypatch.setattr(engine, "_generate_password", lambda length=0: "Pass12345")
|
||||
|
||||
success, password = engine._register_password()
|
||||
|
||||
assert success is True
|
||||
assert password == "Pass12345"
|
||||
assert session.calls[0]["data"] == json.dumps(
|
||||
{
|
||||
"password": "Pass12345",
|
||||
"username": "tester@example.com",
|
||||
}
|
||||
)
|
||||
414
tests/test_registration_email_service_failover.py
Normal file
414
tests/test_registration_email_service_failover.py
Normal file
@@ -0,0 +1,414 @@
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import src.services.base as base_module
|
||||
from src.core.register import (
|
||||
ERROR_OTP_TIMEOUT_SECONDARY,
|
||||
PhaseResult,
|
||||
RegistrationResult,
|
||||
)
|
||||
from src.database.models import Base, EmailService, RegistrationTask
|
||||
from src.database.session import DatabaseSessionManager
|
||||
from src.services import EmailServiceType
|
||||
from src.services.base import BaseEmailService, EmailProviderBackoffState
|
||||
from src.web.routes import registration as registration_routes
|
||||
|
||||
|
||||
class DummyTaskManager:
|
||||
def __init__(self):
|
||||
self.status_updates = []
|
||||
self.logs = {}
|
||||
|
||||
def is_cancelled(self, task_uuid):
|
||||
return False
|
||||
|
||||
def update_status(self, task_uuid, status, email=None, error=None, **kwargs):
|
||||
self.status_updates.append((task_uuid, status, email, error, kwargs))
|
||||
|
||||
def create_log_callback(self, task_uuid, prefix="", batch_id=""):
|
||||
def callback(message):
|
||||
self.logs.setdefault(task_uuid, []).append(message)
|
||||
return callback
|
||||
|
||||
|
||||
class BackoffAwareEmailService(BaseEmailService):
|
||||
def __init__(self, service_type, config=None, name=None):
|
||||
super().__init__(service_type=service_type, name=name)
|
||||
self.config = config or {}
|
||||
|
||||
def create_email(self, config=None):
|
||||
return {"email": "tester@example.com", "service_id": "svc-1"}
|
||||
|
||||
def get_verification_code(self, **kwargs):
|
||||
return None
|
||||
|
||||
def list_emails(self, **kwargs):
|
||||
return []
|
||||
|
||||
def delete_email(self, email_id: str) -> bool:
|
||||
return True
|
||||
|
||||
def check_health(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def test_registration_task_fails_over_after_rate_limit(monkeypatch):
|
||||
runtime_dir = Path("tests_runtime")
|
||||
runtime_dir.mkdir(exist_ok=True)
|
||||
db_path = runtime_dir / "registration_failover.db"
|
||||
if db_path.exists():
|
||||
db_path.unlink()
|
||||
|
||||
manager = DatabaseSessionManager(f"sqlite:///{db_path}")
|
||||
Base.metadata.create_all(bind=manager.engine)
|
||||
|
||||
task_uuid = "task-rate-limit-failover"
|
||||
with manager.session_scope() as session:
|
||||
session.add(RegistrationTask(task_uuid=task_uuid, status="pending"))
|
||||
session.add_all([
|
||||
EmailService(
|
||||
service_type="duck_mail",
|
||||
name="duck-primary",
|
||||
config={
|
||||
"base_url": "https://mail-1.example.test",
|
||||
"default_domain": "mail.example.test",
|
||||
},
|
||||
enabled=True,
|
||||
priority=0,
|
||||
),
|
||||
EmailService(
|
||||
service_type="duck_mail",
|
||||
name="duck-secondary",
|
||||
config={
|
||||
"base_url": "https://mail-2.example.test",
|
||||
"default_domain": "mail.example.test",
|
||||
},
|
||||
enabled=True,
|
||||
priority=1,
|
||||
),
|
||||
])
|
||||
|
||||
@contextmanager
|
||||
def fake_get_db():
|
||||
session = manager.SessionLocal()
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
class DummySettings:
|
||||
pass
|
||||
|
||||
attempts = []
|
||||
|
||||
class FakeRegistrationEngine:
|
||||
def __init__(self, email_service, proxy_url=None, callback_logger=None, task_uuid=None):
|
||||
self.email_service = email_service
|
||||
self.phase_history = []
|
||||
|
||||
def run(self):
|
||||
attempts.append(self.email_service.name)
|
||||
if self.email_service.name == "duck-primary":
|
||||
self.phase_history = [
|
||||
PhaseResult(
|
||||
phase="email_prepare",
|
||||
success=False,
|
||||
error_message="创建邮箱失败",
|
||||
error_code="EMAIL_PROVIDER_RATE_LIMITED",
|
||||
retryable=True,
|
||||
next_action="switch_provider",
|
||||
provider_backoff=EmailProviderBackoffState(
|
||||
failures=1,
|
||||
delay_seconds=30,
|
||||
opened_until=1030.0,
|
||||
retry_after=7,
|
||||
last_error="请求失败: 429",
|
||||
),
|
||||
)
|
||||
]
|
||||
return RegistrationResult(
|
||||
success=False,
|
||||
error_message="创建邮箱失败: 请求失败: 429",
|
||||
logs=[],
|
||||
)
|
||||
self.phase_history = [
|
||||
PhaseResult(
|
||||
phase="email_prepare",
|
||||
success=True,
|
||||
provider_backoff=EmailProviderBackoffState(),
|
||||
)
|
||||
]
|
||||
return RegistrationResult(
|
||||
success=True,
|
||||
email="tester@example.com",
|
||||
password="Pass12345",
|
||||
account_id="acct-1",
|
||||
workspace_id="ws-1",
|
||||
access_token="access-token",
|
||||
refresh_token="refresh-token",
|
||||
id_token="id-token",
|
||||
logs=[],
|
||||
)
|
||||
|
||||
def save_to_database(self, result):
|
||||
return True
|
||||
|
||||
def close(self):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(registration_routes, "get_db", fake_get_db)
|
||||
monkeypatch.setattr(registration_routes, "get_settings", lambda: DummySettings())
|
||||
monkeypatch.setattr(registration_routes, "task_manager", DummyTaskManager())
|
||||
monkeypatch.setattr(registration_routes, "RegistrationEngine", FakeRegistrationEngine)
|
||||
monkeypatch.setattr(
|
||||
registration_routes.EmailServiceFactory,
|
||||
"create",
|
||||
lambda service_type, config, name=None: SimpleNamespace(
|
||||
service_type=service_type,
|
||||
name=name or service_type.value,
|
||||
config=config,
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(registration_routes, "update_proxy_usage", lambda db, proxy_id: None)
|
||||
registration_routes.email_service_circuit_breakers.clear()
|
||||
|
||||
registration_routes._run_sync_registration_task(
|
||||
task_uuid=task_uuid,
|
||||
email_service_type=EmailServiceType.DUCK_MAIL.value,
|
||||
proxy=None,
|
||||
email_service_config=None,
|
||||
)
|
||||
|
||||
with manager.session_scope() as session:
|
||||
task = session.query(RegistrationTask).filter(RegistrationTask.task_uuid == task_uuid).first()
|
||||
services = session.query(EmailService).order_by(EmailService.priority.asc()).all()
|
||||
task_status = task.status
|
||||
task_email_service_id = task.email_service_id
|
||||
primary_service_id = services[0].id
|
||||
secondary_service_id = services[1].id
|
||||
|
||||
assert attempts == ["duck-primary", "duck-secondary"]
|
||||
assert task_status == "completed"
|
||||
assert task_email_service_id == secondary_service_id
|
||||
assert registration_routes.email_service_circuit_breakers[primary_service_id].failures == 1
|
||||
assert registration_routes.email_service_circuit_breakers[primary_service_id].delay_seconds == 30
|
||||
|
||||
|
||||
def test_registration_task_enters_deep_cooldown_after_three_otp_timeouts(monkeypatch):
|
||||
runtime_dir = Path("tests_runtime")
|
||||
runtime_dir.mkdir(exist_ok=True)
|
||||
db_path = runtime_dir / "registration_otp_timeout_backoff.db"
|
||||
if db_path.exists():
|
||||
db_path.unlink()
|
||||
|
||||
manager = DatabaseSessionManager(f"sqlite:///{db_path}")
|
||||
Base.metadata.create_all(bind=manager.engine)
|
||||
|
||||
task_uuids = [
|
||||
"task-otp-timeout-1",
|
||||
"task-otp-timeout-2",
|
||||
"task-otp-timeout-3",
|
||||
]
|
||||
with manager.session_scope() as session:
|
||||
session.add_all([RegistrationTask(task_uuid=task_uuid, status="pending") for task_uuid in task_uuids])
|
||||
session.add(
|
||||
EmailService(
|
||||
service_type="duck_mail",
|
||||
name="duck-primary",
|
||||
config={
|
||||
"base_url": "https://mail-1.example.test",
|
||||
"default_domain": "mail.example.test",
|
||||
},
|
||||
enabled=True,
|
||||
priority=0,
|
||||
)
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
def fake_get_db():
|
||||
session = manager.SessionLocal()
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
class DummySettings:
|
||||
pass
|
||||
|
||||
current_time = {"value": 1000.0}
|
||||
|
||||
class FakeRegistrationEngine:
|
||||
def __init__(self, email_service, proxy_url=None, callback_logger=None, task_uuid=None):
|
||||
self.email_service = email_service
|
||||
self.phase_history = []
|
||||
|
||||
def run(self):
|
||||
self.phase_history = [
|
||||
PhaseResult(
|
||||
phase="email_prepare",
|
||||
success=True,
|
||||
provider_backoff=EmailProviderBackoffState(),
|
||||
)
|
||||
]
|
||||
return RegistrationResult(
|
||||
success=False,
|
||||
error_message="等待验证码超时",
|
||||
error_code=ERROR_OTP_TIMEOUT_SECONDARY,
|
||||
logs=[],
|
||||
)
|
||||
|
||||
def save_to_database(self, result):
|
||||
return True
|
||||
|
||||
def close(self):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(registration_routes, "get_db", fake_get_db)
|
||||
monkeypatch.setattr(registration_routes, "get_settings", lambda: DummySettings())
|
||||
monkeypatch.setattr(registration_routes, "task_manager", DummyTaskManager())
|
||||
monkeypatch.setattr(registration_routes, "RegistrationEngine", FakeRegistrationEngine)
|
||||
monkeypatch.setattr(
|
||||
registration_routes.EmailServiceFactory,
|
||||
"create",
|
||||
lambda service_type, config, name=None: BackoffAwareEmailService(
|
||||
service_type=service_type,
|
||||
config=config,
|
||||
name=name,
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(registration_routes, "update_proxy_usage", lambda db, proxy_id: None)
|
||||
monkeypatch.setattr(base_module.time, "time", lambda: current_time["value"])
|
||||
registration_routes.email_service_circuit_breakers.clear()
|
||||
|
||||
with manager.session_scope() as session:
|
||||
service_id = session.query(EmailService.id).filter(EmailService.name == "duck-primary").scalar()
|
||||
|
||||
expected_delays = [30, 60, 3600]
|
||||
for attempt_index, task_uuid in enumerate(task_uuids, start=1):
|
||||
registration_routes._run_sync_registration_task(
|
||||
task_uuid=task_uuid,
|
||||
email_service_type=EmailServiceType.DUCK_MAIL.value,
|
||||
proxy=None,
|
||||
email_service_config=None,
|
||||
)
|
||||
|
||||
with manager.session_scope() as session:
|
||||
task = session.query(RegistrationTask).filter(RegistrationTask.task_uuid == task_uuid).first()
|
||||
assert task.status == "failed"
|
||||
assert task.error_message == "等待验证码超时"
|
||||
|
||||
state = registration_routes.email_service_circuit_breakers[service_id]
|
||||
assert state.failures == attempt_index
|
||||
assert state.delay_seconds == expected_delays[attempt_index - 1]
|
||||
assert state.opened_until == current_time["value"] + expected_delays[attempt_index - 1]
|
||||
|
||||
if attempt_index < len(task_uuids):
|
||||
current_time["value"] = state.opened_until + 1
|
||||
|
||||
final_state = registration_routes.email_service_circuit_breakers[service_id]
|
||||
assert final_state.delay_seconds == 3600
|
||||
assert final_state.failures == 3
|
||||
|
||||
|
||||
def test_registration_task_success_clears_email_service_backoff(monkeypatch):
|
||||
runtime_dir = Path("tests_runtime")
|
||||
runtime_dir.mkdir(exist_ok=True)
|
||||
db_path = runtime_dir / "registration_success_clears_backoff.db"
|
||||
if db_path.exists():
|
||||
db_path.unlink()
|
||||
|
||||
manager = DatabaseSessionManager(f"sqlite:///{db_path}")
|
||||
Base.metadata.create_all(bind=manager.engine)
|
||||
|
||||
task_uuid = "task-success-clears-backoff"
|
||||
with manager.session_scope() as session:
|
||||
session.add(RegistrationTask(task_uuid=task_uuid, status="pending"))
|
||||
session.add(
|
||||
EmailService(
|
||||
service_type="duck_mail",
|
||||
name="duck-primary",
|
||||
config={
|
||||
"base_url": "https://mail-1.example.test",
|
||||
"default_domain": "mail.example.test",
|
||||
},
|
||||
enabled=True,
|
||||
priority=0,
|
||||
)
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
def fake_get_db():
|
||||
session = manager.SessionLocal()
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
class DummySettings:
|
||||
pass
|
||||
|
||||
class FakeRegistrationEngine:
|
||||
def __init__(self, email_service, proxy_url=None, callback_logger=None, task_uuid=None):
|
||||
self.email_service = email_service
|
||||
self.phase_history = [
|
||||
PhaseResult(
|
||||
phase="email_prepare",
|
||||
success=True,
|
||||
provider_backoff=EmailProviderBackoffState(),
|
||||
)
|
||||
]
|
||||
|
||||
def run(self):
|
||||
return RegistrationResult(
|
||||
success=True,
|
||||
email="tester@example.com",
|
||||
password="Pass12345",
|
||||
account_id="acct-1",
|
||||
workspace_id="ws-1",
|
||||
access_token="access-token",
|
||||
refresh_token="refresh-token",
|
||||
id_token="id-token",
|
||||
logs=[],
|
||||
)
|
||||
|
||||
def save_to_database(self, result):
|
||||
return True
|
||||
|
||||
def close(self):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(registration_routes, "get_db", fake_get_db)
|
||||
monkeypatch.setattr(registration_routes, "get_settings", lambda: DummySettings())
|
||||
monkeypatch.setattr(registration_routes, "task_manager", DummyTaskManager())
|
||||
monkeypatch.setattr(registration_routes, "RegistrationEngine", FakeRegistrationEngine)
|
||||
monkeypatch.setattr(
|
||||
registration_routes.EmailServiceFactory,
|
||||
"create",
|
||||
lambda service_type, config, name=None: BackoffAwareEmailService(
|
||||
service_type=service_type,
|
||||
config=config,
|
||||
name=name,
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(registration_routes, "update_proxy_usage", lambda db, proxy_id: None)
|
||||
registration_routes.email_service_circuit_breakers.clear()
|
||||
|
||||
with manager.session_scope() as session:
|
||||
service_id = session.query(EmailService.id).filter(EmailService.name == "duck-primary").scalar()
|
||||
|
||||
registration_routes.email_service_circuit_breakers[service_id] = EmailProviderBackoffState(
|
||||
failures=2,
|
||||
delay_seconds=60,
|
||||
opened_until=9999.0,
|
||||
last_error="等待验证码超时",
|
||||
)
|
||||
|
||||
registration_routes._run_sync_registration_task(
|
||||
task_uuid=task_uuid,
|
||||
email_service_type=EmailServiceType.DUCK_MAIL.value,
|
||||
proxy=None,
|
||||
email_service_config=None,
|
||||
)
|
||||
|
||||
assert service_id not in registration_routes.email_service_circuit_breakers
|
||||
71
tests/test_registration_otp_phase.py
Normal file
71
tests/test_registration_otp_phase.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import src.core.register as register_module
|
||||
from src.core.register import (
|
||||
ERROR_OTP_TIMEOUT_SECONDARY,
|
||||
PhaseContext,
|
||||
RegistrationEngine,
|
||||
)
|
||||
from src.services import EmailServiceType
|
||||
|
||||
|
||||
class DummySettings:
|
||||
openai_client_id = "client-id"
|
||||
openai_auth_url = "https://auth.example.test"
|
||||
openai_token_url = "https://token.example.test"
|
||||
openai_redirect_uri = "https://callback.example.test"
|
||||
openai_scope = "openid profile email"
|
||||
|
||||
|
||||
class FakeEmailService:
|
||||
def __init__(self, code):
|
||||
self.service_type = EmailServiceType.TEMPMAIL
|
||||
self.code = code
|
||||
self.calls = []
|
||||
|
||||
def get_verification_code(self, **kwargs):
|
||||
self.calls.append(kwargs)
|
||||
return self.code
|
||||
|
||||
|
||||
def _build_engine(monkeypatch, email_service):
|
||||
monkeypatch.setattr(register_module, "get_settings", lambda: DummySettings())
|
||||
return RegistrationEngine(email_service=email_service)
|
||||
|
||||
|
||||
def test_phase_otp_secondary_uses_remaining_budget_from_start_timestamp(monkeypatch):
|
||||
email_service = FakeEmailService(code="654321")
|
||||
engine = _build_engine(monkeypatch, email_service)
|
||||
engine.email = "tester@example.com"
|
||||
engine.email_info = {"service_id": "svc-1"}
|
||||
|
||||
monkeypatch.setattr(register_module.time, "time", lambda: 120.0)
|
||||
|
||||
code, phase_result = engine._phase_otp_secondary(
|
||||
PhaseContext(otp_sent_at=77.0),
|
||||
started_at=100.0,
|
||||
)
|
||||
|
||||
assert code == "654321"
|
||||
assert phase_result.success is True
|
||||
assert email_service.calls[0]["timeout"] == 100
|
||||
assert email_service.calls[0]["otp_sent_at"] == 77.0
|
||||
assert email_service.calls[0]["email"] == "tester@example.com"
|
||||
assert email_service.calls[0]["email_id"] == "svc-1"
|
||||
|
||||
|
||||
def test_phase_otp_secondary_returns_dedicated_timeout_error_code(monkeypatch):
|
||||
email_service = FakeEmailService(code=None)
|
||||
engine = _build_engine(monkeypatch, email_service)
|
||||
engine.email = "tester@example.com"
|
||||
engine.email_info = {"service_id": "svc-1"}
|
||||
|
||||
monkeypatch.setattr(register_module.time, "time", lambda: 120.0)
|
||||
|
||||
code, phase_result = engine._phase_otp_secondary(
|
||||
PhaseContext(otp_sent_at=80.0),
|
||||
started_at=100.0,
|
||||
)
|
||||
|
||||
assert code is None
|
||||
assert phase_result.success is False
|
||||
assert phase_result.error_code == ERROR_OTP_TIMEOUT_SECONDARY
|
||||
assert engine.phase_history[0].error_code == ERROR_OTP_TIMEOUT_SECONDARY
|
||||
@@ -42,7 +42,13 @@ def test_run_sync_registration_task_disables_bad_proxy_and_retries(monkeypatch,
|
||||
monkeypatch.setattr(
|
||||
registration,
|
||||
"EmailServiceFactory",
|
||||
SimpleNamespace(create=lambda service_type, config: SimpleNamespace(service_type=service_type, config=config)),
|
||||
SimpleNamespace(
|
||||
create=lambda service_type, config, name=None: SimpleNamespace(
|
||||
service_type=service_type,
|
||||
config=config,
|
||||
name=name or service_type.value,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
attempted_proxies = []
|
||||
@@ -73,6 +79,7 @@ def test_run_sync_registration_task_disables_bad_proxy_and_retries(monkeypatch,
|
||||
return True
|
||||
|
||||
monkeypatch.setattr(registration, "RegistrationEngine", FakeRegistrationEngine)
|
||||
registration.email_service_circuit_breakers.clear()
|
||||
|
||||
registration._run_sync_registration_task(
|
||||
task_uuid="task-proxy-failover",
|
||||
|
||||
@@ -15,6 +15,7 @@ def test_static_asset_version_is_non_empty_string():
|
||||
def test_email_services_template_uses_versioned_static_assets():
|
||||
template = Path("templates/email_services.html").read_text(encoding="utf-8")
|
||||
|
||||
assert '/static/favicon.svg?v={{ static_version }}' in template
|
||||
assert '/static/css/style.css?v={{ static_version }}' in template
|
||||
assert '/static/js/utils.js?v={{ static_version }}' in template
|
||||
assert '/static/js/email_services.js?v={{ static_version }}' in template
|
||||
@@ -23,6 +24,7 @@ def test_email_services_template_uses_versioned_static_assets():
|
||||
def test_index_template_uses_versioned_static_assets():
|
||||
template = Path("templates/index.html").read_text(encoding="utf-8")
|
||||
|
||||
assert '/static/favicon.svg?v={{ static_version }}' in template
|
||||
assert '/static/css/style.css?v={{ static_version }}' in template
|
||||
assert '/static/js/utils.js?v={{ static_version }}' in template
|
||||
assert '/static/js/app.js?v={{ static_version }}' in template
|
||||
|
||||
143
tests/test_task_recovery.py
Normal file
143
tests/test_task_recovery.py
Normal file
@@ -0,0 +1,143 @@
|
||||
from contextlib import contextmanager
|
||||
import asyncio
|
||||
|
||||
from fastapi import WebSocketDisconnect
|
||||
|
||||
from src.database import crud
|
||||
from src.database.models import Base, RegistrationTask
|
||||
from src.database.session import DatabaseSessionManager
|
||||
from src.web.routes import websocket as websocket_routes
|
||||
from src.web.task_manager import TaskManager
|
||||
|
||||
|
||||
def test_fail_incomplete_registration_tasks_marks_pending_and_running_failed(tmp_path):
|
||||
db_path = tmp_path / "recovery.db"
|
||||
manager = DatabaseSessionManager(f"sqlite:///{db_path}")
|
||||
Base.metadata.create_all(bind=manager.engine)
|
||||
|
||||
with manager.session_scope() as session:
|
||||
session.add_all([
|
||||
RegistrationTask(task_uuid="task-pending", status="pending"),
|
||||
RegistrationTask(task_uuid="task-running", status="running", logs="[01:00:00] still running"),
|
||||
RegistrationTask(task_uuid="task-done", status="completed"),
|
||||
])
|
||||
|
||||
with manager.session_scope() as session:
|
||||
cleaned = crud.fail_incomplete_registration_tasks(
|
||||
session,
|
||||
"服务启动时检测到未完成的历史任务,已标记失败,请重新发起。"
|
||||
)
|
||||
|
||||
assert cleaned == ["task-pending", "task-running"]
|
||||
|
||||
with manager.session_scope() as session:
|
||||
pending_task = crud.get_registration_task_by_uuid(session, "task-pending")
|
||||
running_task = crud.get_registration_task_by_uuid(session, "task-running")
|
||||
done_task = crud.get_registration_task_by_uuid(session, "task-done")
|
||||
|
||||
assert pending_task.status == "failed"
|
||||
assert running_task.status == "failed"
|
||||
assert pending_task.error_message == "服务启动时检测到未完成的历史任务,已标记失败,请重新发起。"
|
||||
assert running_task.completed_at is not None
|
||||
assert "[系统] 服务启动时检测到未完成的历史任务,已标记失败,请重新发起。" in running_task.logs
|
||||
assert done_task.status == "completed"
|
||||
|
||||
|
||||
def test_restore_task_snapshot_loads_status_and_logs_from_database(monkeypatch, tmp_path):
|
||||
db_path = tmp_path / "websocket.db"
|
||||
manager = DatabaseSessionManager(f"sqlite:///{db_path}")
|
||||
Base.metadata.create_all(bind=manager.engine)
|
||||
|
||||
with manager.session_scope() as session:
|
||||
session.add(
|
||||
RegistrationTask(
|
||||
task_uuid="task-websocket",
|
||||
status="failed",
|
||||
logs="[01:00:00] step 1\n[01:00:01] step 2",
|
||||
result={"email": "tester@example.com"},
|
||||
error_message="boom"
|
||||
)
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
def fake_get_db():
|
||||
session = manager.SessionLocal()
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
monkeypatch.setattr(websocket_routes, "get_db", fake_get_db)
|
||||
|
||||
status, logs = websocket_routes._restore_task_snapshot("task-websocket")
|
||||
|
||||
assert status == {
|
||||
"status": "failed",
|
||||
"email": "tester@example.com",
|
||||
"error": "boom",
|
||||
}
|
||||
assert logs == ["[01:00:00] step 1", "[01:00:01] step 2"]
|
||||
|
||||
|
||||
def test_sync_task_state_prefers_longer_persisted_log_history():
|
||||
manager = TaskManager()
|
||||
task_uuid = "task-sync"
|
||||
|
||||
manager.sync_task_state(task_uuid, status={"status": "running"}, logs=["a", "b"])
|
||||
manager.sync_task_state(task_uuid, logs=["a"])
|
||||
|
||||
assert manager.get_status(task_uuid) == {"status": "running"}
|
||||
assert manager.get_logs(task_uuid) == ["a", "b"]
|
||||
|
||||
|
||||
def test_register_websocket_returns_snapshot_and_keeps_live_cursor():
|
||||
manager = TaskManager()
|
||||
task_uuid = "task-live"
|
||||
websocket = object()
|
||||
|
||||
manager.sync_task_state(task_uuid, status={"status": "running"}, logs=["log-1", "log-2"])
|
||||
|
||||
history_logs = manager.register_websocket(task_uuid, websocket)
|
||||
|
||||
assert history_logs == ["log-1", "log-2"]
|
||||
assert manager.get_unsent_logs(task_uuid, websocket) == []
|
||||
|
||||
manager.add_log(task_uuid, "log-3")
|
||||
|
||||
assert manager.get_unsent_logs(task_uuid, websocket) == ["log-3"]
|
||||
|
||||
|
||||
class _FakeWebSocket:
|
||||
def __init__(self):
|
||||
self.messages = []
|
||||
self.accepted = False
|
||||
|
||||
async def accept(self):
|
||||
self.accepted = True
|
||||
|
||||
async def send_json(self, payload):
|
||||
self.messages.append(payload)
|
||||
|
||||
async def receive_json(self):
|
||||
raise WebSocketDisconnect()
|
||||
|
||||
|
||||
def test_batch_websocket_replays_history_logs_from_registration_snapshot(monkeypatch):
|
||||
manager = TaskManager()
|
||||
batch_id = "batch-history"
|
||||
websocket = _FakeWebSocket()
|
||||
|
||||
manager.init_batch(batch_id, total=2)
|
||||
manager.add_batch_log(batch_id, "[01:00:00] first")
|
||||
manager.add_batch_log(batch_id, "[01:00:01] second")
|
||||
|
||||
monkeypatch.setattr(websocket_routes, "task_manager", manager)
|
||||
|
||||
asyncio.run(websocket_routes.batch_websocket(websocket, batch_id))
|
||||
|
||||
assert websocket.accepted is True
|
||||
assert websocket.messages[0]["type"] == "status"
|
||||
assert [msg["message"] for msg in websocket.messages[1:]] == [
|
||||
"[01:00:00] first",
|
||||
"[01:00:01] second",
|
||||
]
|
||||
98
tests/test_tempmail_timestamp_filter.py
Normal file
98
tests/test_tempmail_timestamp_filter.py
Normal file
@@ -0,0 +1,98 @@
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from src.services.tempmail import TempmailService
|
||||
|
||||
|
||||
class FakeResponse:
|
||||
def __init__(self, payload, status_code=200):
|
||||
self._payload = payload
|
||||
self.status_code = status_code
|
||||
|
||||
def json(self):
|
||||
return self._payload
|
||||
|
||||
|
||||
class FakeHTTPClient:
|
||||
def __init__(self, responses):
|
||||
self.responses = list(responses)
|
||||
self.calls = []
|
||||
|
||||
def get(self, url, **kwargs):
|
||||
self.calls.append({"url": url, "kwargs": kwargs})
|
||||
if not self.responses:
|
||||
raise AssertionError(f"未准备响应: GET {url}")
|
||||
return self.responses.pop(0)
|
||||
|
||||
|
||||
def _to_timestamp(value: str) -> float:
|
||||
return datetime.fromisoformat(value.replace("Z", "+00:00")).astimezone(timezone.utc).timestamp()
|
||||
|
||||
|
||||
def test_get_verification_code_ignores_messages_received_before_otp_sent_at():
|
||||
service = TempmailService({"base_url": "https://api.tempmail.test"})
|
||||
service._email_cache["tester@example.com"] = {"token": "token-1"}
|
||||
service.http_client = FakeHTTPClient([
|
||||
FakeResponse(
|
||||
{
|
||||
"emails": [
|
||||
{
|
||||
"id": "old-mail",
|
||||
"received_at": "2026-03-23T10:00:00Z",
|
||||
"from": "noreply@openai.com",
|
||||
"subject": "Old code",
|
||||
"body": "111111",
|
||||
},
|
||||
{
|
||||
"id": "new-mail",
|
||||
"received_at": "2026-03-23T10:00:05Z",
|
||||
"from": "noreply@openai.com",
|
||||
"subject": "New code",
|
||||
"body": "222222",
|
||||
},
|
||||
]
|
||||
}
|
||||
)
|
||||
])
|
||||
|
||||
code = service.get_verification_code(
|
||||
email="tester@example.com",
|
||||
timeout=1,
|
||||
otp_sent_at=_to_timestamp("2026-03-23T10:00:02Z"),
|
||||
)
|
||||
|
||||
assert code == "222222"
|
||||
|
||||
|
||||
def test_get_verification_code_requires_received_at_when_otp_sent_at_is_present():
|
||||
service = TempmailService({"base_url": "https://api.tempmail.test"})
|
||||
service._email_cache["tester@example.com"] = {"token": "token-1"}
|
||||
service.http_client = FakeHTTPClient([
|
||||
FakeResponse(
|
||||
{
|
||||
"emails": [
|
||||
{
|
||||
"id": "legacy-mail",
|
||||
"date": "2026-03-23T10:00:06Z",
|
||||
"from": "noreply@openai.com",
|
||||
"subject": "Legacy code",
|
||||
"body": "333333",
|
||||
},
|
||||
{
|
||||
"id": "received-mail",
|
||||
"received_at": "2026-03-23T10:00:07Z",
|
||||
"from": "noreply@openai.com",
|
||||
"subject": "Received code",
|
||||
"body": "444444",
|
||||
},
|
||||
]
|
||||
}
|
||||
)
|
||||
])
|
||||
|
||||
code = service.get_verification_code(
|
||||
email="tester@example.com",
|
||||
timeout=1,
|
||||
otp_sent_at=_to_timestamp("2026-03-23T10:00:05Z"),
|
||||
)
|
||||
|
||||
assert code == "444444"
|
||||
Reference in New Issue
Block a user