Files
codex-register/tests/test_task_recovery.py

144 lines
4.9 KiB
Python

from contextlib import contextmanager
import asyncio
from fastapi import WebSocketDisconnect
from src.database import crud
from src.database.models import Base, RegistrationTask
from src.database.session import DatabaseSessionManager
from src.web.routes import websocket as websocket_routes
from src.web.task_manager import TaskManager
def test_fail_incomplete_registration_tasks_marks_pending_and_running_failed(tmp_path):
db_path = tmp_path / "recovery.db"
manager = DatabaseSessionManager(f"sqlite:///{db_path}")
Base.metadata.create_all(bind=manager.engine)
with manager.session_scope() as session:
session.add_all([
RegistrationTask(task_uuid="task-pending", status="pending"),
RegistrationTask(task_uuid="task-running", status="running", logs="[01:00:00] still running"),
RegistrationTask(task_uuid="task-done", status="completed"),
])
with manager.session_scope() as session:
cleaned = crud.fail_incomplete_registration_tasks(
session,
"服务启动时检测到未完成的历史任务,已标记失败,请重新发起。"
)
assert cleaned == ["task-pending", "task-running"]
with manager.session_scope() as session:
pending_task = crud.get_registration_task_by_uuid(session, "task-pending")
running_task = crud.get_registration_task_by_uuid(session, "task-running")
done_task = crud.get_registration_task_by_uuid(session, "task-done")
assert pending_task.status == "failed"
assert running_task.status == "failed"
assert pending_task.error_message == "服务启动时检测到未完成的历史任务,已标记失败,请重新发起。"
assert running_task.completed_at is not None
assert "[系统] 服务启动时检测到未完成的历史任务,已标记失败,请重新发起。" in running_task.logs
assert done_task.status == "completed"
def test_restore_task_snapshot_loads_status_and_logs_from_database(monkeypatch, tmp_path):
db_path = tmp_path / "websocket.db"
manager = DatabaseSessionManager(f"sqlite:///{db_path}")
Base.metadata.create_all(bind=manager.engine)
with manager.session_scope() as session:
session.add(
RegistrationTask(
task_uuid="task-websocket",
status="failed",
logs="[01:00:00] step 1\n[01:00:01] step 2",
result={"email": "tester@example.com"},
error_message="boom"
)
)
@contextmanager
def fake_get_db():
session = manager.SessionLocal()
try:
yield session
finally:
session.close()
monkeypatch.setattr(websocket_routes, "get_db", fake_get_db)
status, logs = websocket_routes._restore_task_snapshot("task-websocket")
assert status == {
"status": "failed",
"email": "tester@example.com",
"error": "boom",
}
assert logs == ["[01:00:00] step 1", "[01:00:01] step 2"]
def test_sync_task_state_prefers_longer_persisted_log_history():
manager = TaskManager()
task_uuid = "task-sync"
manager.sync_task_state(task_uuid, status={"status": "running"}, logs=["a", "b"])
manager.sync_task_state(task_uuid, logs=["a"])
assert manager.get_status(task_uuid) == {"status": "running"}
assert manager.get_logs(task_uuid) == ["a", "b"]
def test_register_websocket_returns_snapshot_and_keeps_live_cursor():
manager = TaskManager()
task_uuid = "task-live"
websocket = object()
manager.sync_task_state(task_uuid, status={"status": "running"}, logs=["log-1", "log-2"])
history_logs = manager.register_websocket(task_uuid, websocket)
assert history_logs == ["log-1", "log-2"]
assert manager.get_unsent_logs(task_uuid, websocket) == []
manager.add_log(task_uuid, "log-3")
assert manager.get_unsent_logs(task_uuid, websocket) == ["log-3"]
class _FakeWebSocket:
def __init__(self):
self.messages = []
self.accepted = False
async def accept(self):
self.accepted = True
async def send_json(self, payload):
self.messages.append(payload)
async def receive_json(self):
raise WebSocketDisconnect()
def test_batch_websocket_replays_history_logs_from_registration_snapshot(monkeypatch):
manager = TaskManager()
batch_id = "batch-history"
websocket = _FakeWebSocket()
manager.init_batch(batch_id, total=2)
manager.add_batch_log(batch_id, "[01:00:00] first")
manager.add_batch_log(batch_id, "[01:00:01] second")
monkeypatch.setattr(websocket_routes, "task_manager", manager)
asyncio.run(websocket_routes.batch_websocket(websocket, batch_id))
assert websocket.accepted is True
assert websocket.messages[0]["type"] == "status"
assert [msg["message"] for msg in websocket.messages[1:]] == [
"[01:00:00] first",
"[01:00:01] second",
]