diff --git a/src/database/crud.py b/src/database/crud.py index 62d8703..e42da90 100644 --- a/src/database/crud.py +++ b/src/database/crud.py @@ -10,6 +10,15 @@ from sqlalchemy import and_, or_, desc, asc, func from .models import Account, EmailService, RegistrationTask, Setting, Proxy, CpaService, Sub2ApiService +TOKEN_FIELD_NAMES = ("access_token", "refresh_token", "id_token", "session_token") + + +def _default_token_sync_status(token_values: Dict[str, Any]) -> str: + """根据当前持久化的 token 内容推导同步状态。""" + has_token = any(bool(token_values.get(field)) for field in TOKEN_FIELD_NAMES) + return "pending" if has_token else "not_ready" + + # ============================================================================ # 账户 CRUD # ============================================================================ @@ -31,9 +40,16 @@ def create_account( expires_at: Optional['datetime'] = None, extra_data: Optional[Dict[str, Any]] = None, status: Optional[str] = None, - source: Optional[str] = None + source: Optional[str] = None, + token_sync_status: Optional[str] = None, ) -> Account: """创建新账户""" + token_values = { + "access_token": access_token, + "refresh_token": refresh_token, + "id_token": id_token, + "session_token": session_token, + } db_account = Account( email=email, password=password, @@ -51,7 +67,9 @@ def create_account( extra_data=extra_data or {}, status=status or 'active', source=source or 'register', - registered_at=datetime.utcnow() + registered_at=datetime.utcnow(), + token_sync_status=token_sync_status or _default_token_sync_status(token_values), + token_sync_updated_at=datetime.utcnow(), ) db.add(db_account) db.commit() @@ -108,6 +126,15 @@ def update_account( if not db_account: return None + touches_token = any(field in kwargs for field in TOKEN_FIELD_NAMES) + if touches_token: + persisted_token_values = { + field: kwargs.get(field, getattr(db_account, field)) + for field in TOKEN_FIELD_NAMES + } + kwargs.setdefault("token_sync_status", _default_token_sync_status(persisted_token_values)) + kwargs["token_sync_updated_at"] = datetime.utcnow() + for key, value in kwargs.items(): if hasattr(db_account, key) and value is not None: setattr(db_account, key, value) @@ -724,15 +751,31 @@ def delete_tm_service(db: Session, service_id: int) -> bool: def update_outlook_refresh_token(db: Session, service_id: int, email: str, new_refresh_token: str): """更新 EmailService.config 中指定邮箱的 refresh_token""" service = db.query(EmailService).filter(EmailService.id == service_id).first() - if not service or not service.config: + if not service or not isinstance(service.config, dict): return + + normalized_email = (email or "").strip().lower() + if not normalized_email or not isinstance(new_refresh_token, str) or not new_refresh_token: + return + config = dict(service.config) + updated = False + # 单账户格式 - if config.get("email", "").lower() == email.lower(): + if str(config.get("email", "")).lower() == normalized_email: config["refresh_token"] = new_refresh_token + updated = True + # 多账户列表格式 for acc in config.get("accounts", []): - if acc.get("email", "").lower() == email.lower(): + if not isinstance(acc, dict): + continue + if str(acc.get("email", "")).lower() == normalized_email: acc["refresh_token"] = new_refresh_token + updated = True + + if not updated: + return + service.config = config db.commit() diff --git a/src/database/models.py b/src/database/models.py index 216f7d8..832cfa8 100644 --- a/src/database/models.py +++ b/src/database/models.py @@ -39,6 +39,8 @@ class Account(Base): refresh_token = Column(Text) id_token = Column(Text) session_token = Column(Text) # 会话令牌(优先刷新方式) + token_sync_status = Column(String(20), default='not_ready') # 'not_ready', 'pending', 'synced' + token_sync_updated_at = Column(DateTime, default=datetime.utcnow) client_id = Column(String(255)) # OAuth Client ID account_id = Column(String(255)) workspace_id = Column(String(255)) @@ -80,7 +82,9 @@ class Account(Base): 'subscription_type': self.subscription_type, 'subscription_at': self.subscription_at.isoformat() if self.subscription_at else None, 'created_at': self.created_at.isoformat() if self.created_at else None, - 'updated_at': self.updated_at.isoformat() if self.updated_at else None + 'updated_at': self.updated_at.isoformat() if self.updated_at else None, + 'token_sync_status': self.token_sync_status, + 'token_sync_updated_at': self.token_sync_updated_at.isoformat() if self.token_sync_updated_at else None, } @@ -227,4 +231,4 @@ class Proxy(Base): if self.username and self.password: auth = f"{self.username}:{self.password}@" - return f"{scheme}://{auth}{self.host}:{self.port}" \ No newline at end of file + return f"{scheme}://{auth}{self.host}:{self.port}" diff --git a/src/database/session.py b/src/database/session.py index bb45334..10301de 100644 --- a/src/database/session.py +++ b/src/database/session.py @@ -110,6 +110,8 @@ class DatabaseSessionManager: ("accounts", "subscription_type", "VARCHAR(20)"), ("accounts", "subscription_at", "DATETIME"), ("accounts", "cookies", "TEXT"), + ("accounts", "token_sync_status", "VARCHAR(20) DEFAULT 'not_ready'"), + ("accounts", "token_sync_updated_at", "DATETIME"), ("proxies", "is_default", "BOOLEAN DEFAULT 0"), ("cpa_services", "include_proxy_url", "BOOLEAN DEFAULT 0"), ] diff --git a/src/web/routes/registration.py b/src/web/routes/registration.py index d9bffd0..99807f5 100644 --- a/src/web/routes/registration.py +++ b/src/web/routes/registration.py @@ -7,7 +7,7 @@ import logging import uuid import random from datetime import datetime -from typing import List, Optional, Dict, Tuple +from typing import List, Optional, Dict, Tuple, Any from fastapi import APIRouter, HTTPException, Query, BackgroundTasks from pydantic import BaseModel, Field @@ -599,34 +599,53 @@ async def run_registration_task(task_uuid: str, email_service_type: str, proxy: def _init_batch_state(batch_id: str, task_uuids: List[str]): """初始化批量任务内存状态""" task_manager.init_batch(batch_id, len(task_uuids)) - batch_tasks[batch_id] = { - "total": len(task_uuids), - "completed": 0, - "success": 0, - "failed": 0, - "cancelled": False, - "task_uuids": task_uuids, - "current_index": 0, - "logs": [], - "finished": False - } + metadata = batch_tasks.get(batch_id, {}).copy() + metadata["task_uuids"] = task_uuids + batch_tasks[batch_id] = metadata def _make_batch_helpers(batch_id: str): """返回 add_batch_log 和 update_batch_status 辅助函数""" def add_batch_log(msg: str): - batch_tasks[batch_id]["logs"].append(msg) task_manager.add_batch_log(batch_id, msg) def update_batch_status(**kwargs): - for key, value in kwargs.items(): - if key in batch_tasks[batch_id]: - batch_tasks[batch_id][key] = value task_manager.update_batch_status(batch_id, **kwargs) return add_batch_log, update_batch_status +def _get_batch_snapshot(batch_id: str) -> Optional[Dict[str, Any]]: + """聚合批量任务元数据与实时状态。""" + metadata = batch_tasks.get(batch_id) + if metadata is None: + return None + + status = task_manager.get_batch_status(batch_id) or {} + initial_skipped = int(metadata.get("initial_skipped", 0) or 0) + total = status.get("total", metadata.get("total", 0)) + completed = status.get("completed", 0) + success = status.get("success", 0) + failed = status.get("failed", 0) + skipped = status.get("skipped", 0) + initial_skipped + + return { + "batch_id": batch_id, + "total": total, + "completed": completed, + "success": success, + "failed": failed, + "skipped": skipped, + "current_index": status.get("current_index", 0), + "cancelled": status.get("cancelled", False), + "finished": status.get("finished", False), + "status": status.get("status", metadata.get("status", "pending")), + "logs": task_manager.get_batch_logs(batch_id), + "task_uuids": metadata.get("task_uuids", []), + "service_ids": metadata.get("service_ids", []), + } + + async def run_batch_parallel( batch_id: str, task_uuids: List[str], @@ -665,21 +684,19 @@ async def run_batch_parallel( t = crud.get_registration_task(db, uuid) if t: async with counter_lock: - new_completed = batch_tasks[batch_id]["completed"] + 1 - new_success = batch_tasks[batch_id]["success"] - new_failed = batch_tasks[batch_id]["failed"] if t.status == "completed": - new_success += 1 add_batch_log(f"{prefix} [成功] 注册成功") elif t.status == "failed": - new_failed += 1 add_batch_log(f"{prefix} [失败] 注册失败: {t.error_message}") - update_batch_status(completed=new_completed, success=new_success, failed=new_failed) + task_manager.record_batch_task_result(batch_id, t.status) try: await asyncio.gather(*[_run_one(i, u) for i, u in enumerate(task_uuids)], return_exceptions=True) if not task_manager.is_batch_cancelled(batch_id): - add_batch_log(f"[完成] 批量任务完成!成功: {batch_tasks[batch_id]['success']}, 失败: {batch_tasks[batch_id]['failed']}") + snapshot = task_manager.get_batch_status(batch_id) or {} + add_batch_log( + f"[完成] 批量任务完成!成功: {snapshot.get('success', 0)}, 失败: {snapshot.get('failed', 0)}" + ) update_batch_status(finished=True, status="completed") else: update_batch_status(finished=True, status="cancelled") @@ -687,8 +704,6 @@ async def run_batch_parallel( logger.error(f"批量任务 {batch_id} 异常: {e}") add_batch_log(f"[错误] 批量任务异常: {str(e)}") update_batch_status(finished=True, status="failed") - finally: - batch_tasks[batch_id]["finished"] = True async def run_batch_pipeline( @@ -731,22 +746,17 @@ async def run_batch_pipeline( t = crud.get_registration_task(db, uuid) if t: async with counter_lock: - new_completed = batch_tasks[batch_id]["completed"] + 1 - new_success = batch_tasks[batch_id]["success"] - new_failed = batch_tasks[batch_id]["failed"] if t.status == "completed": - new_success += 1 add_batch_log(f"{pfx} [成功] 注册成功") elif t.status == "failed": - new_failed += 1 add_batch_log(f"{pfx} [失败] 注册失败: {t.error_message}") - update_batch_status(completed=new_completed, success=new_success, failed=new_failed) + task_manager.record_batch_task_result(batch_id, t.status) finally: semaphore.release() try: for i, task_uuid in enumerate(task_uuids): - if task_manager.is_batch_cancelled(batch_id) or batch_tasks[batch_id]["cancelled"]: + if task_manager.is_batch_cancelled(batch_id): with get_db() as db: for remaining_uuid in task_uuids[i:]: crud.update_registration_task(db, remaining_uuid, status="cancelled") @@ -770,14 +780,15 @@ async def run_batch_pipeline( await asyncio.gather(*running_tasks_list, return_exceptions=True) if not task_manager.is_batch_cancelled(batch_id): - add_batch_log(f"[完成] 批量任务完成!成功: {batch_tasks[batch_id]['success']}, 失败: {batch_tasks[batch_id]['failed']}") + snapshot = task_manager.get_batch_status(batch_id) or {} + add_batch_log( + f"[完成] 批量任务完成!成功: {snapshot.get('success', 0)}, 失败: {snapshot.get('failed', 0)}" + ) update_batch_status(finished=True, status="completed") except Exception as e: logger.error(f"批量任务 {batch_id} 异常: {e}") add_batch_log(f"[错误] 批量任务异常: {str(e)}") update_batch_status(finished=True, status="failed") - finally: - batch_tasks[batch_id]["finished"] = True async def run_batch_registration( @@ -910,6 +921,7 @@ async def start_batch_registration( # 创建批量任务 batch_id = str(uuid.uuid4()) task_uuids = [] + batch_tasks[batch_id] = {"total": request.count} with get_db() as db: for _ in range(request.count): @@ -956,34 +968,33 @@ async def start_batch_registration( @router.get("/batch/{batch_id}") async def get_batch_status(batch_id: str): """获取批量任务状态""" - if batch_id not in batch_tasks: + snapshot = _get_batch_snapshot(batch_id) + if snapshot is None: raise HTTPException(status_code=404, detail="批量任务不存在") - batch = batch_tasks[batch_id] return { "batch_id": batch_id, - "total": batch["total"], - "completed": batch["completed"], - "success": batch["success"], - "failed": batch["failed"], - "current_index": batch["current_index"], - "cancelled": batch["cancelled"], - "finished": batch.get("finished", False), - "progress": f"{batch['completed']}/{batch['total']}" + "total": snapshot["total"], + "completed": snapshot["completed"], + "success": snapshot["success"], + "failed": snapshot["failed"], + "current_index": snapshot["current_index"], + "cancelled": snapshot["cancelled"], + "finished": snapshot["finished"], + "progress": f"{snapshot['completed']}/{snapshot['total']}" } @router.post("/batch/{batch_id}/cancel") async def cancel_batch(batch_id: str): """取消批量任务""" - if batch_id not in batch_tasks: + snapshot = _get_batch_snapshot(batch_id) + if snapshot is None: raise HTTPException(status_code=404, detail="批量任务不存在") - batch = batch_tasks[batch_id] - if batch.get("finished"): + if snapshot.get("finished"): raise HTTPException(status_code=400, detail="批量任务已完成") - batch["cancelled"] = True task_manager.cancel_batch(batch_id) return {"success": True, "message": "批量任务取消请求已提交"} @@ -1464,18 +1475,11 @@ async def start_outlook_batch_registration( # 创建批量任务 batch_id = str(uuid.uuid4()) - # 初始化批量任务状态 + # 记录额外元数据,由 task_manager 维护实时状态 batch_tasks[batch_id] = { "total": len(actual_service_ids), - "completed": 0, - "success": 0, - "failed": 0, - "skipped": 0, - "cancelled": False, + "initial_skipped": skipped_count, "service_ids": actual_service_ids, - "current_index": 0, - "logs": [], - "finished": False } # 在后台运行批量注册 @@ -1509,37 +1513,35 @@ async def start_outlook_batch_registration( @router.get("/outlook-batch/{batch_id}") async def get_outlook_batch_status(batch_id: str): """获取 Outlook 批量任务状态""" - if batch_id not in batch_tasks: + snapshot = _get_batch_snapshot(batch_id) + if snapshot is None: raise HTTPException(status_code=404, detail="批量任务不存在") - batch = batch_tasks[batch_id] return { "batch_id": batch_id, - "total": batch["total"], - "completed": batch["completed"], - "success": batch["success"], - "failed": batch["failed"], - "skipped": batch.get("skipped", 0), - "current_index": batch["current_index"], - "cancelled": batch["cancelled"], - "finished": batch.get("finished", False), - "logs": batch.get("logs", []), - "progress": f"{batch['completed']}/{batch['total']}" + "total": snapshot["total"], + "completed": snapshot["completed"], + "success": snapshot["success"], + "failed": snapshot["failed"], + "skipped": snapshot["skipped"], + "current_index": snapshot["current_index"], + "cancelled": snapshot["cancelled"], + "finished": snapshot["finished"], + "logs": snapshot["logs"], + "progress": f"{snapshot['completed']}/{snapshot['total']}" } @router.post("/outlook-batch/{batch_id}/cancel") async def cancel_outlook_batch(batch_id: str): """取消 Outlook 批量任务""" - if batch_id not in batch_tasks: + snapshot = _get_batch_snapshot(batch_id) + if snapshot is None: raise HTTPException(status_code=404, detail="批量任务不存在") - batch = batch_tasks[batch_id] - if batch.get("finished"): + if snapshot.get("finished"): raise HTTPException(status_code=400, detail="批量任务已完成") - # 同时更新两个系统的取消状态 - batch["cancelled"] = True task_manager.cancel_batch(batch_id) return {"success": True, "message": "批量任务取消请求已提交"} diff --git a/src/web/task_manager.py b/src/web/task_manager.py index ed722d4..b8aefd0 100644 --- a/src/web/task_manager.py +++ b/src/web/task_manager.py @@ -223,16 +223,18 @@ class TaskManager: def init_batch(self, batch_id: str, total: int): """初始化批量任务""" - _batch_status[batch_id] = { - "status": "running", - "total": total, - "completed": 0, - "success": 0, - "failed": 0, - "skipped": 0, - "current_index": 0, - "finished": False - } + with _get_batch_lock(batch_id): + _batch_status[batch_id] = { + "status": "running", + "total": total, + "completed": 0, + "success": 0, + "failed": 0, + "skipped": 0, + "current_index": 0, + "finished": False, + "cancelled": False, + } logger.info(f"批量任务 {batch_id} 已初始化,总数: {total}") def add_batch_log(self, batch_id: str, log_message: str): @@ -276,11 +278,11 @@ class TaskManager: def update_batch_status(self, batch_id: str, **kwargs): """更新批量任务状态""" - if batch_id not in _batch_status: - logger.warning(f"批量任务 {batch_id} 不存在") - return - - _batch_status[batch_id].update(kwargs) + with _get_batch_lock(batch_id): + if batch_id not in _batch_status: + logger.warning(f"批量任务 {batch_id} 不存在") + return + _batch_status[batch_id].update(kwargs) # 异步广播状态更新 if self._loop and self._loop.is_running(): @@ -292,6 +294,35 @@ class TaskManager: except Exception as e: logger.warning(f"广播批量状态失败: {e}") + def record_batch_task_result(self, batch_id: str, task_status: str) -> Optional[dict]: + """原子记录单个子任务的终态并返回快照。""" + with _get_batch_lock(batch_id): + status = _batch_status.get(batch_id) + if status is None: + logger.warning(f"批量任务 {batch_id} 不存在") + return None + + status["completed"] += 1 + if task_status == "completed": + status["success"] += 1 + elif task_status == "failed": + status["failed"] += 1 + elif task_status == "cancelled": + status["skipped"] += 1 + + snapshot = status.copy() + + if self._loop and self._loop.is_running(): + try: + asyncio.run_coroutine_threadsafe( + self._broadcast_batch_status(batch_id), + self._loop + ) + except Exception as e: + logger.warning(f"广播批量状态失败: {e}") + + return snapshot + async def _broadcast_batch_status(self, batch_id: str): """广播批量任务状态""" with _ws_lock: @@ -312,7 +343,9 @@ class TaskManager: def get_batch_status(self, batch_id: str) -> Optional[dict]: """获取批量任务状态""" - return _batch_status.get(batch_id) + with _get_batch_lock(batch_id): + status = _batch_status.get(batch_id) + return status.copy() if status else None def get_batch_logs(self, batch_id: str) -> List[str]: """获取批量任务日志""" @@ -321,15 +354,28 @@ class TaskManager: def is_batch_cancelled(self, batch_id: str) -> bool: """检查批量任务是否已取消""" - status = _batch_status.get(batch_id, {}) - return status.get("cancelled", False) + with _get_batch_lock(batch_id): + status = _batch_status.get(batch_id, {}) + return status.get("cancelled", False) def cancel_batch(self, batch_id: str): """取消批量任务""" - if batch_id in _batch_status: - _batch_status[batch_id]["cancelled"] = True - _batch_status[batch_id]["status"] = "cancelling" - logger.info(f"批量任务 {batch_id} 已标记为取消") + changed = False + with _get_batch_lock(batch_id): + if batch_id in _batch_status: + _batch_status[batch_id]["cancelled"] = True + _batch_status[batch_id]["status"] = "cancelling" + changed = True + logger.info(f"批量任务 {batch_id} 已标记为取消") + + if changed and self._loop and self._loop.is_running(): + try: + asyncio.run_coroutine_threadsafe( + self._broadcast_batch_status(batch_id), + self._loop + ) + except Exception as e: + logger.warning(f"广播批量状态失败: {e}") def register_batch_websocket(self, batch_id: str, websocket): """注册批量任务 WebSocket 连接""" diff --git a/tests/test_account_token_sync_status.py b/tests/test_account_token_sync_status.py new file mode 100644 index 0000000..9a13635 --- /dev/null +++ b/tests/test_account_token_sync_status.py @@ -0,0 +1,72 @@ +from src.database import crud +from src.database.session import DatabaseSessionManager + + +def test_create_account_marks_token_sync_pending_when_tokens_persist(tmp_path): + manager = DatabaseSessionManager(f"sqlite:///{tmp_path}/test.db") + manager.create_tables() + manager.migrate_tables() + + with manager.session_scope() as session: + account = crud.create_account( + session, + email="sync@example.com", + email_service="tempmail", + access_token="access-token", + refresh_token="refresh-token", + ) + + assert account.token_sync_status == "pending" + assert account.token_sync_updated_at is not None + + +def test_update_account_marks_token_sync_pending_when_tokens_change(tmp_path): + manager = DatabaseSessionManager(f"sqlite:///{tmp_path}/test.db") + manager.create_tables() + manager.migrate_tables() + + with manager.session_scope() as session: + account = crud.create_account( + session, + email="nosync@example.com", + email_service="tempmail", + ) + + assert account.token_sync_status == "not_ready" + + updated = crud.update_account( + session, + account.id, + access_token="new-access-token", + ) + + assert updated is not None + assert updated.token_sync_status == "pending" + assert updated.token_sync_updated_at is not None + + +def test_update_account_preserves_pending_status_when_other_tokens_remain(tmp_path): + manager = DatabaseSessionManager(f"sqlite:///{tmp_path}/test.db") + manager.create_tables() + manager.migrate_tables() + + with manager.session_scope() as session: + account = crud.create_account( + session, + email="partial-sync@example.com", + email_service="tempmail", + access_token="access-token", + refresh_token="refresh-token", + ) + + updated = crud.update_account( + session, + account.id, + refresh_token="", + ) + + assert updated is not None + assert updated.access_token == "access-token" + assert updated.refresh_token == "" + assert updated.token_sync_status == "pending" + assert updated.token_sync_updated_at is not None diff --git a/tests/test_batch_task_manager.py b/tests/test_batch_task_manager.py new file mode 100644 index 0000000..ef149ac --- /dev/null +++ b/tests/test_batch_task_manager.py @@ -0,0 +1,21 @@ +from concurrent.futures import ThreadPoolExecutor + +from src.web.task_manager import task_manager + + +def test_record_batch_task_result_is_atomic_under_threads(): + batch_id = "batch-atomic-test" + task_manager.init_batch(batch_id, 100) + + statuses = ["completed"] * 60 + ["failed"] * 40 + + with ThreadPoolExecutor(max_workers=16) as executor: + list(executor.map(lambda status: task_manager.record_batch_task_result(batch_id, status), statuses)) + + snapshot = task_manager.get_batch_status(batch_id) + + assert snapshot is not None + assert snapshot["completed"] == 100 + assert snapshot["success"] == 60 + assert snapshot["failed"] == 40 + assert snapshot["skipped"] == 0