mirror of
https://github.com/cnlimiter/codex-register.git
synced 2026-06-29 19:21:34 +08:00
fix(logic): isolate atomic batch counters and token sync fields
This commit is contained in:
@@ -7,7 +7,7 @@ import logging
|
||||
import uuid
|
||||
import random
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Dict, Tuple
|
||||
from typing import List, Optional, Dict, Tuple, Any
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query, BackgroundTasks
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -599,34 +599,53 @@ async def run_registration_task(task_uuid: str, email_service_type: str, proxy:
|
||||
def _init_batch_state(batch_id: str, task_uuids: List[str]):
|
||||
"""初始化批量任务内存状态"""
|
||||
task_manager.init_batch(batch_id, len(task_uuids))
|
||||
batch_tasks[batch_id] = {
|
||||
"total": len(task_uuids),
|
||||
"completed": 0,
|
||||
"success": 0,
|
||||
"failed": 0,
|
||||
"cancelled": False,
|
||||
"task_uuids": task_uuids,
|
||||
"current_index": 0,
|
||||
"logs": [],
|
||||
"finished": False
|
||||
}
|
||||
metadata = batch_tasks.get(batch_id, {}).copy()
|
||||
metadata["task_uuids"] = task_uuids
|
||||
batch_tasks[batch_id] = metadata
|
||||
|
||||
|
||||
def _make_batch_helpers(batch_id: str):
|
||||
"""返回 add_batch_log 和 update_batch_status 辅助函数"""
|
||||
def add_batch_log(msg: str):
|
||||
batch_tasks[batch_id]["logs"].append(msg)
|
||||
task_manager.add_batch_log(batch_id, msg)
|
||||
|
||||
def update_batch_status(**kwargs):
|
||||
for key, value in kwargs.items():
|
||||
if key in batch_tasks[batch_id]:
|
||||
batch_tasks[batch_id][key] = value
|
||||
task_manager.update_batch_status(batch_id, **kwargs)
|
||||
|
||||
return add_batch_log, update_batch_status
|
||||
|
||||
|
||||
def _get_batch_snapshot(batch_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""聚合批量任务元数据与实时状态。"""
|
||||
metadata = batch_tasks.get(batch_id)
|
||||
if metadata is None:
|
||||
return None
|
||||
|
||||
status = task_manager.get_batch_status(batch_id) or {}
|
||||
initial_skipped = int(metadata.get("initial_skipped", 0) or 0)
|
||||
total = status.get("total", metadata.get("total", 0))
|
||||
completed = status.get("completed", 0)
|
||||
success = status.get("success", 0)
|
||||
failed = status.get("failed", 0)
|
||||
skipped = status.get("skipped", 0) + initial_skipped
|
||||
|
||||
return {
|
||||
"batch_id": batch_id,
|
||||
"total": total,
|
||||
"completed": completed,
|
||||
"success": success,
|
||||
"failed": failed,
|
||||
"skipped": skipped,
|
||||
"current_index": status.get("current_index", 0),
|
||||
"cancelled": status.get("cancelled", False),
|
||||
"finished": status.get("finished", False),
|
||||
"status": status.get("status", metadata.get("status", "pending")),
|
||||
"logs": task_manager.get_batch_logs(batch_id),
|
||||
"task_uuids": metadata.get("task_uuids", []),
|
||||
"service_ids": metadata.get("service_ids", []),
|
||||
}
|
||||
|
||||
|
||||
async def run_batch_parallel(
|
||||
batch_id: str,
|
||||
task_uuids: List[str],
|
||||
@@ -665,21 +684,19 @@ async def run_batch_parallel(
|
||||
t = crud.get_registration_task(db, uuid)
|
||||
if t:
|
||||
async with counter_lock:
|
||||
new_completed = batch_tasks[batch_id]["completed"] + 1
|
||||
new_success = batch_tasks[batch_id]["success"]
|
||||
new_failed = batch_tasks[batch_id]["failed"]
|
||||
if t.status == "completed":
|
||||
new_success += 1
|
||||
add_batch_log(f"{prefix} [成功] 注册成功")
|
||||
elif t.status == "failed":
|
||||
new_failed += 1
|
||||
add_batch_log(f"{prefix} [失败] 注册失败: {t.error_message}")
|
||||
update_batch_status(completed=new_completed, success=new_success, failed=new_failed)
|
||||
task_manager.record_batch_task_result(batch_id, t.status)
|
||||
|
||||
try:
|
||||
await asyncio.gather(*[_run_one(i, u) for i, u in enumerate(task_uuids)], return_exceptions=True)
|
||||
if not task_manager.is_batch_cancelled(batch_id):
|
||||
add_batch_log(f"[完成] 批量任务完成!成功: {batch_tasks[batch_id]['success']}, 失败: {batch_tasks[batch_id]['failed']}")
|
||||
snapshot = task_manager.get_batch_status(batch_id) or {}
|
||||
add_batch_log(
|
||||
f"[完成] 批量任务完成!成功: {snapshot.get('success', 0)}, 失败: {snapshot.get('failed', 0)}"
|
||||
)
|
||||
update_batch_status(finished=True, status="completed")
|
||||
else:
|
||||
update_batch_status(finished=True, status="cancelled")
|
||||
@@ -687,8 +704,6 @@ async def run_batch_parallel(
|
||||
logger.error(f"批量任务 {batch_id} 异常: {e}")
|
||||
add_batch_log(f"[错误] 批量任务异常: {str(e)}")
|
||||
update_batch_status(finished=True, status="failed")
|
||||
finally:
|
||||
batch_tasks[batch_id]["finished"] = True
|
||||
|
||||
|
||||
async def run_batch_pipeline(
|
||||
@@ -731,22 +746,17 @@ async def run_batch_pipeline(
|
||||
t = crud.get_registration_task(db, uuid)
|
||||
if t:
|
||||
async with counter_lock:
|
||||
new_completed = batch_tasks[batch_id]["completed"] + 1
|
||||
new_success = batch_tasks[batch_id]["success"]
|
||||
new_failed = batch_tasks[batch_id]["failed"]
|
||||
if t.status == "completed":
|
||||
new_success += 1
|
||||
add_batch_log(f"{pfx} [成功] 注册成功")
|
||||
elif t.status == "failed":
|
||||
new_failed += 1
|
||||
add_batch_log(f"{pfx} [失败] 注册失败: {t.error_message}")
|
||||
update_batch_status(completed=new_completed, success=new_success, failed=new_failed)
|
||||
task_manager.record_batch_task_result(batch_id, t.status)
|
||||
finally:
|
||||
semaphore.release()
|
||||
|
||||
try:
|
||||
for i, task_uuid in enumerate(task_uuids):
|
||||
if task_manager.is_batch_cancelled(batch_id) or batch_tasks[batch_id]["cancelled"]:
|
||||
if task_manager.is_batch_cancelled(batch_id):
|
||||
with get_db() as db:
|
||||
for remaining_uuid in task_uuids[i:]:
|
||||
crud.update_registration_task(db, remaining_uuid, status="cancelled")
|
||||
@@ -770,14 +780,15 @@ async def run_batch_pipeline(
|
||||
await asyncio.gather(*running_tasks_list, return_exceptions=True)
|
||||
|
||||
if not task_manager.is_batch_cancelled(batch_id):
|
||||
add_batch_log(f"[完成] 批量任务完成!成功: {batch_tasks[batch_id]['success']}, 失败: {batch_tasks[batch_id]['failed']}")
|
||||
snapshot = task_manager.get_batch_status(batch_id) or {}
|
||||
add_batch_log(
|
||||
f"[完成] 批量任务完成!成功: {snapshot.get('success', 0)}, 失败: {snapshot.get('failed', 0)}"
|
||||
)
|
||||
update_batch_status(finished=True, status="completed")
|
||||
except Exception as e:
|
||||
logger.error(f"批量任务 {batch_id} 异常: {e}")
|
||||
add_batch_log(f"[错误] 批量任务异常: {str(e)}")
|
||||
update_batch_status(finished=True, status="failed")
|
||||
finally:
|
||||
batch_tasks[batch_id]["finished"] = True
|
||||
|
||||
|
||||
async def run_batch_registration(
|
||||
@@ -910,6 +921,7 @@ async def start_batch_registration(
|
||||
# 创建批量任务
|
||||
batch_id = str(uuid.uuid4())
|
||||
task_uuids = []
|
||||
batch_tasks[batch_id] = {"total": request.count}
|
||||
|
||||
with get_db() as db:
|
||||
for _ in range(request.count):
|
||||
@@ -956,34 +968,33 @@ async def start_batch_registration(
|
||||
@router.get("/batch/{batch_id}")
|
||||
async def get_batch_status(batch_id: str):
|
||||
"""获取批量任务状态"""
|
||||
if batch_id not in batch_tasks:
|
||||
snapshot = _get_batch_snapshot(batch_id)
|
||||
if snapshot is None:
|
||||
raise HTTPException(status_code=404, detail="批量任务不存在")
|
||||
|
||||
batch = batch_tasks[batch_id]
|
||||
return {
|
||||
"batch_id": batch_id,
|
||||
"total": batch["total"],
|
||||
"completed": batch["completed"],
|
||||
"success": batch["success"],
|
||||
"failed": batch["failed"],
|
||||
"current_index": batch["current_index"],
|
||||
"cancelled": batch["cancelled"],
|
||||
"finished": batch.get("finished", False),
|
||||
"progress": f"{batch['completed']}/{batch['total']}"
|
||||
"total": snapshot["total"],
|
||||
"completed": snapshot["completed"],
|
||||
"success": snapshot["success"],
|
||||
"failed": snapshot["failed"],
|
||||
"current_index": snapshot["current_index"],
|
||||
"cancelled": snapshot["cancelled"],
|
||||
"finished": snapshot["finished"],
|
||||
"progress": f"{snapshot['completed']}/{snapshot['total']}"
|
||||
}
|
||||
|
||||
|
||||
@router.post("/batch/{batch_id}/cancel")
|
||||
async def cancel_batch(batch_id: str):
|
||||
"""取消批量任务"""
|
||||
if batch_id not in batch_tasks:
|
||||
snapshot = _get_batch_snapshot(batch_id)
|
||||
if snapshot is None:
|
||||
raise HTTPException(status_code=404, detail="批量任务不存在")
|
||||
|
||||
batch = batch_tasks[batch_id]
|
||||
if batch.get("finished"):
|
||||
if snapshot.get("finished"):
|
||||
raise HTTPException(status_code=400, detail="批量任务已完成")
|
||||
|
||||
batch["cancelled"] = True
|
||||
task_manager.cancel_batch(batch_id)
|
||||
return {"success": True, "message": "批量任务取消请求已提交"}
|
||||
|
||||
@@ -1464,18 +1475,11 @@ async def start_outlook_batch_registration(
|
||||
# 创建批量任务
|
||||
batch_id = str(uuid.uuid4())
|
||||
|
||||
# 初始化批量任务状态
|
||||
# 记录额外元数据,由 task_manager 维护实时状态
|
||||
batch_tasks[batch_id] = {
|
||||
"total": len(actual_service_ids),
|
||||
"completed": 0,
|
||||
"success": 0,
|
||||
"failed": 0,
|
||||
"skipped": 0,
|
||||
"cancelled": False,
|
||||
"initial_skipped": skipped_count,
|
||||
"service_ids": actual_service_ids,
|
||||
"current_index": 0,
|
||||
"logs": [],
|
||||
"finished": False
|
||||
}
|
||||
|
||||
# 在后台运行批量注册
|
||||
@@ -1509,37 +1513,35 @@ async def start_outlook_batch_registration(
|
||||
@router.get("/outlook-batch/{batch_id}")
|
||||
async def get_outlook_batch_status(batch_id: str):
|
||||
"""获取 Outlook 批量任务状态"""
|
||||
if batch_id not in batch_tasks:
|
||||
snapshot = _get_batch_snapshot(batch_id)
|
||||
if snapshot is None:
|
||||
raise HTTPException(status_code=404, detail="批量任务不存在")
|
||||
|
||||
batch = batch_tasks[batch_id]
|
||||
return {
|
||||
"batch_id": batch_id,
|
||||
"total": batch["total"],
|
||||
"completed": batch["completed"],
|
||||
"success": batch["success"],
|
||||
"failed": batch["failed"],
|
||||
"skipped": batch.get("skipped", 0),
|
||||
"current_index": batch["current_index"],
|
||||
"cancelled": batch["cancelled"],
|
||||
"finished": batch.get("finished", False),
|
||||
"logs": batch.get("logs", []),
|
||||
"progress": f"{batch['completed']}/{batch['total']}"
|
||||
"total": snapshot["total"],
|
||||
"completed": snapshot["completed"],
|
||||
"success": snapshot["success"],
|
||||
"failed": snapshot["failed"],
|
||||
"skipped": snapshot["skipped"],
|
||||
"current_index": snapshot["current_index"],
|
||||
"cancelled": snapshot["cancelled"],
|
||||
"finished": snapshot["finished"],
|
||||
"logs": snapshot["logs"],
|
||||
"progress": f"{snapshot['completed']}/{snapshot['total']}"
|
||||
}
|
||||
|
||||
|
||||
@router.post("/outlook-batch/{batch_id}/cancel")
|
||||
async def cancel_outlook_batch(batch_id: str):
|
||||
"""取消 Outlook 批量任务"""
|
||||
if batch_id not in batch_tasks:
|
||||
snapshot = _get_batch_snapshot(batch_id)
|
||||
if snapshot is None:
|
||||
raise HTTPException(status_code=404, detail="批量任务不存在")
|
||||
|
||||
batch = batch_tasks[batch_id]
|
||||
if batch.get("finished"):
|
||||
if snapshot.get("finished"):
|
||||
raise HTTPException(status_code=400, detail="批量任务已完成")
|
||||
|
||||
# 同时更新两个系统的取消状态
|
||||
batch["cancelled"] = True
|
||||
task_manager.cancel_batch(batch_id)
|
||||
|
||||
return {"success": True, "message": "批量任务取消请求已提交"}
|
||||
|
||||
@@ -223,16 +223,18 @@ class TaskManager:
|
||||
|
||||
def init_batch(self, batch_id: str, total: int):
|
||||
"""初始化批量任务"""
|
||||
_batch_status[batch_id] = {
|
||||
"status": "running",
|
||||
"total": total,
|
||||
"completed": 0,
|
||||
"success": 0,
|
||||
"failed": 0,
|
||||
"skipped": 0,
|
||||
"current_index": 0,
|
||||
"finished": False
|
||||
}
|
||||
with _get_batch_lock(batch_id):
|
||||
_batch_status[batch_id] = {
|
||||
"status": "running",
|
||||
"total": total,
|
||||
"completed": 0,
|
||||
"success": 0,
|
||||
"failed": 0,
|
||||
"skipped": 0,
|
||||
"current_index": 0,
|
||||
"finished": False,
|
||||
"cancelled": False,
|
||||
}
|
||||
logger.info(f"批量任务 {batch_id} 已初始化,总数: {total}")
|
||||
|
||||
def add_batch_log(self, batch_id: str, log_message: str):
|
||||
@@ -276,11 +278,11 @@ class TaskManager:
|
||||
|
||||
def update_batch_status(self, batch_id: str, **kwargs):
|
||||
"""更新批量任务状态"""
|
||||
if batch_id not in _batch_status:
|
||||
logger.warning(f"批量任务 {batch_id} 不存在")
|
||||
return
|
||||
|
||||
_batch_status[batch_id].update(kwargs)
|
||||
with _get_batch_lock(batch_id):
|
||||
if batch_id not in _batch_status:
|
||||
logger.warning(f"批量任务 {batch_id} 不存在")
|
||||
return
|
||||
_batch_status[batch_id].update(kwargs)
|
||||
|
||||
# 异步广播状态更新
|
||||
if self._loop and self._loop.is_running():
|
||||
@@ -292,6 +294,35 @@ class TaskManager:
|
||||
except Exception as e:
|
||||
logger.warning(f"广播批量状态失败: {e}")
|
||||
|
||||
def record_batch_task_result(self, batch_id: str, task_status: str) -> Optional[dict]:
|
||||
"""原子记录单个子任务的终态并返回快照。"""
|
||||
with _get_batch_lock(batch_id):
|
||||
status = _batch_status.get(batch_id)
|
||||
if status is None:
|
||||
logger.warning(f"批量任务 {batch_id} 不存在")
|
||||
return None
|
||||
|
||||
status["completed"] += 1
|
||||
if task_status == "completed":
|
||||
status["success"] += 1
|
||||
elif task_status == "failed":
|
||||
status["failed"] += 1
|
||||
elif task_status == "cancelled":
|
||||
status["skipped"] += 1
|
||||
|
||||
snapshot = status.copy()
|
||||
|
||||
if self._loop and self._loop.is_running():
|
||||
try:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._broadcast_batch_status(batch_id),
|
||||
self._loop
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"广播批量状态失败: {e}")
|
||||
|
||||
return snapshot
|
||||
|
||||
async def _broadcast_batch_status(self, batch_id: str):
|
||||
"""广播批量任务状态"""
|
||||
with _ws_lock:
|
||||
@@ -312,7 +343,9 @@ class TaskManager:
|
||||
|
||||
def get_batch_status(self, batch_id: str) -> Optional[dict]:
|
||||
"""获取批量任务状态"""
|
||||
return _batch_status.get(batch_id)
|
||||
with _get_batch_lock(batch_id):
|
||||
status = _batch_status.get(batch_id)
|
||||
return status.copy() if status else None
|
||||
|
||||
def get_batch_logs(self, batch_id: str) -> List[str]:
|
||||
"""获取批量任务日志"""
|
||||
@@ -321,15 +354,28 @@ class TaskManager:
|
||||
|
||||
def is_batch_cancelled(self, batch_id: str) -> bool:
|
||||
"""检查批量任务是否已取消"""
|
||||
status = _batch_status.get(batch_id, {})
|
||||
return status.get("cancelled", False)
|
||||
with _get_batch_lock(batch_id):
|
||||
status = _batch_status.get(batch_id, {})
|
||||
return status.get("cancelled", False)
|
||||
|
||||
def cancel_batch(self, batch_id: str):
|
||||
"""取消批量任务"""
|
||||
if batch_id in _batch_status:
|
||||
_batch_status[batch_id]["cancelled"] = True
|
||||
_batch_status[batch_id]["status"] = "cancelling"
|
||||
logger.info(f"批量任务 {batch_id} 已标记为取消")
|
||||
changed = False
|
||||
with _get_batch_lock(batch_id):
|
||||
if batch_id in _batch_status:
|
||||
_batch_status[batch_id]["cancelled"] = True
|
||||
_batch_status[batch_id]["status"] = "cancelling"
|
||||
changed = True
|
||||
logger.info(f"批量任务 {batch_id} 已标记为取消")
|
||||
|
||||
if changed and self._loop and self._loop.is_running():
|
||||
try:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._broadcast_batch_status(batch_id),
|
||||
self._loop
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"广播批量状态失败: {e}")
|
||||
|
||||
def register_batch_websocket(self, batch_id: str, websocket):
|
||||
"""注册批量任务 WebSocket 连接"""
|
||||
|
||||
Reference in New Issue
Block a user