Merge branch 'master' into fix/worker-mail-otp-extraction

This commit is contained in:
kailian zhou
2026-03-24 17:58:01 +08:00
committed by GitHub
53 changed files with 5999 additions and 409 deletions

View File

@@ -56,7 +56,7 @@ APP_DESCRIPTION = "自动注册 OpenAI/Codex CLI 账号的系统"
OAUTH_CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann"
OAUTH_AUTH_URL = "https://auth.openai.com/oauth/authorize"
OAUTH_TOKEN_URL = "https://auth.openai.com/oauth/token"
OAUTH_REDIRECT_URI = "http://localhost:1455/auth/callback"
OAUTH_REDIRECT_URI = "http://localhost:15555/auth/callback"
OAUTH_SCOPE = "openid email profile offline_access"
# OpenAI API 端点
@@ -267,7 +267,7 @@ DEFAULT_SETTINGS = [
("registration.timeout", "120", "超时时间(秒)", "registration"),
("registration.default_password_length", "12", "默认密码长度", "registration"),
("webui.host", "0.0.0.0", "Web UI 监听主机", "webui"),
("webui.port", "8000", "Web UI 监听端口", "webui"),
("webui.port", "15555", "Web UI 监听端口", "webui"),
("webui.debug", "true", "调试模式", "webui"),
]

View File

@@ -76,7 +76,7 @@ SETTING_DEFINITIONS: Dict[str, SettingDefinition] = {
),
"webui_port": SettingDefinition(
db_key="webui.port",
default_value=8000,
default_value=15555,
category=SettingCategory.WEBUI,
description="Web UI 监听端口"
),
@@ -609,7 +609,7 @@ class Settings(BaseModel):
# Web UI 配置
webui_host: str = "0.0.0.0"
webui_port: int = 8000
webui_port: int = 15555
webui_secret_key: SecretStr = SecretStr("your-secret-key-change-in-production")
webui_access_password: SecretStr = SecretStr("admin123")

457
src/core/login.py Normal file
View File

@@ -0,0 +1,457 @@
"""
登录流程引擎
从 register.py 中拆分的登录专属方法
"""
import urllib.parse
import base64
import json as json_module
import time
from datetime import datetime
from typing import Optional, Dict, Any
from .register import RegistrationEngine, RegistrationResult
from ..config.constants import OPENAI_API_ENDPOINTS
class LoginEngine(RegistrationEngine):
"""
登录引擎
继承 RegistrationEngine包含登录流程专属方法
- _follow_login_redirects
- _submit_login_form
- _send_verification_code_passwordless
- _get_workspace_id
- _select_workspace
- _follow_redirects
- _handle_oauth_callback
"""
def _follow_login_redirects(self, start_url: str) -> bool:
"""跟随重定向链,寻找回调 URL"""
try:
current_url = start_url
max_redirects = 6
for i in range(max_redirects):
self._log(f"重定向 {i+1}/{max_redirects}: {current_url[:100]}...")
response = self.session.get(
current_url,
allow_redirects=False,
timeout=15
)
location = response.headers.get("Location") or ""
# 如果不是重定向状态码,停止
if response.status_code == 200:
self._log(f"非重定向状态码: {response.status_code}")
return True
if not location:
self._log("重定向响应缺少 Location 头")
break
# 构建下一个 URL
next_url = urllib.parse.urljoin(current_url, location)
# 检查是否包含回调参数
if "code=" in next_url and "state=" in next_url:
self._log(f"找到回调 URL: {next_url[:100]}...")
current_url = next_url
self._log("未能在重定向链中找到最终 URL")
return False
except Exception as e:
self._log(f"跟随重定向失败: {e}", "error")
return False
def _submit_login_form(self, did: str, sen_token) -> bool:
"""处理 免密登录"""
try:
self._log("处理免密登录...")
login_body = f'{{"username":{{"value":"{self.email}","kind":"email"}}}}'
headers = {
"referer": "https://auth.openai.com/log-in",
"accept": "application/json",
"content-type": "application/json",
}
if sen_token:
sentinel = (
f'{{"p": "", "t": "", "c": "{sen_token}", '
f'"id": "{did}", "flow": "authorize_continue"}}'
)
headers["openai-sentinel-token"] = sentinel
response = self.session.post(
OPENAI_API_ENDPOINTS["signup"],
headers=headers,
data=login_body,
)
self._log(f"提交登录表单状态: {response.status_code}")
if response.status_code == 200:
return True
return False
except Exception as e:
self._log(f"处理登录失败: {e}", "error")
return False
def _send_verification_code_passwordless(self) -> bool:
"""发送验证码"""
try:
self._otp_sent_at = time.time()
response = self.session.post(
OPENAI_API_ENDPOINTS["passwordless_send_otp"],
headers={
"referer": "https://auth.openai.com/log-in/password",
"accept": "application/json"
}
)
self._log(f"验证码发送状态: {response.status_code}")
return response.status_code == 200
except Exception as e:
self._log(f"发送验证码失败: {e}", "error")
return False
def _decode_workspace_id(self, auth_cookie: str) -> str:
"""从授权 Cookie 中解析 Workspace ID"""
segments = auth_cookie.split(".")
if len(segments) < 1:
raise ValueError("授权 Cookie 格式错误")
payload = segments[0]
pad = "=" * ((4 - (len(payload) % 4)) % 4)
decoded = base64.urlsafe_b64decode((payload + pad).encode("ascii"))
auth_json = json_module.loads(decoded.decode("utf-8"))
workspaces = auth_json.get("workspaces") or []
if not workspaces:
raise ValueError("授权 Cookie 里没有 workspace 信息")
workspace_id = str((workspaces[0] or {}).get("id") or "").strip()
if not workspace_id:
raise ValueError("无法解析 workspace_id")
return workspace_id
def _get_workspace_id(self) -> Optional[str]:
"""获取 Workspace ID"""
backoff_seconds = (1, 2, 4)
max_attempts = len(backoff_seconds) + 1
for attempt in range(1, max_attempts + 1):
try:
auth_cookie = self.session.cookies.get("oai-client-auth-session")
if auth_cookie:
workspace_id = self._decode_workspace_id(auth_cookie)
self._log(f"Workspace ID: {workspace_id}")
return workspace_id
raise ValueError("未能获取到授权 Cookie")
except Exception as e:
level = "warning" if attempt < max_attempts else "error"
self._log(
f"获取 Workspace ID 失败: {e} (第 {attempt}/{max_attempts} 次)",
level,
)
if attempt < max_attempts:
wait_seconds = backoff_seconds[attempt - 1]
self._log(f"等待 {wait_seconds} 秒后重试 Workspace ID", "warning")
time.sleep(wait_seconds)
return None
def _select_workspace(self, workspace_id: str) -> Optional[str]:
"""选择 Workspace"""
try:
select_body = f'{{"workspace_id":"{workspace_id}"}}'
response = self.session.post(
OPENAI_API_ENDPOINTS["select_workspace"],
headers={
"referer": "https://auth.openai.com/sign-in-with-chatgpt/codex/consent",
"content-type": "application/json",
},
data=select_body,
)
if response.status_code != 200:
self._log(f"选择 workspace 失败: {response.status_code}", "error")
self._log(f"响应: {response.text[:200]}", "warning")
return None
continue_url = str((response.json() or {}).get("continue_url") or "").strip()
if not continue_url:
self._log("workspace/select 响应里缺少 continue_url", "error")
return None
self._log(f"Continue URL: {continue_url[:100]}...")
return continue_url
except Exception as e:
self._log(f"选择 Workspace 失败: {e}", "error")
return None
def _follow_redirects(self, start_url: str) -> Optional[str]:
"""跟随重定向链,寻找回调 URL"""
try:
current_url = start_url
max_redirects = 6
for i in range(max_redirects):
self._log(f"重定向 {i+1}/{max_redirects}: {current_url[:100]}...")
response = self.session.get(
current_url,
allow_redirects=False,
timeout=15
)
location = response.headers.get("Location") or ""
# 如果不是重定向状态码,停止
if response.status_code not in [301, 302, 303, 307, 308]:
self._log(f"非重定向状态码: {response.status_code}")
break
if not location:
self._log("重定向响应缺少 Location 头")
break
# 构建下一个 URL
next_url = urllib.parse.urljoin(current_url, location)
# 检查是否包含回调参数
if "code=" in next_url and "state=" in next_url:
self._log(f"找到回调 URL: {next_url[:100]}...")
return next_url
current_url = next_url
self._log("未能在重定向链中找到回调 URL", "error")
return None
except Exception as e:
self._log(f"跟随重定向失败: {e}", "error")
return None
def _handle_oauth_callback(self, callback_url: str) -> Optional[Dict[str, Any]]:
"""处理 OAuth 回调"""
try:
if not self.oauth_start:
self._log("OAuth 流程未初始化", "error")
return None
self._log("处理 OAuth 回调...")
token_info = self.oauth_manager.handle_callback(
callback_url=callback_url,
expected_state=self.oauth_start.state,
code_verifier=self.oauth_start.code_verifier
)
self._log("OAuth 授权成功")
return token_info
except Exception as e:
self._log(f"处理 OAuth 回调失败: {e}", "error")
return None
def run(self) -> RegistrationResult:
"""
执行完整的注册流程
支持已注册账号自动登录:
- 如果检测到邮箱已注册,自动切换到登录流程
- 已注册账号跳过:设置密码、发送验证码、创建用户账户
- 共用步骤获取验证码、验证验证码、Workspace 和 OAuth 回调
Returns:
RegistrationResult: 注册结果
"""
result = RegistrationResult(success=False, logs=self.logs)
try:
self._log("=" * 60)
self._log("开始注册流程")
self._log("=" * 60)
self._log("1. 检查 IP 地理位置...")
ip_ok, location = self._check_ip_location()
if not ip_ok:
result.error_message = f"IP 地理位置不支持: {location}"
self._log(f"IP 检查失败: {location}", "error")
return result
self._log(f"IP 位置: {location}")
self._log("2. 创建邮箱...")
if not self._create_email():
result.error_message = "创建邮箱失败"
return result
result.email = self.email
self._log("3. 初始化会话...")
if not self._init_session():
result.error_message = "初始化会话失败"
return result
self._log("4. 开始 OAuth 授权流程...")
if not self._start_oauth():
result.error_message = "开始 OAuth 流程失败"
return result
self._log("5. 获取 Device ID...")
did = self._get_device_id()
if not did:
result.error_message = "获取 Device ID 失败"
return result
self._log("6. 检查 Sentinel 拦截...")
sen_token = self._check_sentinel(did)
if sen_token:
self._log("Sentinel 检查通过")
else:
self._log("Sentinel 检查失败或未启用", "warning")
self._log("7. 提交注册表单...")
signup_result = self._submit_signup_form(did, sen_token)
if not signup_result.success:
result.error_message = f"提交注册表单失败: {signup_result.error_message}"
return result
if self._is_existing_account:
self._log(f"8. 邮箱 {self.email} 在 OpenAI 已注册,跳过注册流程", "warning")
result.error_message = f"邮箱 {self.email} 已在 OpenAI 注册"
return result
self._log("8. 注册密码...")
password_ok, password = self._register_password()
if not password_ok:
result.error_message = "注册密码失败"
return result
self._log("9. 发送验证码...")
if not self._send_verification_code():
result.error_message = "发送验证码失败"
return result
self._log("10. 等待验证码...")
code = self._get_verification_code()
if not code:
self._log("10. 验证码超时,重新发送...")
if self._send_verification_code():
code = self._get_verification_code()
if not code:
result.error_message = "获取验证码失败"
return result
self._log("11. 验证验证码...")
if not self._validate_verification_code(code):
result.error_message = "验证验证码失败"
return result
self._log("12. 创建用户账户...")
if not self._create_user_account():
result.error_message = "创建用户账户失败"
return result
self._log("13-1. 结束注册,启用登录流程...")
if not self._follow_login_redirects(self.oauth_start.auth_url):
result.error_message = "跟随重定向链失败"
return result
self._log("13-2. 提交登陆表单")
if not self._submit_login_form(did, sen_token):
result.error_message = "提交登陆表单失败"
return result
self._log("14. 发送验证码...")
if not self._send_verification_code_passwordless():
result.error_message = "发送验证码失败"
return result
self._log("15. 等待验证码...")
code = self._get_verification_code()
if not code:
self._log("15. 验证码超时,重新发送...")
if self._send_verification_code_passwordless():
code = self._get_verification_code()
if not code:
result.error_message = "获取验证码失败"
return result
self._log("16. 验证验证码...")
if not self._validate_verification_code(code):
result.error_message = "验证验证码失败"
return result
self._log("17. 获取 Workspace ID...")
workspace_id = self._get_workspace_id()
if not workspace_id:
result.error_message = "获取 Workspace ID 失败"
return result
result.workspace_id = workspace_id
self._log("18. 选择 Workspace...")
continue_url = self._select_workspace(workspace_id)
if not continue_url:
result.error_message = "选择 Workspace 失败"
return result
self._log("19. 跟随重定向链...")
callback_url = self._follow_redirects(continue_url)
if not callback_url:
result.error_message = "跟随重定向链失败"
return result
self._log("20. 处理 OAuth 回调...")
token_info = self._handle_oauth_callback(callback_url)
if not token_info:
result.error_message = "处理 OAuth 回调失败"
return result
result.account_id = token_info.get("account_id", "")
result.access_token = token_info.get("access_token", "")
result.refresh_token = token_info.get("refresh_token", "")
result.id_token = token_info.get("id_token", "")
result.password = self.password or ""
result.source = "register"
session_cookie = self.session.cookies.get("__Secure-next-auth.session-token")
if session_cookie:
self.session_token = session_cookie
result.session_token = session_cookie
self._log("获取到 Session Token")
self._log("=" * 60)
self._log("注册成功!")
self._log(f"邮箱: {result.email}")
self._log(f"Account ID: {result.account_id}")
self._log(f"Workspace ID: {result.workspace_id}")
self._log("=" * 60)
result.success = True
result.metadata = {
"email_service": self.email_service.service_type.value,
"proxy_used": self.proxy_url,
"registered_at": datetime.now().isoformat(),
}
return result
except Exception as e:
self._log(f"注册过程中发生未预期错误: {e}", "error")
result.error_message = str(e)
return result
finally:
self.close()

File diff suppressed because it is too large Load Diff

View File

@@ -2,14 +2,24 @@
数据库 CRUD 操作
"""
from typing import List, Optional, Dict, Any, Union
from typing import List, Optional, Dict, Any, Union, Iterable, Set
from datetime import datetime, timedelta
from sqlalchemy.orm import Session
from sqlalchemy.orm.attributes import flag_modified
from sqlalchemy import and_, or_, desc, asc, func
from .models import Account, EmailService, RegistrationTask, Setting, Proxy, CpaService, Sub2ApiService
TOKEN_FIELD_NAMES = ("access_token", "refresh_token", "id_token", "session_token")
def _default_token_sync_status(token_values: Dict[str, Any]) -> str:
"""根据当前持久化的 token 内容推导同步状态。"""
has_token = any(bool(token_values.get(field)) for field in TOKEN_FIELD_NAMES)
return "pending" if has_token else "not_ready"
# ============================================================================
# 账户 CRUD
# ============================================================================
@@ -27,13 +37,21 @@ def create_account(
access_token: Optional[str] = None,
refresh_token: Optional[str] = None,
id_token: Optional[str] = None,
cookies: Optional[str] = None,
proxy_used: Optional[str] = None,
expires_at: Optional['datetime'] = None,
extra_data: Optional[Dict[str, Any]] = None,
status: Optional[str] = None,
source: Optional[str] = None
source: Optional[str] = None,
token_sync_status: Optional[str] = None,
) -> Account:
"""创建新账户"""
token_values = {
"access_token": access_token,
"refresh_token": refresh_token,
"id_token": id_token,
"session_token": session_token,
}
db_account = Account(
email=email,
password=password,
@@ -46,12 +64,15 @@ def create_account(
access_token=access_token,
refresh_token=refresh_token,
id_token=id_token,
cookies=cookies,
proxy_used=proxy_used,
expires_at=expires_at,
extra_data=extra_data or {},
status=status or 'active',
source=source or 'register',
registered_at=datetime.utcnow()
registered_at=datetime.utcnow(),
token_sync_status=token_sync_status or _default_token_sync_status(token_values),
token_sync_updated_at=datetime.utcnow(),
)
db.add(db_account)
db.commit()
@@ -108,6 +129,14 @@ def update_account(
if not db_account:
return None
touches_token = any(field in kwargs for field in TOKEN_FIELD_NAMES)
if touches_token:
persisted_token_values = {
field: kwargs.get(field, getattr(db_account, field))
for field in TOKEN_FIELD_NAMES
}
kwargs.setdefault("token_sync_status", _default_token_sync_status(persisted_token_values))
kwargs["token_sync_updated_at"] = datetime.utcnow()
for key, value in kwargs.items():
if hasattr(db_account, key) and value is not None:
setattr(db_account, key, value)
@@ -326,6 +355,34 @@ def delete_registration_task(db: Session, task_uuid: str) -> bool:
return True
def fail_incomplete_registration_tasks(db: Session, error_message: str) -> List[str]:
"""将服务重启后遗留的未完成任务标记为失败"""
tasks = db.query(RegistrationTask).filter(
RegistrationTask.status.in_(("pending", "running"))
).all()
if not tasks:
return []
now = datetime.utcnow()
cleaned_task_ids: List[str] = []
cleanup_log = f"[系统] {error_message}"
for task in tasks:
task.status = "failed"
task.error_message = error_message
task.completed_at = now
if task.logs:
if cleanup_log not in task.logs:
task.logs = f"{task.logs}\n{cleanup_log}"
else:
task.logs = cleanup_log
cleaned_task_ids.append(task.task_uuid)
db.commit()
return cleaned_task_ids
# 为 API 路由添加别名
get_account = get_account_by_id
get_registration_task = get_registration_task_by_uuid
@@ -437,9 +494,13 @@ def get_proxies(
return query.all()
def get_enabled_proxies(db: Session) -> List[Proxy]:
def get_enabled_proxies(db: Session, exclude_ids: Optional[Iterable[int]] = None) -> List[Proxy]:
"""获取所有启用的代理"""
return db.query(Proxy).filter(Proxy.enabled == True).all()
query = db.query(Proxy).filter(Proxy.enabled == True)
excluded: Set[int] = {int(proxy_id) for proxy_id in (exclude_ids or [])}
if excluded:
query = query.filter(~Proxy.id.in_(excluded))
return query.all()
def update_proxy(
@@ -483,14 +544,18 @@ def update_proxy_last_used(db: Session, proxy_id: int) -> bool:
return True
def get_random_proxy(db: Session) -> Optional[Proxy]:
def get_random_proxy(db: Session, exclude_ids: Optional[Iterable[int]] = None) -> Optional[Proxy]:
"""随机获取一个启用的代理,优先返回 is_default=True 的代理"""
import random
excluded: Set[int] = {int(proxy_id) for proxy_id in (exclude_ids or [])}
# 优先返回默认代理
default_proxy = db.query(Proxy).filter(Proxy.enabled == True, Proxy.is_default == True).first()
default_query = db.query(Proxy).filter(Proxy.enabled == True, Proxy.is_default == True)
if excluded:
default_query = default_query.filter(~Proxy.id.in_(excluded))
default_proxy = default_query.first()
if default_proxy:
return default_proxy
proxies = get_enabled_proxies(db)
proxies = get_enabled_proxies(db, exclude_ids=excluded)
if not proxies:
return None
return random.choice(proxies)
@@ -713,4 +778,38 @@ def delete_tm_service(db: Session, service_id: int) -> bool:
return False
db.delete(svc)
db.commit()
return True
return True
def update_outlook_refresh_token(db: Session, service_id: int, email: str, new_refresh_token: str):
"""更新 EmailService.config 中指定邮箱的 refresh_token"""
service = db.query(EmailService).filter(EmailService.id == service_id).first()
if not service or not isinstance(service.config, dict):
return
normalized_email = (email or "").strip().lower()
if not normalized_email or not isinstance(new_refresh_token, str) or not new_refresh_token:
return
config = dict(service.config)
updated = False
# 单账户格式
if str(config.get("email", "")).lower() == normalized_email:
config["refresh_token"] = new_refresh_token
updated = True
# 多账户列表格式
for acc in config.get("accounts", []):
if not isinstance(acc, dict):
continue
if str(acc.get("email", "")).lower() == normalized_email:
acc["refresh_token"] = new_refresh_token
updated = True
if not updated:
return
service.config = config
flag_modified(service, "config")
db.commit()

View File

@@ -39,6 +39,8 @@ class Account(Base):
refresh_token = Column(Text)
id_token = Column(Text)
session_token = Column(Text) # 会话令牌(优先刷新方式)
token_sync_status = Column(String(20), default='not_ready') # 'not_ready', 'pending', 'synced'
token_sync_updated_at = Column(DateTime, default=datetime.utcnow)
client_id = Column(String(255)) # OAuth Client ID
account_id = Column(String(255))
workspace_id = Column(String(255))
@@ -80,7 +82,9 @@ class Account(Base):
'subscription_type': self.subscription_type,
'subscription_at': self.subscription_at.isoformat() if self.subscription_at else None,
'created_at': self.created_at.isoformat() if self.created_at else None,
'updated_at': self.updated_at.isoformat() if self.updated_at else None
'updated_at': self.updated_at.isoformat() if self.updated_at else None,
'token_sync_status': self.token_sync_status,
'token_sync_updated_at': self.token_sync_updated_at.isoformat() if self.token_sync_updated_at else None,
}
@@ -227,4 +231,4 @@ class Proxy(Base):
if self.username and self.password:
auth = f"{self.username}:{self.password}@"
return f"{scheme}://{auth}{self.host}:{self.port}"
return f"{scheme}://{auth}{self.host}:{self.port}"

View File

@@ -45,7 +45,7 @@ class DatabaseSessionManager:
self.database_url = _build_sqlalchemy_url(database_url)
self.engine = create_engine(
self.database_url,
connect_args={"check_same_thread": False} if self.database_url.startswith("sqlite") else {},
connect_args={"check_same_thread": False, "timeout": 30} if self.database_url.startswith("sqlite") else {},
echo=False, # 设置为 True 可以查看所有 SQL 语句
pool_pre_ping=True # 连接池预检查
)
@@ -110,6 +110,8 @@ class DatabaseSessionManager:
("accounts", "subscription_type", "VARCHAR(20)"),
("accounts", "subscription_at", "DATETIME"),
("accounts", "cookies", "TEXT"),
("accounts", "token_sync_status", "VARCHAR(20) DEFAULT 'not_ready'"),
("accounts", "token_sync_updated_at", "DATETIME"),
("proxies", "is_default", "BOOLEAN DEFAULT 0"),
("cpa_services", "include_proxy_url", "BOOLEAN DEFAULT 0"),
]

View File

@@ -6,6 +6,8 @@
import abc
import logging
import re
import time
from dataclasses import dataclass
from typing import Optional, Dict, Any, List
from enum import Enum
@@ -14,12 +16,109 @@ from ..config.constants import EmailServiceType, OTP_CODE_PATTERN, OTP_CODE_SEMA
logger = logging.getLogger(__name__)
EMAIL_PROVIDER_BACKOFF_BASE_SECONDS = 30
EMAIL_PROVIDER_BACKOFF_MAX_SECONDS = 3600
OTP_TIMEOUT_ERROR_PREFIX = "OTP_TIMEOUT"
@dataclass(frozen=True)
class EmailProviderBackoffState:
"""邮箱供应商退避状态"""
failures: int = 0
delay_seconds: int = 0
opened_until: float = 0.0
retry_after: Optional[int] = None
last_error: Optional[str] = None
def is_open(self, now: Optional[float] = None) -> bool:
now_ts = now if now is not None else time.time()
return self.opened_until > now_ts
def to_dict(self) -> Dict[str, Any]:
return {
"failures": self.failures,
"delay_seconds": self.delay_seconds,
"opened_until": self.opened_until,
"retry_after": self.retry_after,
"last_error": self.last_error,
}
def calculate_adaptive_backoff_delay(
failures: int,
base_delay: int = EMAIL_PROVIDER_BACKOFF_BASE_SECONDS,
max_delay: int = EMAIL_PROVIDER_BACKOFF_MAX_SECONDS,
is_timeout: bool = False,
) -> int:
"""根据连续失败次数计算指数退避时长"""
normalized_failures = max(0, failures)
if is_timeout and normalized_failures >= 3:
return max_delay
exponent = max(0, normalized_failures - 1)
return min(base_delay * (2 ** exponent), max_delay)
def is_otp_timeout_error(error: object) -> bool:
"""识别 OTP 超时类错误码。"""
if error is None:
return False
if isinstance(error, OTPTimeoutEmailServiceError):
return True
error_code = getattr(error, "error_code", "")
if isinstance(error_code, str) and error_code.startswith(OTP_TIMEOUT_ERROR_PREFIX):
return True
return False
def apply_adaptive_backoff(
current_state: Optional[EmailProviderBackoffState],
error: "EmailServiceError",
now: Optional[float] = None,
) -> EmailProviderBackoffState:
"""在限流场景下推进邮箱供应商退避状态"""
state = current_state or EmailProviderBackoffState()
now_ts = now if now is not None else time.time()
next_failures = state.failures + 1
delay_seconds = calculate_adaptive_backoff_delay(
next_failures,
is_timeout=is_otp_timeout_error(error),
)
return EmailProviderBackoffState(
failures=next_failures,
delay_seconds=delay_seconds,
opened_until=now_ts + delay_seconds,
retry_after=getattr(error, "retry_after", None),
last_error=str(error),
)
def reset_adaptive_backoff() -> EmailProviderBackoffState:
"""重置邮箱供应商退避状态"""
return EmailProviderBackoffState()
class EmailServiceError(Exception):
"""邮箱服务异常"""
pass
class RateLimitedEmailServiceError(EmailServiceError):
"""邮箱服务被限流"""
def __init__(self, message: str, retry_after: Optional[int] = None):
super().__init__(message)
self.retry_after = retry_after
class OTPTimeoutEmailServiceError(EmailServiceError):
"""OTP 验证码等待超时。"""
def __init__(self, message: str, error_code: str = OTP_TIMEOUT_ERROR_PREFIX):
super().__init__(message)
self.error_code = error_code
class EmailServiceStatus(Enum):
"""邮箱服务状态"""
HEALTHY = "healthy"
@@ -46,6 +145,7 @@ class BaseEmailService(abc.ABC):
self.name = name or f"{service_type.value}_service"
self._status = EmailServiceStatus.HEALTHY
self._last_error = None
self._provider_backoff = reset_adaptive_backoff()
_EMAIL_ADDRESS_PATTERN = re.compile(r"[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}")
@@ -59,6 +159,15 @@ class BaseEmailService(abc.ABC):
"""获取最后一次错误信息"""
return self._last_error
@property
def provider_backoff_state(self) -> EmailProviderBackoffState:
"""获取当前邮箱供应商退避状态"""
return self._provider_backoff
def apply_provider_backoff_state(self, state: Optional[EmailProviderBackoffState]) -> None:
"""注入外部持久化的邮箱供应商退避状态"""
self._provider_backoff = state or reset_adaptive_backoff()
@abc.abstractmethod
def create_email(self, config: Dict[str, Any] = None) -> Dict[str, Any]:
"""
@@ -95,7 +204,7 @@ class BaseEmailService(abc.ABC):
email_id: 邮箱服务中的 ID如果需要
timeout: 超时时间(秒)
pattern: 验证码正则表达式
otp_sent_at: OTP 发送时间戳,用于过滤旧邮件
otp_sent_at: OTP 发送时间戳,只允许使用严格晚于该锚点的邮件
Returns:
验证码字符串,如果超时或未找到返回 None
@@ -309,8 +418,16 @@ class BaseEmailService(abc.ABC):
if success:
self._status = EmailServiceStatus.HEALTHY
self._last_error = None
self._provider_backoff = reset_adaptive_backoff()
else:
self._status = EmailServiceStatus.DEGRADED
if isinstance(error, RateLimitedEmailServiceError) or is_otp_timeout_error(error):
self._status = EmailServiceStatus.UNAVAILABLE
self._provider_backoff = apply_adaptive_backoff(
self._provider_backoff,
error,
)
else:
self._status = EmailServiceStatus.DEGRADED
if error:
self._last_error = str(error)

View File

@@ -12,7 +12,7 @@ from datetime import datetime, timezone
from html import unescape
from typing import Any, Dict, List, Optional
from .base import BaseEmailService, EmailServiceError, EmailServiceType
from .base import BaseEmailService, EmailServiceError, EmailServiceType, RateLimitedEmailServiceError
from ..config.constants import OTP_CODE_PATTERN
from ..core.http_client import HTTPClient, RequestConfig
@@ -102,7 +102,19 @@ class DuckMailService(BaseEmailService):
error_message = f"{error_message} - {error_payload}"
except Exception:
error_message = f"{error_message} - {response.text[:200]}"
raise EmailServiceError(error_message)
retry_after = None
if response.status_code == 429:
retry_after_header = response.headers.get("Retry-After")
if retry_after_header:
try:
retry_after = max(1, int(retry_after_header))
except ValueError:
retry_after = None
error = RateLimitedEmailServiceError(error_message, retry_after=retry_after)
else:
error = EmailServiceError(error_message)
self.update_status(False, error)
raise error
try:
return response.json()

View File

@@ -10,7 +10,7 @@ import random
import string
from typing import Optional, Dict, Any, List
from .base import BaseEmailService, EmailServiceError, EmailServiceType
from .base import BaseEmailService, EmailServiceError, EmailServiceType, RateLimitedEmailServiceError
from ..core.http_client import HTTPClient, RequestConfig
from ..config.constants import OTP_CODE_PATTERN
@@ -96,8 +96,19 @@ class FreemailService(BaseEmailService):
error_msg = f"{error_msg} - {error_data}"
except Exception:
error_msg = f"{error_msg} - {response.text[:200]}"
self.update_status(False, EmailServiceError(error_msg))
raise EmailServiceError(error_msg)
retry_after = None
if response.status_code == 429:
retry_after_header = response.headers.get("Retry-After")
if retry_after_header:
try:
retry_after = max(1, int(retry_after_header))
except ValueError:
retry_after = None
error = RateLimitedEmailServiceError(error_msg, retry_after=retry_after)
else:
error = EmailServiceError(error_msg)
self.update_status(False, error)
raise error
try:
return response.json()

View File

@@ -10,7 +10,7 @@ import logging
from typing import Optional, Dict, Any, List
from urllib.parse import urljoin
from .base import BaseEmailService, EmailServiceError, EmailServiceType
from .base import BaseEmailService, EmailServiceError, EmailServiceType, RateLimitedEmailServiceError
from ..core.http_client import HTTPClient, RequestConfig
from ..config.constants import OTP_CODE_PATTERN
@@ -148,8 +148,20 @@ class MeoMailEmailService(BaseEmailService):
except:
error_msg = f"{error_msg} - {response.text[:200]}"
self.update_status(False, EmailServiceError(error_msg))
raise EmailServiceError(error_msg)
retry_after = None
if response.status_code == 429:
retry_after_header = response.headers.get("Retry-After")
if retry_after_header:
try:
retry_after = max(1, int(retry_after_header))
except ValueError:
retry_after = None
error = RateLimitedEmailServiceError(error_msg, retry_after=retry_after)
else:
error = EmailServiceError(error_msg)
self.update_status(False, error)
raise error
# 解析响应
try:
@@ -553,4 +565,4 @@ class MeoMailEmailService(BaseEmailService):
"system_config": config,
"cached_emails_count": len(self._emails_cache),
"status": self.status.value,
}
}

View File

@@ -15,7 +15,7 @@ from email.policy import default as email_policy
from html import unescape
from typing import Optional, Dict, Any, List
from .base import BaseEmailService, EmailServiceError, EmailServiceType
from .base import BaseEmailService, EmailServiceError, EmailServiceType, RateLimitedEmailServiceError
from ..core.http_client import HTTPClient, RequestConfig
from ..config.constants import OTP_CODE_PATTERN
@@ -200,8 +200,19 @@ class TempMailService(BaseEmailService):
error_msg = f"{error_msg} - {error_data}"
except Exception:
error_msg = f"{error_msg} - {response.text[:200]}"
self.update_status(False, EmailServiceError(error_msg))
raise EmailServiceError(error_msg)
retry_after = None
if response.status_code == 429:
retry_after_header = response.headers.get("Retry-After")
if retry_after_header:
try:
retry_after = max(1, int(retry_after_header))
except ValueError:
retry_after = None
error = RateLimitedEmailServiceError(error_msg, retry_after=retry_after)
else:
error = EmailServiceError(error_msg)
self.update_status(False, error)
raise error
try:
return response.json()

View File

@@ -6,9 +6,7 @@ import re
import time
import logging
from typing import Optional, Dict, Any, List
import json
from curl_cffi import requests as cffi_requests
from datetime import datetime, timezone
from .base import BaseEmailService, EmailServiceError, EmailServiceType
from ..core.http_client import HTTPClient, RequestConfig
@@ -17,6 +15,8 @@ from ..config.constants import OTP_CODE_PATTERN
logger = logging.getLogger(__name__)
OTP_SENT_AT_TOLERANCE_SECONDS = 2
class TempmailService(BaseEmailService):
"""
@@ -58,10 +58,65 @@ class TempmailService(BaseEmailService):
config=http_config
)
# 状态变量
# 状态变量(内存缓存,重启后从 DB 按需查询)
self._email_cache: Dict[str, Dict[str, Any]] = {}
self._last_check_time: float = 0
def _parse_message_time(self, value: Any) -> Optional[float]:
"""解析 Tempmail 邮件时间,兼容 Unix 时间戳与 ISO 8601。"""
if value is None or value == "":
return None
if isinstance(value, (int, float)):
timestamp = float(value)
else:
text = str(value).strip()
if not text:
return None
try:
timestamp = float(text)
except ValueError:
try:
normalized = text.replace("Z", "+00:00")
timestamp = datetime.fromisoformat(normalized).astimezone(timezone.utc).timestamp()
except Exception:
return None
while timestamp > 1e11:
timestamp /= 1000.0
return timestamp if timestamp > 0 else None
def _get_received_timestamp(self, message: Dict[str, Any]) -> Optional[float]:
"""返回 Tempmail 邮件的接收时间戳。"""
for field_name in ("received_at", "date", "created_at", "createdAt", "timestamp"):
timestamp = self._parse_message_time(message.get(field_name))
if timestamp is not None:
return timestamp
return None
def _save_token_to_db(self, email: str, token: str) -> None:
"""将邮箱 token 持久化到 Setting 表key=tempmail_token:{email}"""
try:
from ..database.session import get_db
from ..database.crud import set_setting
with get_db() as db:
set_setting(db, f"tempmail_token:{email}", token, category="tempmail")
except Exception as e:
logger.warning(f"保存 Tempmail token 到数据库失败: {e}")
def _load_token_from_db(self, email: str) -> Optional[str]:
"""从 Setting 表读取邮箱 token"""
try:
from ..database.session import get_db
from ..database.crud import get_setting
with get_db() as db:
setting = get_setting(db, f"tempmail_token:{email}")
return setting.value if setting else None
except Exception as e:
logger.warning(f"从数据库读取 Tempmail token 失败: {e}")
return None
def create_email(self, config: Dict[str, Any] = None) -> Dict[str, Any]:
"""
创建新的临时邮箱
@@ -107,6 +162,7 @@ class TempmailService(BaseEmailService):
"created_at": time.time(),
}
self._email_cache[email] = email_info
self._save_token_to_db(email, token)
logger.info(f"成功创建 Tempmail.lol 邮箱: {email}")
self.update_status(True)
@@ -134,19 +190,21 @@ class TempmailService(BaseEmailService):
email_id: 邮箱 token如果不提供从缓存中查找
timeout: 超时时间(秒)
pattern: 验证码正则表达式
otp_sent_at: OTP 发送时间戳Tempmail 服务暂不使用此参数)
otp_sent_at: OTP 发送时间戳,只允许使用严格晚于该锚点减去容差后的邮件
Returns:
验证码字符串,如果超时或未找到返回 None
"""
token = email_id
if not token:
# 缓存查找 token
# 先从内存缓存查找,再从数据库查找
if email in self._email_cache:
token = self._email_cache[email].get("token")
else:
logger.warning(f"未找到邮箱 {email} 的 token无法获取验证码")
return None
if not token:
token = self._load_token_from_db(email)
if not token:
logger.warning(f"未找到邮箱 {email} 的 token无法获取验证码")
return None
if not token:
logger.warning(f"邮箱 {email} 没有 token无法获取验证码")
@@ -187,11 +245,21 @@ class TempmailService(BaseEmailService):
if not isinstance(msg, dict):
continue
# 使用 date 作为唯一标识
msg_date = msg.get("date", 0)
if not msg_date or msg_date in seen_ids:
msg_timestamp = self._get_received_timestamp(msg)
if otp_sent_at is not None:
min_allowed_timestamp = otp_sent_at - OTP_SENT_AT_TOLERANCE_SECONDS
if msg_timestamp is None or msg_timestamp <= min_allowed_timestamp:
continue
message_id = str(
msg.get("id")
or msg.get("date")
or msg.get("createdAt")
or f"{msg.get('from', '')}:{msg.get('subject', '')}:{msg_timestamp}"
).strip()
if not message_id or message_id in seen_ids:
continue
seen_ids.add(msg_date)
seen_ids.add(message_id)
sender = str(msg.get("from", "")).lower()
subject = str(msg.get("subject", ""))
@@ -397,4 +465,4 @@ class TempmailService(BaseEmailService):
"email": email,
"message": "等待验证码超时"
})
return None
return None

View File

@@ -15,9 +15,11 @@ from fastapi import FastAPI, Request, Form
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse, RedirectResponse
from fastapi.responses import FileResponse, HTMLResponse, RedirectResponse
from ..config.settings import get_settings
from ..database import crud
from ..database.session import get_db
from .routes import api_router
from .routes.websocket import router as ws_router
from .task_manager import task_manager
@@ -31,6 +33,11 @@ if getattr(sys, 'frozen', False):
else:
_RESOURCE_ROOT = Path(__file__).parent.parent.parent
if __name__ == "__main__":
from webui import setup_application as _setup_application
_setup_application()
# 静态文件和模板目录
STATIC_DIR = _RESOURCE_ROOT / "static"
TEMPLATES_DIR = _RESOURCE_ROOT / "templates"
@@ -108,8 +115,9 @@ def create_app() -> FastAPI:
async def login_page(request: Request, next: Optional[str] = "/"):
"""登录页面"""
return templates.TemplateResponse(
"login.html",
{"request": request, "error": "", "next": next or "/"}
request=request,
name="login.html",
context={"request": request, "error": "", "next": next or "/"}
)
@app.post("/login")
@@ -118,8 +126,9 @@ def create_app() -> FastAPI:
expected = get_settings().webui_access_password.get_secret_value()
if not secrets.compare_digest(password, expected):
return templates.TemplateResponse(
"login.html",
{"request": request, "error": "密码错误", "next": next or "/"},
request=request,
name="login.html",
context={"request": request, "error": "密码错误", "next": next or "/"},
status_code=401
)
@@ -134,38 +143,48 @@ def create_app() -> FastAPI:
response.delete_cookie("webui_auth")
return response
@app.get("/favicon.ico", include_in_schema=False)
async def favicon_ico():
"""兼容浏览器对根路径 favicon 的默认请求。"""
return FileResponse(STATIC_DIR / "favicon.svg", media_type="image/svg+xml")
@app.get("/favicon.svg", include_in_schema=False)
async def favicon_svg():
"""提供统一的站点图标资源。"""
return FileResponse(STATIC_DIR / "favicon.svg", media_type="image/svg+xml")
@app.get("/", response_class=HTMLResponse)
async def index(request: Request):
"""首页 - 注册页面"""
if not _is_authenticated(request):
return _redirect_to_login(request)
return templates.TemplateResponse("index.html", {"request": request})
return templates.TemplateResponse(request=request, name="index.html", context={"request": request})
@app.get("/accounts", response_class=HTMLResponse)
async def accounts_page(request: Request):
"""账号管理页面"""
if not _is_authenticated(request):
return _redirect_to_login(request)
return templates.TemplateResponse("accounts.html", {"request": request})
return templates.TemplateResponse(request=request, name="accounts.html", context={"request": request})
@app.get("/email-services", response_class=HTMLResponse)
async def email_services_page(request: Request):
"""邮箱服务管理页面"""
if not _is_authenticated(request):
return _redirect_to_login(request)
return templates.TemplateResponse("email_services.html", {"request": request})
return templates.TemplateResponse(request=request, name="email_services.html", context={"request": request})
@app.get("/settings", response_class=HTMLResponse)
async def settings_page(request: Request):
"""设置页面"""
if not _is_authenticated(request):
return _redirect_to_login(request)
return templates.TemplateResponse("settings.html", {"request": request})
return templates.TemplateResponse(request=request, name="settings.html", context={"request": request})
@app.get("/payment", response_class=HTMLResponse)
async def payment_page(request: Request):
"""支付页面"""
return templates.TemplateResponse("payment.html", {"request": request})
return templates.TemplateResponse(request=request, name="payment.html", context={"request": request})
@app.on_event("startup")
async def startup_event():
@@ -183,6 +202,12 @@ def create_app() -> FastAPI:
loop = asyncio.get_event_loop()
task_manager.set_loop(loop)
stale_error = "服务启动时检测到未完成的历史任务,已标记失败,请重新发起。"
with get_db() as db:
stale_tasks = crud.fail_incomplete_registration_tasks(db, stale_error)
if stale_tasks:
logger.warning("已收敛 %s 个僵尸任务: %s", len(stale_tasks), ", ".join(task[:8] for task in stale_tasks))
logger.info("=" * 50)
logger.info(f"{settings.app_name} v{settings.app_version} 启动中...")
logger.info(f"调试模式: {settings.debug}")
@@ -199,3 +224,23 @@ def create_app() -> FastAPI:
# 创建全局应用实例
app = create_app()
if __name__ == "__main__":
import uvicorn
settings = get_settings()
logger.info(
"通过模块入口启动 Web UI: http://%s:%s",
settings.webui_host,
settings.webui_port,
)
uvicorn.run(
app,
host=settings.webui_host,
port=settings.webui_port,
reload=False,
log_level="info" if settings.debug else "warning",
access_log=settings.debug,
ws="websockets",
)

File diff suppressed because it is too large Load Diff

View File

@@ -7,12 +7,33 @@ import asyncio
import logging
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from ...database import crud
from ...database.session import get_db
from ..task_manager import task_manager
logger = logging.getLogger(__name__)
router = APIRouter()
def _restore_task_snapshot(task_uuid: str) -> tuple[dict, list[str]]:
"""从数据库恢复任务状态和历史日志,解决服务重启后的监控空白。"""
with get_db() as db:
task = crud.get_registration_task(db, task_uuid)
if not task:
return {}, []
status = {"status": task.status}
if task.result and task.result.get("email"):
status["email"] = task.result["email"]
if task.error_message:
status["error"] = task.error_message
logs = task.logs.splitlines() if task.logs else []
task_manager.sync_task_state(task_uuid, status=status, logs=logs)
return status, logs
@router.websocket("/ws/task/{task_uuid}")
async def task_websocket(websocket: WebSocket, task_uuid: str):
"""
@@ -25,14 +46,15 @@ async def task_websocket(websocket: WebSocket, task_uuid: str):
- 客户端发送: {"type": "cancel"} - 取消任务
"""
await websocket.accept()
restored_status, restored_logs = _restore_task_snapshot(task_uuid)
# 注册连接(会记录当前日志数量,避免重复发送历史日志)
task_manager.register_websocket(task_uuid, websocket)
# 注册连接,并取得注册时刻的历史日志快照,避免与后续实时推送串扰
history_logs = task_manager.register_websocket(task_uuid, websocket)
logger.info(f"WebSocket 连接已建立: {task_uuid}")
try:
# 发送当前状态
status = task_manager.get_status(task_uuid)
status = task_manager.get_status(task_uuid) or restored_status
if status:
await websocket.send_json({
"type": "status",
@@ -40,9 +62,8 @@ async def task_websocket(websocket: WebSocket, task_uuid: str):
**status
})
# 发送历史日志(只发送注册时已存在的日志,避免与实时推送重复)
history_logs = task_manager.get_unsent_logs(task_uuid, websocket)
for log in history_logs:
# 发送历史日志。服务重启后 _restore_task_snapshot 会先把数据库快照回填到内存。
for log in history_logs or restored_logs:
await websocket.send_json({
"type": "log",
"task_uuid": task_uuid,
@@ -107,8 +128,8 @@ async def batch_websocket(websocket: WebSocket, batch_id: str):
"""
await websocket.accept()
# 注册连接(会记录当前日志数量,避免重复发送历史日志)
task_manager.register_batch_websocket(batch_id, websocket)
# 注册连接,并取得注册时刻的历史日志快照,避免漏发/重复发送
history_logs = task_manager.register_batch_websocket(batch_id, websocket)
logger.info(f"批量任务 WebSocket 连接已建立: {batch_id}")
try:
@@ -121,8 +142,6 @@ async def batch_websocket(websocket: WebSocket, batch_id: str):
**status
})
# 发送历史日志(只发送注册时已存在的日志,避免与实时推送重复)
history_logs = task_manager.get_unsent_batch_logs(batch_id, websocket)
for log in history_logs:
await websocket.send_json({
"type": "log",

View File

@@ -144,20 +144,22 @@ class TaskManager:
except Exception as e:
logger.warning(f"WebSocket 发送状态失败: {e}")
def register_websocket(self, task_uuid: str, websocket):
"""注册 WebSocket 连接"""
def register_websocket(self, task_uuid: str, websocket) -> List[str]:
"""注册 WebSocket 连接,并返回注册时刻的历史日志快照"""
history_logs: List[str] = []
with _ws_lock:
if task_uuid not in _ws_connections:
_ws_connections[task_uuid] = []
# 避免重复注册同一个连接
if websocket not in _ws_connections[task_uuid]:
_ws_connections[task_uuid].append(websocket)
# 记录已发送的日志数量,用于发送历史日志时避免重复
with _get_log_lock(task_uuid):
_ws_sent_index[task_uuid][id(websocket)] = len(_log_queues.get(task_uuid, []))
history_logs = _log_queues.get(task_uuid, []).copy()
_ws_sent_index[task_uuid][id(websocket)] = len(history_logs)
_ws_connections[task_uuid].append(websocket)
logger.info(f"WebSocket 连接已注册: {task_uuid}")
else:
logger.warning(f"WebSocket 连接已存在,跳过重复注册: {task_uuid}")
return history_logs
def get_unsent_logs(self, task_uuid: str, websocket) -> List[str]:
"""获取未发送给该 WebSocket 的日志"""
@@ -190,6 +192,24 @@ class TaskManager:
with _get_log_lock(task_uuid):
return _log_queues.get(task_uuid, []).copy()
def sync_task_state(
self,
task_uuid: str,
status: Optional[dict] = None,
logs: Optional[List[str]] = None
):
"""将数据库中的任务快照回填到内存态,便于重连恢复。"""
if status:
current_status = _task_status.get(task_uuid, {}).copy()
current_status.update(status)
_task_status[task_uuid] = current_status
if logs is not None:
with _get_log_lock(task_uuid):
cached_logs = _log_queues.get(task_uuid, [])
if len(logs) >= len(cached_logs):
_log_queues[task_uuid] = list(logs)
def update_status(self, task_uuid: str, status: str, **kwargs):
"""更新任务状态"""
if task_uuid not in _task_status:
@@ -198,6 +218,15 @@ class TaskManager:
_task_status[task_uuid]["status"] = status
_task_status[task_uuid].update(kwargs)
if self._loop and self._loop.is_running():
try:
asyncio.run_coroutine_threadsafe(
self.broadcast_status(task_uuid, status, **kwargs),
self._loop
)
except Exception as e:
logger.warning(f"广播任务状态失败: {e}")
def get_status(self, task_uuid: str) -> Optional[dict]:
"""获取任务状态"""
return _task_status.get(task_uuid)
@@ -211,18 +240,25 @@ class TaskManager:
# ============== 批量任务管理 ==============
def init_batch(self, batch_id: str, total: int):
def init_batch(self, batch_id: str, total: int, **kwargs):
"""初始化批量任务"""
_batch_status[batch_id] = {
"status": "running",
"total": total,
"completed": 0,
"success": 0,
"failed": 0,
"skipped": 0,
"current_index": 0,
"finished": False
}
with _get_batch_lock(batch_id):
previous = _batch_status.get(batch_id, {})
status = {
"status": "running",
"total": total,
"completed": 0,
"success": 0,
"failed": 0,
"skipped": previous.get("skipped", 0),
"cancelled": previous.get("cancelled", False),
"current_index": 0,
"finished": False,
}
status.update(previous)
status.update(kwargs)
status["total"] = total
_batch_status[batch_id] = status
logger.info(f"批量任务 {batch_id} 已初始化,总数: {total}")
def add_batch_log(self, batch_id: str, log_message: str):
@@ -266,11 +302,11 @@ class TaskManager:
def update_batch_status(self, batch_id: str, **kwargs):
"""更新批量任务状态"""
if batch_id not in _batch_status:
logger.warning(f"批量任务 {batch_id} 不存在")
return
_batch_status[batch_id].update(kwargs)
with _get_batch_lock(batch_id):
if batch_id not in _batch_status:
logger.warning(f"批量任务 {batch_id} 不存在")
return
_batch_status[batch_id].update(kwargs)
# 异步广播状态更新
if self._loop and self._loop.is_running():
@@ -302,7 +338,9 @@ class TaskManager:
def get_batch_status(self, batch_id: str) -> Optional[dict]:
"""获取批量任务状态"""
return _batch_status.get(batch_id)
with _get_batch_lock(batch_id):
status = _batch_status.get(batch_id)
return status.copy() if status is not None else None
def get_batch_logs(self, batch_id: str) -> List[str]:
"""获取批量任务日志"""
@@ -316,26 +354,29 @@ class TaskManager:
def cancel_batch(self, batch_id: str):
"""取消批量任务"""
if batch_id in _batch_status:
_batch_status[batch_id]["cancelled"] = True
_batch_status[batch_id]["status"] = "cancelling"
logger.info(f"批量任务 {batch_id} 已标记为取消")
with _get_batch_lock(batch_id):
if batch_id in _batch_status:
_batch_status[batch_id]["cancelled"] = True
_batch_status[batch_id]["status"] = "cancelling"
logger.info(f"批量任务 {batch_id} 已标记为取消")
def register_batch_websocket(self, batch_id: str, websocket):
"""注册批量任务 WebSocket 连接"""
def register_batch_websocket(self, batch_id: str, websocket) -> List[str]:
"""注册批量任务 WebSocket 连接,并返回注册时刻的历史日志快照"""
key = f"batch_{batch_id}"
history_logs: List[str] = []
with _ws_lock:
if key not in _ws_connections:
_ws_connections[key] = []
# 避免重复注册同一个连接
if websocket not in _ws_connections[key]:
_ws_connections[key].append(websocket)
# 记录已发送的日志数量,用于发送历史日志时避免重复
with _get_batch_lock(batch_id):
_ws_sent_index[key][id(websocket)] = len(_batch_logs.get(batch_id, []))
history_logs = _batch_logs.get(batch_id, []).copy()
_ws_sent_index[key][id(websocket)] = len(history_logs)
_ws_connections[key].append(websocket)
logger.info(f"批量任务 WebSocket 连接已注册: {batch_id}")
else:
logger.warning(f"批量任务 WebSocket 连接已存在,跳过重复注册: {batch_id}")
return history_logs
def get_unsent_batch_logs(self, batch_id: str, websocket) -> List[str]:
"""获取未发送给该 WebSocket 的批量任务日志"""