fix(task_manager): 修复多线程竞态条件,添加全局元锁保护字典创建

This commit is contained in:
cnlimiter
2026-03-16 10:45:25 +08:00
parent bb75fe08dd
commit 07f0a2cca0

View File

@@ -16,9 +16,12 @@ logger = logging.getLogger(__name__)
# 全局线程池(支持最多 50 个并发注册任务)
_executor = ThreadPoolExecutor(max_workers=50, thread_name_prefix="reg_worker")
# 全局元锁:保护所有 defaultdict 的首次 key 创建(避免多线程竞态)
_meta_lock = threading.Lock()
# 任务日志队列 (task_uuid -> list of logs)
_log_queues: Dict[str, List[str]] = defaultdict(list)
_log_locks: Dict[str, threading.Lock] = defaultdict(threading.Lock)
_log_locks: Dict[str, threading.Lock] = {}
# WebSocket 连接管理 (task_uuid -> list of websockets)
_ws_connections: Dict[str, List] = defaultdict(list)
@@ -36,7 +39,25 @@ _task_cancelled: Dict[str, bool] = {}
# 批量任务状态 (batch_id -> dict)
_batch_status: Dict[str, dict] = {}
_batch_logs: Dict[str, List[str]] = defaultdict(list)
_batch_locks: Dict[str, threading.Lock] = defaultdict(threading.Lock)
_batch_locks: Dict[str, threading.Lock] = {}
def _get_log_lock(task_uuid: str) -> threading.Lock:
"""线程安全地获取或创建任务日志锁"""
if task_uuid not in _log_locks:
with _meta_lock:
if task_uuid not in _log_locks:
_log_locks[task_uuid] = threading.Lock()
return _log_locks[task_uuid]
def _get_batch_lock(batch_id: str) -> threading.Lock:
"""线程安全地获取或创建批量任务日志锁"""
if batch_id not in _batch_locks:
with _meta_lock:
if batch_id not in _batch_locks:
_batch_locks[batch_id] = threading.Lock()
return _batch_locks[batch_id]
class TaskManager:
@@ -77,7 +98,7 @@ class TaskManager:
logger.warning(f"推送日志到 WebSocket 失败: {e}")
# 广播后再添加到队列
with _log_locks[task_uuid]:
with _get_log_lock(task_uuid):
_log_queues[task_uuid].append(log_message)
async def _broadcast_log(self, task_uuid: str, log_message: str):
@@ -132,7 +153,7 @@ class TaskManager:
if websocket not in _ws_connections[task_uuid]:
_ws_connections[task_uuid].append(websocket)
# 记录已发送的日志数量,用于发送历史日志时避免重复
with _log_locks[task_uuid]:
with _get_log_lock(task_uuid):
_ws_sent_index[task_uuid][id(websocket)] = len(_log_queues.get(task_uuid, []))
logger.info(f"WebSocket 连接已注册: {task_uuid}")
else:
@@ -144,7 +165,7 @@ class TaskManager:
ws_id = id(websocket)
sent_count = _ws_sent_index.get(task_uuid, {}).get(ws_id, 0)
with _log_locks[task_uuid]:
with _get_log_lock(task_uuid):
all_logs = _log_queues.get(task_uuid, [])
unsent_logs = all_logs[sent_count:]
# 更新已发送索引
@@ -166,7 +187,7 @@ class TaskManager:
def get_logs(self, task_uuid: str) -> List[str]:
"""获取任务的所有日志"""
with _log_locks[task_uuid]:
with _get_log_lock(task_uuid):
return _log_queues.get(task_uuid, []).copy()
def update_status(self, task_uuid: str, status: str, **kwargs):
@@ -217,7 +238,7 @@ class TaskManager:
logger.warning(f"推送批量日志到 WebSocket 失败: {e}")
# 广播后再添加到队列
with _batch_locks[batch_id]:
with _get_batch_lock(batch_id):
_batch_logs[batch_id].append(log_message)
async def _broadcast_batch_log(self, batch_id: str, log_message: str):
@@ -285,7 +306,7 @@ class TaskManager:
def get_batch_logs(self, batch_id: str) -> List[str]:
"""获取批量任务日志"""
with _batch_locks[batch_id]:
with _get_batch_lock(batch_id):
return _batch_logs.get(batch_id, []).copy()
def is_batch_cancelled(self, batch_id: str) -> bool:
@@ -310,7 +331,7 @@ class TaskManager:
if websocket not in _ws_connections[key]:
_ws_connections[key].append(websocket)
# 记录已发送的日志数量,用于发送历史日志时避免重复
with _batch_locks[batch_id]:
with _get_batch_lock(batch_id):
_ws_sent_index[key][id(websocket)] = len(_batch_logs.get(batch_id, []))
logger.info(f"批量任务 WebSocket 连接已注册: {batch_id}")
else:
@@ -323,7 +344,7 @@ class TaskManager:
ws_id = id(websocket)
sent_count = _ws_sent_index.get(key, {}).get(ws_id, 0)
with _batch_locks[batch_id]:
with _get_batch_lock(batch_id):
all_logs = _batch_logs.get(batch_id, [])
unsent_logs = all_logs[sent_count:]
# 更新已发送索引