diff --git a/src/web/task_manager.py b/src/web/task_manager.py index b9e31b9..31c620b 100644 --- a/src/web/task_manager.py +++ b/src/web/task_manager.py @@ -16,9 +16,12 @@ logger = logging.getLogger(__name__) # 全局线程池(支持最多 50 个并发注册任务) _executor = ThreadPoolExecutor(max_workers=50, thread_name_prefix="reg_worker") +# 全局元锁:保护所有 defaultdict 的首次 key 创建(避免多线程竞态) +_meta_lock = threading.Lock() + # 任务日志队列 (task_uuid -> list of logs) _log_queues: Dict[str, List[str]] = defaultdict(list) -_log_locks: Dict[str, threading.Lock] = defaultdict(threading.Lock) +_log_locks: Dict[str, threading.Lock] = {} # WebSocket 连接管理 (task_uuid -> list of websockets) _ws_connections: Dict[str, List] = defaultdict(list) @@ -36,7 +39,25 @@ _task_cancelled: Dict[str, bool] = {} # 批量任务状态 (batch_id -> dict) _batch_status: Dict[str, dict] = {} _batch_logs: Dict[str, List[str]] = defaultdict(list) -_batch_locks: Dict[str, threading.Lock] = defaultdict(threading.Lock) +_batch_locks: Dict[str, threading.Lock] = {} + + +def _get_log_lock(task_uuid: str) -> threading.Lock: + """线程安全地获取或创建任务日志锁""" + if task_uuid not in _log_locks: + with _meta_lock: + if task_uuid not in _log_locks: + _log_locks[task_uuid] = threading.Lock() + return _log_locks[task_uuid] + + +def _get_batch_lock(batch_id: str) -> threading.Lock: + """线程安全地获取或创建批量任务日志锁""" + if batch_id not in _batch_locks: + with _meta_lock: + if batch_id not in _batch_locks: + _batch_locks[batch_id] = threading.Lock() + return _batch_locks[batch_id] class TaskManager: @@ -77,7 +98,7 @@ class TaskManager: logger.warning(f"推送日志到 WebSocket 失败: {e}") # 广播后再添加到队列 - with _log_locks[task_uuid]: + with _get_log_lock(task_uuid): _log_queues[task_uuid].append(log_message) async def _broadcast_log(self, task_uuid: str, log_message: str): @@ -132,7 +153,7 @@ class TaskManager: if websocket not in _ws_connections[task_uuid]: _ws_connections[task_uuid].append(websocket) # 记录已发送的日志数量,用于发送历史日志时避免重复 - with _log_locks[task_uuid]: + with _get_log_lock(task_uuid): _ws_sent_index[task_uuid][id(websocket)] = len(_log_queues.get(task_uuid, [])) logger.info(f"WebSocket 连接已注册: {task_uuid}") else: @@ -144,7 +165,7 @@ class TaskManager: ws_id = id(websocket) sent_count = _ws_sent_index.get(task_uuid, {}).get(ws_id, 0) - with _log_locks[task_uuid]: + with _get_log_lock(task_uuid): all_logs = _log_queues.get(task_uuid, []) unsent_logs = all_logs[sent_count:] # 更新已发送索引 @@ -166,7 +187,7 @@ class TaskManager: def get_logs(self, task_uuid: str) -> List[str]: """获取任务的所有日志""" - with _log_locks[task_uuid]: + with _get_log_lock(task_uuid): return _log_queues.get(task_uuid, []).copy() def update_status(self, task_uuid: str, status: str, **kwargs): @@ -217,7 +238,7 @@ class TaskManager: logger.warning(f"推送批量日志到 WebSocket 失败: {e}") # 广播后再添加到队列 - with _batch_locks[batch_id]: + with _get_batch_lock(batch_id): _batch_logs[batch_id].append(log_message) async def _broadcast_batch_log(self, batch_id: str, log_message: str): @@ -285,7 +306,7 @@ class TaskManager: def get_batch_logs(self, batch_id: str) -> List[str]: """获取批量任务日志""" - with _batch_locks[batch_id]: + with _get_batch_lock(batch_id): return _batch_logs.get(batch_id, []).copy() def is_batch_cancelled(self, batch_id: str) -> bool: @@ -310,7 +331,7 @@ class TaskManager: if websocket not in _ws_connections[key]: _ws_connections[key].append(websocket) # 记录已发送的日志数量,用于发送历史日志时避免重复 - with _batch_locks[batch_id]: + with _get_batch_lock(batch_id): _ws_sent_index[key][id(websocket)] = len(_batch_logs.get(batch_id, [])) logger.info(f"批量任务 WebSocket 连接已注册: {batch_id}") else: @@ -323,7 +344,7 @@ class TaskManager: ws_id = id(websocket) sent_count = _ws_sent_index.get(key, {}).get(ws_id, 0) - with _batch_locks[batch_id]: + with _get_batch_lock(batch_id): all_logs = _batch_logs.get(batch_id, []) unsent_logs = all_logs[sent_count:] # 更新已发送索引