mirror of
https://github.com/cnlimiter/codex-register.git
synced 2026-05-07 08:02:51 +08:00
fix(logic): isolate atomic batch counters and token sync fields
This commit is contained in:
@@ -10,6 +10,15 @@ from sqlalchemy import and_, or_, desc, asc, func
|
||||
from .models import Account, EmailService, RegistrationTask, Setting, Proxy, CpaService, Sub2ApiService
|
||||
|
||||
|
||||
TOKEN_FIELD_NAMES = ("access_token", "refresh_token", "id_token", "session_token")
|
||||
|
||||
|
||||
def _default_token_sync_status(token_values: Dict[str, Any]) -> str:
|
||||
"""根据当前持久化的 token 内容推导同步状态。"""
|
||||
has_token = any(bool(token_values.get(field)) for field in TOKEN_FIELD_NAMES)
|
||||
return "pending" if has_token else "not_ready"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 账户 CRUD
|
||||
# ============================================================================
|
||||
@@ -31,9 +40,16 @@ def create_account(
|
||||
expires_at: Optional['datetime'] = None,
|
||||
extra_data: Optional[Dict[str, Any]] = None,
|
||||
status: Optional[str] = None,
|
||||
source: Optional[str] = None
|
||||
source: Optional[str] = None,
|
||||
token_sync_status: Optional[str] = None,
|
||||
) -> Account:
|
||||
"""创建新账户"""
|
||||
token_values = {
|
||||
"access_token": access_token,
|
||||
"refresh_token": refresh_token,
|
||||
"id_token": id_token,
|
||||
"session_token": session_token,
|
||||
}
|
||||
db_account = Account(
|
||||
email=email,
|
||||
password=password,
|
||||
@@ -51,7 +67,9 @@ def create_account(
|
||||
extra_data=extra_data or {},
|
||||
status=status or 'active',
|
||||
source=source or 'register',
|
||||
registered_at=datetime.utcnow()
|
||||
registered_at=datetime.utcnow(),
|
||||
token_sync_status=token_sync_status or _default_token_sync_status(token_values),
|
||||
token_sync_updated_at=datetime.utcnow(),
|
||||
)
|
||||
db.add(db_account)
|
||||
db.commit()
|
||||
@@ -108,6 +126,15 @@ def update_account(
|
||||
if not db_account:
|
||||
return None
|
||||
|
||||
touches_token = any(field in kwargs for field in TOKEN_FIELD_NAMES)
|
||||
if touches_token:
|
||||
persisted_token_values = {
|
||||
field: kwargs.get(field, getattr(db_account, field))
|
||||
for field in TOKEN_FIELD_NAMES
|
||||
}
|
||||
kwargs.setdefault("token_sync_status", _default_token_sync_status(persisted_token_values))
|
||||
kwargs["token_sync_updated_at"] = datetime.utcnow()
|
||||
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(db_account, key) and value is not None:
|
||||
setattr(db_account, key, value)
|
||||
@@ -724,15 +751,31 @@ def delete_tm_service(db: Session, service_id: int) -> bool:
|
||||
def update_outlook_refresh_token(db: Session, service_id: int, email: str, new_refresh_token: str):
|
||||
"""更新 EmailService.config 中指定邮箱的 refresh_token"""
|
||||
service = db.query(EmailService).filter(EmailService.id == service_id).first()
|
||||
if not service or not service.config:
|
||||
if not service or not isinstance(service.config, dict):
|
||||
return
|
||||
|
||||
normalized_email = (email or "").strip().lower()
|
||||
if not normalized_email or not isinstance(new_refresh_token, str) or not new_refresh_token:
|
||||
return
|
||||
|
||||
config = dict(service.config)
|
||||
updated = False
|
||||
|
||||
# 单账户格式
|
||||
if config.get("email", "").lower() == email.lower():
|
||||
if str(config.get("email", "")).lower() == normalized_email:
|
||||
config["refresh_token"] = new_refresh_token
|
||||
updated = True
|
||||
|
||||
# 多账户列表格式
|
||||
for acc in config.get("accounts", []):
|
||||
if acc.get("email", "").lower() == email.lower():
|
||||
if not isinstance(acc, dict):
|
||||
continue
|
||||
if str(acc.get("email", "")).lower() == normalized_email:
|
||||
acc["refresh_token"] = new_refresh_token
|
||||
updated = True
|
||||
|
||||
if not updated:
|
||||
return
|
||||
|
||||
service.config = config
|
||||
db.commit()
|
||||
|
||||
@@ -39,6 +39,8 @@ class Account(Base):
|
||||
refresh_token = Column(Text)
|
||||
id_token = Column(Text)
|
||||
session_token = Column(Text) # 会话令牌(优先刷新方式)
|
||||
token_sync_status = Column(String(20), default='not_ready') # 'not_ready', 'pending', 'synced'
|
||||
token_sync_updated_at = Column(DateTime, default=datetime.utcnow)
|
||||
client_id = Column(String(255)) # OAuth Client ID
|
||||
account_id = Column(String(255))
|
||||
workspace_id = Column(String(255))
|
||||
@@ -80,7 +82,9 @@ class Account(Base):
|
||||
'subscription_type': self.subscription_type,
|
||||
'subscription_at': self.subscription_at.isoformat() if self.subscription_at else None,
|
||||
'created_at': self.created_at.isoformat() if self.created_at else None,
|
||||
'updated_at': self.updated_at.isoformat() if self.updated_at else None
|
||||
'updated_at': self.updated_at.isoformat() if self.updated_at else None,
|
||||
'token_sync_status': self.token_sync_status,
|
||||
'token_sync_updated_at': self.token_sync_updated_at.isoformat() if self.token_sync_updated_at else None,
|
||||
}
|
||||
|
||||
|
||||
@@ -227,4 +231,4 @@ class Proxy(Base):
|
||||
if self.username and self.password:
|
||||
auth = f"{self.username}:{self.password}@"
|
||||
|
||||
return f"{scheme}://{auth}{self.host}:{self.port}"
|
||||
return f"{scheme}://{auth}{self.host}:{self.port}"
|
||||
|
||||
@@ -110,6 +110,8 @@ class DatabaseSessionManager:
|
||||
("accounts", "subscription_type", "VARCHAR(20)"),
|
||||
("accounts", "subscription_at", "DATETIME"),
|
||||
("accounts", "cookies", "TEXT"),
|
||||
("accounts", "token_sync_status", "VARCHAR(20) DEFAULT 'not_ready'"),
|
||||
("accounts", "token_sync_updated_at", "DATETIME"),
|
||||
("proxies", "is_default", "BOOLEAN DEFAULT 0"),
|
||||
("cpa_services", "include_proxy_url", "BOOLEAN DEFAULT 0"),
|
||||
]
|
||||
|
||||
@@ -7,7 +7,7 @@ import logging
|
||||
import uuid
|
||||
import random
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Dict, Tuple
|
||||
from typing import List, Optional, Dict, Tuple, Any
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query, BackgroundTasks
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -599,34 +599,53 @@ async def run_registration_task(task_uuid: str, email_service_type: str, proxy:
|
||||
def _init_batch_state(batch_id: str, task_uuids: List[str]):
|
||||
"""初始化批量任务内存状态"""
|
||||
task_manager.init_batch(batch_id, len(task_uuids))
|
||||
batch_tasks[batch_id] = {
|
||||
"total": len(task_uuids),
|
||||
"completed": 0,
|
||||
"success": 0,
|
||||
"failed": 0,
|
||||
"cancelled": False,
|
||||
"task_uuids": task_uuids,
|
||||
"current_index": 0,
|
||||
"logs": [],
|
||||
"finished": False
|
||||
}
|
||||
metadata = batch_tasks.get(batch_id, {}).copy()
|
||||
metadata["task_uuids"] = task_uuids
|
||||
batch_tasks[batch_id] = metadata
|
||||
|
||||
|
||||
def _make_batch_helpers(batch_id: str):
|
||||
"""返回 add_batch_log 和 update_batch_status 辅助函数"""
|
||||
def add_batch_log(msg: str):
|
||||
batch_tasks[batch_id]["logs"].append(msg)
|
||||
task_manager.add_batch_log(batch_id, msg)
|
||||
|
||||
def update_batch_status(**kwargs):
|
||||
for key, value in kwargs.items():
|
||||
if key in batch_tasks[batch_id]:
|
||||
batch_tasks[batch_id][key] = value
|
||||
task_manager.update_batch_status(batch_id, **kwargs)
|
||||
|
||||
return add_batch_log, update_batch_status
|
||||
|
||||
|
||||
def _get_batch_snapshot(batch_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""聚合批量任务元数据与实时状态。"""
|
||||
metadata = batch_tasks.get(batch_id)
|
||||
if metadata is None:
|
||||
return None
|
||||
|
||||
status = task_manager.get_batch_status(batch_id) or {}
|
||||
initial_skipped = int(metadata.get("initial_skipped", 0) or 0)
|
||||
total = status.get("total", metadata.get("total", 0))
|
||||
completed = status.get("completed", 0)
|
||||
success = status.get("success", 0)
|
||||
failed = status.get("failed", 0)
|
||||
skipped = status.get("skipped", 0) + initial_skipped
|
||||
|
||||
return {
|
||||
"batch_id": batch_id,
|
||||
"total": total,
|
||||
"completed": completed,
|
||||
"success": success,
|
||||
"failed": failed,
|
||||
"skipped": skipped,
|
||||
"current_index": status.get("current_index", 0),
|
||||
"cancelled": status.get("cancelled", False),
|
||||
"finished": status.get("finished", False),
|
||||
"status": status.get("status", metadata.get("status", "pending")),
|
||||
"logs": task_manager.get_batch_logs(batch_id),
|
||||
"task_uuids": metadata.get("task_uuids", []),
|
||||
"service_ids": metadata.get("service_ids", []),
|
||||
}
|
||||
|
||||
|
||||
async def run_batch_parallel(
|
||||
batch_id: str,
|
||||
task_uuids: List[str],
|
||||
@@ -665,21 +684,19 @@ async def run_batch_parallel(
|
||||
t = crud.get_registration_task(db, uuid)
|
||||
if t:
|
||||
async with counter_lock:
|
||||
new_completed = batch_tasks[batch_id]["completed"] + 1
|
||||
new_success = batch_tasks[batch_id]["success"]
|
||||
new_failed = batch_tasks[batch_id]["failed"]
|
||||
if t.status == "completed":
|
||||
new_success += 1
|
||||
add_batch_log(f"{prefix} [成功] 注册成功")
|
||||
elif t.status == "failed":
|
||||
new_failed += 1
|
||||
add_batch_log(f"{prefix} [失败] 注册失败: {t.error_message}")
|
||||
update_batch_status(completed=new_completed, success=new_success, failed=new_failed)
|
||||
task_manager.record_batch_task_result(batch_id, t.status)
|
||||
|
||||
try:
|
||||
await asyncio.gather(*[_run_one(i, u) for i, u in enumerate(task_uuids)], return_exceptions=True)
|
||||
if not task_manager.is_batch_cancelled(batch_id):
|
||||
add_batch_log(f"[完成] 批量任务完成!成功: {batch_tasks[batch_id]['success']}, 失败: {batch_tasks[batch_id]['failed']}")
|
||||
snapshot = task_manager.get_batch_status(batch_id) or {}
|
||||
add_batch_log(
|
||||
f"[完成] 批量任务完成!成功: {snapshot.get('success', 0)}, 失败: {snapshot.get('failed', 0)}"
|
||||
)
|
||||
update_batch_status(finished=True, status="completed")
|
||||
else:
|
||||
update_batch_status(finished=True, status="cancelled")
|
||||
@@ -687,8 +704,6 @@ async def run_batch_parallel(
|
||||
logger.error(f"批量任务 {batch_id} 异常: {e}")
|
||||
add_batch_log(f"[错误] 批量任务异常: {str(e)}")
|
||||
update_batch_status(finished=True, status="failed")
|
||||
finally:
|
||||
batch_tasks[batch_id]["finished"] = True
|
||||
|
||||
|
||||
async def run_batch_pipeline(
|
||||
@@ -731,22 +746,17 @@ async def run_batch_pipeline(
|
||||
t = crud.get_registration_task(db, uuid)
|
||||
if t:
|
||||
async with counter_lock:
|
||||
new_completed = batch_tasks[batch_id]["completed"] + 1
|
||||
new_success = batch_tasks[batch_id]["success"]
|
||||
new_failed = batch_tasks[batch_id]["failed"]
|
||||
if t.status == "completed":
|
||||
new_success += 1
|
||||
add_batch_log(f"{pfx} [成功] 注册成功")
|
||||
elif t.status == "failed":
|
||||
new_failed += 1
|
||||
add_batch_log(f"{pfx} [失败] 注册失败: {t.error_message}")
|
||||
update_batch_status(completed=new_completed, success=new_success, failed=new_failed)
|
||||
task_manager.record_batch_task_result(batch_id, t.status)
|
||||
finally:
|
||||
semaphore.release()
|
||||
|
||||
try:
|
||||
for i, task_uuid in enumerate(task_uuids):
|
||||
if task_manager.is_batch_cancelled(batch_id) or batch_tasks[batch_id]["cancelled"]:
|
||||
if task_manager.is_batch_cancelled(batch_id):
|
||||
with get_db() as db:
|
||||
for remaining_uuid in task_uuids[i:]:
|
||||
crud.update_registration_task(db, remaining_uuid, status="cancelled")
|
||||
@@ -770,14 +780,15 @@ async def run_batch_pipeline(
|
||||
await asyncio.gather(*running_tasks_list, return_exceptions=True)
|
||||
|
||||
if not task_manager.is_batch_cancelled(batch_id):
|
||||
add_batch_log(f"[完成] 批量任务完成!成功: {batch_tasks[batch_id]['success']}, 失败: {batch_tasks[batch_id]['failed']}")
|
||||
snapshot = task_manager.get_batch_status(batch_id) or {}
|
||||
add_batch_log(
|
||||
f"[完成] 批量任务完成!成功: {snapshot.get('success', 0)}, 失败: {snapshot.get('failed', 0)}"
|
||||
)
|
||||
update_batch_status(finished=True, status="completed")
|
||||
except Exception as e:
|
||||
logger.error(f"批量任务 {batch_id} 异常: {e}")
|
||||
add_batch_log(f"[错误] 批量任务异常: {str(e)}")
|
||||
update_batch_status(finished=True, status="failed")
|
||||
finally:
|
||||
batch_tasks[batch_id]["finished"] = True
|
||||
|
||||
|
||||
async def run_batch_registration(
|
||||
@@ -910,6 +921,7 @@ async def start_batch_registration(
|
||||
# 创建批量任务
|
||||
batch_id = str(uuid.uuid4())
|
||||
task_uuids = []
|
||||
batch_tasks[batch_id] = {"total": request.count}
|
||||
|
||||
with get_db() as db:
|
||||
for _ in range(request.count):
|
||||
@@ -956,34 +968,33 @@ async def start_batch_registration(
|
||||
@router.get("/batch/{batch_id}")
|
||||
async def get_batch_status(batch_id: str):
|
||||
"""获取批量任务状态"""
|
||||
if batch_id not in batch_tasks:
|
||||
snapshot = _get_batch_snapshot(batch_id)
|
||||
if snapshot is None:
|
||||
raise HTTPException(status_code=404, detail="批量任务不存在")
|
||||
|
||||
batch = batch_tasks[batch_id]
|
||||
return {
|
||||
"batch_id": batch_id,
|
||||
"total": batch["total"],
|
||||
"completed": batch["completed"],
|
||||
"success": batch["success"],
|
||||
"failed": batch["failed"],
|
||||
"current_index": batch["current_index"],
|
||||
"cancelled": batch["cancelled"],
|
||||
"finished": batch.get("finished", False),
|
||||
"progress": f"{batch['completed']}/{batch['total']}"
|
||||
"total": snapshot["total"],
|
||||
"completed": snapshot["completed"],
|
||||
"success": snapshot["success"],
|
||||
"failed": snapshot["failed"],
|
||||
"current_index": snapshot["current_index"],
|
||||
"cancelled": snapshot["cancelled"],
|
||||
"finished": snapshot["finished"],
|
||||
"progress": f"{snapshot['completed']}/{snapshot['total']}"
|
||||
}
|
||||
|
||||
|
||||
@router.post("/batch/{batch_id}/cancel")
|
||||
async def cancel_batch(batch_id: str):
|
||||
"""取消批量任务"""
|
||||
if batch_id not in batch_tasks:
|
||||
snapshot = _get_batch_snapshot(batch_id)
|
||||
if snapshot is None:
|
||||
raise HTTPException(status_code=404, detail="批量任务不存在")
|
||||
|
||||
batch = batch_tasks[batch_id]
|
||||
if batch.get("finished"):
|
||||
if snapshot.get("finished"):
|
||||
raise HTTPException(status_code=400, detail="批量任务已完成")
|
||||
|
||||
batch["cancelled"] = True
|
||||
task_manager.cancel_batch(batch_id)
|
||||
return {"success": True, "message": "批量任务取消请求已提交"}
|
||||
|
||||
@@ -1464,18 +1475,11 @@ async def start_outlook_batch_registration(
|
||||
# 创建批量任务
|
||||
batch_id = str(uuid.uuid4())
|
||||
|
||||
# 初始化批量任务状态
|
||||
# 记录额外元数据,由 task_manager 维护实时状态
|
||||
batch_tasks[batch_id] = {
|
||||
"total": len(actual_service_ids),
|
||||
"completed": 0,
|
||||
"success": 0,
|
||||
"failed": 0,
|
||||
"skipped": 0,
|
||||
"cancelled": False,
|
||||
"initial_skipped": skipped_count,
|
||||
"service_ids": actual_service_ids,
|
||||
"current_index": 0,
|
||||
"logs": [],
|
||||
"finished": False
|
||||
}
|
||||
|
||||
# 在后台运行批量注册
|
||||
@@ -1509,37 +1513,35 @@ async def start_outlook_batch_registration(
|
||||
@router.get("/outlook-batch/{batch_id}")
|
||||
async def get_outlook_batch_status(batch_id: str):
|
||||
"""获取 Outlook 批量任务状态"""
|
||||
if batch_id not in batch_tasks:
|
||||
snapshot = _get_batch_snapshot(batch_id)
|
||||
if snapshot is None:
|
||||
raise HTTPException(status_code=404, detail="批量任务不存在")
|
||||
|
||||
batch = batch_tasks[batch_id]
|
||||
return {
|
||||
"batch_id": batch_id,
|
||||
"total": batch["total"],
|
||||
"completed": batch["completed"],
|
||||
"success": batch["success"],
|
||||
"failed": batch["failed"],
|
||||
"skipped": batch.get("skipped", 0),
|
||||
"current_index": batch["current_index"],
|
||||
"cancelled": batch["cancelled"],
|
||||
"finished": batch.get("finished", False),
|
||||
"logs": batch.get("logs", []),
|
||||
"progress": f"{batch['completed']}/{batch['total']}"
|
||||
"total": snapshot["total"],
|
||||
"completed": snapshot["completed"],
|
||||
"success": snapshot["success"],
|
||||
"failed": snapshot["failed"],
|
||||
"skipped": snapshot["skipped"],
|
||||
"current_index": snapshot["current_index"],
|
||||
"cancelled": snapshot["cancelled"],
|
||||
"finished": snapshot["finished"],
|
||||
"logs": snapshot["logs"],
|
||||
"progress": f"{snapshot['completed']}/{snapshot['total']}"
|
||||
}
|
||||
|
||||
|
||||
@router.post("/outlook-batch/{batch_id}/cancel")
|
||||
async def cancel_outlook_batch(batch_id: str):
|
||||
"""取消 Outlook 批量任务"""
|
||||
if batch_id not in batch_tasks:
|
||||
snapshot = _get_batch_snapshot(batch_id)
|
||||
if snapshot is None:
|
||||
raise HTTPException(status_code=404, detail="批量任务不存在")
|
||||
|
||||
batch = batch_tasks[batch_id]
|
||||
if batch.get("finished"):
|
||||
if snapshot.get("finished"):
|
||||
raise HTTPException(status_code=400, detail="批量任务已完成")
|
||||
|
||||
# 同时更新两个系统的取消状态
|
||||
batch["cancelled"] = True
|
||||
task_manager.cancel_batch(batch_id)
|
||||
|
||||
return {"success": True, "message": "批量任务取消请求已提交"}
|
||||
|
||||
@@ -223,16 +223,18 @@ class TaskManager:
|
||||
|
||||
def init_batch(self, batch_id: str, total: int):
|
||||
"""初始化批量任务"""
|
||||
_batch_status[batch_id] = {
|
||||
"status": "running",
|
||||
"total": total,
|
||||
"completed": 0,
|
||||
"success": 0,
|
||||
"failed": 0,
|
||||
"skipped": 0,
|
||||
"current_index": 0,
|
||||
"finished": False
|
||||
}
|
||||
with _get_batch_lock(batch_id):
|
||||
_batch_status[batch_id] = {
|
||||
"status": "running",
|
||||
"total": total,
|
||||
"completed": 0,
|
||||
"success": 0,
|
||||
"failed": 0,
|
||||
"skipped": 0,
|
||||
"current_index": 0,
|
||||
"finished": False,
|
||||
"cancelled": False,
|
||||
}
|
||||
logger.info(f"批量任务 {batch_id} 已初始化,总数: {total}")
|
||||
|
||||
def add_batch_log(self, batch_id: str, log_message: str):
|
||||
@@ -276,11 +278,11 @@ class TaskManager:
|
||||
|
||||
def update_batch_status(self, batch_id: str, **kwargs):
|
||||
"""更新批量任务状态"""
|
||||
if batch_id not in _batch_status:
|
||||
logger.warning(f"批量任务 {batch_id} 不存在")
|
||||
return
|
||||
|
||||
_batch_status[batch_id].update(kwargs)
|
||||
with _get_batch_lock(batch_id):
|
||||
if batch_id not in _batch_status:
|
||||
logger.warning(f"批量任务 {batch_id} 不存在")
|
||||
return
|
||||
_batch_status[batch_id].update(kwargs)
|
||||
|
||||
# 异步广播状态更新
|
||||
if self._loop and self._loop.is_running():
|
||||
@@ -292,6 +294,35 @@ class TaskManager:
|
||||
except Exception as e:
|
||||
logger.warning(f"广播批量状态失败: {e}")
|
||||
|
||||
def record_batch_task_result(self, batch_id: str, task_status: str) -> Optional[dict]:
|
||||
"""原子记录单个子任务的终态并返回快照。"""
|
||||
with _get_batch_lock(batch_id):
|
||||
status = _batch_status.get(batch_id)
|
||||
if status is None:
|
||||
logger.warning(f"批量任务 {batch_id} 不存在")
|
||||
return None
|
||||
|
||||
status["completed"] += 1
|
||||
if task_status == "completed":
|
||||
status["success"] += 1
|
||||
elif task_status == "failed":
|
||||
status["failed"] += 1
|
||||
elif task_status == "cancelled":
|
||||
status["skipped"] += 1
|
||||
|
||||
snapshot = status.copy()
|
||||
|
||||
if self._loop and self._loop.is_running():
|
||||
try:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._broadcast_batch_status(batch_id),
|
||||
self._loop
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"广播批量状态失败: {e}")
|
||||
|
||||
return snapshot
|
||||
|
||||
async def _broadcast_batch_status(self, batch_id: str):
|
||||
"""广播批量任务状态"""
|
||||
with _ws_lock:
|
||||
@@ -312,7 +343,9 @@ class TaskManager:
|
||||
|
||||
def get_batch_status(self, batch_id: str) -> Optional[dict]:
|
||||
"""获取批量任务状态"""
|
||||
return _batch_status.get(batch_id)
|
||||
with _get_batch_lock(batch_id):
|
||||
status = _batch_status.get(batch_id)
|
||||
return status.copy() if status else None
|
||||
|
||||
def get_batch_logs(self, batch_id: str) -> List[str]:
|
||||
"""获取批量任务日志"""
|
||||
@@ -321,15 +354,28 @@ class TaskManager:
|
||||
|
||||
def is_batch_cancelled(self, batch_id: str) -> bool:
|
||||
"""检查批量任务是否已取消"""
|
||||
status = _batch_status.get(batch_id, {})
|
||||
return status.get("cancelled", False)
|
||||
with _get_batch_lock(batch_id):
|
||||
status = _batch_status.get(batch_id, {})
|
||||
return status.get("cancelled", False)
|
||||
|
||||
def cancel_batch(self, batch_id: str):
|
||||
"""取消批量任务"""
|
||||
if batch_id in _batch_status:
|
||||
_batch_status[batch_id]["cancelled"] = True
|
||||
_batch_status[batch_id]["status"] = "cancelling"
|
||||
logger.info(f"批量任务 {batch_id} 已标记为取消")
|
||||
changed = False
|
||||
with _get_batch_lock(batch_id):
|
||||
if batch_id in _batch_status:
|
||||
_batch_status[batch_id]["cancelled"] = True
|
||||
_batch_status[batch_id]["status"] = "cancelling"
|
||||
changed = True
|
||||
logger.info(f"批量任务 {batch_id} 已标记为取消")
|
||||
|
||||
if changed and self._loop and self._loop.is_running():
|
||||
try:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._broadcast_batch_status(batch_id),
|
||||
self._loop
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"广播批量状态失败: {e}")
|
||||
|
||||
def register_batch_websocket(self, batch_id: str, websocket):
|
||||
"""注册批量任务 WebSocket 连接"""
|
||||
|
||||
72
tests/test_account_token_sync_status.py
Normal file
72
tests/test_account_token_sync_status.py
Normal file
@@ -0,0 +1,72 @@
|
||||
from src.database import crud
|
||||
from src.database.session import DatabaseSessionManager
|
||||
|
||||
|
||||
def test_create_account_marks_token_sync_pending_when_tokens_persist(tmp_path):
|
||||
manager = DatabaseSessionManager(f"sqlite:///{tmp_path}/test.db")
|
||||
manager.create_tables()
|
||||
manager.migrate_tables()
|
||||
|
||||
with manager.session_scope() as session:
|
||||
account = crud.create_account(
|
||||
session,
|
||||
email="sync@example.com",
|
||||
email_service="tempmail",
|
||||
access_token="access-token",
|
||||
refresh_token="refresh-token",
|
||||
)
|
||||
|
||||
assert account.token_sync_status == "pending"
|
||||
assert account.token_sync_updated_at is not None
|
||||
|
||||
|
||||
def test_update_account_marks_token_sync_pending_when_tokens_change(tmp_path):
|
||||
manager = DatabaseSessionManager(f"sqlite:///{tmp_path}/test.db")
|
||||
manager.create_tables()
|
||||
manager.migrate_tables()
|
||||
|
||||
with manager.session_scope() as session:
|
||||
account = crud.create_account(
|
||||
session,
|
||||
email="nosync@example.com",
|
||||
email_service="tempmail",
|
||||
)
|
||||
|
||||
assert account.token_sync_status == "not_ready"
|
||||
|
||||
updated = crud.update_account(
|
||||
session,
|
||||
account.id,
|
||||
access_token="new-access-token",
|
||||
)
|
||||
|
||||
assert updated is not None
|
||||
assert updated.token_sync_status == "pending"
|
||||
assert updated.token_sync_updated_at is not None
|
||||
|
||||
|
||||
def test_update_account_preserves_pending_status_when_other_tokens_remain(tmp_path):
|
||||
manager = DatabaseSessionManager(f"sqlite:///{tmp_path}/test.db")
|
||||
manager.create_tables()
|
||||
manager.migrate_tables()
|
||||
|
||||
with manager.session_scope() as session:
|
||||
account = crud.create_account(
|
||||
session,
|
||||
email="partial-sync@example.com",
|
||||
email_service="tempmail",
|
||||
access_token="access-token",
|
||||
refresh_token="refresh-token",
|
||||
)
|
||||
|
||||
updated = crud.update_account(
|
||||
session,
|
||||
account.id,
|
||||
refresh_token="",
|
||||
)
|
||||
|
||||
assert updated is not None
|
||||
assert updated.access_token == "access-token"
|
||||
assert updated.refresh_token == ""
|
||||
assert updated.token_sync_status == "pending"
|
||||
assert updated.token_sync_updated_at is not None
|
||||
21
tests/test_batch_task_manager.py
Normal file
21
tests/test_batch_task_manager.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from src.web.task_manager import task_manager
|
||||
|
||||
|
||||
def test_record_batch_task_result_is_atomic_under_threads():
|
||||
batch_id = "batch-atomic-test"
|
||||
task_manager.init_batch(batch_id, 100)
|
||||
|
||||
statuses = ["completed"] * 60 + ["failed"] * 40
|
||||
|
||||
with ThreadPoolExecutor(max_workers=16) as executor:
|
||||
list(executor.map(lambda status: task_manager.record_batch_task_result(batch_id, status), statuses))
|
||||
|
||||
snapshot = task_manager.get_batch_status(batch_id)
|
||||
|
||||
assert snapshot is not None
|
||||
assert snapshot["completed"] == 100
|
||||
assert snapshot["success"] == 60
|
||||
assert snapshot["failed"] == 40
|
||||
assert snapshot["skipped"] == 0
|
||||
Reference in New Issue
Block a user