Merge branch 'cnlimiter:master' into master

This commit is contained in:
pigracing
2026-03-16 16:43:44 +08:00
committed by GitHub
10 changed files with 566 additions and 335 deletions

View File

@@ -351,26 +351,42 @@ async def export_accounts_csv(request: BatchExportRequest):
@router.post("/export/cpa")
async def export_accounts_cpa(request: BatchExportRequest):
"""导出账号为 CPA Token JSON 格式"""
"""导出账号为 CPA Token JSON 格式(每个账号单独一个 JSON 文件,打包为 ZIP"""
import io
import zipfile
from ...core.cpa_upload import generate_token_json
with get_db() as db:
accounts = db.query(Account).filter(Account.id.in_(request.ids)).all()
# 生成 CPA 格式的 Token 数组
export_data = [generate_token_json(acc) for acc in accounts]
# 生成文件名
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"cpa_tokens_{timestamp}.json"
# 返回 JSON 响应
content = json.dumps(export_data, ensure_ascii=False, indent=2)
if len(accounts) == 1:
# 单个账号直接返回 JSON 文件
acc = accounts[0]
token_data = generate_token_json(acc)
content = json.dumps(token_data, ensure_ascii=False, indent=2)
filename = f"{acc.email}.json"
return StreamingResponse(
iter([content]),
media_type="application/json",
headers={"Content-Disposition": f"attachment; filename={filename}"}
)
# 多个账号打包为 ZIP
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf:
for acc in accounts:
token_data = generate_token_json(acc)
content = json.dumps(token_data, ensure_ascii=False, indent=2)
zf.writestr(f"{acc.email}.json", content)
zip_buffer.seek(0)
zip_filename = f"cpa_tokens_{timestamp}.zip"
return StreamingResponse(
iter([content]),
media_type="application/json",
headers={"Content-Disposition": f"attachment; filename={filename}"}
zip_buffer,
media_type="application/zip",
headers={"Content-Disposition": f"attachment; filename={zip_filename}"}
)

View File

@@ -82,6 +82,8 @@ class BatchRegistrationRequest(BaseModel):
email_service_id: Optional[int] = None # 使用数据库中已配置的邮箱服务 ID
interval_min: int = 5 # 最小间隔秒数
interval_max: int = 30 # 最大间隔秒数
concurrency: int = 1 # 并发线程数 (1-50)
mode: str = "pipeline" # 执行模式: "parallel" 或 "pipeline"
class RegistrationTaskResponse(BaseModel):
@@ -142,6 +144,8 @@ class OutlookBatchRegistrationRequest(BaseModel):
proxy: Optional[str] = None
interval_min: int = 5
interval_max: int = 30
concurrency: int = 1 # 并发线程数 (1-50)
mode: str = "pipeline" # 执行模式: "parallel" 或 "pipeline"
class OutlookBatchRegistrationResponse(BaseModel):
@@ -172,7 +176,7 @@ def task_to_response(task: RegistrationTask) -> RegistrationTaskResponse:
)
def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy: Optional[str], email_service_config: Optional[dict], email_service_id: Optional[int] = None):
def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy: Optional[str], email_service_config: Optional[dict], email_service_id: Optional[int] = None, log_prefix: str = "", batch_id: str = ""):
"""
在线程池中执行的同步注册任务
@@ -310,7 +314,7 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy:
email_service = EmailServiceFactory.create(service_type, config)
# 创建注册引擎 - 使用 TaskManager 的日志回调
log_callback = task_manager.create_log_callback(task_uuid)
log_callback = task_manager.create_log_callback(task_uuid, prefix=log_prefix, batch_id=batch_id)
engine = RegistrationEngine(
email_service=email_service,
@@ -373,7 +377,7 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy:
pass
async def run_registration_task(task_uuid: str, email_service_type: str, proxy: Optional[str], email_service_config: Optional[dict], email_service_id: Optional[int] = None):
async def run_registration_task(task_uuid: str, email_service_type: str, proxy: Optional[str], email_service_config: Optional[dict], email_service_id: Optional[int] = None, log_prefix: str = "", batch_id: str = ""):
"""
异步执行注册任务
@@ -386,10 +390,10 @@ async def run_registration_task(task_uuid: str, email_service_type: str, proxy:
# 初始化 TaskManager 状态
task_manager.update_status(task_uuid, "pending")
task_manager.add_log(task_uuid, f"[系统] 任务 {task_uuid[:8]} 已加入队列")
task_manager.add_log(task_uuid, f"{log_prefix} [系统] 任务 {task_uuid[:8]} 已加入队列" if log_prefix else f"[系统] 任务 {task_uuid[:8]} 已加入队列")
try:
# 在线程池中执行同步任务
# 在线程池中执行同步任务(传入 log_prefix 和 batch_id 供回调使用)
await loop.run_in_executor(
task_manager.executor,
_run_sync_registration_task,
@@ -397,7 +401,9 @@ async def run_registration_task(task_uuid: str, email_service_type: str, proxy:
email_service_type,
proxy,
email_service_config,
email_service_id
email_service_id,
log_prefix,
batch_id
)
except Exception as e:
logger.error(f"线程池执行异常: {task_uuid}, 错误: {e}")
@@ -405,6 +411,172 @@ async def run_registration_task(task_uuid: str, email_service_type: str, proxy:
task_manager.update_status(task_uuid, "failed", error=str(e))
def _init_batch_state(batch_id: str, task_uuids: List[str]):
"""初始化批量任务内存状态"""
task_manager.init_batch(batch_id, len(task_uuids))
batch_tasks[batch_id] = {
"total": len(task_uuids),
"completed": 0,
"success": 0,
"failed": 0,
"cancelled": False,
"task_uuids": task_uuids,
"current_index": 0,
"logs": [],
"finished": False
}
def _make_batch_helpers(batch_id: str):
"""返回 add_batch_log 和 update_batch_status 辅助函数"""
def add_batch_log(msg: str):
batch_tasks[batch_id]["logs"].append(msg)
task_manager.add_batch_log(batch_id, msg)
def update_batch_status(**kwargs):
for key, value in kwargs.items():
if key in batch_tasks[batch_id]:
batch_tasks[batch_id][key] = value
task_manager.update_batch_status(batch_id, **kwargs)
return add_batch_log, update_batch_status
async def run_batch_parallel(
batch_id: str,
task_uuids: List[str],
email_service_type: str,
proxy: Optional[str],
email_service_config: Optional[dict],
email_service_id: Optional[int],
concurrency: int
):
"""
并行模式所有任务同时提交Semaphore 控制最大并发数
"""
_init_batch_state(batch_id, task_uuids)
add_batch_log, update_batch_status = _make_batch_helpers(batch_id)
semaphore = asyncio.Semaphore(concurrency)
counter_lock = asyncio.Lock()
add_batch_log(f"[系统] 并行模式启动,并发数: {concurrency},总任务: {len(task_uuids)}")
async def _run_one(idx: int, uuid: str):
prefix = f"[任务{idx + 1}]"
async with semaphore:
await run_registration_task(
uuid, email_service_type, proxy, email_service_config, email_service_id,
log_prefix=prefix, batch_id=batch_id
)
with get_db() as db:
t = crud.get_registration_task(db, uuid)
if t:
async with counter_lock:
new_completed = batch_tasks[batch_id]["completed"] + 1
new_success = batch_tasks[batch_id]["success"]
new_failed = batch_tasks[batch_id]["failed"]
if t.status == "completed":
new_success += 1
add_batch_log(f"{prefix} [成功] 注册成功")
elif t.status == "failed":
new_failed += 1
add_batch_log(f"{prefix} [失败] 注册失败: {t.error_message}")
update_batch_status(completed=new_completed, success=new_success, failed=new_failed)
try:
await asyncio.gather(*[_run_one(i, u) for i, u in enumerate(task_uuids)], return_exceptions=True)
if not task_manager.is_batch_cancelled(batch_id):
add_batch_log(f"[完成] 批量任务完成!成功: {batch_tasks[batch_id]['success']}, 失败: {batch_tasks[batch_id]['failed']}")
update_batch_status(finished=True, status="completed")
else:
update_batch_status(finished=True, status="cancelled")
except Exception as e:
logger.error(f"批量任务 {batch_id} 异常: {e}")
add_batch_log(f"[错误] 批量任务异常: {str(e)}")
update_batch_status(finished=True, status="failed")
finally:
batch_tasks[batch_id]["finished"] = True
async def run_batch_pipeline(
batch_id: str,
task_uuids: List[str],
email_service_type: str,
proxy: Optional[str],
email_service_config: Optional[dict],
email_service_id: Optional[int],
interval_min: int,
interval_max: int,
concurrency: int
):
"""
流水线模式:每隔 interval 秒启动一个新任务Semaphore 限制最大并发数
"""
_init_batch_state(batch_id, task_uuids)
add_batch_log, update_batch_status = _make_batch_helpers(batch_id)
semaphore = asyncio.Semaphore(concurrency)
counter_lock = asyncio.Lock()
running_tasks_list = []
add_batch_log(f"[系统] 流水线模式启动,并发数: {concurrency},总任务: {len(task_uuids)}")
async def _run_and_release(idx: int, uuid: str, pfx: str):
try:
await run_registration_task(
uuid, email_service_type, proxy, email_service_config, email_service_id,
log_prefix=pfx, batch_id=batch_id
)
with get_db() as db:
t = crud.get_registration_task(db, uuid)
if t:
async with counter_lock:
new_completed = batch_tasks[batch_id]["completed"] + 1
new_success = batch_tasks[batch_id]["success"]
new_failed = batch_tasks[batch_id]["failed"]
if t.status == "completed":
new_success += 1
add_batch_log(f"{pfx} [成功] 注册成功")
elif t.status == "failed":
new_failed += 1
add_batch_log(f"{pfx} [失败] 注册失败: {t.error_message}")
update_batch_status(completed=new_completed, success=new_success, failed=new_failed)
finally:
semaphore.release()
try:
for i, task_uuid in enumerate(task_uuids):
if task_manager.is_batch_cancelled(batch_id) or batch_tasks[batch_id]["cancelled"]:
with get_db() as db:
for remaining_uuid in task_uuids[i:]:
crud.update_registration_task(db, remaining_uuid, status="cancelled")
add_batch_log("[取消] 批量任务已取消")
update_batch_status(finished=True, status="cancelled")
break
update_batch_status(current_index=i)
await semaphore.acquire()
prefix = f"[任务{i + 1}]"
add_batch_log(f"{prefix} 开始注册...")
t = asyncio.create_task(_run_and_release(i, task_uuid, prefix))
running_tasks_list.append(t)
if i < len(task_uuids) - 1 and not task_manager.is_batch_cancelled(batch_id):
wait_time = random.randint(interval_min, interval_max)
logger.info(f"批量任务 {batch_id}: 等待 {wait_time} 秒后启动下一个任务")
await asyncio.sleep(wait_time)
if running_tasks_list:
await asyncio.gather(*running_tasks_list, return_exceptions=True)
if not task_manager.is_batch_cancelled(batch_id):
add_batch_log(f"[完成] 批量任务完成!成功: {batch_tasks[batch_id]['success']}, 失败: {batch_tasks[batch_id]['failed']}")
update_batch_status(finished=True, status="completed")
except Exception as e:
logger.error(f"批量任务 {batch_id} 异常: {e}")
add_batch_log(f"[错误] 批量任务异常: {str(e)}")
update_batch_status(finished=True, status="failed")
finally:
batch_tasks[batch_id]["finished"] = True
async def run_batch_registration(
batch_id: str,
task_uuids: List[str],
@@ -413,95 +585,22 @@ async def run_batch_registration(
email_service_config: Optional[dict],
email_service_id: Optional[int],
interval_min: int,
interval_max: int
interval_max: int,
concurrency: int = 1,
mode: str = "pipeline"
):
"""
异步执行批量注册任务
使用线程池执行每个注册任务,避免阻塞主事件循环
"""
# 初始化 TaskManager 批量任务(支持 WebSocket 推送)
task_manager.init_batch(batch_id, len(task_uuids))
batch_tasks[batch_id] = {
"total": len(task_uuids),
"completed": 0,
"success": 0,
"failed": 0,
"cancelled": False,
"task_uuids": task_uuids,
"current_index": 0
}
def add_batch_log(msg: str):
batch_tasks[batch_id]["logs"] = batch_tasks[batch_id].get("logs", [])
batch_tasks[batch_id]["logs"].append(msg)
task_manager.add_batch_log(batch_id, msg)
def update_batch_status(**kwargs):
for key, value in kwargs.items():
if key in batch_tasks[batch_id]:
batch_tasks[batch_id][key] = value
task_manager.update_batch_status(batch_id, **kwargs)
try:
for i, task_uuid in enumerate(task_uuids):
# 检查是否已取消
if task_manager.is_batch_cancelled(batch_id) or batch_tasks[batch_id]["cancelled"]:
# 取消剩余任务
with get_db() as db:
for remaining_uuid in task_uuids[i:]:
crud.update_registration_task(db, remaining_uuid, status="cancelled")
add_batch_log(f"[取消] 批量任务已取消")
update_batch_status(finished=True, status="cancelled")
logger.info(f"批量任务 {batch_id} 已取消")
break
update_batch_status(current_index=i)
# 运行单个注册任务(使用线程池)
await run_registration_task(
task_uuid, email_service_type, proxy, email_service_config, email_service_id
)
# 更新统计
with get_db() as db:
task = crud.get_registration_task(db, task_uuid)
if task:
new_completed = batch_tasks[batch_id]["completed"] + 1
new_success = batch_tasks[batch_id]["success"]
new_failed = batch_tasks[batch_id]["failed"]
if task.status == "completed":
new_success += 1
add_batch_log(f"[成功] 第 {new_success} 个账号注册成功")
elif task.status == "failed":
new_failed += 1
add_batch_log(f"[失败] 第 {new_failed} 个账号注册失败: {task.error_message}")
update_batch_status(
completed=new_completed,
success=new_success,
failed=new_failed
)
# 如果不是最后一个任务,等待随机间隔
if i < len(task_uuids) - 1 and not task_manager.is_batch_cancelled(batch_id):
wait_time = random.randint(interval_min, interval_max)
logger.info(f"批量任务 {batch_id}: 等待 {wait_time} 秒后继续下一个任务")
await asyncio.sleep(wait_time)
if not task_manager.is_batch_cancelled(batch_id):
add_batch_log(f"[完成] 批量任务完成!成功: {batch_tasks[batch_id]['success']}, 失败: {batch_tasks[batch_id]['failed']}")
update_batch_status(finished=True, status="completed")
logger.info(f"批量任务 {batch_id} 完成: 成功 {batch_tasks[batch_id]['success']}, 失败 {batch_tasks[batch_id]['failed']}")
except Exception as e:
logger.error(f"批量任务 {batch_id} 异常: {e}")
add_batch_log(f"[错误] 批量任务异常: {str(e)}")
update_batch_status(finished=True, status="failed")
finally:
batch_tasks[batch_id]["finished"] = True
"""根据 mode 分发到并行或流水线执行"""
if mode == "parallel":
await run_batch_parallel(
batch_id, task_uuids, email_service_type, proxy,
email_service_config, email_service_id, concurrency
)
else:
await run_batch_pipeline(
batch_id, task_uuids, email_service_type, proxy,
email_service_config, email_service_id,
interval_min, interval_max, concurrency
)
# ============== API Endpoints ==============
@@ -579,6 +678,12 @@ async def start_batch_registration(
if request.interval_min < 0 or request.interval_max < request.interval_min:
raise HTTPException(status_code=400, detail="间隔时间参数无效")
if not 1 <= request.concurrency <= 50:
raise HTTPException(status_code=400, detail="并发数必须在 1-50 之间")
if request.mode not in ("parallel", "pipeline"):
raise HTTPException(status_code=400, detail="模式必须为 parallel 或 pipeline")
# 创建批量任务
batch_id = str(uuid.uuid4())
task_uuids = []
@@ -607,7 +712,9 @@ async def start_batch_registration(
request.email_service_config,
request.email_service_id,
request.interval_min,
request.interval_max
request.interval_max,
request.concurrency,
request.mode
)
return BatchRegistrationResponse(
@@ -903,168 +1010,52 @@ async def get_outlook_accounts_for_registration():
)
def _run_sync_outlook_batch_registration(
batch_id: str,
service_ids: List[int],
skip_registered: bool,
proxy: Optional[str],
interval_min: int,
interval_max: int
):
"""
在线程池中执行的同步 Outlook 批量注册任务
"""
from ...database.models import EmailService as EmailServiceModel
from ...database.models import Account
# 初始化 TaskManager 批量任务
task_manager.init_batch(batch_id, len(service_ids))
# 兼容旧版 batch_tasks用于 REST API 轮询降级)
batch_tasks[batch_id] = {
"total": len(service_ids),
"completed": 0,
"success": 0,
"failed": 0,
"skipped": 0,
"cancelled": False,
"service_ids": service_ids,
"current_index": 0,
"logs": []
}
def add_batch_log(msg: str):
"""同时添加日志到两个系统"""
batch_tasks[batch_id]["logs"].append(msg)
task_manager.add_batch_log(batch_id, msg)
def update_batch_status(**kwargs):
"""同时更新两个系统的状态"""
for key, value in kwargs.items():
if key in batch_tasks[batch_id]:
batch_tasks[batch_id][key] = value
task_manager.update_batch_status(batch_id, **kwargs)
try:
for i, service_id in enumerate(service_ids):
# 检查是否已取消
if task_manager.is_batch_cancelled(batch_id):
add_batch_log(f"[取消] 批量任务已取消")
update_batch_status(finished=True, status="cancelled")
logger.info(f"Outlook 批量任务 {batch_id} 已取消")
break
update_batch_status(current_index=i)
with get_db() as db:
# 获取邮箱服务
service = db.query(EmailServiceModel).filter(
EmailServiceModel.id == service_id
).first()
if not service:
add_batch_log(f"[跳过] 服务 ID {service_id} 不存在")
update_batch_status(skipped=batch_tasks[batch_id]["skipped"] + 1,
completed=batch_tasks[batch_id]["completed"] + 1)
continue
config = service.config or {}
email = config.get("email") or service.name
# 检查是否已注册
if skip_registered:
existing_account = db.query(Account).filter(
Account.email == email
).first()
if existing_account:
add_batch_log(f"[跳过] {email} 已注册 (账号 ID: {existing_account.id})")
update_batch_status(skipped=batch_tasks[batch_id]["skipped"] + 1,
completed=batch_tasks[batch_id]["completed"] + 1)
continue
# 创建注册任务
task_uuid = str(uuid.uuid4())
task = crud.create_registration_task(
db,
task_uuid=task_uuid,
proxy=proxy,
email_service_id=service_id
)
add_batch_log(f"[注册] 开始注册 {email}...")
# 同步执行注册任务
_run_sync_registration_task(task_uuid, "outlook", proxy, None, service_id)
# 更新统计
with get_db() as db:
task = crud.get_registration_task(db, task_uuid)
if task:
new_completed = batch_tasks[batch_id]["completed"] + 1
new_success = batch_tasks[batch_id]["success"]
new_failed = batch_tasks[batch_id]["failed"]
if task.status == "completed":
new_success += 1
add_batch_log(f"[成功] {email} 注册成功")
elif task.status == "failed":
new_failed += 1
add_batch_log(f"[失败] {email} 注册失败: {task.error_message}")
update_batch_status(
completed=new_completed,
success=new_success,
failed=new_failed
)
# 如果不是最后一个任务,等待随机间隔
if i < len(service_ids) - 1 and not task_manager.is_batch_cancelled(batch_id):
wait_time = random.randint(interval_min, interval_max)
logger.info(f"Outlook 批量任务 {batch_id}: 等待 {wait_time} 秒后继续下一个任务")
import time
time.sleep(wait_time)
# 完成批量任务
if not task_manager.is_batch_cancelled(batch_id):
add_batch_log(f"[完成] 批量任务完成!成功: {batch_tasks[batch_id]['success']}, 失败: {batch_tasks[batch_id]['failed']}, 跳过: {batch_tasks[batch_id]['skipped']}")
update_batch_status(finished=True, status="completed")
logger.info(f"Outlook 批量任务 {batch_id} 完成: 成功 {batch_tasks[batch_id]['success']}, 失败 {batch_tasks[batch_id]['failed']}, 跳过 {batch_tasks[batch_id]['skipped']}")
except Exception as e:
logger.error(f"Outlook 批量任务 {batch_id} 异常: {e}")
add_batch_log(f"[错误] 批量任务异常: {str(e)}")
update_batch_status(finished=True, status="failed")
async def run_outlook_batch_registration(
batch_id: str,
service_ids: List[int],
skip_registered: bool,
proxy: Optional[str],
interval_min: int,
interval_max: int
interval_max: int,
concurrency: int = 1,
mode: str = "pipeline"
):
"""
异步执行 Outlook 批量注册任务
异步执行 Outlook 批量注册任务,复用通用并发逻辑
使用线程池执行,避免阻塞主事件循环
将每个 service_id 映射为一个独立的 task_uuid然后调用
run_batch_registration 的并发逻辑
"""
loop = task_manager.get_loop()
if loop is None:
loop = asyncio.get_event_loop()
task_manager.set_loop(loop)
# 在线程池中执行
await loop.run_in_executor(
task_manager.executor,
_run_sync_outlook_batch_registration,
batch_id,
service_ids,
skip_registered,
proxy,
interval_min,
interval_max
# 预先为每个 service_id 创建注册任务记录
task_uuids = []
with get_db() as db:
for service_id in service_ids:
task_uuid = str(uuid.uuid4())
crud.create_registration_task(
db,
task_uuid=task_uuid,
proxy=proxy,
email_service_id=service_id
)
task_uuids.append(task_uuid)
# 复用通用并发逻辑outlook 服务类型,每个任务通过 email_service_id 定位账户)
await run_batch_registration(
batch_id=batch_id,
task_uuids=task_uuids,
email_service_type="outlook",
proxy=proxy,
email_service_config=None,
email_service_id=None, # 每个任务已绑定了独立的 email_service_id
interval_min=interval_min,
interval_max=interval_max,
concurrency=concurrency,
mode=mode
)
@@ -1092,6 +1083,12 @@ async def start_outlook_batch_registration(
if request.interval_min < 0 or request.interval_max < request.interval_min:
raise HTTPException(status_code=400, detail="间隔时间参数无效")
if not 1 <= request.concurrency <= 50:
raise HTTPException(status_code=400, detail="并发数必须在 1-50 之间")
if request.mode not in ("parallel", "pipeline"):
raise HTTPException(status_code=400, detail="模式必须为 parallel 或 pipeline")
# 过滤掉已注册的邮箱
actual_service_ids = request.service_ids
skipped_count = 0
@@ -1154,7 +1151,9 @@ async def start_outlook_batch_registration(
request.skip_registered,
request.proxy,
request.interval_min,
request.interval_max
request.interval_max,
request.concurrency,
request.mode
)
return OutlookBatchRegistrationResponse(

View File

@@ -13,12 +13,15 @@ from datetime import datetime
logger = logging.getLogger(__name__)
# 全局线程池
_executor = ThreadPoolExecutor(max_workers=4, thread_name_prefix="reg_worker")
# 全局线程池(支持最多 50 个并发注册任务)
_executor = ThreadPoolExecutor(max_workers=50, thread_name_prefix="reg_worker")
# 全局元锁:保护所有 defaultdict 的首次 key 创建(避免多线程竞态)
_meta_lock = threading.Lock()
# 任务日志队列 (task_uuid -> list of logs)
_log_queues: Dict[str, List[str]] = defaultdict(list)
_log_locks: Dict[str, threading.Lock] = defaultdict(threading.Lock)
_log_locks: Dict[str, threading.Lock] = {}
# WebSocket 连接管理 (task_uuid -> list of websockets)
_ws_connections: Dict[str, List] = defaultdict(list)
@@ -36,7 +39,25 @@ _task_cancelled: Dict[str, bool] = {}
# 批量任务状态 (batch_id -> dict)
_batch_status: Dict[str, dict] = {}
_batch_logs: Dict[str, List[str]] = defaultdict(list)
_batch_locks: Dict[str, threading.Lock] = defaultdict(threading.Lock)
_batch_locks: Dict[str, threading.Lock] = {}
def _get_log_lock(task_uuid: str) -> threading.Lock:
"""线程安全地获取或创建任务日志锁"""
if task_uuid not in _log_locks:
with _meta_lock:
if task_uuid not in _log_locks:
_log_locks[task_uuid] = threading.Lock()
return _log_locks[task_uuid]
def _get_batch_lock(batch_id: str) -> threading.Lock:
"""线程安全地获取或创建批量任务日志锁"""
if batch_id not in _batch_locks:
with _meta_lock:
if batch_id not in _batch_locks:
_batch_locks[batch_id] = threading.Lock()
return _batch_locks[batch_id]
class TaskManager:
@@ -77,7 +98,7 @@ class TaskManager:
logger.warning(f"推送日志到 WebSocket 失败: {e}")
# 广播后再添加到队列
with _log_locks[task_uuid]:
with _get_log_lock(task_uuid):
_log_queues[task_uuid].append(log_message)
async def _broadcast_log(self, task_uuid: str, log_message: str):
@@ -132,7 +153,7 @@ class TaskManager:
if websocket not in _ws_connections[task_uuid]:
_ws_connections[task_uuid].append(websocket)
# 记录已发送的日志数量,用于发送历史日志时避免重复
with _log_locks[task_uuid]:
with _get_log_lock(task_uuid):
_ws_sent_index[task_uuid][id(websocket)] = len(_log_queues.get(task_uuid, []))
logger.info(f"WebSocket 连接已注册: {task_uuid}")
else:
@@ -144,7 +165,7 @@ class TaskManager:
ws_id = id(websocket)
sent_count = _ws_sent_index.get(task_uuid, {}).get(ws_id, 0)
with _log_locks[task_uuid]:
with _get_log_lock(task_uuid):
all_logs = _log_queues.get(task_uuid, [])
unsent_logs = all_logs[sent_count:]
# 更新已发送索引
@@ -166,7 +187,7 @@ class TaskManager:
def get_logs(self, task_uuid: str) -> List[str]:
"""获取任务的所有日志"""
with _log_locks[task_uuid]:
with _get_log_lock(task_uuid):
return _log_queues.get(task_uuid, []).copy()
def update_status(self, task_uuid: str, status: str, **kwargs):
@@ -217,7 +238,7 @@ class TaskManager:
logger.warning(f"推送批量日志到 WebSocket 失败: {e}")
# 广播后再添加到队列
with _batch_locks[batch_id]:
with _get_batch_lock(batch_id):
_batch_logs[batch_id].append(log_message)
async def _broadcast_batch_log(self, batch_id: str, log_message: str):
@@ -285,7 +306,7 @@ class TaskManager:
def get_batch_logs(self, batch_id: str) -> List[str]:
"""获取批量任务日志"""
with _batch_locks[batch_id]:
with _get_batch_lock(batch_id):
return _batch_logs.get(batch_id, []).copy()
def is_batch_cancelled(self, batch_id: str) -> bool:
@@ -310,7 +331,7 @@ class TaskManager:
if websocket not in _ws_connections[key]:
_ws_connections[key].append(websocket)
# 记录已发送的日志数量,用于发送历史日志时避免重复
with _batch_locks[batch_id]:
with _get_batch_lock(batch_id):
_ws_sent_index[key][id(websocket)] = len(_batch_logs.get(batch_id, []))
logger.info(f"批量任务 WebSocket 连接已注册: {batch_id}")
else:
@@ -323,7 +344,7 @@ class TaskManager:
ws_id = id(websocket)
sent_count = _ws_sent_index.get(key, {}).get(ws_id, 0)
with _batch_locks[batch_id]:
with _get_batch_lock(batch_id):
all_logs = _batch_logs.get(batch_id, [])
unsent_logs = all_logs[sent_count:]
# 更新已发送索引
@@ -344,10 +365,14 @@ class TaskManager:
_ws_sent_index[key].pop(id(websocket), None)
logger.info(f"批量任务 WebSocket 连接已注销: {batch_id}")
def create_log_callback(self, task_uuid: str) -> Callable[[str], None]:
"""创建日志回调函数"""
def create_log_callback(self, task_uuid: str, prefix: str = "", batch_id: str = "") -> Callable[[str], None]:
"""创建日志回调函数,可附加任务编号前缀,并同时推送到批量任务频道"""
def callback(msg: str):
self.add_log(task_uuid, msg)
full_msg = f"{prefix} {msg}" if prefix else msg
self.add_log(task_uuid, full_msg)
# 如果属于批量任务,同步推送到 batch 频道,前端可在混合日志中看到详细步骤
if batch_id:
self.add_batch_log(batch_id, full_msg)
return callback
def create_check_cancelled_callback(self, task_uuid: str) -> Callable[[], bool]: