mirror of
https://github.com/cnlimiter/codex-register.git
synced 2026-06-28 10:41:36 +08:00
Merge branch 'master' into fix/worker-mail-otp-extraction
This commit is contained in:
@@ -56,7 +56,7 @@ APP_DESCRIPTION = "自动注册 OpenAI/Codex CLI 账号的系统"
|
||||
OAUTH_CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
||||
OAUTH_AUTH_URL = "https://auth.openai.com/oauth/authorize"
|
||||
OAUTH_TOKEN_URL = "https://auth.openai.com/oauth/token"
|
||||
OAUTH_REDIRECT_URI = "http://localhost:1455/auth/callback"
|
||||
OAUTH_REDIRECT_URI = "http://localhost:15555/auth/callback"
|
||||
OAUTH_SCOPE = "openid email profile offline_access"
|
||||
|
||||
# OpenAI API 端点
|
||||
@@ -267,7 +267,7 @@ DEFAULT_SETTINGS = [
|
||||
("registration.timeout", "120", "超时时间(秒)", "registration"),
|
||||
("registration.default_password_length", "12", "默认密码长度", "registration"),
|
||||
("webui.host", "0.0.0.0", "Web UI 监听主机", "webui"),
|
||||
("webui.port", "8000", "Web UI 监听端口", "webui"),
|
||||
("webui.port", "15555", "Web UI 监听端口", "webui"),
|
||||
("webui.debug", "true", "调试模式", "webui"),
|
||||
]
|
||||
|
||||
|
||||
@@ -76,7 +76,7 @@ SETTING_DEFINITIONS: Dict[str, SettingDefinition] = {
|
||||
),
|
||||
"webui_port": SettingDefinition(
|
||||
db_key="webui.port",
|
||||
default_value=8000,
|
||||
default_value=15555,
|
||||
category=SettingCategory.WEBUI,
|
||||
description="Web UI 监听端口"
|
||||
),
|
||||
@@ -609,7 +609,7 @@ class Settings(BaseModel):
|
||||
|
||||
# Web UI 配置
|
||||
webui_host: str = "0.0.0.0"
|
||||
webui_port: int = 8000
|
||||
webui_port: int = 15555
|
||||
webui_secret_key: SecretStr = SecretStr("your-secret-key-change-in-production")
|
||||
webui_access_password: SecretStr = SecretStr("admin123")
|
||||
|
||||
|
||||
457
src/core/login.py
Normal file
457
src/core/login.py
Normal file
@@ -0,0 +1,457 @@
|
||||
"""
|
||||
登录流程引擎
|
||||
从 register.py 中拆分的登录专属方法
|
||||
"""
|
||||
|
||||
import urllib.parse
|
||||
import base64
|
||||
import json as json_module
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
from .register import RegistrationEngine, RegistrationResult
|
||||
from ..config.constants import OPENAI_API_ENDPOINTS
|
||||
|
||||
|
||||
class LoginEngine(RegistrationEngine):
|
||||
"""
|
||||
登录引擎
|
||||
继承 RegistrationEngine,包含登录流程专属方法:
|
||||
- _follow_login_redirects
|
||||
- _submit_login_form
|
||||
- _send_verification_code_passwordless
|
||||
- _get_workspace_id
|
||||
- _select_workspace
|
||||
- _follow_redirects
|
||||
- _handle_oauth_callback
|
||||
"""
|
||||
|
||||
def _follow_login_redirects(self, start_url: str) -> bool:
|
||||
"""跟随重定向链,寻找回调 URL"""
|
||||
try:
|
||||
current_url = start_url
|
||||
max_redirects = 6
|
||||
|
||||
for i in range(max_redirects):
|
||||
self._log(f"重定向 {i+1}/{max_redirects}: {current_url[:100]}...")
|
||||
|
||||
response = self.session.get(
|
||||
current_url,
|
||||
allow_redirects=False,
|
||||
timeout=15
|
||||
)
|
||||
|
||||
location = response.headers.get("Location") or ""
|
||||
|
||||
# 如果不是重定向状态码,停止
|
||||
if response.status_code == 200:
|
||||
self._log(f"非重定向状态码: {response.status_code}")
|
||||
return True
|
||||
|
||||
if not location:
|
||||
self._log("重定向响应缺少 Location 头")
|
||||
break
|
||||
|
||||
# 构建下一个 URL
|
||||
next_url = urllib.parse.urljoin(current_url, location)
|
||||
|
||||
# 检查是否包含回调参数
|
||||
if "code=" in next_url and "state=" in next_url:
|
||||
self._log(f"找到回调 URL: {next_url[:100]}...")
|
||||
|
||||
current_url = next_url
|
||||
|
||||
self._log("未能在重定向链中找到最终 URL")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self._log(f"跟随重定向失败: {e}", "error")
|
||||
return False
|
||||
|
||||
def _submit_login_form(self, did: str, sen_token) -> bool:
|
||||
"""处理 免密登录"""
|
||||
try:
|
||||
self._log("处理免密登录...")
|
||||
login_body = f'{{"username":{{"value":"{self.email}","kind":"email"}}}}'
|
||||
headers = {
|
||||
"referer": "https://auth.openai.com/log-in",
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
}
|
||||
|
||||
if sen_token:
|
||||
sentinel = (
|
||||
f'{{"p": "", "t": "", "c": "{sen_token}", '
|
||||
f'"id": "{did}", "flow": "authorize_continue"}}'
|
||||
)
|
||||
headers["openai-sentinel-token"] = sentinel
|
||||
|
||||
response = self.session.post(
|
||||
OPENAI_API_ENDPOINTS["signup"],
|
||||
headers=headers,
|
||||
data=login_body,
|
||||
)
|
||||
self._log(f"提交登录表单状态: {response.status_code}")
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self._log(f"处理登录失败: {e}", "error")
|
||||
return False
|
||||
|
||||
def _send_verification_code_passwordless(self) -> bool:
|
||||
"""发送验证码"""
|
||||
try:
|
||||
self._otp_sent_at = time.time()
|
||||
response = self.session.post(
|
||||
OPENAI_API_ENDPOINTS["passwordless_send_otp"],
|
||||
headers={
|
||||
"referer": "https://auth.openai.com/log-in/password",
|
||||
"accept": "application/json"
|
||||
}
|
||||
)
|
||||
|
||||
self._log(f"验证码发送状态: {response.status_code}")
|
||||
return response.status_code == 200
|
||||
|
||||
except Exception as e:
|
||||
self._log(f"发送验证码失败: {e}", "error")
|
||||
return False
|
||||
|
||||
def _decode_workspace_id(self, auth_cookie: str) -> str:
|
||||
"""从授权 Cookie 中解析 Workspace ID"""
|
||||
segments = auth_cookie.split(".")
|
||||
if len(segments) < 1:
|
||||
raise ValueError("授权 Cookie 格式错误")
|
||||
|
||||
payload = segments[0]
|
||||
pad = "=" * ((4 - (len(payload) % 4)) % 4)
|
||||
decoded = base64.urlsafe_b64decode((payload + pad).encode("ascii"))
|
||||
auth_json = json_module.loads(decoded.decode("utf-8"))
|
||||
|
||||
workspaces = auth_json.get("workspaces") or []
|
||||
if not workspaces:
|
||||
raise ValueError("授权 Cookie 里没有 workspace 信息")
|
||||
|
||||
workspace_id = str((workspaces[0] or {}).get("id") or "").strip()
|
||||
if not workspace_id:
|
||||
raise ValueError("无法解析 workspace_id")
|
||||
|
||||
return workspace_id
|
||||
|
||||
def _get_workspace_id(self) -> Optional[str]:
|
||||
"""获取 Workspace ID"""
|
||||
backoff_seconds = (1, 2, 4)
|
||||
max_attempts = len(backoff_seconds) + 1
|
||||
|
||||
for attempt in range(1, max_attempts + 1):
|
||||
try:
|
||||
auth_cookie = self.session.cookies.get("oai-client-auth-session")
|
||||
if auth_cookie:
|
||||
workspace_id = self._decode_workspace_id(auth_cookie)
|
||||
self._log(f"Workspace ID: {workspace_id}")
|
||||
return workspace_id
|
||||
|
||||
raise ValueError("未能获取到授权 Cookie")
|
||||
except Exception as e:
|
||||
level = "warning" if attempt < max_attempts else "error"
|
||||
self._log(
|
||||
f"获取 Workspace ID 失败: {e} (第 {attempt}/{max_attempts} 次)",
|
||||
level,
|
||||
)
|
||||
|
||||
if attempt < max_attempts:
|
||||
wait_seconds = backoff_seconds[attempt - 1]
|
||||
self._log(f"等待 {wait_seconds} 秒后重试 Workspace ID", "warning")
|
||||
time.sleep(wait_seconds)
|
||||
|
||||
return None
|
||||
|
||||
def _select_workspace(self, workspace_id: str) -> Optional[str]:
|
||||
"""选择 Workspace"""
|
||||
try:
|
||||
select_body = f'{{"workspace_id":"{workspace_id}"}}'
|
||||
|
||||
response = self.session.post(
|
||||
OPENAI_API_ENDPOINTS["select_workspace"],
|
||||
headers={
|
||||
"referer": "https://auth.openai.com/sign-in-with-chatgpt/codex/consent",
|
||||
"content-type": "application/json",
|
||||
},
|
||||
data=select_body,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
self._log(f"选择 workspace 失败: {response.status_code}", "error")
|
||||
self._log(f"响应: {response.text[:200]}", "warning")
|
||||
return None
|
||||
|
||||
continue_url = str((response.json() or {}).get("continue_url") or "").strip()
|
||||
if not continue_url:
|
||||
self._log("workspace/select 响应里缺少 continue_url", "error")
|
||||
return None
|
||||
|
||||
self._log(f"Continue URL: {continue_url[:100]}...")
|
||||
return continue_url
|
||||
|
||||
except Exception as e:
|
||||
self._log(f"选择 Workspace 失败: {e}", "error")
|
||||
return None
|
||||
|
||||
def _follow_redirects(self, start_url: str) -> Optional[str]:
|
||||
"""跟随重定向链,寻找回调 URL"""
|
||||
try:
|
||||
current_url = start_url
|
||||
max_redirects = 6
|
||||
|
||||
for i in range(max_redirects):
|
||||
self._log(f"重定向 {i+1}/{max_redirects}: {current_url[:100]}...")
|
||||
|
||||
response = self.session.get(
|
||||
current_url,
|
||||
allow_redirects=False,
|
||||
timeout=15
|
||||
)
|
||||
|
||||
location = response.headers.get("Location") or ""
|
||||
|
||||
# 如果不是重定向状态码,停止
|
||||
if response.status_code not in [301, 302, 303, 307, 308]:
|
||||
self._log(f"非重定向状态码: {response.status_code}")
|
||||
break
|
||||
|
||||
if not location:
|
||||
self._log("重定向响应缺少 Location 头")
|
||||
break
|
||||
|
||||
# 构建下一个 URL
|
||||
next_url = urllib.parse.urljoin(current_url, location)
|
||||
|
||||
# 检查是否包含回调参数
|
||||
if "code=" in next_url and "state=" in next_url:
|
||||
self._log(f"找到回调 URL: {next_url[:100]}...")
|
||||
return next_url
|
||||
|
||||
current_url = next_url
|
||||
|
||||
self._log("未能在重定向链中找到回调 URL", "error")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
self._log(f"跟随重定向失败: {e}", "error")
|
||||
return None
|
||||
|
||||
def _handle_oauth_callback(self, callback_url: str) -> Optional[Dict[str, Any]]:
|
||||
"""处理 OAuth 回调"""
|
||||
try:
|
||||
if not self.oauth_start:
|
||||
self._log("OAuth 流程未初始化", "error")
|
||||
return None
|
||||
|
||||
self._log("处理 OAuth 回调...")
|
||||
token_info = self.oauth_manager.handle_callback(
|
||||
callback_url=callback_url,
|
||||
expected_state=self.oauth_start.state,
|
||||
code_verifier=self.oauth_start.code_verifier
|
||||
)
|
||||
|
||||
self._log("OAuth 授权成功")
|
||||
return token_info
|
||||
|
||||
except Exception as e:
|
||||
self._log(f"处理 OAuth 回调失败: {e}", "error")
|
||||
return None
|
||||
|
||||
def run(self) -> RegistrationResult:
|
||||
"""
|
||||
执行完整的注册流程
|
||||
|
||||
支持已注册账号自动登录:
|
||||
- 如果检测到邮箱已注册,自动切换到登录流程
|
||||
- 已注册账号跳过:设置密码、发送验证码、创建用户账户
|
||||
- 共用步骤:获取验证码、验证验证码、Workspace 和 OAuth 回调
|
||||
|
||||
Returns:
|
||||
RegistrationResult: 注册结果
|
||||
"""
|
||||
result = RegistrationResult(success=False, logs=self.logs)
|
||||
|
||||
try:
|
||||
self._log("=" * 60)
|
||||
self._log("开始注册流程")
|
||||
self._log("=" * 60)
|
||||
|
||||
self._log("1. 检查 IP 地理位置...")
|
||||
ip_ok, location = self._check_ip_location()
|
||||
if not ip_ok:
|
||||
result.error_message = f"IP 地理位置不支持: {location}"
|
||||
self._log(f"IP 检查失败: {location}", "error")
|
||||
return result
|
||||
|
||||
self._log(f"IP 位置: {location}")
|
||||
|
||||
self._log("2. 创建邮箱...")
|
||||
if not self._create_email():
|
||||
result.error_message = "创建邮箱失败"
|
||||
return result
|
||||
|
||||
result.email = self.email
|
||||
|
||||
self._log("3. 初始化会话...")
|
||||
if not self._init_session():
|
||||
result.error_message = "初始化会话失败"
|
||||
return result
|
||||
|
||||
self._log("4. 开始 OAuth 授权流程...")
|
||||
if not self._start_oauth():
|
||||
result.error_message = "开始 OAuth 流程失败"
|
||||
return result
|
||||
|
||||
self._log("5. 获取 Device ID...")
|
||||
did = self._get_device_id()
|
||||
if not did:
|
||||
result.error_message = "获取 Device ID 失败"
|
||||
return result
|
||||
|
||||
self._log("6. 检查 Sentinel 拦截...")
|
||||
sen_token = self._check_sentinel(did)
|
||||
if sen_token:
|
||||
self._log("Sentinel 检查通过")
|
||||
else:
|
||||
self._log("Sentinel 检查失败或未启用", "warning")
|
||||
|
||||
self._log("7. 提交注册表单...")
|
||||
signup_result = self._submit_signup_form(did, sen_token)
|
||||
if not signup_result.success:
|
||||
result.error_message = f"提交注册表单失败: {signup_result.error_message}"
|
||||
return result
|
||||
|
||||
if self._is_existing_account:
|
||||
self._log(f"8. 邮箱 {self.email} 在 OpenAI 已注册,跳过注册流程", "warning")
|
||||
result.error_message = f"邮箱 {self.email} 已在 OpenAI 注册"
|
||||
return result
|
||||
|
||||
self._log("8. 注册密码...")
|
||||
password_ok, password = self._register_password()
|
||||
if not password_ok:
|
||||
result.error_message = "注册密码失败"
|
||||
return result
|
||||
|
||||
self._log("9. 发送验证码...")
|
||||
if not self._send_verification_code():
|
||||
result.error_message = "发送验证码失败"
|
||||
return result
|
||||
|
||||
self._log("10. 等待验证码...")
|
||||
code = self._get_verification_code()
|
||||
if not code:
|
||||
self._log("10. 验证码超时,重新发送...")
|
||||
if self._send_verification_code():
|
||||
code = self._get_verification_code()
|
||||
if not code:
|
||||
result.error_message = "获取验证码失败"
|
||||
return result
|
||||
|
||||
self._log("11. 验证验证码...")
|
||||
if not self._validate_verification_code(code):
|
||||
result.error_message = "验证验证码失败"
|
||||
return result
|
||||
|
||||
self._log("12. 创建用户账户...")
|
||||
if not self._create_user_account():
|
||||
result.error_message = "创建用户账户失败"
|
||||
return result
|
||||
|
||||
self._log("13-1. 结束注册,启用登录流程...")
|
||||
if not self._follow_login_redirects(self.oauth_start.auth_url):
|
||||
result.error_message = "跟随重定向链失败"
|
||||
return result
|
||||
|
||||
self._log("13-2. 提交登陆表单")
|
||||
if not self._submit_login_form(did, sen_token):
|
||||
result.error_message = "提交登陆表单失败"
|
||||
return result
|
||||
|
||||
self._log("14. 发送验证码...")
|
||||
if not self._send_verification_code_passwordless():
|
||||
result.error_message = "发送验证码失败"
|
||||
return result
|
||||
|
||||
self._log("15. 等待验证码...")
|
||||
code = self._get_verification_code()
|
||||
if not code:
|
||||
self._log("15. 验证码超时,重新发送...")
|
||||
if self._send_verification_code_passwordless():
|
||||
code = self._get_verification_code()
|
||||
if not code:
|
||||
result.error_message = "获取验证码失败"
|
||||
return result
|
||||
|
||||
self._log("16. 验证验证码...")
|
||||
if not self._validate_verification_code(code):
|
||||
result.error_message = "验证验证码失败"
|
||||
return result
|
||||
|
||||
self._log("17. 获取 Workspace ID...")
|
||||
workspace_id = self._get_workspace_id()
|
||||
if not workspace_id:
|
||||
result.error_message = "获取 Workspace ID 失败"
|
||||
return result
|
||||
|
||||
result.workspace_id = workspace_id
|
||||
|
||||
self._log("18. 选择 Workspace...")
|
||||
continue_url = self._select_workspace(workspace_id)
|
||||
if not continue_url:
|
||||
result.error_message = "选择 Workspace 失败"
|
||||
return result
|
||||
|
||||
self._log("19. 跟随重定向链...")
|
||||
callback_url = self._follow_redirects(continue_url)
|
||||
if not callback_url:
|
||||
result.error_message = "跟随重定向链失败"
|
||||
return result
|
||||
|
||||
self._log("20. 处理 OAuth 回调...")
|
||||
token_info = self._handle_oauth_callback(callback_url)
|
||||
if not token_info:
|
||||
result.error_message = "处理 OAuth 回调失败"
|
||||
return result
|
||||
|
||||
result.account_id = token_info.get("account_id", "")
|
||||
result.access_token = token_info.get("access_token", "")
|
||||
result.refresh_token = token_info.get("refresh_token", "")
|
||||
result.id_token = token_info.get("id_token", "")
|
||||
result.password = self.password or ""
|
||||
result.source = "register"
|
||||
|
||||
session_cookie = self.session.cookies.get("__Secure-next-auth.session-token")
|
||||
if session_cookie:
|
||||
self.session_token = session_cookie
|
||||
result.session_token = session_cookie
|
||||
self._log("获取到 Session Token")
|
||||
|
||||
self._log("=" * 60)
|
||||
self._log("注册成功!")
|
||||
self._log(f"邮箱: {result.email}")
|
||||
self._log(f"Account ID: {result.account_id}")
|
||||
self._log(f"Workspace ID: {result.workspace_id}")
|
||||
self._log("=" * 60)
|
||||
|
||||
result.success = True
|
||||
result.metadata = {
|
||||
"email_service": self.email_service.service_type.value,
|
||||
"proxy_used": self.proxy_url,
|
||||
"registered_at": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
self._log(f"注册过程中发生未预期错误: {e}", "error")
|
||||
result.error_message = str(e)
|
||||
return result
|
||||
finally:
|
||||
self.close()
|
||||
File diff suppressed because it is too large
Load Diff
@@ -2,14 +2,24 @@
|
||||
数据库 CRUD 操作
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Dict, Any, Union
|
||||
from typing import List, Optional, Dict, Any, Union, Iterable, Set
|
||||
from datetime import datetime, timedelta
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
from sqlalchemy import and_, or_, desc, asc, func
|
||||
|
||||
from .models import Account, EmailService, RegistrationTask, Setting, Proxy, CpaService, Sub2ApiService
|
||||
|
||||
|
||||
TOKEN_FIELD_NAMES = ("access_token", "refresh_token", "id_token", "session_token")
|
||||
|
||||
|
||||
def _default_token_sync_status(token_values: Dict[str, Any]) -> str:
|
||||
"""根据当前持久化的 token 内容推导同步状态。"""
|
||||
has_token = any(bool(token_values.get(field)) for field in TOKEN_FIELD_NAMES)
|
||||
return "pending" if has_token else "not_ready"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 账户 CRUD
|
||||
# ============================================================================
|
||||
@@ -27,13 +37,21 @@ def create_account(
|
||||
access_token: Optional[str] = None,
|
||||
refresh_token: Optional[str] = None,
|
||||
id_token: Optional[str] = None,
|
||||
cookies: Optional[str] = None,
|
||||
proxy_used: Optional[str] = None,
|
||||
expires_at: Optional['datetime'] = None,
|
||||
extra_data: Optional[Dict[str, Any]] = None,
|
||||
status: Optional[str] = None,
|
||||
source: Optional[str] = None
|
||||
source: Optional[str] = None,
|
||||
token_sync_status: Optional[str] = None,
|
||||
) -> Account:
|
||||
"""创建新账户"""
|
||||
token_values = {
|
||||
"access_token": access_token,
|
||||
"refresh_token": refresh_token,
|
||||
"id_token": id_token,
|
||||
"session_token": session_token,
|
||||
}
|
||||
db_account = Account(
|
||||
email=email,
|
||||
password=password,
|
||||
@@ -46,12 +64,15 @@ def create_account(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
id_token=id_token,
|
||||
cookies=cookies,
|
||||
proxy_used=proxy_used,
|
||||
expires_at=expires_at,
|
||||
extra_data=extra_data or {},
|
||||
status=status or 'active',
|
||||
source=source or 'register',
|
||||
registered_at=datetime.utcnow()
|
||||
registered_at=datetime.utcnow(),
|
||||
token_sync_status=token_sync_status or _default_token_sync_status(token_values),
|
||||
token_sync_updated_at=datetime.utcnow(),
|
||||
)
|
||||
db.add(db_account)
|
||||
db.commit()
|
||||
@@ -108,6 +129,14 @@ def update_account(
|
||||
if not db_account:
|
||||
return None
|
||||
|
||||
touches_token = any(field in kwargs for field in TOKEN_FIELD_NAMES)
|
||||
if touches_token:
|
||||
persisted_token_values = {
|
||||
field: kwargs.get(field, getattr(db_account, field))
|
||||
for field in TOKEN_FIELD_NAMES
|
||||
}
|
||||
kwargs.setdefault("token_sync_status", _default_token_sync_status(persisted_token_values))
|
||||
kwargs["token_sync_updated_at"] = datetime.utcnow()
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(db_account, key) and value is not None:
|
||||
setattr(db_account, key, value)
|
||||
@@ -326,6 +355,34 @@ def delete_registration_task(db: Session, task_uuid: str) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def fail_incomplete_registration_tasks(db: Session, error_message: str) -> List[str]:
|
||||
"""将服务重启后遗留的未完成任务标记为失败"""
|
||||
tasks = db.query(RegistrationTask).filter(
|
||||
RegistrationTask.status.in_(("pending", "running"))
|
||||
).all()
|
||||
|
||||
if not tasks:
|
||||
return []
|
||||
|
||||
now = datetime.utcnow()
|
||||
cleaned_task_ids: List[str] = []
|
||||
cleanup_log = f"[系统] {error_message}"
|
||||
|
||||
for task in tasks:
|
||||
task.status = "failed"
|
||||
task.error_message = error_message
|
||||
task.completed_at = now
|
||||
if task.logs:
|
||||
if cleanup_log not in task.logs:
|
||||
task.logs = f"{task.logs}\n{cleanup_log}"
|
||||
else:
|
||||
task.logs = cleanup_log
|
||||
cleaned_task_ids.append(task.task_uuid)
|
||||
|
||||
db.commit()
|
||||
return cleaned_task_ids
|
||||
|
||||
|
||||
# 为 API 路由添加别名
|
||||
get_account = get_account_by_id
|
||||
get_registration_task = get_registration_task_by_uuid
|
||||
@@ -437,9 +494,13 @@ def get_proxies(
|
||||
return query.all()
|
||||
|
||||
|
||||
def get_enabled_proxies(db: Session) -> List[Proxy]:
|
||||
def get_enabled_proxies(db: Session, exclude_ids: Optional[Iterable[int]] = None) -> List[Proxy]:
|
||||
"""获取所有启用的代理"""
|
||||
return db.query(Proxy).filter(Proxy.enabled == True).all()
|
||||
query = db.query(Proxy).filter(Proxy.enabled == True)
|
||||
excluded: Set[int] = {int(proxy_id) for proxy_id in (exclude_ids or [])}
|
||||
if excluded:
|
||||
query = query.filter(~Proxy.id.in_(excluded))
|
||||
return query.all()
|
||||
|
||||
|
||||
def update_proxy(
|
||||
@@ -483,14 +544,18 @@ def update_proxy_last_used(db: Session, proxy_id: int) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def get_random_proxy(db: Session) -> Optional[Proxy]:
|
||||
def get_random_proxy(db: Session, exclude_ids: Optional[Iterable[int]] = None) -> Optional[Proxy]:
|
||||
"""随机获取一个启用的代理,优先返回 is_default=True 的代理"""
|
||||
import random
|
||||
excluded: Set[int] = {int(proxy_id) for proxy_id in (exclude_ids or [])}
|
||||
# 优先返回默认代理
|
||||
default_proxy = db.query(Proxy).filter(Proxy.enabled == True, Proxy.is_default == True).first()
|
||||
default_query = db.query(Proxy).filter(Proxy.enabled == True, Proxy.is_default == True)
|
||||
if excluded:
|
||||
default_query = default_query.filter(~Proxy.id.in_(excluded))
|
||||
default_proxy = default_query.first()
|
||||
if default_proxy:
|
||||
return default_proxy
|
||||
proxies = get_enabled_proxies(db)
|
||||
proxies = get_enabled_proxies(db, exclude_ids=excluded)
|
||||
if not proxies:
|
||||
return None
|
||||
return random.choice(proxies)
|
||||
@@ -713,4 +778,38 @@ def delete_tm_service(db: Session, service_id: int) -> bool:
|
||||
return False
|
||||
db.delete(svc)
|
||||
db.commit()
|
||||
return True
|
||||
return True
|
||||
|
||||
|
||||
def update_outlook_refresh_token(db: Session, service_id: int, email: str, new_refresh_token: str):
|
||||
"""更新 EmailService.config 中指定邮箱的 refresh_token"""
|
||||
service = db.query(EmailService).filter(EmailService.id == service_id).first()
|
||||
if not service or not isinstance(service.config, dict):
|
||||
return
|
||||
|
||||
normalized_email = (email or "").strip().lower()
|
||||
if not normalized_email or not isinstance(new_refresh_token, str) or not new_refresh_token:
|
||||
return
|
||||
|
||||
config = dict(service.config)
|
||||
updated = False
|
||||
|
||||
# 单账户格式
|
||||
if str(config.get("email", "")).lower() == normalized_email:
|
||||
config["refresh_token"] = new_refresh_token
|
||||
updated = True
|
||||
|
||||
# 多账户列表格式
|
||||
for acc in config.get("accounts", []):
|
||||
if not isinstance(acc, dict):
|
||||
continue
|
||||
if str(acc.get("email", "")).lower() == normalized_email:
|
||||
acc["refresh_token"] = new_refresh_token
|
||||
updated = True
|
||||
|
||||
if not updated:
|
||||
return
|
||||
|
||||
service.config = config
|
||||
flag_modified(service, "config")
|
||||
db.commit()
|
||||
|
||||
@@ -39,6 +39,8 @@ class Account(Base):
|
||||
refresh_token = Column(Text)
|
||||
id_token = Column(Text)
|
||||
session_token = Column(Text) # 会话令牌(优先刷新方式)
|
||||
token_sync_status = Column(String(20), default='not_ready') # 'not_ready', 'pending', 'synced'
|
||||
token_sync_updated_at = Column(DateTime, default=datetime.utcnow)
|
||||
client_id = Column(String(255)) # OAuth Client ID
|
||||
account_id = Column(String(255))
|
||||
workspace_id = Column(String(255))
|
||||
@@ -80,7 +82,9 @@ class Account(Base):
|
||||
'subscription_type': self.subscription_type,
|
||||
'subscription_at': self.subscription_at.isoformat() if self.subscription_at else None,
|
||||
'created_at': self.created_at.isoformat() if self.created_at else None,
|
||||
'updated_at': self.updated_at.isoformat() if self.updated_at else None
|
||||
'updated_at': self.updated_at.isoformat() if self.updated_at else None,
|
||||
'token_sync_status': self.token_sync_status,
|
||||
'token_sync_updated_at': self.token_sync_updated_at.isoformat() if self.token_sync_updated_at else None,
|
||||
}
|
||||
|
||||
|
||||
@@ -227,4 +231,4 @@ class Proxy(Base):
|
||||
if self.username and self.password:
|
||||
auth = f"{self.username}:{self.password}@"
|
||||
|
||||
return f"{scheme}://{auth}{self.host}:{self.port}"
|
||||
return f"{scheme}://{auth}{self.host}:{self.port}"
|
||||
|
||||
@@ -45,7 +45,7 @@ class DatabaseSessionManager:
|
||||
self.database_url = _build_sqlalchemy_url(database_url)
|
||||
self.engine = create_engine(
|
||||
self.database_url,
|
||||
connect_args={"check_same_thread": False} if self.database_url.startswith("sqlite") else {},
|
||||
connect_args={"check_same_thread": False, "timeout": 30} if self.database_url.startswith("sqlite") else {},
|
||||
echo=False, # 设置为 True 可以查看所有 SQL 语句
|
||||
pool_pre_ping=True # 连接池预检查
|
||||
)
|
||||
@@ -110,6 +110,8 @@ class DatabaseSessionManager:
|
||||
("accounts", "subscription_type", "VARCHAR(20)"),
|
||||
("accounts", "subscription_at", "DATETIME"),
|
||||
("accounts", "cookies", "TEXT"),
|
||||
("accounts", "token_sync_status", "VARCHAR(20) DEFAULT 'not_ready'"),
|
||||
("accounts", "token_sync_updated_at", "DATETIME"),
|
||||
("proxies", "is_default", "BOOLEAN DEFAULT 0"),
|
||||
("cpa_services", "include_proxy_url", "BOOLEAN DEFAULT 0"),
|
||||
]
|
||||
|
||||
@@ -6,6 +6,8 @@
|
||||
import abc
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Dict, Any, List
|
||||
from enum import Enum
|
||||
|
||||
@@ -14,12 +16,109 @@ from ..config.constants import EmailServiceType, OTP_CODE_PATTERN, OTP_CODE_SEMA
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
EMAIL_PROVIDER_BACKOFF_BASE_SECONDS = 30
|
||||
EMAIL_PROVIDER_BACKOFF_MAX_SECONDS = 3600
|
||||
OTP_TIMEOUT_ERROR_PREFIX = "OTP_TIMEOUT"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class EmailProviderBackoffState:
|
||||
"""邮箱供应商退避状态"""
|
||||
|
||||
failures: int = 0
|
||||
delay_seconds: int = 0
|
||||
opened_until: float = 0.0
|
||||
retry_after: Optional[int] = None
|
||||
last_error: Optional[str] = None
|
||||
|
||||
def is_open(self, now: Optional[float] = None) -> bool:
|
||||
now_ts = now if now is not None else time.time()
|
||||
return self.opened_until > now_ts
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"failures": self.failures,
|
||||
"delay_seconds": self.delay_seconds,
|
||||
"opened_until": self.opened_until,
|
||||
"retry_after": self.retry_after,
|
||||
"last_error": self.last_error,
|
||||
}
|
||||
|
||||
|
||||
def calculate_adaptive_backoff_delay(
|
||||
failures: int,
|
||||
base_delay: int = EMAIL_PROVIDER_BACKOFF_BASE_SECONDS,
|
||||
max_delay: int = EMAIL_PROVIDER_BACKOFF_MAX_SECONDS,
|
||||
is_timeout: bool = False,
|
||||
) -> int:
|
||||
"""根据连续失败次数计算指数退避时长"""
|
||||
normalized_failures = max(0, failures)
|
||||
if is_timeout and normalized_failures >= 3:
|
||||
return max_delay
|
||||
exponent = max(0, normalized_failures - 1)
|
||||
return min(base_delay * (2 ** exponent), max_delay)
|
||||
|
||||
|
||||
def is_otp_timeout_error(error: object) -> bool:
|
||||
"""识别 OTP 超时类错误码。"""
|
||||
if error is None:
|
||||
return False
|
||||
if isinstance(error, OTPTimeoutEmailServiceError):
|
||||
return True
|
||||
error_code = getattr(error, "error_code", "")
|
||||
if isinstance(error_code, str) and error_code.startswith(OTP_TIMEOUT_ERROR_PREFIX):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def apply_adaptive_backoff(
|
||||
current_state: Optional[EmailProviderBackoffState],
|
||||
error: "EmailServiceError",
|
||||
now: Optional[float] = None,
|
||||
) -> EmailProviderBackoffState:
|
||||
"""在限流场景下推进邮箱供应商退避状态"""
|
||||
state = current_state or EmailProviderBackoffState()
|
||||
now_ts = now if now is not None else time.time()
|
||||
next_failures = state.failures + 1
|
||||
delay_seconds = calculate_adaptive_backoff_delay(
|
||||
next_failures,
|
||||
is_timeout=is_otp_timeout_error(error),
|
||||
)
|
||||
return EmailProviderBackoffState(
|
||||
failures=next_failures,
|
||||
delay_seconds=delay_seconds,
|
||||
opened_until=now_ts + delay_seconds,
|
||||
retry_after=getattr(error, "retry_after", None),
|
||||
last_error=str(error),
|
||||
)
|
||||
|
||||
|
||||
def reset_adaptive_backoff() -> EmailProviderBackoffState:
|
||||
"""重置邮箱供应商退避状态"""
|
||||
return EmailProviderBackoffState()
|
||||
|
||||
|
||||
class EmailServiceError(Exception):
|
||||
"""邮箱服务异常"""
|
||||
pass
|
||||
|
||||
|
||||
class RateLimitedEmailServiceError(EmailServiceError):
|
||||
"""邮箱服务被限流"""
|
||||
|
||||
def __init__(self, message: str, retry_after: Optional[int] = None):
|
||||
super().__init__(message)
|
||||
self.retry_after = retry_after
|
||||
|
||||
|
||||
class OTPTimeoutEmailServiceError(EmailServiceError):
|
||||
"""OTP 验证码等待超时。"""
|
||||
|
||||
def __init__(self, message: str, error_code: str = OTP_TIMEOUT_ERROR_PREFIX):
|
||||
super().__init__(message)
|
||||
self.error_code = error_code
|
||||
|
||||
|
||||
class EmailServiceStatus(Enum):
|
||||
"""邮箱服务状态"""
|
||||
HEALTHY = "healthy"
|
||||
@@ -46,6 +145,7 @@ class BaseEmailService(abc.ABC):
|
||||
self.name = name or f"{service_type.value}_service"
|
||||
self._status = EmailServiceStatus.HEALTHY
|
||||
self._last_error = None
|
||||
self._provider_backoff = reset_adaptive_backoff()
|
||||
|
||||
_EMAIL_ADDRESS_PATTERN = re.compile(r"[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}")
|
||||
|
||||
@@ -59,6 +159,15 @@ class BaseEmailService(abc.ABC):
|
||||
"""获取最后一次错误信息"""
|
||||
return self._last_error
|
||||
|
||||
@property
|
||||
def provider_backoff_state(self) -> EmailProviderBackoffState:
|
||||
"""获取当前邮箱供应商退避状态"""
|
||||
return self._provider_backoff
|
||||
|
||||
def apply_provider_backoff_state(self, state: Optional[EmailProviderBackoffState]) -> None:
|
||||
"""注入外部持久化的邮箱供应商退避状态"""
|
||||
self._provider_backoff = state or reset_adaptive_backoff()
|
||||
|
||||
@abc.abstractmethod
|
||||
def create_email(self, config: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
@@ -95,7 +204,7 @@ class BaseEmailService(abc.ABC):
|
||||
email_id: 邮箱服务中的 ID(如果需要)
|
||||
timeout: 超时时间(秒)
|
||||
pattern: 验证码正则表达式
|
||||
otp_sent_at: OTP 发送时间戳,用于过滤旧邮件
|
||||
otp_sent_at: OTP 发送时间戳,只允许使用严格晚于该锚点的邮件
|
||||
|
||||
Returns:
|
||||
验证码字符串,如果超时或未找到返回 None
|
||||
@@ -309,8 +418,16 @@ class BaseEmailService(abc.ABC):
|
||||
if success:
|
||||
self._status = EmailServiceStatus.HEALTHY
|
||||
self._last_error = None
|
||||
self._provider_backoff = reset_adaptive_backoff()
|
||||
else:
|
||||
self._status = EmailServiceStatus.DEGRADED
|
||||
if isinstance(error, RateLimitedEmailServiceError) or is_otp_timeout_error(error):
|
||||
self._status = EmailServiceStatus.UNAVAILABLE
|
||||
self._provider_backoff = apply_adaptive_backoff(
|
||||
self._provider_backoff,
|
||||
error,
|
||||
)
|
||||
else:
|
||||
self._status = EmailServiceStatus.DEGRADED
|
||||
if error:
|
||||
self._last_error = str(error)
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ from datetime import datetime, timezone
|
||||
from html import unescape
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from .base import BaseEmailService, EmailServiceError, EmailServiceType
|
||||
from .base import BaseEmailService, EmailServiceError, EmailServiceType, RateLimitedEmailServiceError
|
||||
from ..config.constants import OTP_CODE_PATTERN
|
||||
from ..core.http_client import HTTPClient, RequestConfig
|
||||
|
||||
@@ -102,7 +102,19 @@ class DuckMailService(BaseEmailService):
|
||||
error_message = f"{error_message} - {error_payload}"
|
||||
except Exception:
|
||||
error_message = f"{error_message} - {response.text[:200]}"
|
||||
raise EmailServiceError(error_message)
|
||||
retry_after = None
|
||||
if response.status_code == 429:
|
||||
retry_after_header = response.headers.get("Retry-After")
|
||||
if retry_after_header:
|
||||
try:
|
||||
retry_after = max(1, int(retry_after_header))
|
||||
except ValueError:
|
||||
retry_after = None
|
||||
error = RateLimitedEmailServiceError(error_message, retry_after=retry_after)
|
||||
else:
|
||||
error = EmailServiceError(error_message)
|
||||
self.update_status(False, error)
|
||||
raise error
|
||||
|
||||
try:
|
||||
return response.json()
|
||||
|
||||
@@ -10,7 +10,7 @@ import random
|
||||
import string
|
||||
from typing import Optional, Dict, Any, List
|
||||
|
||||
from .base import BaseEmailService, EmailServiceError, EmailServiceType
|
||||
from .base import BaseEmailService, EmailServiceError, EmailServiceType, RateLimitedEmailServiceError
|
||||
from ..core.http_client import HTTPClient, RequestConfig
|
||||
from ..config.constants import OTP_CODE_PATTERN
|
||||
|
||||
@@ -96,8 +96,19 @@ class FreemailService(BaseEmailService):
|
||||
error_msg = f"{error_msg} - {error_data}"
|
||||
except Exception:
|
||||
error_msg = f"{error_msg} - {response.text[:200]}"
|
||||
self.update_status(False, EmailServiceError(error_msg))
|
||||
raise EmailServiceError(error_msg)
|
||||
retry_after = None
|
||||
if response.status_code == 429:
|
||||
retry_after_header = response.headers.get("Retry-After")
|
||||
if retry_after_header:
|
||||
try:
|
||||
retry_after = max(1, int(retry_after_header))
|
||||
except ValueError:
|
||||
retry_after = None
|
||||
error = RateLimitedEmailServiceError(error_msg, retry_after=retry_after)
|
||||
else:
|
||||
error = EmailServiceError(error_msg)
|
||||
self.update_status(False, error)
|
||||
raise error
|
||||
|
||||
try:
|
||||
return response.json()
|
||||
|
||||
@@ -10,7 +10,7 @@ import logging
|
||||
from typing import Optional, Dict, Any, List
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from .base import BaseEmailService, EmailServiceError, EmailServiceType
|
||||
from .base import BaseEmailService, EmailServiceError, EmailServiceType, RateLimitedEmailServiceError
|
||||
from ..core.http_client import HTTPClient, RequestConfig
|
||||
from ..config.constants import OTP_CODE_PATTERN
|
||||
|
||||
@@ -148,8 +148,20 @@ class MeoMailEmailService(BaseEmailService):
|
||||
except:
|
||||
error_msg = f"{error_msg} - {response.text[:200]}"
|
||||
|
||||
self.update_status(False, EmailServiceError(error_msg))
|
||||
raise EmailServiceError(error_msg)
|
||||
retry_after = None
|
||||
if response.status_code == 429:
|
||||
retry_after_header = response.headers.get("Retry-After")
|
||||
if retry_after_header:
|
||||
try:
|
||||
retry_after = max(1, int(retry_after_header))
|
||||
except ValueError:
|
||||
retry_after = None
|
||||
error = RateLimitedEmailServiceError(error_msg, retry_after=retry_after)
|
||||
else:
|
||||
error = EmailServiceError(error_msg)
|
||||
|
||||
self.update_status(False, error)
|
||||
raise error
|
||||
|
||||
# 解析响应
|
||||
try:
|
||||
@@ -553,4 +565,4 @@ class MeoMailEmailService(BaseEmailService):
|
||||
"system_config": config,
|
||||
"cached_emails_count": len(self._emails_cache),
|
||||
"status": self.status.value,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@ from email.policy import default as email_policy
|
||||
from html import unescape
|
||||
from typing import Optional, Dict, Any, List
|
||||
|
||||
from .base import BaseEmailService, EmailServiceError, EmailServiceType
|
||||
from .base import BaseEmailService, EmailServiceError, EmailServiceType, RateLimitedEmailServiceError
|
||||
from ..core.http_client import HTTPClient, RequestConfig
|
||||
from ..config.constants import OTP_CODE_PATTERN
|
||||
|
||||
@@ -200,8 +200,19 @@ class TempMailService(BaseEmailService):
|
||||
error_msg = f"{error_msg} - {error_data}"
|
||||
except Exception:
|
||||
error_msg = f"{error_msg} - {response.text[:200]}"
|
||||
self.update_status(False, EmailServiceError(error_msg))
|
||||
raise EmailServiceError(error_msg)
|
||||
retry_after = None
|
||||
if response.status_code == 429:
|
||||
retry_after_header = response.headers.get("Retry-After")
|
||||
if retry_after_header:
|
||||
try:
|
||||
retry_after = max(1, int(retry_after_header))
|
||||
except ValueError:
|
||||
retry_after = None
|
||||
error = RateLimitedEmailServiceError(error_msg, retry_after=retry_after)
|
||||
else:
|
||||
error = EmailServiceError(error_msg)
|
||||
self.update_status(False, error)
|
||||
raise error
|
||||
|
||||
try:
|
||||
return response.json()
|
||||
|
||||
@@ -6,9 +6,7 @@ import re
|
||||
import time
|
||||
import logging
|
||||
from typing import Optional, Dict, Any, List
|
||||
import json
|
||||
|
||||
from curl_cffi import requests as cffi_requests
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from .base import BaseEmailService, EmailServiceError, EmailServiceType
|
||||
from ..core.http_client import HTTPClient, RequestConfig
|
||||
@@ -17,6 +15,8 @@ from ..config.constants import OTP_CODE_PATTERN
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
OTP_SENT_AT_TOLERANCE_SECONDS = 2
|
||||
|
||||
|
||||
class TempmailService(BaseEmailService):
|
||||
"""
|
||||
@@ -58,10 +58,65 @@ class TempmailService(BaseEmailService):
|
||||
config=http_config
|
||||
)
|
||||
|
||||
# 状态变量
|
||||
# 状态变量(内存缓存,重启后从 DB 按需查询)
|
||||
self._email_cache: Dict[str, Dict[str, Any]] = {}
|
||||
self._last_check_time: float = 0
|
||||
|
||||
def _parse_message_time(self, value: Any) -> Optional[float]:
|
||||
"""解析 Tempmail 邮件时间,兼容 Unix 时间戳与 ISO 8601。"""
|
||||
if value is None or value == "":
|
||||
return None
|
||||
|
||||
if isinstance(value, (int, float)):
|
||||
timestamp = float(value)
|
||||
else:
|
||||
text = str(value).strip()
|
||||
if not text:
|
||||
return None
|
||||
|
||||
try:
|
||||
timestamp = float(text)
|
||||
except ValueError:
|
||||
try:
|
||||
normalized = text.replace("Z", "+00:00")
|
||||
timestamp = datetime.fromisoformat(normalized).astimezone(timezone.utc).timestamp()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
while timestamp > 1e11:
|
||||
timestamp /= 1000.0
|
||||
return timestamp if timestamp > 0 else None
|
||||
|
||||
def _get_received_timestamp(self, message: Dict[str, Any]) -> Optional[float]:
|
||||
"""返回 Tempmail 邮件的接收时间戳。"""
|
||||
for field_name in ("received_at", "date", "created_at", "createdAt", "timestamp"):
|
||||
timestamp = self._parse_message_time(message.get(field_name))
|
||||
if timestamp is not None:
|
||||
return timestamp
|
||||
return None
|
||||
|
||||
def _save_token_to_db(self, email: str, token: str) -> None:
|
||||
"""将邮箱 token 持久化到 Setting 表,key=tempmail_token:{email}"""
|
||||
try:
|
||||
from ..database.session import get_db
|
||||
from ..database.crud import set_setting
|
||||
with get_db() as db:
|
||||
set_setting(db, f"tempmail_token:{email}", token, category="tempmail")
|
||||
except Exception as e:
|
||||
logger.warning(f"保存 Tempmail token 到数据库失败: {e}")
|
||||
|
||||
def _load_token_from_db(self, email: str) -> Optional[str]:
|
||||
"""从 Setting 表读取邮箱 token"""
|
||||
try:
|
||||
from ..database.session import get_db
|
||||
from ..database.crud import get_setting
|
||||
with get_db() as db:
|
||||
setting = get_setting(db, f"tempmail_token:{email}")
|
||||
return setting.value if setting else None
|
||||
except Exception as e:
|
||||
logger.warning(f"从数据库读取 Tempmail token 失败: {e}")
|
||||
return None
|
||||
|
||||
def create_email(self, config: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
创建新的临时邮箱
|
||||
@@ -107,6 +162,7 @@ class TempmailService(BaseEmailService):
|
||||
"created_at": time.time(),
|
||||
}
|
||||
self._email_cache[email] = email_info
|
||||
self._save_token_to_db(email, token)
|
||||
|
||||
logger.info(f"成功创建 Tempmail.lol 邮箱: {email}")
|
||||
self.update_status(True)
|
||||
@@ -134,19 +190,21 @@ class TempmailService(BaseEmailService):
|
||||
email_id: 邮箱 token(如果不提供,从缓存中查找)
|
||||
timeout: 超时时间(秒)
|
||||
pattern: 验证码正则表达式
|
||||
otp_sent_at: OTP 发送时间戳(Tempmail 服务暂不使用此参数)
|
||||
otp_sent_at: OTP 发送时间戳,只允许使用严格晚于该锚点减去容差后的邮件
|
||||
|
||||
Returns:
|
||||
验证码字符串,如果超时或未找到返回 None
|
||||
"""
|
||||
token = email_id
|
||||
if not token:
|
||||
# 从缓存中查找 token
|
||||
# 先从内存缓存查找,再从数据库查找
|
||||
if email in self._email_cache:
|
||||
token = self._email_cache[email].get("token")
|
||||
else:
|
||||
logger.warning(f"未找到邮箱 {email} 的 token,无法获取验证码")
|
||||
return None
|
||||
if not token:
|
||||
token = self._load_token_from_db(email)
|
||||
if not token:
|
||||
logger.warning(f"未找到邮箱 {email} 的 token,无法获取验证码")
|
||||
return None
|
||||
|
||||
if not token:
|
||||
logger.warning(f"邮箱 {email} 没有 token,无法获取验证码")
|
||||
@@ -187,11 +245,21 @@ class TempmailService(BaseEmailService):
|
||||
if not isinstance(msg, dict):
|
||||
continue
|
||||
|
||||
# 使用 date 作为唯一标识
|
||||
msg_date = msg.get("date", 0)
|
||||
if not msg_date or msg_date in seen_ids:
|
||||
msg_timestamp = self._get_received_timestamp(msg)
|
||||
if otp_sent_at is not None:
|
||||
min_allowed_timestamp = otp_sent_at - OTP_SENT_AT_TOLERANCE_SECONDS
|
||||
if msg_timestamp is None or msg_timestamp <= min_allowed_timestamp:
|
||||
continue
|
||||
|
||||
message_id = str(
|
||||
msg.get("id")
|
||||
or msg.get("date")
|
||||
or msg.get("createdAt")
|
||||
or f"{msg.get('from', '')}:{msg.get('subject', '')}:{msg_timestamp}"
|
||||
).strip()
|
||||
if not message_id or message_id in seen_ids:
|
||||
continue
|
||||
seen_ids.add(msg_date)
|
||||
seen_ids.add(message_id)
|
||||
|
||||
sender = str(msg.get("from", "")).lower()
|
||||
subject = str(msg.get("subject", ""))
|
||||
@@ -397,4 +465,4 @@ class TempmailService(BaseEmailService):
|
||||
"email": email,
|
||||
"message": "等待验证码超时"
|
||||
})
|
||||
return None
|
||||
return None
|
||||
|
||||
@@ -15,9 +15,11 @@ from fastapi import FastAPI, Request, Form
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import HTMLResponse, RedirectResponse
|
||||
from fastapi.responses import FileResponse, HTMLResponse, RedirectResponse
|
||||
|
||||
from ..config.settings import get_settings
|
||||
from ..database import crud
|
||||
from ..database.session import get_db
|
||||
from .routes import api_router
|
||||
from .routes.websocket import router as ws_router
|
||||
from .task_manager import task_manager
|
||||
@@ -31,6 +33,11 @@ if getattr(sys, 'frozen', False):
|
||||
else:
|
||||
_RESOURCE_ROOT = Path(__file__).parent.parent.parent
|
||||
|
||||
if __name__ == "__main__":
|
||||
from webui import setup_application as _setup_application
|
||||
|
||||
_setup_application()
|
||||
|
||||
# 静态文件和模板目录
|
||||
STATIC_DIR = _RESOURCE_ROOT / "static"
|
||||
TEMPLATES_DIR = _RESOURCE_ROOT / "templates"
|
||||
@@ -108,8 +115,9 @@ def create_app() -> FastAPI:
|
||||
async def login_page(request: Request, next: Optional[str] = "/"):
|
||||
"""登录页面"""
|
||||
return templates.TemplateResponse(
|
||||
"login.html",
|
||||
{"request": request, "error": "", "next": next or "/"}
|
||||
request=request,
|
||||
name="login.html",
|
||||
context={"request": request, "error": "", "next": next or "/"}
|
||||
)
|
||||
|
||||
@app.post("/login")
|
||||
@@ -118,8 +126,9 @@ def create_app() -> FastAPI:
|
||||
expected = get_settings().webui_access_password.get_secret_value()
|
||||
if not secrets.compare_digest(password, expected):
|
||||
return templates.TemplateResponse(
|
||||
"login.html",
|
||||
{"request": request, "error": "密码错误", "next": next or "/"},
|
||||
request=request,
|
||||
name="login.html",
|
||||
context={"request": request, "error": "密码错误", "next": next or "/"},
|
||||
status_code=401
|
||||
)
|
||||
|
||||
@@ -134,38 +143,48 @@ def create_app() -> FastAPI:
|
||||
response.delete_cookie("webui_auth")
|
||||
return response
|
||||
|
||||
@app.get("/favicon.ico", include_in_schema=False)
|
||||
async def favicon_ico():
|
||||
"""兼容浏览器对根路径 favicon 的默认请求。"""
|
||||
return FileResponse(STATIC_DIR / "favicon.svg", media_type="image/svg+xml")
|
||||
|
||||
@app.get("/favicon.svg", include_in_schema=False)
|
||||
async def favicon_svg():
|
||||
"""提供统一的站点图标资源。"""
|
||||
return FileResponse(STATIC_DIR / "favicon.svg", media_type="image/svg+xml")
|
||||
|
||||
@app.get("/", response_class=HTMLResponse)
|
||||
async def index(request: Request):
|
||||
"""首页 - 注册页面"""
|
||||
if not _is_authenticated(request):
|
||||
return _redirect_to_login(request)
|
||||
return templates.TemplateResponse("index.html", {"request": request})
|
||||
return templates.TemplateResponse(request=request, name="index.html", context={"request": request})
|
||||
|
||||
@app.get("/accounts", response_class=HTMLResponse)
|
||||
async def accounts_page(request: Request):
|
||||
"""账号管理页面"""
|
||||
if not _is_authenticated(request):
|
||||
return _redirect_to_login(request)
|
||||
return templates.TemplateResponse("accounts.html", {"request": request})
|
||||
return templates.TemplateResponse(request=request, name="accounts.html", context={"request": request})
|
||||
|
||||
@app.get("/email-services", response_class=HTMLResponse)
|
||||
async def email_services_page(request: Request):
|
||||
"""邮箱服务管理页面"""
|
||||
if not _is_authenticated(request):
|
||||
return _redirect_to_login(request)
|
||||
return templates.TemplateResponse("email_services.html", {"request": request})
|
||||
return templates.TemplateResponse(request=request, name="email_services.html", context={"request": request})
|
||||
|
||||
@app.get("/settings", response_class=HTMLResponse)
|
||||
async def settings_page(request: Request):
|
||||
"""设置页面"""
|
||||
if not _is_authenticated(request):
|
||||
return _redirect_to_login(request)
|
||||
return templates.TemplateResponse("settings.html", {"request": request})
|
||||
return templates.TemplateResponse(request=request, name="settings.html", context={"request": request})
|
||||
|
||||
@app.get("/payment", response_class=HTMLResponse)
|
||||
async def payment_page(request: Request):
|
||||
"""支付页面"""
|
||||
return templates.TemplateResponse("payment.html", {"request": request})
|
||||
return templates.TemplateResponse(request=request, name="payment.html", context={"request": request})
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
@@ -183,6 +202,12 @@ def create_app() -> FastAPI:
|
||||
loop = asyncio.get_event_loop()
|
||||
task_manager.set_loop(loop)
|
||||
|
||||
stale_error = "服务启动时检测到未完成的历史任务,已标记失败,请重新发起。"
|
||||
with get_db() as db:
|
||||
stale_tasks = crud.fail_incomplete_registration_tasks(db, stale_error)
|
||||
if stale_tasks:
|
||||
logger.warning("已收敛 %s 个僵尸任务: %s", len(stale_tasks), ", ".join(task[:8] for task in stale_tasks))
|
||||
|
||||
logger.info("=" * 50)
|
||||
logger.info(f"{settings.app_name} v{settings.app_version} 启动中...")
|
||||
logger.info(f"调试模式: {settings.debug}")
|
||||
@@ -199,3 +224,23 @@ def create_app() -> FastAPI:
|
||||
|
||||
# 创建全局应用实例
|
||||
app = create_app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
settings = get_settings()
|
||||
logger.info(
|
||||
"通过模块入口启动 Web UI: http://%s:%s",
|
||||
settings.webui_host,
|
||||
settings.webui_port,
|
||||
)
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=settings.webui_host,
|
||||
port=settings.webui_port,
|
||||
reload=False,
|
||||
log_level="info" if settings.debug else "warning",
|
||||
access_log=settings.debug,
|
||||
ws="websockets",
|
||||
)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -7,12 +7,33 @@ import asyncio
|
||||
import logging
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
||||
|
||||
from ...database import crud
|
||||
from ...database.session import get_db
|
||||
from ..task_manager import task_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _restore_task_snapshot(task_uuid: str) -> tuple[dict, list[str]]:
|
||||
"""从数据库恢复任务状态和历史日志,解决服务重启后的监控空白。"""
|
||||
with get_db() as db:
|
||||
task = crud.get_registration_task(db, task_uuid)
|
||||
|
||||
if not task:
|
||||
return {}, []
|
||||
|
||||
status = {"status": task.status}
|
||||
if task.result and task.result.get("email"):
|
||||
status["email"] = task.result["email"]
|
||||
if task.error_message:
|
||||
status["error"] = task.error_message
|
||||
|
||||
logs = task.logs.splitlines() if task.logs else []
|
||||
task_manager.sync_task_state(task_uuid, status=status, logs=logs)
|
||||
return status, logs
|
||||
|
||||
|
||||
@router.websocket("/ws/task/{task_uuid}")
|
||||
async def task_websocket(websocket: WebSocket, task_uuid: str):
|
||||
"""
|
||||
@@ -25,14 +46,15 @@ async def task_websocket(websocket: WebSocket, task_uuid: str):
|
||||
- 客户端发送: {"type": "cancel"} - 取消任务
|
||||
"""
|
||||
await websocket.accept()
|
||||
restored_status, restored_logs = _restore_task_snapshot(task_uuid)
|
||||
|
||||
# 注册连接(会记录当前日志数量,避免重复发送历史日志)
|
||||
task_manager.register_websocket(task_uuid, websocket)
|
||||
# 注册连接,并取得注册时刻的历史日志快照,避免与后续实时推送串扰
|
||||
history_logs = task_manager.register_websocket(task_uuid, websocket)
|
||||
logger.info(f"WebSocket 连接已建立: {task_uuid}")
|
||||
|
||||
try:
|
||||
# 发送当前状态
|
||||
status = task_manager.get_status(task_uuid)
|
||||
status = task_manager.get_status(task_uuid) or restored_status
|
||||
if status:
|
||||
await websocket.send_json({
|
||||
"type": "status",
|
||||
@@ -40,9 +62,8 @@ async def task_websocket(websocket: WebSocket, task_uuid: str):
|
||||
**status
|
||||
})
|
||||
|
||||
# 发送历史日志(只发送注册时已存在的日志,避免与实时推送重复)
|
||||
history_logs = task_manager.get_unsent_logs(task_uuid, websocket)
|
||||
for log in history_logs:
|
||||
# 发送历史日志。服务重启后 _restore_task_snapshot 会先把数据库快照回填到内存。
|
||||
for log in history_logs or restored_logs:
|
||||
await websocket.send_json({
|
||||
"type": "log",
|
||||
"task_uuid": task_uuid,
|
||||
@@ -107,8 +128,8 @@ async def batch_websocket(websocket: WebSocket, batch_id: str):
|
||||
"""
|
||||
await websocket.accept()
|
||||
|
||||
# 注册连接(会记录当前日志数量,避免重复发送历史日志)
|
||||
task_manager.register_batch_websocket(batch_id, websocket)
|
||||
# 注册连接,并取得注册时刻的历史日志快照,避免漏发/重复发送
|
||||
history_logs = task_manager.register_batch_websocket(batch_id, websocket)
|
||||
logger.info(f"批量任务 WebSocket 连接已建立: {batch_id}")
|
||||
|
||||
try:
|
||||
@@ -121,8 +142,6 @@ async def batch_websocket(websocket: WebSocket, batch_id: str):
|
||||
**status
|
||||
})
|
||||
|
||||
# 发送历史日志(只发送注册时已存在的日志,避免与实时推送重复)
|
||||
history_logs = task_manager.get_unsent_batch_logs(batch_id, websocket)
|
||||
for log in history_logs:
|
||||
await websocket.send_json({
|
||||
"type": "log",
|
||||
|
||||
@@ -144,20 +144,22 @@ class TaskManager:
|
||||
except Exception as e:
|
||||
logger.warning(f"WebSocket 发送状态失败: {e}")
|
||||
|
||||
def register_websocket(self, task_uuid: str, websocket):
|
||||
"""注册 WebSocket 连接"""
|
||||
def register_websocket(self, task_uuid: str, websocket) -> List[str]:
|
||||
"""注册 WebSocket 连接,并返回注册时刻的历史日志快照"""
|
||||
history_logs: List[str] = []
|
||||
with _ws_lock:
|
||||
if task_uuid not in _ws_connections:
|
||||
_ws_connections[task_uuid] = []
|
||||
# 避免重复注册同一个连接
|
||||
if websocket not in _ws_connections[task_uuid]:
|
||||
_ws_connections[task_uuid].append(websocket)
|
||||
# 记录已发送的日志数量,用于发送历史日志时避免重复
|
||||
with _get_log_lock(task_uuid):
|
||||
_ws_sent_index[task_uuid][id(websocket)] = len(_log_queues.get(task_uuid, []))
|
||||
history_logs = _log_queues.get(task_uuid, []).copy()
|
||||
_ws_sent_index[task_uuid][id(websocket)] = len(history_logs)
|
||||
_ws_connections[task_uuid].append(websocket)
|
||||
logger.info(f"WebSocket 连接已注册: {task_uuid}")
|
||||
else:
|
||||
logger.warning(f"WebSocket 连接已存在,跳过重复注册: {task_uuid}")
|
||||
return history_logs
|
||||
|
||||
def get_unsent_logs(self, task_uuid: str, websocket) -> List[str]:
|
||||
"""获取未发送给该 WebSocket 的日志"""
|
||||
@@ -190,6 +192,24 @@ class TaskManager:
|
||||
with _get_log_lock(task_uuid):
|
||||
return _log_queues.get(task_uuid, []).copy()
|
||||
|
||||
def sync_task_state(
|
||||
self,
|
||||
task_uuid: str,
|
||||
status: Optional[dict] = None,
|
||||
logs: Optional[List[str]] = None
|
||||
):
|
||||
"""将数据库中的任务快照回填到内存态,便于重连恢复。"""
|
||||
if status:
|
||||
current_status = _task_status.get(task_uuid, {}).copy()
|
||||
current_status.update(status)
|
||||
_task_status[task_uuid] = current_status
|
||||
|
||||
if logs is not None:
|
||||
with _get_log_lock(task_uuid):
|
||||
cached_logs = _log_queues.get(task_uuid, [])
|
||||
if len(logs) >= len(cached_logs):
|
||||
_log_queues[task_uuid] = list(logs)
|
||||
|
||||
def update_status(self, task_uuid: str, status: str, **kwargs):
|
||||
"""更新任务状态"""
|
||||
if task_uuid not in _task_status:
|
||||
@@ -198,6 +218,15 @@ class TaskManager:
|
||||
_task_status[task_uuid]["status"] = status
|
||||
_task_status[task_uuid].update(kwargs)
|
||||
|
||||
if self._loop and self._loop.is_running():
|
||||
try:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.broadcast_status(task_uuid, status, **kwargs),
|
||||
self._loop
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"广播任务状态失败: {e}")
|
||||
|
||||
def get_status(self, task_uuid: str) -> Optional[dict]:
|
||||
"""获取任务状态"""
|
||||
return _task_status.get(task_uuid)
|
||||
@@ -211,18 +240,25 @@ class TaskManager:
|
||||
|
||||
# ============== 批量任务管理 ==============
|
||||
|
||||
def init_batch(self, batch_id: str, total: int):
|
||||
def init_batch(self, batch_id: str, total: int, **kwargs):
|
||||
"""初始化批量任务"""
|
||||
_batch_status[batch_id] = {
|
||||
"status": "running",
|
||||
"total": total,
|
||||
"completed": 0,
|
||||
"success": 0,
|
||||
"failed": 0,
|
||||
"skipped": 0,
|
||||
"current_index": 0,
|
||||
"finished": False
|
||||
}
|
||||
with _get_batch_lock(batch_id):
|
||||
previous = _batch_status.get(batch_id, {})
|
||||
status = {
|
||||
"status": "running",
|
||||
"total": total,
|
||||
"completed": 0,
|
||||
"success": 0,
|
||||
"failed": 0,
|
||||
"skipped": previous.get("skipped", 0),
|
||||
"cancelled": previous.get("cancelled", False),
|
||||
"current_index": 0,
|
||||
"finished": False,
|
||||
}
|
||||
status.update(previous)
|
||||
status.update(kwargs)
|
||||
status["total"] = total
|
||||
_batch_status[batch_id] = status
|
||||
logger.info(f"批量任务 {batch_id} 已初始化,总数: {total}")
|
||||
|
||||
def add_batch_log(self, batch_id: str, log_message: str):
|
||||
@@ -266,11 +302,11 @@ class TaskManager:
|
||||
|
||||
def update_batch_status(self, batch_id: str, **kwargs):
|
||||
"""更新批量任务状态"""
|
||||
if batch_id not in _batch_status:
|
||||
logger.warning(f"批量任务 {batch_id} 不存在")
|
||||
return
|
||||
|
||||
_batch_status[batch_id].update(kwargs)
|
||||
with _get_batch_lock(batch_id):
|
||||
if batch_id not in _batch_status:
|
||||
logger.warning(f"批量任务 {batch_id} 不存在")
|
||||
return
|
||||
_batch_status[batch_id].update(kwargs)
|
||||
|
||||
# 异步广播状态更新
|
||||
if self._loop and self._loop.is_running():
|
||||
@@ -302,7 +338,9 @@ class TaskManager:
|
||||
|
||||
def get_batch_status(self, batch_id: str) -> Optional[dict]:
|
||||
"""获取批量任务状态"""
|
||||
return _batch_status.get(batch_id)
|
||||
with _get_batch_lock(batch_id):
|
||||
status = _batch_status.get(batch_id)
|
||||
return status.copy() if status is not None else None
|
||||
|
||||
def get_batch_logs(self, batch_id: str) -> List[str]:
|
||||
"""获取批量任务日志"""
|
||||
@@ -316,26 +354,29 @@ class TaskManager:
|
||||
|
||||
def cancel_batch(self, batch_id: str):
|
||||
"""取消批量任务"""
|
||||
if batch_id in _batch_status:
|
||||
_batch_status[batch_id]["cancelled"] = True
|
||||
_batch_status[batch_id]["status"] = "cancelling"
|
||||
logger.info(f"批量任务 {batch_id} 已标记为取消")
|
||||
with _get_batch_lock(batch_id):
|
||||
if batch_id in _batch_status:
|
||||
_batch_status[batch_id]["cancelled"] = True
|
||||
_batch_status[batch_id]["status"] = "cancelling"
|
||||
logger.info(f"批量任务 {batch_id} 已标记为取消")
|
||||
|
||||
def register_batch_websocket(self, batch_id: str, websocket):
|
||||
"""注册批量任务 WebSocket 连接"""
|
||||
def register_batch_websocket(self, batch_id: str, websocket) -> List[str]:
|
||||
"""注册批量任务 WebSocket 连接,并返回注册时刻的历史日志快照"""
|
||||
key = f"batch_{batch_id}"
|
||||
history_logs: List[str] = []
|
||||
with _ws_lock:
|
||||
if key not in _ws_connections:
|
||||
_ws_connections[key] = []
|
||||
# 避免重复注册同一个连接
|
||||
if websocket not in _ws_connections[key]:
|
||||
_ws_connections[key].append(websocket)
|
||||
# 记录已发送的日志数量,用于发送历史日志时避免重复
|
||||
with _get_batch_lock(batch_id):
|
||||
_ws_sent_index[key][id(websocket)] = len(_batch_logs.get(batch_id, []))
|
||||
history_logs = _batch_logs.get(batch_id, []).copy()
|
||||
_ws_sent_index[key][id(websocket)] = len(history_logs)
|
||||
_ws_connections[key].append(websocket)
|
||||
logger.info(f"批量任务 WebSocket 连接已注册: {batch_id}")
|
||||
else:
|
||||
logger.warning(f"批量任务 WebSocket 连接已存在,跳过重复注册: {batch_id}")
|
||||
return history_logs
|
||||
|
||||
def get_unsent_batch_logs(self, batch_id: str, websocket) -> List[str]:
|
||||
"""获取未发送给该 WebSocket 的批量任务日志"""
|
||||
|
||||
Reference in New Issue
Block a user