mirror of
https://github.com/cnlimiter/codex-register.git
synced 2026-05-27 11:10:30 +08:00
Fix registration OTP anchor and batch task state
This commit is contained in:
@@ -70,3 +70,37 @@ def test_update_account_preserves_pending_status_when_other_tokens_remain(tmp_pa
|
||||
assert updated.refresh_token == ""
|
||||
assert updated.token_sync_status == "pending"
|
||||
assert updated.token_sync_updated_at is not None
|
||||
|
||||
|
||||
def test_update_outlook_refresh_token_persists_nested_config_changes(tmp_path):
|
||||
manager = DatabaseSessionManager(f"sqlite:///{tmp_path}/test.db")
|
||||
manager.create_tables()
|
||||
manager.migrate_tables()
|
||||
|
||||
with manager.session_scope() as session:
|
||||
service = crud.create_email_service(
|
||||
session,
|
||||
service_type="outlook",
|
||||
name="outlook-service",
|
||||
config={
|
||||
"accounts": [
|
||||
{"email": "first@example.com", "refresh_token": "old-first"},
|
||||
{"email": "second@example.com", "refresh_token": "old-second"},
|
||||
]
|
||||
},
|
||||
)
|
||||
service_id = service.id
|
||||
|
||||
crud.update_outlook_refresh_token(
|
||||
session,
|
||||
service_id=service_id,
|
||||
email="second@example.com",
|
||||
new_refresh_token="new-second",
|
||||
)
|
||||
|
||||
with manager.session_scope() as session:
|
||||
reloaded = crud.get_email_service_by_id(session, service_id)
|
||||
|
||||
assert reloaded is not None
|
||||
assert reloaded.config["accounts"][0]["refresh_token"] == "old-first"
|
||||
assert reloaded.config["accounts"][1]["refresh_token"] == "new-second"
|
||||
|
||||
@@ -6,24 +6,23 @@ from src.web.routes import registration as registration_routes
|
||||
from src.web.task_manager import task_manager
|
||||
|
||||
|
||||
def test_init_batch_state_keeps_batch_tasks_and_task_manager_in_sync():
|
||||
def test_init_batch_state_persists_state_in_task_manager():
|
||||
batch_id = "batch-sync-init"
|
||||
task_uuids = ["task-1", "task-2", "task-3"]
|
||||
|
||||
registration_routes.batch_tasks.pop(batch_id, None)
|
||||
registration_routes._init_batch_state(batch_id, task_uuids)
|
||||
|
||||
batch_snapshot = registration_routes.batch_tasks[batch_id]
|
||||
manager_snapshot = task_manager.get_batch_status(batch_id)
|
||||
|
||||
assert manager_snapshot is not None
|
||||
assert batch_snapshot["total"] == manager_snapshot["total"] == 3
|
||||
assert batch_snapshot["completed"] == manager_snapshot["completed"] == 0
|
||||
assert batch_snapshot["success"] == manager_snapshot["success"] == 0
|
||||
assert batch_snapshot["failed"] == manager_snapshot["failed"] == 0
|
||||
assert batch_snapshot["finished"] is False
|
||||
assert manager_snapshot["task_uuids"] == task_uuids
|
||||
assert manager_snapshot["total"] == 3
|
||||
assert manager_snapshot["completed"] == 0
|
||||
assert manager_snapshot["success"] == 0
|
||||
assert manager_snapshot["failed"] == 0
|
||||
assert manager_snapshot["finished"] is False
|
||||
assert manager_snapshot["status"] == "running"
|
||||
assert task_manager.get_batch_logs(batch_id) == []
|
||||
|
||||
|
||||
def test_run_batch_parallel_keeps_counter_updates_in_sync(monkeypatch):
|
||||
@@ -61,7 +60,6 @@ def test_run_batch_parallel_keeps_counter_updates_in_sync(monkeypatch):
|
||||
error_message = None if status == "completed" else f"{task_uuid}-error"
|
||||
return SimpleNamespace(status=status, error_message=error_message)
|
||||
|
||||
registration_routes.batch_tasks.pop(batch_id, None)
|
||||
monkeypatch.setattr(registration_routes, "run_registration_task", fake_run_registration_task)
|
||||
monkeypatch.setattr(registration_routes, "get_db", fake_get_db)
|
||||
monkeypatch.setattr(registration_routes.crud, "get_registration_task", fake_get_registration_task)
|
||||
@@ -78,13 +76,11 @@ def test_run_batch_parallel_keeps_counter_updates_in_sync(monkeypatch):
|
||||
)
|
||||
)
|
||||
|
||||
batch_snapshot = registration_routes.batch_tasks[batch_id]
|
||||
manager_snapshot = task_manager.get_batch_status(batch_id)
|
||||
|
||||
assert manager_snapshot is not None
|
||||
assert batch_snapshot["completed"] == manager_snapshot["completed"] == 3
|
||||
assert batch_snapshot["success"] == manager_snapshot["success"] == 2
|
||||
assert batch_snapshot["failed"] == manager_snapshot["failed"] == 1
|
||||
assert batch_snapshot["finished"] is True
|
||||
assert manager_snapshot["completed"] == 3
|
||||
assert manager_snapshot["success"] == 2
|
||||
assert manager_snapshot["failed"] == 1
|
||||
assert manager_snapshot["finished"] is True
|
||||
assert manager_snapshot["status"] == "completed"
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
import threading
|
||||
from types import SimpleNamespace
|
||||
|
||||
import src.services.base as base_module
|
||||
@@ -412,3 +413,128 @@ def test_registration_task_success_clears_email_service_backoff(monkeypatch):
|
||||
)
|
||||
|
||||
assert service_id not in registration_routes.email_service_circuit_breakers
|
||||
|
||||
|
||||
def test_registration_task_backoff_failures_do_not_get_lost_under_concurrency(monkeypatch):
|
||||
runtime_dir = Path("tests_runtime")
|
||||
runtime_dir.mkdir(exist_ok=True)
|
||||
db_path = runtime_dir / "registration_backoff_concurrency.db"
|
||||
if db_path.exists():
|
||||
db_path.unlink()
|
||||
|
||||
manager = DatabaseSessionManager(f"sqlite:///{db_path}")
|
||||
Base.metadata.create_all(bind=manager.engine)
|
||||
|
||||
task_uuids = ["task-backoff-1", "task-backoff-2"]
|
||||
with manager.session_scope() as session:
|
||||
for task_uuid in task_uuids:
|
||||
session.add(RegistrationTask(task_uuid=task_uuid, status="pending"))
|
||||
session.add(
|
||||
EmailService(
|
||||
service_type="duck_mail",
|
||||
name="duck-primary",
|
||||
config={
|
||||
"base_url": "https://mail-1.example.test",
|
||||
"default_domain": "mail.example.test",
|
||||
},
|
||||
enabled=True,
|
||||
priority=0,
|
||||
)
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
def fake_get_db():
|
||||
session = manager.SessionLocal()
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
class DummySettings:
|
||||
pass
|
||||
|
||||
start_lock = threading.Lock()
|
||||
started = {"count": 0}
|
||||
peer_started = threading.Event()
|
||||
|
||||
class FakeRegistrationEngine:
|
||||
def __init__(self, email_service, proxy_url=None, callback_logger=None, task_uuid=None):
|
||||
self.email_service = email_service
|
||||
self.phase_history = []
|
||||
|
||||
def run(self):
|
||||
with start_lock:
|
||||
started["count"] += 1
|
||||
if started["count"] == len(task_uuids):
|
||||
peer_started.set()
|
||||
peer_started.wait(timeout=0.1)
|
||||
|
||||
current_state = self.email_service.provider_backoff_state
|
||||
next_failures = current_state.failures + 1
|
||||
delay_seconds = 30 if next_failures == 1 else 60
|
||||
self.phase_history = [
|
||||
PhaseResult(
|
||||
phase="email_prepare",
|
||||
success=False,
|
||||
error_message="创建邮箱失败",
|
||||
error_code="EMAIL_PROVIDER_RATE_LIMITED",
|
||||
retryable=True,
|
||||
next_action="switch_provider",
|
||||
provider_backoff=EmailProviderBackoffState(
|
||||
failures=next_failures,
|
||||
delay_seconds=delay_seconds,
|
||||
opened_until=1000.0 + delay_seconds,
|
||||
last_error="请求失败: 429",
|
||||
),
|
||||
)
|
||||
]
|
||||
return RegistrationResult(
|
||||
success=False,
|
||||
error_message="创建邮箱失败: 请求失败: 429",
|
||||
logs=[],
|
||||
)
|
||||
|
||||
def save_to_database(self, result):
|
||||
return True
|
||||
|
||||
def close(self):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(registration_routes, "get_db", fake_get_db)
|
||||
monkeypatch.setattr(registration_routes, "get_settings", lambda: DummySettings())
|
||||
monkeypatch.setattr(registration_routes, "task_manager", DummyTaskManager())
|
||||
monkeypatch.setattr(registration_routes, "RegistrationEngine", FakeRegistrationEngine)
|
||||
monkeypatch.setattr(
|
||||
registration_routes.EmailServiceFactory,
|
||||
"create",
|
||||
lambda service_type, config, name=None: BackoffAwareEmailService(
|
||||
service_type=service_type,
|
||||
config=config,
|
||||
name=name,
|
||||
),
|
||||
)
|
||||
registration_routes.email_service_circuit_breakers.clear()
|
||||
|
||||
with manager.session_scope() as session:
|
||||
service_id = session.query(EmailService.id).filter(EmailService.name == "duck-primary").scalar()
|
||||
|
||||
threads = [
|
||||
threading.Thread(
|
||||
target=registration_routes._run_sync_registration_task,
|
||||
kwargs={
|
||||
"task_uuid": task_uuid,
|
||||
"email_service_type": EmailServiceType.DUCK_MAIL.value,
|
||||
"proxy": None,
|
||||
"email_service_config": None,
|
||||
},
|
||||
)
|
||||
for task_uuid in task_uuids
|
||||
]
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
state = registration_routes.email_service_circuit_breakers[service_id]
|
||||
assert state.failures == 2
|
||||
assert state.delay_seconds == 60
|
||||
|
||||
@@ -69,3 +69,36 @@ def test_phase_otp_secondary_returns_dedicated_timeout_error_code(monkeypatch):
|
||||
assert phase_result.success is False
|
||||
assert phase_result.error_code == ERROR_OTP_TIMEOUT_SECONDARY
|
||||
assert engine.phase_history[0].error_code == ERROR_OTP_TIMEOUT_SECONDARY
|
||||
|
||||
|
||||
def test_advance_login_authorization_refreshes_otp_anchor_after_password_submit(monkeypatch):
|
||||
email_service = FakeEmailService(code=None)
|
||||
engine = _build_engine(monkeypatch, email_service)
|
||||
engine.oauth_start = object()
|
||||
engine._otp_sent_at = 10.0
|
||||
|
||||
monkeypatch.setattr(register_module.time, "time", lambda: 456.0)
|
||||
monkeypatch.setattr(engine, "_init_session", lambda: True)
|
||||
monkeypatch.setattr(engine, "_start_oauth", lambda: True)
|
||||
monkeypatch.setattr(engine, "_get_device_id", lambda: True)
|
||||
monkeypatch.setattr(engine, "_try_reenter_login_flow", lambda: True)
|
||||
monkeypatch.setattr(
|
||||
engine,
|
||||
"_submit_login_password_step_and_get_continue_url",
|
||||
lambda: (True, "https://continue.example.test"),
|
||||
)
|
||||
|
||||
seen_anchors = []
|
||||
|
||||
def fake_get_verification_code():
|
||||
seen_anchors.append(engine._otp_sent_at)
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(engine, "_get_verification_code", fake_get_verification_code)
|
||||
|
||||
workspace_id, callback_url = engine._advance_login_authorization()
|
||||
|
||||
assert workspace_id is None
|
||||
assert callback_url is None
|
||||
assert engine._otp_sent_at == 456.0
|
||||
assert seen_anchors == [456.0]
|
||||
|
||||
Reference in New Issue
Block a user