mirror of
https://github.com/cnlimiter/codex-register.git
synced 2026-05-06 20:02:51 +08:00
fix(task_manager): 修复多线程竞态条件,添加全局元锁保护字典创建
This commit is contained in:
@@ -16,9 +16,12 @@ logger = logging.getLogger(__name__)
|
||||
# 全局线程池(支持最多 50 个并发注册任务)
|
||||
_executor = ThreadPoolExecutor(max_workers=50, thread_name_prefix="reg_worker")
|
||||
|
||||
# 全局元锁:保护所有 defaultdict 的首次 key 创建(避免多线程竞态)
|
||||
_meta_lock = threading.Lock()
|
||||
|
||||
# 任务日志队列 (task_uuid -> list of logs)
|
||||
_log_queues: Dict[str, List[str]] = defaultdict(list)
|
||||
_log_locks: Dict[str, threading.Lock] = defaultdict(threading.Lock)
|
||||
_log_locks: Dict[str, threading.Lock] = {}
|
||||
|
||||
# WebSocket 连接管理 (task_uuid -> list of websockets)
|
||||
_ws_connections: Dict[str, List] = defaultdict(list)
|
||||
@@ -36,7 +39,25 @@ _task_cancelled: Dict[str, bool] = {}
|
||||
# 批量任务状态 (batch_id -> dict)
|
||||
_batch_status: Dict[str, dict] = {}
|
||||
_batch_logs: Dict[str, List[str]] = defaultdict(list)
|
||||
_batch_locks: Dict[str, threading.Lock] = defaultdict(threading.Lock)
|
||||
_batch_locks: Dict[str, threading.Lock] = {}
|
||||
|
||||
|
||||
def _get_log_lock(task_uuid: str) -> threading.Lock:
|
||||
"""线程安全地获取或创建任务日志锁"""
|
||||
if task_uuid not in _log_locks:
|
||||
with _meta_lock:
|
||||
if task_uuid not in _log_locks:
|
||||
_log_locks[task_uuid] = threading.Lock()
|
||||
return _log_locks[task_uuid]
|
||||
|
||||
|
||||
def _get_batch_lock(batch_id: str) -> threading.Lock:
|
||||
"""线程安全地获取或创建批量任务日志锁"""
|
||||
if batch_id not in _batch_locks:
|
||||
with _meta_lock:
|
||||
if batch_id not in _batch_locks:
|
||||
_batch_locks[batch_id] = threading.Lock()
|
||||
return _batch_locks[batch_id]
|
||||
|
||||
|
||||
class TaskManager:
|
||||
@@ -77,7 +98,7 @@ class TaskManager:
|
||||
logger.warning(f"推送日志到 WebSocket 失败: {e}")
|
||||
|
||||
# 广播后再添加到队列
|
||||
with _log_locks[task_uuid]:
|
||||
with _get_log_lock(task_uuid):
|
||||
_log_queues[task_uuid].append(log_message)
|
||||
|
||||
async def _broadcast_log(self, task_uuid: str, log_message: str):
|
||||
@@ -132,7 +153,7 @@ class TaskManager:
|
||||
if websocket not in _ws_connections[task_uuid]:
|
||||
_ws_connections[task_uuid].append(websocket)
|
||||
# 记录已发送的日志数量,用于发送历史日志时避免重复
|
||||
with _log_locks[task_uuid]:
|
||||
with _get_log_lock(task_uuid):
|
||||
_ws_sent_index[task_uuid][id(websocket)] = len(_log_queues.get(task_uuid, []))
|
||||
logger.info(f"WebSocket 连接已注册: {task_uuid}")
|
||||
else:
|
||||
@@ -144,7 +165,7 @@ class TaskManager:
|
||||
ws_id = id(websocket)
|
||||
sent_count = _ws_sent_index.get(task_uuid, {}).get(ws_id, 0)
|
||||
|
||||
with _log_locks[task_uuid]:
|
||||
with _get_log_lock(task_uuid):
|
||||
all_logs = _log_queues.get(task_uuid, [])
|
||||
unsent_logs = all_logs[sent_count:]
|
||||
# 更新已发送索引
|
||||
@@ -166,7 +187,7 @@ class TaskManager:
|
||||
|
||||
def get_logs(self, task_uuid: str) -> List[str]:
|
||||
"""获取任务的所有日志"""
|
||||
with _log_locks[task_uuid]:
|
||||
with _get_log_lock(task_uuid):
|
||||
return _log_queues.get(task_uuid, []).copy()
|
||||
|
||||
def update_status(self, task_uuid: str, status: str, **kwargs):
|
||||
@@ -217,7 +238,7 @@ class TaskManager:
|
||||
logger.warning(f"推送批量日志到 WebSocket 失败: {e}")
|
||||
|
||||
# 广播后再添加到队列
|
||||
with _batch_locks[batch_id]:
|
||||
with _get_batch_lock(batch_id):
|
||||
_batch_logs[batch_id].append(log_message)
|
||||
|
||||
async def _broadcast_batch_log(self, batch_id: str, log_message: str):
|
||||
@@ -285,7 +306,7 @@ class TaskManager:
|
||||
|
||||
def get_batch_logs(self, batch_id: str) -> List[str]:
|
||||
"""获取批量任务日志"""
|
||||
with _batch_locks[batch_id]:
|
||||
with _get_batch_lock(batch_id):
|
||||
return _batch_logs.get(batch_id, []).copy()
|
||||
|
||||
def is_batch_cancelled(self, batch_id: str) -> bool:
|
||||
@@ -310,7 +331,7 @@ class TaskManager:
|
||||
if websocket not in _ws_connections[key]:
|
||||
_ws_connections[key].append(websocket)
|
||||
# 记录已发送的日志数量,用于发送历史日志时避免重复
|
||||
with _batch_locks[batch_id]:
|
||||
with _get_batch_lock(batch_id):
|
||||
_ws_sent_index[key][id(websocket)] = len(_batch_logs.get(batch_id, []))
|
||||
logger.info(f"批量任务 WebSocket 连接已注册: {batch_id}")
|
||||
else:
|
||||
@@ -323,7 +344,7 @@ class TaskManager:
|
||||
ws_id = id(websocket)
|
||||
sent_count = _ws_sent_index.get(key, {}).get(ws_id, 0)
|
||||
|
||||
with _batch_locks[batch_id]:
|
||||
with _get_batch_lock(batch_id):
|
||||
all_logs = _batch_logs.get(batch_id, [])
|
||||
unsent_logs = all_logs[sent_count:]
|
||||
# 更新已发送索引
|
||||
|
||||
Reference in New Issue
Block a user