fix(logic): isolate atomic batch counters and token sync fields

This commit is contained in:
Mison
2026-03-23 11:23:31 +08:00
parent 16154bb5ae
commit cf571d37c1
7 changed files with 294 additions and 104 deletions

View File

@@ -10,6 +10,15 @@ from sqlalchemy import and_, or_, desc, asc, func
from .models import Account, EmailService, RegistrationTask, Setting, Proxy, CpaService, Sub2ApiService
TOKEN_FIELD_NAMES = ("access_token", "refresh_token", "id_token", "session_token")
def _default_token_sync_status(token_values: Dict[str, Any]) -> str:
"""根据当前持久化的 token 内容推导同步状态。"""
has_token = any(bool(token_values.get(field)) for field in TOKEN_FIELD_NAMES)
return "pending" if has_token else "not_ready"
# ============================================================================
# 账户 CRUD
# ============================================================================
@@ -31,9 +40,16 @@ def create_account(
expires_at: Optional['datetime'] = None,
extra_data: Optional[Dict[str, Any]] = None,
status: Optional[str] = None,
source: Optional[str] = None
source: Optional[str] = None,
token_sync_status: Optional[str] = None,
) -> Account:
"""创建新账户"""
token_values = {
"access_token": access_token,
"refresh_token": refresh_token,
"id_token": id_token,
"session_token": session_token,
}
db_account = Account(
email=email,
password=password,
@@ -51,7 +67,9 @@ def create_account(
extra_data=extra_data or {},
status=status or 'active',
source=source or 'register',
registered_at=datetime.utcnow()
registered_at=datetime.utcnow(),
token_sync_status=token_sync_status or _default_token_sync_status(token_values),
token_sync_updated_at=datetime.utcnow(),
)
db.add(db_account)
db.commit()
@@ -108,6 +126,15 @@ def update_account(
if not db_account:
return None
touches_token = any(field in kwargs for field in TOKEN_FIELD_NAMES)
if touches_token:
persisted_token_values = {
field: kwargs.get(field, getattr(db_account, field))
for field in TOKEN_FIELD_NAMES
}
kwargs.setdefault("token_sync_status", _default_token_sync_status(persisted_token_values))
kwargs["token_sync_updated_at"] = datetime.utcnow()
for key, value in kwargs.items():
if hasattr(db_account, key) and value is not None:
setattr(db_account, key, value)
@@ -724,15 +751,31 @@ def delete_tm_service(db: Session, service_id: int) -> bool:
def update_outlook_refresh_token(db: Session, service_id: int, email: str, new_refresh_token: str):
"""更新 EmailService.config 中指定邮箱的 refresh_token"""
service = db.query(EmailService).filter(EmailService.id == service_id).first()
if not service or not service.config:
if not service or not isinstance(service.config, dict):
return
normalized_email = (email or "").strip().lower()
if not normalized_email or not isinstance(new_refresh_token, str) or not new_refresh_token:
return
config = dict(service.config)
updated = False
# 单账户格式
if config.get("email", "").lower() == email.lower():
if str(config.get("email", "")).lower() == normalized_email:
config["refresh_token"] = new_refresh_token
updated = True
# 多账户列表格式
for acc in config.get("accounts", []):
if acc.get("email", "").lower() == email.lower():
if not isinstance(acc, dict):
continue
if str(acc.get("email", "")).lower() == normalized_email:
acc["refresh_token"] = new_refresh_token
updated = True
if not updated:
return
service.config = config
db.commit()

View File

@@ -39,6 +39,8 @@ class Account(Base):
refresh_token = Column(Text)
id_token = Column(Text)
session_token = Column(Text) # 会话令牌(优先刷新方式)
token_sync_status = Column(String(20), default='not_ready') # 'not_ready', 'pending', 'synced'
token_sync_updated_at = Column(DateTime, default=datetime.utcnow)
client_id = Column(String(255)) # OAuth Client ID
account_id = Column(String(255))
workspace_id = Column(String(255))
@@ -80,7 +82,9 @@ class Account(Base):
'subscription_type': self.subscription_type,
'subscription_at': self.subscription_at.isoformat() if self.subscription_at else None,
'created_at': self.created_at.isoformat() if self.created_at else None,
'updated_at': self.updated_at.isoformat() if self.updated_at else None
'updated_at': self.updated_at.isoformat() if self.updated_at else None,
'token_sync_status': self.token_sync_status,
'token_sync_updated_at': self.token_sync_updated_at.isoformat() if self.token_sync_updated_at else None,
}
@@ -227,4 +231,4 @@ class Proxy(Base):
if self.username and self.password:
auth = f"{self.username}:{self.password}@"
return f"{scheme}://{auth}{self.host}:{self.port}"
return f"{scheme}://{auth}{self.host}:{self.port}"

View File

@@ -110,6 +110,8 @@ class DatabaseSessionManager:
("accounts", "subscription_type", "VARCHAR(20)"),
("accounts", "subscription_at", "DATETIME"),
("accounts", "cookies", "TEXT"),
("accounts", "token_sync_status", "VARCHAR(20) DEFAULT 'not_ready'"),
("accounts", "token_sync_updated_at", "DATETIME"),
("proxies", "is_default", "BOOLEAN DEFAULT 0"),
("cpa_services", "include_proxy_url", "BOOLEAN DEFAULT 0"),
]

View File

@@ -7,7 +7,7 @@ import logging
import uuid
import random
from datetime import datetime
from typing import List, Optional, Dict, Tuple
from typing import List, Optional, Dict, Tuple, Any
from fastapi import APIRouter, HTTPException, Query, BackgroundTasks
from pydantic import BaseModel, Field
@@ -599,34 +599,53 @@ async def run_registration_task(task_uuid: str, email_service_type: str, proxy:
def _init_batch_state(batch_id: str, task_uuids: List[str]):
"""初始化批量任务内存状态"""
task_manager.init_batch(batch_id, len(task_uuids))
batch_tasks[batch_id] = {
"total": len(task_uuids),
"completed": 0,
"success": 0,
"failed": 0,
"cancelled": False,
"task_uuids": task_uuids,
"current_index": 0,
"logs": [],
"finished": False
}
metadata = batch_tasks.get(batch_id, {}).copy()
metadata["task_uuids"] = task_uuids
batch_tasks[batch_id] = metadata
def _make_batch_helpers(batch_id: str):
"""返回 add_batch_log 和 update_batch_status 辅助函数"""
def add_batch_log(msg: str):
batch_tasks[batch_id]["logs"].append(msg)
task_manager.add_batch_log(batch_id, msg)
def update_batch_status(**kwargs):
for key, value in kwargs.items():
if key in batch_tasks[batch_id]:
batch_tasks[batch_id][key] = value
task_manager.update_batch_status(batch_id, **kwargs)
return add_batch_log, update_batch_status
def _get_batch_snapshot(batch_id: str) -> Optional[Dict[str, Any]]:
"""聚合批量任务元数据与实时状态。"""
metadata = batch_tasks.get(batch_id)
if metadata is None:
return None
status = task_manager.get_batch_status(batch_id) or {}
initial_skipped = int(metadata.get("initial_skipped", 0) or 0)
total = status.get("total", metadata.get("total", 0))
completed = status.get("completed", 0)
success = status.get("success", 0)
failed = status.get("failed", 0)
skipped = status.get("skipped", 0) + initial_skipped
return {
"batch_id": batch_id,
"total": total,
"completed": completed,
"success": success,
"failed": failed,
"skipped": skipped,
"current_index": status.get("current_index", 0),
"cancelled": status.get("cancelled", False),
"finished": status.get("finished", False),
"status": status.get("status", metadata.get("status", "pending")),
"logs": task_manager.get_batch_logs(batch_id),
"task_uuids": metadata.get("task_uuids", []),
"service_ids": metadata.get("service_ids", []),
}
async def run_batch_parallel(
batch_id: str,
task_uuids: List[str],
@@ -665,21 +684,19 @@ async def run_batch_parallel(
t = crud.get_registration_task(db, uuid)
if t:
async with counter_lock:
new_completed = batch_tasks[batch_id]["completed"] + 1
new_success = batch_tasks[batch_id]["success"]
new_failed = batch_tasks[batch_id]["failed"]
if t.status == "completed":
new_success += 1
add_batch_log(f"{prefix} [成功] 注册成功")
elif t.status == "failed":
new_failed += 1
add_batch_log(f"{prefix} [失败] 注册失败: {t.error_message}")
update_batch_status(completed=new_completed, success=new_success, failed=new_failed)
task_manager.record_batch_task_result(batch_id, t.status)
try:
await asyncio.gather(*[_run_one(i, u) for i, u in enumerate(task_uuids)], return_exceptions=True)
if not task_manager.is_batch_cancelled(batch_id):
add_batch_log(f"[完成] 批量任务完成!成功: {batch_tasks[batch_id]['success']}, 失败: {batch_tasks[batch_id]['failed']}")
snapshot = task_manager.get_batch_status(batch_id) or {}
add_batch_log(
f"[完成] 批量任务完成!成功: {snapshot.get('success', 0)}, 失败: {snapshot.get('failed', 0)}"
)
update_batch_status(finished=True, status="completed")
else:
update_batch_status(finished=True, status="cancelled")
@@ -687,8 +704,6 @@ async def run_batch_parallel(
logger.error(f"批量任务 {batch_id} 异常: {e}")
add_batch_log(f"[错误] 批量任务异常: {str(e)}")
update_batch_status(finished=True, status="failed")
finally:
batch_tasks[batch_id]["finished"] = True
async def run_batch_pipeline(
@@ -731,22 +746,17 @@ async def run_batch_pipeline(
t = crud.get_registration_task(db, uuid)
if t:
async with counter_lock:
new_completed = batch_tasks[batch_id]["completed"] + 1
new_success = batch_tasks[batch_id]["success"]
new_failed = batch_tasks[batch_id]["failed"]
if t.status == "completed":
new_success += 1
add_batch_log(f"{pfx} [成功] 注册成功")
elif t.status == "failed":
new_failed += 1
add_batch_log(f"{pfx} [失败] 注册失败: {t.error_message}")
update_batch_status(completed=new_completed, success=new_success, failed=new_failed)
task_manager.record_batch_task_result(batch_id, t.status)
finally:
semaphore.release()
try:
for i, task_uuid in enumerate(task_uuids):
if task_manager.is_batch_cancelled(batch_id) or batch_tasks[batch_id]["cancelled"]:
if task_manager.is_batch_cancelled(batch_id):
with get_db() as db:
for remaining_uuid in task_uuids[i:]:
crud.update_registration_task(db, remaining_uuid, status="cancelled")
@@ -770,14 +780,15 @@ async def run_batch_pipeline(
await asyncio.gather(*running_tasks_list, return_exceptions=True)
if not task_manager.is_batch_cancelled(batch_id):
add_batch_log(f"[完成] 批量任务完成!成功: {batch_tasks[batch_id]['success']}, 失败: {batch_tasks[batch_id]['failed']}")
snapshot = task_manager.get_batch_status(batch_id) or {}
add_batch_log(
f"[完成] 批量任务完成!成功: {snapshot.get('success', 0)}, 失败: {snapshot.get('failed', 0)}"
)
update_batch_status(finished=True, status="completed")
except Exception as e:
logger.error(f"批量任务 {batch_id} 异常: {e}")
add_batch_log(f"[错误] 批量任务异常: {str(e)}")
update_batch_status(finished=True, status="failed")
finally:
batch_tasks[batch_id]["finished"] = True
async def run_batch_registration(
@@ -910,6 +921,7 @@ async def start_batch_registration(
# 创建批量任务
batch_id = str(uuid.uuid4())
task_uuids = []
batch_tasks[batch_id] = {"total": request.count}
with get_db() as db:
for _ in range(request.count):
@@ -956,34 +968,33 @@ async def start_batch_registration(
@router.get("/batch/{batch_id}")
async def get_batch_status(batch_id: str):
"""获取批量任务状态"""
if batch_id not in batch_tasks:
snapshot = _get_batch_snapshot(batch_id)
if snapshot is None:
raise HTTPException(status_code=404, detail="批量任务不存在")
batch = batch_tasks[batch_id]
return {
"batch_id": batch_id,
"total": batch["total"],
"completed": batch["completed"],
"success": batch["success"],
"failed": batch["failed"],
"current_index": batch["current_index"],
"cancelled": batch["cancelled"],
"finished": batch.get("finished", False),
"progress": f"{batch['completed']}/{batch['total']}"
"total": snapshot["total"],
"completed": snapshot["completed"],
"success": snapshot["success"],
"failed": snapshot["failed"],
"current_index": snapshot["current_index"],
"cancelled": snapshot["cancelled"],
"finished": snapshot["finished"],
"progress": f"{snapshot['completed']}/{snapshot['total']}"
}
@router.post("/batch/{batch_id}/cancel")
async def cancel_batch(batch_id: str):
"""取消批量任务"""
if batch_id not in batch_tasks:
snapshot = _get_batch_snapshot(batch_id)
if snapshot is None:
raise HTTPException(status_code=404, detail="批量任务不存在")
batch = batch_tasks[batch_id]
if batch.get("finished"):
if snapshot.get("finished"):
raise HTTPException(status_code=400, detail="批量任务已完成")
batch["cancelled"] = True
task_manager.cancel_batch(batch_id)
return {"success": True, "message": "批量任务取消请求已提交"}
@@ -1464,18 +1475,11 @@ async def start_outlook_batch_registration(
# 创建批量任务
batch_id = str(uuid.uuid4())
# 初始化批量任务状态
# 记录额外元数据,由 task_manager 维护实时状态
batch_tasks[batch_id] = {
"total": len(actual_service_ids),
"completed": 0,
"success": 0,
"failed": 0,
"skipped": 0,
"cancelled": False,
"initial_skipped": skipped_count,
"service_ids": actual_service_ids,
"current_index": 0,
"logs": [],
"finished": False
}
# 在后台运行批量注册
@@ -1509,37 +1513,35 @@ async def start_outlook_batch_registration(
@router.get("/outlook-batch/{batch_id}")
async def get_outlook_batch_status(batch_id: str):
"""获取 Outlook 批量任务状态"""
if batch_id not in batch_tasks:
snapshot = _get_batch_snapshot(batch_id)
if snapshot is None:
raise HTTPException(status_code=404, detail="批量任务不存在")
batch = batch_tasks[batch_id]
return {
"batch_id": batch_id,
"total": batch["total"],
"completed": batch["completed"],
"success": batch["success"],
"failed": batch["failed"],
"skipped": batch.get("skipped", 0),
"current_index": batch["current_index"],
"cancelled": batch["cancelled"],
"finished": batch.get("finished", False),
"logs": batch.get("logs", []),
"progress": f"{batch['completed']}/{batch['total']}"
"total": snapshot["total"],
"completed": snapshot["completed"],
"success": snapshot["success"],
"failed": snapshot["failed"],
"skipped": snapshot["skipped"],
"current_index": snapshot["current_index"],
"cancelled": snapshot["cancelled"],
"finished": snapshot["finished"],
"logs": snapshot["logs"],
"progress": f"{snapshot['completed']}/{snapshot['total']}"
}
@router.post("/outlook-batch/{batch_id}/cancel")
async def cancel_outlook_batch(batch_id: str):
"""取消 Outlook 批量任务"""
if batch_id not in batch_tasks:
snapshot = _get_batch_snapshot(batch_id)
if snapshot is None:
raise HTTPException(status_code=404, detail="批量任务不存在")
batch = batch_tasks[batch_id]
if batch.get("finished"):
if snapshot.get("finished"):
raise HTTPException(status_code=400, detail="批量任务已完成")
# 同时更新两个系统的取消状态
batch["cancelled"] = True
task_manager.cancel_batch(batch_id)
return {"success": True, "message": "批量任务取消请求已提交"}

View File

@@ -223,16 +223,18 @@ class TaskManager:
def init_batch(self, batch_id: str, total: int):
"""初始化批量任务"""
_batch_status[batch_id] = {
"status": "running",
"total": total,
"completed": 0,
"success": 0,
"failed": 0,
"skipped": 0,
"current_index": 0,
"finished": False
}
with _get_batch_lock(batch_id):
_batch_status[batch_id] = {
"status": "running",
"total": total,
"completed": 0,
"success": 0,
"failed": 0,
"skipped": 0,
"current_index": 0,
"finished": False,
"cancelled": False,
}
logger.info(f"批量任务 {batch_id} 已初始化,总数: {total}")
def add_batch_log(self, batch_id: str, log_message: str):
@@ -276,11 +278,11 @@ class TaskManager:
def update_batch_status(self, batch_id: str, **kwargs):
"""更新批量任务状态"""
if batch_id not in _batch_status:
logger.warning(f"批量任务 {batch_id} 不存在")
return
_batch_status[batch_id].update(kwargs)
with _get_batch_lock(batch_id):
if batch_id not in _batch_status:
logger.warning(f"批量任务 {batch_id} 不存在")
return
_batch_status[batch_id].update(kwargs)
# 异步广播状态更新
if self._loop and self._loop.is_running():
@@ -292,6 +294,35 @@ class TaskManager:
except Exception as e:
logger.warning(f"广播批量状态失败: {e}")
def record_batch_task_result(self, batch_id: str, task_status: str) -> Optional[dict]:
"""原子记录单个子任务的终态并返回快照。"""
with _get_batch_lock(batch_id):
status = _batch_status.get(batch_id)
if status is None:
logger.warning(f"批量任务 {batch_id} 不存在")
return None
status["completed"] += 1
if task_status == "completed":
status["success"] += 1
elif task_status == "failed":
status["failed"] += 1
elif task_status == "cancelled":
status["skipped"] += 1
snapshot = status.copy()
if self._loop and self._loop.is_running():
try:
asyncio.run_coroutine_threadsafe(
self._broadcast_batch_status(batch_id),
self._loop
)
except Exception as e:
logger.warning(f"广播批量状态失败: {e}")
return snapshot
async def _broadcast_batch_status(self, batch_id: str):
"""广播批量任务状态"""
with _ws_lock:
@@ -312,7 +343,9 @@ class TaskManager:
def get_batch_status(self, batch_id: str) -> Optional[dict]:
"""获取批量任务状态"""
return _batch_status.get(batch_id)
with _get_batch_lock(batch_id):
status = _batch_status.get(batch_id)
return status.copy() if status else None
def get_batch_logs(self, batch_id: str) -> List[str]:
"""获取批量任务日志"""
@@ -321,15 +354,28 @@ class TaskManager:
def is_batch_cancelled(self, batch_id: str) -> bool:
"""检查批量任务是否已取消"""
status = _batch_status.get(batch_id, {})
return status.get("cancelled", False)
with _get_batch_lock(batch_id):
status = _batch_status.get(batch_id, {})
return status.get("cancelled", False)
def cancel_batch(self, batch_id: str):
"""取消批量任务"""
if batch_id in _batch_status:
_batch_status[batch_id]["cancelled"] = True
_batch_status[batch_id]["status"] = "cancelling"
logger.info(f"批量任务 {batch_id} 已标记为取消")
changed = False
with _get_batch_lock(batch_id):
if batch_id in _batch_status:
_batch_status[batch_id]["cancelled"] = True
_batch_status[batch_id]["status"] = "cancelling"
changed = True
logger.info(f"批量任务 {batch_id} 已标记为取消")
if changed and self._loop and self._loop.is_running():
try:
asyncio.run_coroutine_threadsafe(
self._broadcast_batch_status(batch_id),
self._loop
)
except Exception as e:
logger.warning(f"广播批量状态失败: {e}")
def register_batch_websocket(self, batch_id: str, websocket):
"""注册批量任务 WebSocket 连接"""

View File

@@ -0,0 +1,72 @@
from src.database import crud
from src.database.session import DatabaseSessionManager
def test_create_account_marks_token_sync_pending_when_tokens_persist(tmp_path):
manager = DatabaseSessionManager(f"sqlite:///{tmp_path}/test.db")
manager.create_tables()
manager.migrate_tables()
with manager.session_scope() as session:
account = crud.create_account(
session,
email="sync@example.com",
email_service="tempmail",
access_token="access-token",
refresh_token="refresh-token",
)
assert account.token_sync_status == "pending"
assert account.token_sync_updated_at is not None
def test_update_account_marks_token_sync_pending_when_tokens_change(tmp_path):
manager = DatabaseSessionManager(f"sqlite:///{tmp_path}/test.db")
manager.create_tables()
manager.migrate_tables()
with manager.session_scope() as session:
account = crud.create_account(
session,
email="nosync@example.com",
email_service="tempmail",
)
assert account.token_sync_status == "not_ready"
updated = crud.update_account(
session,
account.id,
access_token="new-access-token",
)
assert updated is not None
assert updated.token_sync_status == "pending"
assert updated.token_sync_updated_at is not None
def test_update_account_preserves_pending_status_when_other_tokens_remain(tmp_path):
manager = DatabaseSessionManager(f"sqlite:///{tmp_path}/test.db")
manager.create_tables()
manager.migrate_tables()
with manager.session_scope() as session:
account = crud.create_account(
session,
email="partial-sync@example.com",
email_service="tempmail",
access_token="access-token",
refresh_token="refresh-token",
)
updated = crud.update_account(
session,
account.id,
refresh_token="",
)
assert updated is not None
assert updated.access_token == "access-token"
assert updated.refresh_token == ""
assert updated.token_sync_status == "pending"
assert updated.token_sync_updated_at is not None

View File

@@ -0,0 +1,21 @@
from concurrent.futures import ThreadPoolExecutor
from src.web.task_manager import task_manager
def test_record_batch_task_result_is_atomic_under_threads():
batch_id = "batch-atomic-test"
task_manager.init_batch(batch_id, 100)
statuses = ["completed"] * 60 + ["failed"] * 40
with ThreadPoolExecutor(max_workers=16) as executor:
list(executor.map(lambda status: task_manager.record_batch_task_result(batch_id, status), statuses))
snapshot = task_manager.get_batch_status(batch_id)
assert snapshot is not None
assert snapshot["completed"] == 100
assert snapshot["success"] == 60
assert snapshot["failed"] == 40
assert snapshot["skipped"] == 0