mirror of
https://github.com/cnlimiter/codex-register.git
synced 2026-05-07 06:22:56 +08:00
144 lines
4.9 KiB
Python
144 lines
4.9 KiB
Python
from contextlib import contextmanager
|
|
import asyncio
|
|
|
|
from fastapi import WebSocketDisconnect
|
|
|
|
from src.database import crud
|
|
from src.database.models import Base, RegistrationTask
|
|
from src.database.session import DatabaseSessionManager
|
|
from src.web.routes import websocket as websocket_routes
|
|
from src.web.task_manager import TaskManager
|
|
|
|
|
|
def test_fail_incomplete_registration_tasks_marks_pending_and_running_failed(tmp_path):
|
|
db_path = tmp_path / "recovery.db"
|
|
manager = DatabaseSessionManager(f"sqlite:///{db_path}")
|
|
Base.metadata.create_all(bind=manager.engine)
|
|
|
|
with manager.session_scope() as session:
|
|
session.add_all([
|
|
RegistrationTask(task_uuid="task-pending", status="pending"),
|
|
RegistrationTask(task_uuid="task-running", status="running", logs="[01:00:00] still running"),
|
|
RegistrationTask(task_uuid="task-done", status="completed"),
|
|
])
|
|
|
|
with manager.session_scope() as session:
|
|
cleaned = crud.fail_incomplete_registration_tasks(
|
|
session,
|
|
"服务启动时检测到未完成的历史任务,已标记失败,请重新发起。"
|
|
)
|
|
|
|
assert cleaned == ["task-pending", "task-running"]
|
|
|
|
with manager.session_scope() as session:
|
|
pending_task = crud.get_registration_task_by_uuid(session, "task-pending")
|
|
running_task = crud.get_registration_task_by_uuid(session, "task-running")
|
|
done_task = crud.get_registration_task_by_uuid(session, "task-done")
|
|
|
|
assert pending_task.status == "failed"
|
|
assert running_task.status == "failed"
|
|
assert pending_task.error_message == "服务启动时检测到未完成的历史任务,已标记失败,请重新发起。"
|
|
assert running_task.completed_at is not None
|
|
assert "[系统] 服务启动时检测到未完成的历史任务,已标记失败,请重新发起。" in running_task.logs
|
|
assert done_task.status == "completed"
|
|
|
|
|
|
def test_restore_task_snapshot_loads_status_and_logs_from_database(monkeypatch, tmp_path):
|
|
db_path = tmp_path / "websocket.db"
|
|
manager = DatabaseSessionManager(f"sqlite:///{db_path}")
|
|
Base.metadata.create_all(bind=manager.engine)
|
|
|
|
with manager.session_scope() as session:
|
|
session.add(
|
|
RegistrationTask(
|
|
task_uuid="task-websocket",
|
|
status="failed",
|
|
logs="[01:00:00] step 1\n[01:00:01] step 2",
|
|
result={"email": "tester@example.com"},
|
|
error_message="boom"
|
|
)
|
|
)
|
|
|
|
@contextmanager
|
|
def fake_get_db():
|
|
session = manager.SessionLocal()
|
|
try:
|
|
yield session
|
|
finally:
|
|
session.close()
|
|
|
|
monkeypatch.setattr(websocket_routes, "get_db", fake_get_db)
|
|
|
|
status, logs = websocket_routes._restore_task_snapshot("task-websocket")
|
|
|
|
assert status == {
|
|
"status": "failed",
|
|
"email": "tester@example.com",
|
|
"error": "boom",
|
|
}
|
|
assert logs == ["[01:00:00] step 1", "[01:00:01] step 2"]
|
|
|
|
|
|
def test_sync_task_state_prefers_longer_persisted_log_history():
|
|
manager = TaskManager()
|
|
task_uuid = "task-sync"
|
|
|
|
manager.sync_task_state(task_uuid, status={"status": "running"}, logs=["a", "b"])
|
|
manager.sync_task_state(task_uuid, logs=["a"])
|
|
|
|
assert manager.get_status(task_uuid) == {"status": "running"}
|
|
assert manager.get_logs(task_uuid) == ["a", "b"]
|
|
|
|
|
|
def test_register_websocket_returns_snapshot_and_keeps_live_cursor():
|
|
manager = TaskManager()
|
|
task_uuid = "task-live"
|
|
websocket = object()
|
|
|
|
manager.sync_task_state(task_uuid, status={"status": "running"}, logs=["log-1", "log-2"])
|
|
|
|
history_logs = manager.register_websocket(task_uuid, websocket)
|
|
|
|
assert history_logs == ["log-1", "log-2"]
|
|
assert manager.get_unsent_logs(task_uuid, websocket) == []
|
|
|
|
manager.add_log(task_uuid, "log-3")
|
|
|
|
assert manager.get_unsent_logs(task_uuid, websocket) == ["log-3"]
|
|
|
|
|
|
class _FakeWebSocket:
|
|
def __init__(self):
|
|
self.messages = []
|
|
self.accepted = False
|
|
|
|
async def accept(self):
|
|
self.accepted = True
|
|
|
|
async def send_json(self, payload):
|
|
self.messages.append(payload)
|
|
|
|
async def receive_json(self):
|
|
raise WebSocketDisconnect()
|
|
|
|
|
|
def test_batch_websocket_replays_history_logs_from_registration_snapshot(monkeypatch):
|
|
manager = TaskManager()
|
|
batch_id = "batch-history"
|
|
websocket = _FakeWebSocket()
|
|
|
|
manager.init_batch(batch_id, total=2)
|
|
manager.add_batch_log(batch_id, "[01:00:00] first")
|
|
manager.add_batch_log(batch_id, "[01:00:01] second")
|
|
|
|
monkeypatch.setattr(websocket_routes, "task_manager", manager)
|
|
|
|
asyncio.run(websocket_routes.batch_websocket(websocket, batch_id))
|
|
|
|
assert websocket.accepted is True
|
|
assert websocket.messages[0]["type"] == "status"
|
|
assert [msg["message"] for msg in websocket.messages[1:]] == [
|
|
"[01:00:00] first",
|
|
"[01:00:01] second",
|
|
]
|