fix(logic): isolate atomic batch counters and token sync fields

This commit is contained in:
Mison
2026-03-23 11:23:31 +08:00
parent 16154bb5ae
commit cf571d37c1
7 changed files with 294 additions and 104 deletions

View File

@@ -7,7 +7,7 @@ import logging
import uuid
import random
from datetime import datetime
from typing import List, Optional, Dict, Tuple
from typing import List, Optional, Dict, Tuple, Any
from fastapi import APIRouter, HTTPException, Query, BackgroundTasks
from pydantic import BaseModel, Field
@@ -599,34 +599,53 @@ async def run_registration_task(task_uuid: str, email_service_type: str, proxy:
def _init_batch_state(batch_id: str, task_uuids: List[str]):
"""初始化批量任务内存状态"""
task_manager.init_batch(batch_id, len(task_uuids))
batch_tasks[batch_id] = {
"total": len(task_uuids),
"completed": 0,
"success": 0,
"failed": 0,
"cancelled": False,
"task_uuids": task_uuids,
"current_index": 0,
"logs": [],
"finished": False
}
metadata = batch_tasks.get(batch_id, {}).copy()
metadata["task_uuids"] = task_uuids
batch_tasks[batch_id] = metadata
def _make_batch_helpers(batch_id: str):
"""返回 add_batch_log 和 update_batch_status 辅助函数"""
def add_batch_log(msg: str):
batch_tasks[batch_id]["logs"].append(msg)
task_manager.add_batch_log(batch_id, msg)
def update_batch_status(**kwargs):
for key, value in kwargs.items():
if key in batch_tasks[batch_id]:
batch_tasks[batch_id][key] = value
task_manager.update_batch_status(batch_id, **kwargs)
return add_batch_log, update_batch_status
def _get_batch_snapshot(batch_id: str) -> Optional[Dict[str, Any]]:
"""聚合批量任务元数据与实时状态。"""
metadata = batch_tasks.get(batch_id)
if metadata is None:
return None
status = task_manager.get_batch_status(batch_id) or {}
initial_skipped = int(metadata.get("initial_skipped", 0) or 0)
total = status.get("total", metadata.get("total", 0))
completed = status.get("completed", 0)
success = status.get("success", 0)
failed = status.get("failed", 0)
skipped = status.get("skipped", 0) + initial_skipped
return {
"batch_id": batch_id,
"total": total,
"completed": completed,
"success": success,
"failed": failed,
"skipped": skipped,
"current_index": status.get("current_index", 0),
"cancelled": status.get("cancelled", False),
"finished": status.get("finished", False),
"status": status.get("status", metadata.get("status", "pending")),
"logs": task_manager.get_batch_logs(batch_id),
"task_uuids": metadata.get("task_uuids", []),
"service_ids": metadata.get("service_ids", []),
}
async def run_batch_parallel(
batch_id: str,
task_uuids: List[str],
@@ -665,21 +684,19 @@ async def run_batch_parallel(
t = crud.get_registration_task(db, uuid)
if t:
async with counter_lock:
new_completed = batch_tasks[batch_id]["completed"] + 1
new_success = batch_tasks[batch_id]["success"]
new_failed = batch_tasks[batch_id]["failed"]
if t.status == "completed":
new_success += 1
add_batch_log(f"{prefix} [成功] 注册成功")
elif t.status == "failed":
new_failed += 1
add_batch_log(f"{prefix} [失败] 注册失败: {t.error_message}")
update_batch_status(completed=new_completed, success=new_success, failed=new_failed)
task_manager.record_batch_task_result(batch_id, t.status)
try:
await asyncio.gather(*[_run_one(i, u) for i, u in enumerate(task_uuids)], return_exceptions=True)
if not task_manager.is_batch_cancelled(batch_id):
add_batch_log(f"[完成] 批量任务完成!成功: {batch_tasks[batch_id]['success']}, 失败: {batch_tasks[batch_id]['failed']}")
snapshot = task_manager.get_batch_status(batch_id) or {}
add_batch_log(
f"[完成] 批量任务完成!成功: {snapshot.get('success', 0)}, 失败: {snapshot.get('failed', 0)}"
)
update_batch_status(finished=True, status="completed")
else:
update_batch_status(finished=True, status="cancelled")
@@ -687,8 +704,6 @@ async def run_batch_parallel(
logger.error(f"批量任务 {batch_id} 异常: {e}")
add_batch_log(f"[错误] 批量任务异常: {str(e)}")
update_batch_status(finished=True, status="failed")
finally:
batch_tasks[batch_id]["finished"] = True
async def run_batch_pipeline(
@@ -731,22 +746,17 @@ async def run_batch_pipeline(
t = crud.get_registration_task(db, uuid)
if t:
async with counter_lock:
new_completed = batch_tasks[batch_id]["completed"] + 1
new_success = batch_tasks[batch_id]["success"]
new_failed = batch_tasks[batch_id]["failed"]
if t.status == "completed":
new_success += 1
add_batch_log(f"{pfx} [成功] 注册成功")
elif t.status == "failed":
new_failed += 1
add_batch_log(f"{pfx} [失败] 注册失败: {t.error_message}")
update_batch_status(completed=new_completed, success=new_success, failed=new_failed)
task_manager.record_batch_task_result(batch_id, t.status)
finally:
semaphore.release()
try:
for i, task_uuid in enumerate(task_uuids):
if task_manager.is_batch_cancelled(batch_id) or batch_tasks[batch_id]["cancelled"]:
if task_manager.is_batch_cancelled(batch_id):
with get_db() as db:
for remaining_uuid in task_uuids[i:]:
crud.update_registration_task(db, remaining_uuid, status="cancelled")
@@ -770,14 +780,15 @@ async def run_batch_pipeline(
await asyncio.gather(*running_tasks_list, return_exceptions=True)
if not task_manager.is_batch_cancelled(batch_id):
add_batch_log(f"[完成] 批量任务完成!成功: {batch_tasks[batch_id]['success']}, 失败: {batch_tasks[batch_id]['failed']}")
snapshot = task_manager.get_batch_status(batch_id) or {}
add_batch_log(
f"[完成] 批量任务完成!成功: {snapshot.get('success', 0)}, 失败: {snapshot.get('failed', 0)}"
)
update_batch_status(finished=True, status="completed")
except Exception as e:
logger.error(f"批量任务 {batch_id} 异常: {e}")
add_batch_log(f"[错误] 批量任务异常: {str(e)}")
update_batch_status(finished=True, status="failed")
finally:
batch_tasks[batch_id]["finished"] = True
async def run_batch_registration(
@@ -910,6 +921,7 @@ async def start_batch_registration(
# 创建批量任务
batch_id = str(uuid.uuid4())
task_uuids = []
batch_tasks[batch_id] = {"total": request.count}
with get_db() as db:
for _ in range(request.count):
@@ -956,34 +968,33 @@ async def start_batch_registration(
@router.get("/batch/{batch_id}")
async def get_batch_status(batch_id: str):
"""获取批量任务状态"""
if batch_id not in batch_tasks:
snapshot = _get_batch_snapshot(batch_id)
if snapshot is None:
raise HTTPException(status_code=404, detail="批量任务不存在")
batch = batch_tasks[batch_id]
return {
"batch_id": batch_id,
"total": batch["total"],
"completed": batch["completed"],
"success": batch["success"],
"failed": batch["failed"],
"current_index": batch["current_index"],
"cancelled": batch["cancelled"],
"finished": batch.get("finished", False),
"progress": f"{batch['completed']}/{batch['total']}"
"total": snapshot["total"],
"completed": snapshot["completed"],
"success": snapshot["success"],
"failed": snapshot["failed"],
"current_index": snapshot["current_index"],
"cancelled": snapshot["cancelled"],
"finished": snapshot["finished"],
"progress": f"{snapshot['completed']}/{snapshot['total']}"
}
@router.post("/batch/{batch_id}/cancel")
async def cancel_batch(batch_id: str):
"""取消批量任务"""
if batch_id not in batch_tasks:
snapshot = _get_batch_snapshot(batch_id)
if snapshot is None:
raise HTTPException(status_code=404, detail="批量任务不存在")
batch = batch_tasks[batch_id]
if batch.get("finished"):
if snapshot.get("finished"):
raise HTTPException(status_code=400, detail="批量任务已完成")
batch["cancelled"] = True
task_manager.cancel_batch(batch_id)
return {"success": True, "message": "批量任务取消请求已提交"}
@@ -1464,18 +1475,11 @@ async def start_outlook_batch_registration(
# 创建批量任务
batch_id = str(uuid.uuid4())
# 初始化批量任务状态
# 记录额外元数据,由 task_manager 维护实时状态
batch_tasks[batch_id] = {
"total": len(actual_service_ids),
"completed": 0,
"success": 0,
"failed": 0,
"skipped": 0,
"cancelled": False,
"initial_skipped": skipped_count,
"service_ids": actual_service_ids,
"current_index": 0,
"logs": [],
"finished": False
}
# 在后台运行批量注册
@@ -1509,37 +1513,35 @@ async def start_outlook_batch_registration(
@router.get("/outlook-batch/{batch_id}")
async def get_outlook_batch_status(batch_id: str):
"""获取 Outlook 批量任务状态"""
if batch_id not in batch_tasks:
snapshot = _get_batch_snapshot(batch_id)
if snapshot is None:
raise HTTPException(status_code=404, detail="批量任务不存在")
batch = batch_tasks[batch_id]
return {
"batch_id": batch_id,
"total": batch["total"],
"completed": batch["completed"],
"success": batch["success"],
"failed": batch["failed"],
"skipped": batch.get("skipped", 0),
"current_index": batch["current_index"],
"cancelled": batch["cancelled"],
"finished": batch.get("finished", False),
"logs": batch.get("logs", []),
"progress": f"{batch['completed']}/{batch['total']}"
"total": snapshot["total"],
"completed": snapshot["completed"],
"success": snapshot["success"],
"failed": snapshot["failed"],
"skipped": snapshot["skipped"],
"current_index": snapshot["current_index"],
"cancelled": snapshot["cancelled"],
"finished": snapshot["finished"],
"logs": snapshot["logs"],
"progress": f"{snapshot['completed']}/{snapshot['total']}"
}
@router.post("/outlook-batch/{batch_id}/cancel")
async def cancel_outlook_batch(batch_id: str):
"""取消 Outlook 批量任务"""
if batch_id not in batch_tasks:
snapshot = _get_batch_snapshot(batch_id)
if snapshot is None:
raise HTTPException(status_code=404, detail="批量任务不存在")
batch = batch_tasks[batch_id]
if batch.get("finished"):
if snapshot.get("finished"):
raise HTTPException(status_code=400, detail="批量任务已完成")
# 同时更新两个系统的取消状态
batch["cancelled"] = True
task_manager.cancel_batch(batch_id)
return {"success": True, "message": "批量任务取消请求已提交"}

View File

@@ -223,16 +223,18 @@ class TaskManager:
def init_batch(self, batch_id: str, total: int):
"""初始化批量任务"""
_batch_status[batch_id] = {
"status": "running",
"total": total,
"completed": 0,
"success": 0,
"failed": 0,
"skipped": 0,
"current_index": 0,
"finished": False
}
with _get_batch_lock(batch_id):
_batch_status[batch_id] = {
"status": "running",
"total": total,
"completed": 0,
"success": 0,
"failed": 0,
"skipped": 0,
"current_index": 0,
"finished": False,
"cancelled": False,
}
logger.info(f"批量任务 {batch_id} 已初始化,总数: {total}")
def add_batch_log(self, batch_id: str, log_message: str):
@@ -276,11 +278,11 @@ class TaskManager:
def update_batch_status(self, batch_id: str, **kwargs):
"""更新批量任务状态"""
if batch_id not in _batch_status:
logger.warning(f"批量任务 {batch_id} 不存在")
return
_batch_status[batch_id].update(kwargs)
with _get_batch_lock(batch_id):
if batch_id not in _batch_status:
logger.warning(f"批量任务 {batch_id} 不存在")
return
_batch_status[batch_id].update(kwargs)
# 异步广播状态更新
if self._loop and self._loop.is_running():
@@ -292,6 +294,35 @@ class TaskManager:
except Exception as e:
logger.warning(f"广播批量状态失败: {e}")
def record_batch_task_result(self, batch_id: str, task_status: str) -> Optional[dict]:
"""原子记录单个子任务的终态并返回快照。"""
with _get_batch_lock(batch_id):
status = _batch_status.get(batch_id)
if status is None:
logger.warning(f"批量任务 {batch_id} 不存在")
return None
status["completed"] += 1
if task_status == "completed":
status["success"] += 1
elif task_status == "failed":
status["failed"] += 1
elif task_status == "cancelled":
status["skipped"] += 1
snapshot = status.copy()
if self._loop and self._loop.is_running():
try:
asyncio.run_coroutine_threadsafe(
self._broadcast_batch_status(batch_id),
self._loop
)
except Exception as e:
logger.warning(f"广播批量状态失败: {e}")
return snapshot
async def _broadcast_batch_status(self, batch_id: str):
"""广播批量任务状态"""
with _ws_lock:
@@ -312,7 +343,9 @@ class TaskManager:
def get_batch_status(self, batch_id: str) -> Optional[dict]:
"""获取批量任务状态"""
return _batch_status.get(batch_id)
with _get_batch_lock(batch_id):
status = _batch_status.get(batch_id)
return status.copy() if status else None
def get_batch_logs(self, batch_id: str) -> List[str]:
"""获取批量任务日志"""
@@ -321,15 +354,28 @@ class TaskManager:
def is_batch_cancelled(self, batch_id: str) -> bool:
"""检查批量任务是否已取消"""
status = _batch_status.get(batch_id, {})
return status.get("cancelled", False)
with _get_batch_lock(batch_id):
status = _batch_status.get(batch_id, {})
return status.get("cancelled", False)
def cancel_batch(self, batch_id: str):
"""取消批量任务"""
if batch_id in _batch_status:
_batch_status[batch_id]["cancelled"] = True
_batch_status[batch_id]["status"] = "cancelling"
logger.info(f"批量任务 {batch_id} 已标记为取消")
changed = False
with _get_batch_lock(batch_id):
if batch_id in _batch_status:
_batch_status[batch_id]["cancelled"] = True
_batch_status[batch_id]["status"] = "cancelling"
changed = True
logger.info(f"批量任务 {batch_id} 已标记为取消")
if changed and self._loop and self._loop.is_running():
try:
asyncio.run_coroutine_threadsafe(
self._broadcast_batch_status(batch_id),
self._loop
)
except Exception as e:
logger.warning(f"广播批量状态失败: {e}")
def register_batch_websocket(self, batch_id: str, websocket):
"""注册批量任务 WebSocket 连接"""