Fix registration OTP anchor and batch task state

This commit is contained in:
Mison
2026-03-24 07:20:38 +08:00
parent b8b1eb72d1
commit 78f2d0accc
9 changed files with 393 additions and 145 deletions

View File

@@ -70,3 +70,37 @@ def test_update_account_preserves_pending_status_when_other_tokens_remain(tmp_pa
assert updated.refresh_token == ""
assert updated.token_sync_status == "pending"
assert updated.token_sync_updated_at is not None
def test_update_outlook_refresh_token_persists_nested_config_changes(tmp_path):
manager = DatabaseSessionManager(f"sqlite:///{tmp_path}/test.db")
manager.create_tables()
manager.migrate_tables()
with manager.session_scope() as session:
service = crud.create_email_service(
session,
service_type="outlook",
name="outlook-service",
config={
"accounts": [
{"email": "first@example.com", "refresh_token": "old-first"},
{"email": "second@example.com", "refresh_token": "old-second"},
]
},
)
service_id = service.id
crud.update_outlook_refresh_token(
session,
service_id=service_id,
email="second@example.com",
new_refresh_token="new-second",
)
with manager.session_scope() as session:
reloaded = crud.get_email_service_by_id(session, service_id)
assert reloaded is not None
assert reloaded.config["accounts"][0]["refresh_token"] == "old-first"
assert reloaded.config["accounts"][1]["refresh_token"] == "new-second"

View File

@@ -6,24 +6,23 @@ from src.web.routes import registration as registration_routes
from src.web.task_manager import task_manager
def test_init_batch_state_keeps_batch_tasks_and_task_manager_in_sync():
def test_init_batch_state_persists_state_in_task_manager():
batch_id = "batch-sync-init"
task_uuids = ["task-1", "task-2", "task-3"]
registration_routes.batch_tasks.pop(batch_id, None)
registration_routes._init_batch_state(batch_id, task_uuids)
batch_snapshot = registration_routes.batch_tasks[batch_id]
manager_snapshot = task_manager.get_batch_status(batch_id)
assert manager_snapshot is not None
assert batch_snapshot["total"] == manager_snapshot["total"] == 3
assert batch_snapshot["completed"] == manager_snapshot["completed"] == 0
assert batch_snapshot["success"] == manager_snapshot["success"] == 0
assert batch_snapshot["failed"] == manager_snapshot["failed"] == 0
assert batch_snapshot["finished"] is False
assert manager_snapshot["task_uuids"] == task_uuids
assert manager_snapshot["total"] == 3
assert manager_snapshot["completed"] == 0
assert manager_snapshot["success"] == 0
assert manager_snapshot["failed"] == 0
assert manager_snapshot["finished"] is False
assert manager_snapshot["status"] == "running"
assert task_manager.get_batch_logs(batch_id) == []
def test_run_batch_parallel_keeps_counter_updates_in_sync(monkeypatch):
@@ -61,7 +60,6 @@ def test_run_batch_parallel_keeps_counter_updates_in_sync(monkeypatch):
error_message = None if status == "completed" else f"{task_uuid}-error"
return SimpleNamespace(status=status, error_message=error_message)
registration_routes.batch_tasks.pop(batch_id, None)
monkeypatch.setattr(registration_routes, "run_registration_task", fake_run_registration_task)
monkeypatch.setattr(registration_routes, "get_db", fake_get_db)
monkeypatch.setattr(registration_routes.crud, "get_registration_task", fake_get_registration_task)
@@ -78,13 +76,11 @@ def test_run_batch_parallel_keeps_counter_updates_in_sync(monkeypatch):
)
)
batch_snapshot = registration_routes.batch_tasks[batch_id]
manager_snapshot = task_manager.get_batch_status(batch_id)
assert manager_snapshot is not None
assert batch_snapshot["completed"] == manager_snapshot["completed"] == 3
assert batch_snapshot["success"] == manager_snapshot["success"] == 2
assert batch_snapshot["failed"] == manager_snapshot["failed"] == 1
assert batch_snapshot["finished"] is True
assert manager_snapshot["completed"] == 3
assert manager_snapshot["success"] == 2
assert manager_snapshot["failed"] == 1
assert manager_snapshot["finished"] is True
assert manager_snapshot["status"] == "completed"

View File

@@ -1,5 +1,6 @@
from contextlib import contextmanager
from pathlib import Path
import threading
from types import SimpleNamespace
import src.services.base as base_module
@@ -412,3 +413,128 @@ def test_registration_task_success_clears_email_service_backoff(monkeypatch):
)
assert service_id not in registration_routes.email_service_circuit_breakers
def test_registration_task_backoff_failures_do_not_get_lost_under_concurrency(monkeypatch):
runtime_dir = Path("tests_runtime")
runtime_dir.mkdir(exist_ok=True)
db_path = runtime_dir / "registration_backoff_concurrency.db"
if db_path.exists():
db_path.unlink()
manager = DatabaseSessionManager(f"sqlite:///{db_path}")
Base.metadata.create_all(bind=manager.engine)
task_uuids = ["task-backoff-1", "task-backoff-2"]
with manager.session_scope() as session:
for task_uuid in task_uuids:
session.add(RegistrationTask(task_uuid=task_uuid, status="pending"))
session.add(
EmailService(
service_type="duck_mail",
name="duck-primary",
config={
"base_url": "https://mail-1.example.test",
"default_domain": "mail.example.test",
},
enabled=True,
priority=0,
)
)
@contextmanager
def fake_get_db():
session = manager.SessionLocal()
try:
yield session
finally:
session.close()
class DummySettings:
pass
start_lock = threading.Lock()
started = {"count": 0}
peer_started = threading.Event()
class FakeRegistrationEngine:
def __init__(self, email_service, proxy_url=None, callback_logger=None, task_uuid=None):
self.email_service = email_service
self.phase_history = []
def run(self):
with start_lock:
started["count"] += 1
if started["count"] == len(task_uuids):
peer_started.set()
peer_started.wait(timeout=0.1)
current_state = self.email_service.provider_backoff_state
next_failures = current_state.failures + 1
delay_seconds = 30 if next_failures == 1 else 60
self.phase_history = [
PhaseResult(
phase="email_prepare",
success=False,
error_message="创建邮箱失败",
error_code="EMAIL_PROVIDER_RATE_LIMITED",
retryable=True,
next_action="switch_provider",
provider_backoff=EmailProviderBackoffState(
failures=next_failures,
delay_seconds=delay_seconds,
opened_until=1000.0 + delay_seconds,
last_error="请求失败: 429",
),
)
]
return RegistrationResult(
success=False,
error_message="创建邮箱失败: 请求失败: 429",
logs=[],
)
def save_to_database(self, result):
return True
def close(self):
return None
monkeypatch.setattr(registration_routes, "get_db", fake_get_db)
monkeypatch.setattr(registration_routes, "get_settings", lambda: DummySettings())
monkeypatch.setattr(registration_routes, "task_manager", DummyTaskManager())
monkeypatch.setattr(registration_routes, "RegistrationEngine", FakeRegistrationEngine)
monkeypatch.setattr(
registration_routes.EmailServiceFactory,
"create",
lambda service_type, config, name=None: BackoffAwareEmailService(
service_type=service_type,
config=config,
name=name,
),
)
registration_routes.email_service_circuit_breakers.clear()
with manager.session_scope() as session:
service_id = session.query(EmailService.id).filter(EmailService.name == "duck-primary").scalar()
threads = [
threading.Thread(
target=registration_routes._run_sync_registration_task,
kwargs={
"task_uuid": task_uuid,
"email_service_type": EmailServiceType.DUCK_MAIL.value,
"proxy": None,
"email_service_config": None,
},
)
for task_uuid in task_uuids
]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
state = registration_routes.email_service_circuit_breakers[service_id]
assert state.failures == 2
assert state.delay_seconds == 60

View File

@@ -69,3 +69,36 @@ def test_phase_otp_secondary_returns_dedicated_timeout_error_code(monkeypatch):
assert phase_result.success is False
assert phase_result.error_code == ERROR_OTP_TIMEOUT_SECONDARY
assert engine.phase_history[0].error_code == ERROR_OTP_TIMEOUT_SECONDARY
def test_advance_login_authorization_refreshes_otp_anchor_after_password_submit(monkeypatch):
email_service = FakeEmailService(code=None)
engine = _build_engine(monkeypatch, email_service)
engine.oauth_start = object()
engine._otp_sent_at = 10.0
monkeypatch.setattr(register_module.time, "time", lambda: 456.0)
monkeypatch.setattr(engine, "_init_session", lambda: True)
monkeypatch.setattr(engine, "_start_oauth", lambda: True)
monkeypatch.setattr(engine, "_get_device_id", lambda: True)
monkeypatch.setattr(engine, "_try_reenter_login_flow", lambda: True)
monkeypatch.setattr(
engine,
"_submit_login_password_step_and_get_continue_url",
lambda: (True, "https://continue.example.test"),
)
seen_anchors = []
def fake_get_verification_code():
seen_anchors.append(engine._otp_sent_at)
return None
monkeypatch.setattr(engine, "_get_verification_code", fake_get_verification_code)
workspace_id, callback_url = engine._advance_login_authorization()
assert workspace_id is None
assert callback_url is None
assert engine._otp_sent_at == 456.0
assert seen_anchors == [456.0]