mirror of
https://github.com/cnlimiter/codex-register.git
synced 2026-07-04 05:41:51 +08:00
Merge branch 'cnlimiter:master' into master
This commit is contained in:
@@ -351,26 +351,42 @@ async def export_accounts_csv(request: BatchExportRequest):
|
||||
|
||||
@router.post("/export/cpa")
|
||||
async def export_accounts_cpa(request: BatchExportRequest):
|
||||
"""导出账号为 CPA Token JSON 格式"""
|
||||
"""导出账号为 CPA Token JSON 格式(每个账号单独一个 JSON 文件,打包为 ZIP)"""
|
||||
import io
|
||||
import zipfile
|
||||
from ...core.cpa_upload import generate_token_json
|
||||
|
||||
with get_db() as db:
|
||||
accounts = db.query(Account).filter(Account.id.in_(request.ids)).all()
|
||||
|
||||
# 生成 CPA 格式的 Token 数组
|
||||
export_data = [generate_token_json(acc) for acc in accounts]
|
||||
|
||||
# 生成文件名
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"cpa_tokens_{timestamp}.json"
|
||||
|
||||
# 返回 JSON 响应
|
||||
content = json.dumps(export_data, ensure_ascii=False, indent=2)
|
||||
if len(accounts) == 1:
|
||||
# 单个账号直接返回 JSON 文件
|
||||
acc = accounts[0]
|
||||
token_data = generate_token_json(acc)
|
||||
content = json.dumps(token_data, ensure_ascii=False, indent=2)
|
||||
filename = f"{acc.email}.json"
|
||||
return StreamingResponse(
|
||||
iter([content]),
|
||||
media_type="application/json",
|
||||
headers={"Content-Disposition": f"attachment; filename={filename}"}
|
||||
)
|
||||
|
||||
# 多个账号打包为 ZIP
|
||||
zip_buffer = io.BytesIO()
|
||||
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf:
|
||||
for acc in accounts:
|
||||
token_data = generate_token_json(acc)
|
||||
content = json.dumps(token_data, ensure_ascii=False, indent=2)
|
||||
zf.writestr(f"{acc.email}.json", content)
|
||||
|
||||
zip_buffer.seek(0)
|
||||
zip_filename = f"cpa_tokens_{timestamp}.zip"
|
||||
return StreamingResponse(
|
||||
iter([content]),
|
||||
media_type="application/json",
|
||||
headers={"Content-Disposition": f"attachment; filename={filename}"}
|
||||
zip_buffer,
|
||||
media_type="application/zip",
|
||||
headers={"Content-Disposition": f"attachment; filename={zip_filename}"}
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -82,6 +82,8 @@ class BatchRegistrationRequest(BaseModel):
|
||||
email_service_id: Optional[int] = None # 使用数据库中已配置的邮箱服务 ID
|
||||
interval_min: int = 5 # 最小间隔秒数
|
||||
interval_max: int = 30 # 最大间隔秒数
|
||||
concurrency: int = 1 # 并发线程数 (1-50)
|
||||
mode: str = "pipeline" # 执行模式: "parallel" 或 "pipeline"
|
||||
|
||||
|
||||
class RegistrationTaskResponse(BaseModel):
|
||||
@@ -142,6 +144,8 @@ class OutlookBatchRegistrationRequest(BaseModel):
|
||||
proxy: Optional[str] = None
|
||||
interval_min: int = 5
|
||||
interval_max: int = 30
|
||||
concurrency: int = 1 # 并发线程数 (1-50)
|
||||
mode: str = "pipeline" # 执行模式: "parallel" 或 "pipeline"
|
||||
|
||||
|
||||
class OutlookBatchRegistrationResponse(BaseModel):
|
||||
@@ -172,7 +176,7 @@ def task_to_response(task: RegistrationTask) -> RegistrationTaskResponse:
|
||||
)
|
||||
|
||||
|
||||
def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy: Optional[str], email_service_config: Optional[dict], email_service_id: Optional[int] = None):
|
||||
def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy: Optional[str], email_service_config: Optional[dict], email_service_id: Optional[int] = None, log_prefix: str = "", batch_id: str = ""):
|
||||
"""
|
||||
在线程池中执行的同步注册任务
|
||||
|
||||
@@ -310,7 +314,7 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy:
|
||||
email_service = EmailServiceFactory.create(service_type, config)
|
||||
|
||||
# 创建注册引擎 - 使用 TaskManager 的日志回调
|
||||
log_callback = task_manager.create_log_callback(task_uuid)
|
||||
log_callback = task_manager.create_log_callback(task_uuid, prefix=log_prefix, batch_id=batch_id)
|
||||
|
||||
engine = RegistrationEngine(
|
||||
email_service=email_service,
|
||||
@@ -373,7 +377,7 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy:
|
||||
pass
|
||||
|
||||
|
||||
async def run_registration_task(task_uuid: str, email_service_type: str, proxy: Optional[str], email_service_config: Optional[dict], email_service_id: Optional[int] = None):
|
||||
async def run_registration_task(task_uuid: str, email_service_type: str, proxy: Optional[str], email_service_config: Optional[dict], email_service_id: Optional[int] = None, log_prefix: str = "", batch_id: str = ""):
|
||||
"""
|
||||
异步执行注册任务
|
||||
|
||||
@@ -386,10 +390,10 @@ async def run_registration_task(task_uuid: str, email_service_type: str, proxy:
|
||||
|
||||
# 初始化 TaskManager 状态
|
||||
task_manager.update_status(task_uuid, "pending")
|
||||
task_manager.add_log(task_uuid, f"[系统] 任务 {task_uuid[:8]} 已加入队列")
|
||||
task_manager.add_log(task_uuid, f"{log_prefix} [系统] 任务 {task_uuid[:8]} 已加入队列" if log_prefix else f"[系统] 任务 {task_uuid[:8]} 已加入队列")
|
||||
|
||||
try:
|
||||
# 在线程池中执行同步任务
|
||||
# 在线程池中执行同步任务(传入 log_prefix 和 batch_id 供回调使用)
|
||||
await loop.run_in_executor(
|
||||
task_manager.executor,
|
||||
_run_sync_registration_task,
|
||||
@@ -397,7 +401,9 @@ async def run_registration_task(task_uuid: str, email_service_type: str, proxy:
|
||||
email_service_type,
|
||||
proxy,
|
||||
email_service_config,
|
||||
email_service_id
|
||||
email_service_id,
|
||||
log_prefix,
|
||||
batch_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"线程池执行异常: {task_uuid}, 错误: {e}")
|
||||
@@ -405,6 +411,172 @@ async def run_registration_task(task_uuid: str, email_service_type: str, proxy:
|
||||
task_manager.update_status(task_uuid, "failed", error=str(e))
|
||||
|
||||
|
||||
def _init_batch_state(batch_id: str, task_uuids: List[str]):
|
||||
"""初始化批量任务内存状态"""
|
||||
task_manager.init_batch(batch_id, len(task_uuids))
|
||||
batch_tasks[batch_id] = {
|
||||
"total": len(task_uuids),
|
||||
"completed": 0,
|
||||
"success": 0,
|
||||
"failed": 0,
|
||||
"cancelled": False,
|
||||
"task_uuids": task_uuids,
|
||||
"current_index": 0,
|
||||
"logs": [],
|
||||
"finished": False
|
||||
}
|
||||
|
||||
|
||||
def _make_batch_helpers(batch_id: str):
|
||||
"""返回 add_batch_log 和 update_batch_status 辅助函数"""
|
||||
def add_batch_log(msg: str):
|
||||
batch_tasks[batch_id]["logs"].append(msg)
|
||||
task_manager.add_batch_log(batch_id, msg)
|
||||
|
||||
def update_batch_status(**kwargs):
|
||||
for key, value in kwargs.items():
|
||||
if key in batch_tasks[batch_id]:
|
||||
batch_tasks[batch_id][key] = value
|
||||
task_manager.update_batch_status(batch_id, **kwargs)
|
||||
|
||||
return add_batch_log, update_batch_status
|
||||
|
||||
|
||||
async def run_batch_parallel(
|
||||
batch_id: str,
|
||||
task_uuids: List[str],
|
||||
email_service_type: str,
|
||||
proxy: Optional[str],
|
||||
email_service_config: Optional[dict],
|
||||
email_service_id: Optional[int],
|
||||
concurrency: int
|
||||
):
|
||||
"""
|
||||
并行模式:所有任务同时提交,Semaphore 控制最大并发数
|
||||
"""
|
||||
_init_batch_state(batch_id, task_uuids)
|
||||
add_batch_log, update_batch_status = _make_batch_helpers(batch_id)
|
||||
semaphore = asyncio.Semaphore(concurrency)
|
||||
counter_lock = asyncio.Lock()
|
||||
add_batch_log(f"[系统] 并行模式启动,并发数: {concurrency},总任务: {len(task_uuids)}")
|
||||
|
||||
async def _run_one(idx: int, uuid: str):
|
||||
prefix = f"[任务{idx + 1}]"
|
||||
async with semaphore:
|
||||
await run_registration_task(
|
||||
uuid, email_service_type, proxy, email_service_config, email_service_id,
|
||||
log_prefix=prefix, batch_id=batch_id
|
||||
)
|
||||
with get_db() as db:
|
||||
t = crud.get_registration_task(db, uuid)
|
||||
if t:
|
||||
async with counter_lock:
|
||||
new_completed = batch_tasks[batch_id]["completed"] + 1
|
||||
new_success = batch_tasks[batch_id]["success"]
|
||||
new_failed = batch_tasks[batch_id]["failed"]
|
||||
if t.status == "completed":
|
||||
new_success += 1
|
||||
add_batch_log(f"{prefix} [成功] 注册成功")
|
||||
elif t.status == "failed":
|
||||
new_failed += 1
|
||||
add_batch_log(f"{prefix} [失败] 注册失败: {t.error_message}")
|
||||
update_batch_status(completed=new_completed, success=new_success, failed=new_failed)
|
||||
|
||||
try:
|
||||
await asyncio.gather(*[_run_one(i, u) for i, u in enumerate(task_uuids)], return_exceptions=True)
|
||||
if not task_manager.is_batch_cancelled(batch_id):
|
||||
add_batch_log(f"[完成] 批量任务完成!成功: {batch_tasks[batch_id]['success']}, 失败: {batch_tasks[batch_id]['failed']}")
|
||||
update_batch_status(finished=True, status="completed")
|
||||
else:
|
||||
update_batch_status(finished=True, status="cancelled")
|
||||
except Exception as e:
|
||||
logger.error(f"批量任务 {batch_id} 异常: {e}")
|
||||
add_batch_log(f"[错误] 批量任务异常: {str(e)}")
|
||||
update_batch_status(finished=True, status="failed")
|
||||
finally:
|
||||
batch_tasks[batch_id]["finished"] = True
|
||||
|
||||
|
||||
async def run_batch_pipeline(
|
||||
batch_id: str,
|
||||
task_uuids: List[str],
|
||||
email_service_type: str,
|
||||
proxy: Optional[str],
|
||||
email_service_config: Optional[dict],
|
||||
email_service_id: Optional[int],
|
||||
interval_min: int,
|
||||
interval_max: int,
|
||||
concurrency: int
|
||||
):
|
||||
"""
|
||||
流水线模式:每隔 interval 秒启动一个新任务,Semaphore 限制最大并发数
|
||||
"""
|
||||
_init_batch_state(batch_id, task_uuids)
|
||||
add_batch_log, update_batch_status = _make_batch_helpers(batch_id)
|
||||
semaphore = asyncio.Semaphore(concurrency)
|
||||
counter_lock = asyncio.Lock()
|
||||
running_tasks_list = []
|
||||
add_batch_log(f"[系统] 流水线模式启动,并发数: {concurrency},总任务: {len(task_uuids)}")
|
||||
|
||||
async def _run_and_release(idx: int, uuid: str, pfx: str):
|
||||
try:
|
||||
await run_registration_task(
|
||||
uuid, email_service_type, proxy, email_service_config, email_service_id,
|
||||
log_prefix=pfx, batch_id=batch_id
|
||||
)
|
||||
with get_db() as db:
|
||||
t = crud.get_registration_task(db, uuid)
|
||||
if t:
|
||||
async with counter_lock:
|
||||
new_completed = batch_tasks[batch_id]["completed"] + 1
|
||||
new_success = batch_tasks[batch_id]["success"]
|
||||
new_failed = batch_tasks[batch_id]["failed"]
|
||||
if t.status == "completed":
|
||||
new_success += 1
|
||||
add_batch_log(f"{pfx} [成功] 注册成功")
|
||||
elif t.status == "failed":
|
||||
new_failed += 1
|
||||
add_batch_log(f"{pfx} [失败] 注册失败: {t.error_message}")
|
||||
update_batch_status(completed=new_completed, success=new_success, failed=new_failed)
|
||||
finally:
|
||||
semaphore.release()
|
||||
|
||||
try:
|
||||
for i, task_uuid in enumerate(task_uuids):
|
||||
if task_manager.is_batch_cancelled(batch_id) or batch_tasks[batch_id]["cancelled"]:
|
||||
with get_db() as db:
|
||||
for remaining_uuid in task_uuids[i:]:
|
||||
crud.update_registration_task(db, remaining_uuid, status="cancelled")
|
||||
add_batch_log("[取消] 批量任务已取消")
|
||||
update_batch_status(finished=True, status="cancelled")
|
||||
break
|
||||
|
||||
update_batch_status(current_index=i)
|
||||
await semaphore.acquire()
|
||||
prefix = f"[任务{i + 1}]"
|
||||
add_batch_log(f"{prefix} 开始注册...")
|
||||
t = asyncio.create_task(_run_and_release(i, task_uuid, prefix))
|
||||
running_tasks_list.append(t)
|
||||
|
||||
if i < len(task_uuids) - 1 and not task_manager.is_batch_cancelled(batch_id):
|
||||
wait_time = random.randint(interval_min, interval_max)
|
||||
logger.info(f"批量任务 {batch_id}: 等待 {wait_time} 秒后启动下一个任务")
|
||||
await asyncio.sleep(wait_time)
|
||||
|
||||
if running_tasks_list:
|
||||
await asyncio.gather(*running_tasks_list, return_exceptions=True)
|
||||
|
||||
if not task_manager.is_batch_cancelled(batch_id):
|
||||
add_batch_log(f"[完成] 批量任务完成!成功: {batch_tasks[batch_id]['success']}, 失败: {batch_tasks[batch_id]['failed']}")
|
||||
update_batch_status(finished=True, status="completed")
|
||||
except Exception as e:
|
||||
logger.error(f"批量任务 {batch_id} 异常: {e}")
|
||||
add_batch_log(f"[错误] 批量任务异常: {str(e)}")
|
||||
update_batch_status(finished=True, status="failed")
|
||||
finally:
|
||||
batch_tasks[batch_id]["finished"] = True
|
||||
|
||||
|
||||
async def run_batch_registration(
|
||||
batch_id: str,
|
||||
task_uuids: List[str],
|
||||
@@ -413,95 +585,22 @@ async def run_batch_registration(
|
||||
email_service_config: Optional[dict],
|
||||
email_service_id: Optional[int],
|
||||
interval_min: int,
|
||||
interval_max: int
|
||||
interval_max: int,
|
||||
concurrency: int = 1,
|
||||
mode: str = "pipeline"
|
||||
):
|
||||
"""
|
||||
异步执行批量注册任务
|
||||
|
||||
使用线程池执行每个注册任务,避免阻塞主事件循环
|
||||
"""
|
||||
# 初始化 TaskManager 批量任务(支持 WebSocket 推送)
|
||||
task_manager.init_batch(batch_id, len(task_uuids))
|
||||
|
||||
batch_tasks[batch_id] = {
|
||||
"total": len(task_uuids),
|
||||
"completed": 0,
|
||||
"success": 0,
|
||||
"failed": 0,
|
||||
"cancelled": False,
|
||||
"task_uuids": task_uuids,
|
||||
"current_index": 0
|
||||
}
|
||||
|
||||
def add_batch_log(msg: str):
|
||||
batch_tasks[batch_id]["logs"] = batch_tasks[batch_id].get("logs", [])
|
||||
batch_tasks[batch_id]["logs"].append(msg)
|
||||
task_manager.add_batch_log(batch_id, msg)
|
||||
|
||||
def update_batch_status(**kwargs):
|
||||
for key, value in kwargs.items():
|
||||
if key in batch_tasks[batch_id]:
|
||||
batch_tasks[batch_id][key] = value
|
||||
task_manager.update_batch_status(batch_id, **kwargs)
|
||||
|
||||
try:
|
||||
for i, task_uuid in enumerate(task_uuids):
|
||||
# 检查是否已取消
|
||||
if task_manager.is_batch_cancelled(batch_id) or batch_tasks[batch_id]["cancelled"]:
|
||||
# 取消剩余任务
|
||||
with get_db() as db:
|
||||
for remaining_uuid in task_uuids[i:]:
|
||||
crud.update_registration_task(db, remaining_uuid, status="cancelled")
|
||||
add_batch_log(f"[取消] 批量任务已取消")
|
||||
update_batch_status(finished=True, status="cancelled")
|
||||
logger.info(f"批量任务 {batch_id} 已取消")
|
||||
break
|
||||
|
||||
update_batch_status(current_index=i)
|
||||
|
||||
# 运行单个注册任务(使用线程池)
|
||||
await run_registration_task(
|
||||
task_uuid, email_service_type, proxy, email_service_config, email_service_id
|
||||
)
|
||||
|
||||
# 更新统计
|
||||
with get_db() as db:
|
||||
task = crud.get_registration_task(db, task_uuid)
|
||||
if task:
|
||||
new_completed = batch_tasks[batch_id]["completed"] + 1
|
||||
new_success = batch_tasks[batch_id]["success"]
|
||||
new_failed = batch_tasks[batch_id]["failed"]
|
||||
|
||||
if task.status == "completed":
|
||||
new_success += 1
|
||||
add_batch_log(f"[成功] 第 {new_success} 个账号注册成功")
|
||||
elif task.status == "failed":
|
||||
new_failed += 1
|
||||
add_batch_log(f"[失败] 第 {new_failed} 个账号注册失败: {task.error_message}")
|
||||
|
||||
update_batch_status(
|
||||
completed=new_completed,
|
||||
success=new_success,
|
||||
failed=new_failed
|
||||
)
|
||||
|
||||
# 如果不是最后一个任务,等待随机间隔
|
||||
if i < len(task_uuids) - 1 and not task_manager.is_batch_cancelled(batch_id):
|
||||
wait_time = random.randint(interval_min, interval_max)
|
||||
logger.info(f"批量任务 {batch_id}: 等待 {wait_time} 秒后继续下一个任务")
|
||||
await asyncio.sleep(wait_time)
|
||||
|
||||
if not task_manager.is_batch_cancelled(batch_id):
|
||||
add_batch_log(f"[完成] 批量任务完成!成功: {batch_tasks[batch_id]['success']}, 失败: {batch_tasks[batch_id]['failed']}")
|
||||
update_batch_status(finished=True, status="completed")
|
||||
logger.info(f"批量任务 {batch_id} 完成: 成功 {batch_tasks[batch_id]['success']}, 失败 {batch_tasks[batch_id]['failed']}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"批量任务 {batch_id} 异常: {e}")
|
||||
add_batch_log(f"[错误] 批量任务异常: {str(e)}")
|
||||
update_batch_status(finished=True, status="failed")
|
||||
finally:
|
||||
batch_tasks[batch_id]["finished"] = True
|
||||
"""根据 mode 分发到并行或流水线执行"""
|
||||
if mode == "parallel":
|
||||
await run_batch_parallel(
|
||||
batch_id, task_uuids, email_service_type, proxy,
|
||||
email_service_config, email_service_id, concurrency
|
||||
)
|
||||
else:
|
||||
await run_batch_pipeline(
|
||||
batch_id, task_uuids, email_service_type, proxy,
|
||||
email_service_config, email_service_id,
|
||||
interval_min, interval_max, concurrency
|
||||
)
|
||||
|
||||
|
||||
# ============== API Endpoints ==============
|
||||
@@ -579,6 +678,12 @@ async def start_batch_registration(
|
||||
if request.interval_min < 0 or request.interval_max < request.interval_min:
|
||||
raise HTTPException(status_code=400, detail="间隔时间参数无效")
|
||||
|
||||
if not 1 <= request.concurrency <= 50:
|
||||
raise HTTPException(status_code=400, detail="并发数必须在 1-50 之间")
|
||||
|
||||
if request.mode not in ("parallel", "pipeline"):
|
||||
raise HTTPException(status_code=400, detail="模式必须为 parallel 或 pipeline")
|
||||
|
||||
# 创建批量任务
|
||||
batch_id = str(uuid.uuid4())
|
||||
task_uuids = []
|
||||
@@ -607,7 +712,9 @@ async def start_batch_registration(
|
||||
request.email_service_config,
|
||||
request.email_service_id,
|
||||
request.interval_min,
|
||||
request.interval_max
|
||||
request.interval_max,
|
||||
request.concurrency,
|
||||
request.mode
|
||||
)
|
||||
|
||||
return BatchRegistrationResponse(
|
||||
@@ -903,168 +1010,52 @@ async def get_outlook_accounts_for_registration():
|
||||
)
|
||||
|
||||
|
||||
def _run_sync_outlook_batch_registration(
|
||||
batch_id: str,
|
||||
service_ids: List[int],
|
||||
skip_registered: bool,
|
||||
proxy: Optional[str],
|
||||
interval_min: int,
|
||||
interval_max: int
|
||||
):
|
||||
"""
|
||||
在线程池中执行的同步 Outlook 批量注册任务
|
||||
"""
|
||||
from ...database.models import EmailService as EmailServiceModel
|
||||
from ...database.models import Account
|
||||
|
||||
# 初始化 TaskManager 批量任务
|
||||
task_manager.init_batch(batch_id, len(service_ids))
|
||||
|
||||
# 兼容旧版 batch_tasks(用于 REST API 轮询降级)
|
||||
batch_tasks[batch_id] = {
|
||||
"total": len(service_ids),
|
||||
"completed": 0,
|
||||
"success": 0,
|
||||
"failed": 0,
|
||||
"skipped": 0,
|
||||
"cancelled": False,
|
||||
"service_ids": service_ids,
|
||||
"current_index": 0,
|
||||
"logs": []
|
||||
}
|
||||
|
||||
def add_batch_log(msg: str):
|
||||
"""同时添加日志到两个系统"""
|
||||
batch_tasks[batch_id]["logs"].append(msg)
|
||||
task_manager.add_batch_log(batch_id, msg)
|
||||
|
||||
def update_batch_status(**kwargs):
|
||||
"""同时更新两个系统的状态"""
|
||||
for key, value in kwargs.items():
|
||||
if key in batch_tasks[batch_id]:
|
||||
batch_tasks[batch_id][key] = value
|
||||
task_manager.update_batch_status(batch_id, **kwargs)
|
||||
|
||||
try:
|
||||
for i, service_id in enumerate(service_ids):
|
||||
# 检查是否已取消
|
||||
if task_manager.is_batch_cancelled(batch_id):
|
||||
add_batch_log(f"[取消] 批量任务已取消")
|
||||
update_batch_status(finished=True, status="cancelled")
|
||||
logger.info(f"Outlook 批量任务 {batch_id} 已取消")
|
||||
break
|
||||
|
||||
update_batch_status(current_index=i)
|
||||
|
||||
with get_db() as db:
|
||||
# 获取邮箱服务
|
||||
service = db.query(EmailServiceModel).filter(
|
||||
EmailServiceModel.id == service_id
|
||||
).first()
|
||||
|
||||
if not service:
|
||||
add_batch_log(f"[跳过] 服务 ID {service_id} 不存在")
|
||||
update_batch_status(skipped=batch_tasks[batch_id]["skipped"] + 1,
|
||||
completed=batch_tasks[batch_id]["completed"] + 1)
|
||||
continue
|
||||
|
||||
config = service.config or {}
|
||||
email = config.get("email") or service.name
|
||||
|
||||
# 检查是否已注册
|
||||
if skip_registered:
|
||||
existing_account = db.query(Account).filter(
|
||||
Account.email == email
|
||||
).first()
|
||||
|
||||
if existing_account:
|
||||
add_batch_log(f"[跳过] {email} 已注册 (账号 ID: {existing_account.id})")
|
||||
update_batch_status(skipped=batch_tasks[batch_id]["skipped"] + 1,
|
||||
completed=batch_tasks[batch_id]["completed"] + 1)
|
||||
continue
|
||||
|
||||
# 创建注册任务
|
||||
task_uuid = str(uuid.uuid4())
|
||||
task = crud.create_registration_task(
|
||||
db,
|
||||
task_uuid=task_uuid,
|
||||
proxy=proxy,
|
||||
email_service_id=service_id
|
||||
)
|
||||
|
||||
add_batch_log(f"[注册] 开始注册 {email}...")
|
||||
|
||||
# 同步执行注册任务
|
||||
_run_sync_registration_task(task_uuid, "outlook", proxy, None, service_id)
|
||||
|
||||
# 更新统计
|
||||
with get_db() as db:
|
||||
task = crud.get_registration_task(db, task_uuid)
|
||||
if task:
|
||||
new_completed = batch_tasks[batch_id]["completed"] + 1
|
||||
new_success = batch_tasks[batch_id]["success"]
|
||||
new_failed = batch_tasks[batch_id]["failed"]
|
||||
|
||||
if task.status == "completed":
|
||||
new_success += 1
|
||||
add_batch_log(f"[成功] {email} 注册成功")
|
||||
elif task.status == "failed":
|
||||
new_failed += 1
|
||||
add_batch_log(f"[失败] {email} 注册失败: {task.error_message}")
|
||||
|
||||
update_batch_status(
|
||||
completed=new_completed,
|
||||
success=new_success,
|
||||
failed=new_failed
|
||||
)
|
||||
|
||||
# 如果不是最后一个任务,等待随机间隔
|
||||
if i < len(service_ids) - 1 and not task_manager.is_batch_cancelled(batch_id):
|
||||
wait_time = random.randint(interval_min, interval_max)
|
||||
logger.info(f"Outlook 批量任务 {batch_id}: 等待 {wait_time} 秒后继续下一个任务")
|
||||
import time
|
||||
time.sleep(wait_time)
|
||||
|
||||
# 完成批量任务
|
||||
if not task_manager.is_batch_cancelled(batch_id):
|
||||
add_batch_log(f"[完成] 批量任务完成!成功: {batch_tasks[batch_id]['success']}, 失败: {batch_tasks[batch_id]['failed']}, 跳过: {batch_tasks[batch_id]['skipped']}")
|
||||
update_batch_status(finished=True, status="completed")
|
||||
logger.info(f"Outlook 批量任务 {batch_id} 完成: 成功 {batch_tasks[batch_id]['success']}, 失败 {batch_tasks[batch_id]['failed']}, 跳过 {batch_tasks[batch_id]['skipped']}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Outlook 批量任务 {batch_id} 异常: {e}")
|
||||
add_batch_log(f"[错误] 批量任务异常: {str(e)}")
|
||||
update_batch_status(finished=True, status="failed")
|
||||
|
||||
|
||||
async def run_outlook_batch_registration(
|
||||
batch_id: str,
|
||||
service_ids: List[int],
|
||||
skip_registered: bool,
|
||||
proxy: Optional[str],
|
||||
interval_min: int,
|
||||
interval_max: int
|
||||
interval_max: int,
|
||||
concurrency: int = 1,
|
||||
mode: str = "pipeline"
|
||||
):
|
||||
"""
|
||||
异步执行 Outlook 批量注册任务
|
||||
异步执行 Outlook 批量注册任务,复用通用并发逻辑
|
||||
|
||||
使用线程池执行,避免阻塞主事件循环
|
||||
将每个 service_id 映射为一个独立的 task_uuid,然后调用
|
||||
run_batch_registration 的并发逻辑
|
||||
"""
|
||||
loop = task_manager.get_loop()
|
||||
if loop is None:
|
||||
loop = asyncio.get_event_loop()
|
||||
task_manager.set_loop(loop)
|
||||
|
||||
# 在线程池中执行
|
||||
await loop.run_in_executor(
|
||||
task_manager.executor,
|
||||
_run_sync_outlook_batch_registration,
|
||||
batch_id,
|
||||
service_ids,
|
||||
skip_registered,
|
||||
proxy,
|
||||
interval_min,
|
||||
interval_max
|
||||
# 预先为每个 service_id 创建注册任务记录
|
||||
task_uuids = []
|
||||
with get_db() as db:
|
||||
for service_id in service_ids:
|
||||
task_uuid = str(uuid.uuid4())
|
||||
crud.create_registration_task(
|
||||
db,
|
||||
task_uuid=task_uuid,
|
||||
proxy=proxy,
|
||||
email_service_id=service_id
|
||||
)
|
||||
task_uuids.append(task_uuid)
|
||||
|
||||
# 复用通用并发逻辑(outlook 服务类型,每个任务通过 email_service_id 定位账户)
|
||||
await run_batch_registration(
|
||||
batch_id=batch_id,
|
||||
task_uuids=task_uuids,
|
||||
email_service_type="outlook",
|
||||
proxy=proxy,
|
||||
email_service_config=None,
|
||||
email_service_id=None, # 每个任务已绑定了独立的 email_service_id
|
||||
interval_min=interval_min,
|
||||
interval_max=interval_max,
|
||||
concurrency=concurrency,
|
||||
mode=mode
|
||||
)
|
||||
|
||||
|
||||
@@ -1092,6 +1083,12 @@ async def start_outlook_batch_registration(
|
||||
if request.interval_min < 0 or request.interval_max < request.interval_min:
|
||||
raise HTTPException(status_code=400, detail="间隔时间参数无效")
|
||||
|
||||
if not 1 <= request.concurrency <= 50:
|
||||
raise HTTPException(status_code=400, detail="并发数必须在 1-50 之间")
|
||||
|
||||
if request.mode not in ("parallel", "pipeline"):
|
||||
raise HTTPException(status_code=400, detail="模式必须为 parallel 或 pipeline")
|
||||
|
||||
# 过滤掉已注册的邮箱
|
||||
actual_service_ids = request.service_ids
|
||||
skipped_count = 0
|
||||
@@ -1154,7 +1151,9 @@ async def start_outlook_batch_registration(
|
||||
request.skip_registered,
|
||||
request.proxy,
|
||||
request.interval_min,
|
||||
request.interval_max
|
||||
request.interval_max,
|
||||
request.concurrency,
|
||||
request.mode
|
||||
)
|
||||
|
||||
return OutlookBatchRegistrationResponse(
|
||||
|
||||
@@ -13,12 +13,15 @@ from datetime import datetime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 全局线程池
|
||||
_executor = ThreadPoolExecutor(max_workers=4, thread_name_prefix="reg_worker")
|
||||
# 全局线程池(支持最多 50 个并发注册任务)
|
||||
_executor = ThreadPoolExecutor(max_workers=50, thread_name_prefix="reg_worker")
|
||||
|
||||
# 全局元锁:保护所有 defaultdict 的首次 key 创建(避免多线程竞态)
|
||||
_meta_lock = threading.Lock()
|
||||
|
||||
# 任务日志队列 (task_uuid -> list of logs)
|
||||
_log_queues: Dict[str, List[str]] = defaultdict(list)
|
||||
_log_locks: Dict[str, threading.Lock] = defaultdict(threading.Lock)
|
||||
_log_locks: Dict[str, threading.Lock] = {}
|
||||
|
||||
# WebSocket 连接管理 (task_uuid -> list of websockets)
|
||||
_ws_connections: Dict[str, List] = defaultdict(list)
|
||||
@@ -36,7 +39,25 @@ _task_cancelled: Dict[str, bool] = {}
|
||||
# 批量任务状态 (batch_id -> dict)
|
||||
_batch_status: Dict[str, dict] = {}
|
||||
_batch_logs: Dict[str, List[str]] = defaultdict(list)
|
||||
_batch_locks: Dict[str, threading.Lock] = defaultdict(threading.Lock)
|
||||
_batch_locks: Dict[str, threading.Lock] = {}
|
||||
|
||||
|
||||
def _get_log_lock(task_uuid: str) -> threading.Lock:
|
||||
"""线程安全地获取或创建任务日志锁"""
|
||||
if task_uuid not in _log_locks:
|
||||
with _meta_lock:
|
||||
if task_uuid not in _log_locks:
|
||||
_log_locks[task_uuid] = threading.Lock()
|
||||
return _log_locks[task_uuid]
|
||||
|
||||
|
||||
def _get_batch_lock(batch_id: str) -> threading.Lock:
|
||||
"""线程安全地获取或创建批量任务日志锁"""
|
||||
if batch_id not in _batch_locks:
|
||||
with _meta_lock:
|
||||
if batch_id not in _batch_locks:
|
||||
_batch_locks[batch_id] = threading.Lock()
|
||||
return _batch_locks[batch_id]
|
||||
|
||||
|
||||
class TaskManager:
|
||||
@@ -77,7 +98,7 @@ class TaskManager:
|
||||
logger.warning(f"推送日志到 WebSocket 失败: {e}")
|
||||
|
||||
# 广播后再添加到队列
|
||||
with _log_locks[task_uuid]:
|
||||
with _get_log_lock(task_uuid):
|
||||
_log_queues[task_uuid].append(log_message)
|
||||
|
||||
async def _broadcast_log(self, task_uuid: str, log_message: str):
|
||||
@@ -132,7 +153,7 @@ class TaskManager:
|
||||
if websocket not in _ws_connections[task_uuid]:
|
||||
_ws_connections[task_uuid].append(websocket)
|
||||
# 记录已发送的日志数量,用于发送历史日志时避免重复
|
||||
with _log_locks[task_uuid]:
|
||||
with _get_log_lock(task_uuid):
|
||||
_ws_sent_index[task_uuid][id(websocket)] = len(_log_queues.get(task_uuid, []))
|
||||
logger.info(f"WebSocket 连接已注册: {task_uuid}")
|
||||
else:
|
||||
@@ -144,7 +165,7 @@ class TaskManager:
|
||||
ws_id = id(websocket)
|
||||
sent_count = _ws_sent_index.get(task_uuid, {}).get(ws_id, 0)
|
||||
|
||||
with _log_locks[task_uuid]:
|
||||
with _get_log_lock(task_uuid):
|
||||
all_logs = _log_queues.get(task_uuid, [])
|
||||
unsent_logs = all_logs[sent_count:]
|
||||
# 更新已发送索引
|
||||
@@ -166,7 +187,7 @@ class TaskManager:
|
||||
|
||||
def get_logs(self, task_uuid: str) -> List[str]:
|
||||
"""获取任务的所有日志"""
|
||||
with _log_locks[task_uuid]:
|
||||
with _get_log_lock(task_uuid):
|
||||
return _log_queues.get(task_uuid, []).copy()
|
||||
|
||||
def update_status(self, task_uuid: str, status: str, **kwargs):
|
||||
@@ -217,7 +238,7 @@ class TaskManager:
|
||||
logger.warning(f"推送批量日志到 WebSocket 失败: {e}")
|
||||
|
||||
# 广播后再添加到队列
|
||||
with _batch_locks[batch_id]:
|
||||
with _get_batch_lock(batch_id):
|
||||
_batch_logs[batch_id].append(log_message)
|
||||
|
||||
async def _broadcast_batch_log(self, batch_id: str, log_message: str):
|
||||
@@ -285,7 +306,7 @@ class TaskManager:
|
||||
|
||||
def get_batch_logs(self, batch_id: str) -> List[str]:
|
||||
"""获取批量任务日志"""
|
||||
with _batch_locks[batch_id]:
|
||||
with _get_batch_lock(batch_id):
|
||||
return _batch_logs.get(batch_id, []).copy()
|
||||
|
||||
def is_batch_cancelled(self, batch_id: str) -> bool:
|
||||
@@ -310,7 +331,7 @@ class TaskManager:
|
||||
if websocket not in _ws_connections[key]:
|
||||
_ws_connections[key].append(websocket)
|
||||
# 记录已发送的日志数量,用于发送历史日志时避免重复
|
||||
with _batch_locks[batch_id]:
|
||||
with _get_batch_lock(batch_id):
|
||||
_ws_sent_index[key][id(websocket)] = len(_batch_logs.get(batch_id, []))
|
||||
logger.info(f"批量任务 WebSocket 连接已注册: {batch_id}")
|
||||
else:
|
||||
@@ -323,7 +344,7 @@ class TaskManager:
|
||||
ws_id = id(websocket)
|
||||
sent_count = _ws_sent_index.get(key, {}).get(ws_id, 0)
|
||||
|
||||
with _batch_locks[batch_id]:
|
||||
with _get_batch_lock(batch_id):
|
||||
all_logs = _batch_logs.get(batch_id, [])
|
||||
unsent_logs = all_logs[sent_count:]
|
||||
# 更新已发送索引
|
||||
@@ -344,10 +365,14 @@ class TaskManager:
|
||||
_ws_sent_index[key].pop(id(websocket), None)
|
||||
logger.info(f"批量任务 WebSocket 连接已注销: {batch_id}")
|
||||
|
||||
def create_log_callback(self, task_uuid: str) -> Callable[[str], None]:
|
||||
"""创建日志回调函数"""
|
||||
def create_log_callback(self, task_uuid: str, prefix: str = "", batch_id: str = "") -> Callable[[str], None]:
|
||||
"""创建日志回调函数,可附加任务编号前缀,并同时推送到批量任务频道"""
|
||||
def callback(msg: str):
|
||||
self.add_log(task_uuid, msg)
|
||||
full_msg = f"{prefix} {msg}" if prefix else msg
|
||||
self.add_log(task_uuid, full_msg)
|
||||
# 如果属于批量任务,同步推送到 batch 频道,前端可在混合日志中看到详细步骤
|
||||
if batch_id:
|
||||
self.add_batch_log(batch_id, full_msg)
|
||||
return callback
|
||||
|
||||
def create_check_cancelled_callback(self, task_uuid: str) -> Callable[[], bool]:
|
||||
|
||||
Reference in New Issue
Block a user