fix: restore protocol baseline, resolve 403/400 registration errors, and fully remove deprecated playwright dependency

This commit is contained in:
Mison
2026-03-23 18:49:51 +08:00
parent a7a6391f0d
commit de2c4aa7ab
60 changed files with 4126 additions and 2834 deletions

View File

@@ -21,13 +21,13 @@ jobs:
include: include:
- os: windows-latest - os: windows-latest
artifact_name: codex-register.exe artifact_name: codex-register.exe
asset_name: codex-register-v2-windows-x64.exe asset_name: codex-register-windows-x64.exe
- os: ubuntu-latest - os: ubuntu-latest
artifact_name: codex-register artifact_name: codex-register
asset_name: codex-register-v2-linux-x64 asset_name: codex-register-linux-x64
- os: macos-latest - os: macos-latest
artifact_name: codex-register artifact_name: codex-register
asset_name: codex-register-v2-macos-arm64 asset_name: codex-register-macos-arm64
steps: steps:
- name: 检出代码 - name: 检出代码
@@ -74,19 +74,19 @@ jobs:
- name: 整理文件并打包 zip - name: 整理文件并打包 zip
run: | run: |
mkdir -p release mkdir -p release
# download-artifact@v4 将每个 artifact 放在 dist/<asset_name>/ 子目录下 # 为每个平台二进制文件打包成 zip
# 遍历子目录,用目录名作为平台标识(即 matrix.asset_name find dist/ -type f | while read f; do
for artifact_dir in dist/*/; do name=$(basename "$f")
platform=$(basename "$artifact_dir") # 根据文件名确定平台标识
# 找到该目录下的二进制文件(只有一个) case "$name" in
binary=$(find "$artifact_dir" -maxdepth 1 -type f | head -n1) *windows*) platform=$(echo "$name" | sed 's/\.[^.]*$//') ;;
if [ -z "$binary" ]; then *linux*) platform="$name" ;;
echo "警告:$artifact_dir 下没有找到文件,跳过" *macos*) platform="$name" ;;
continue *) platform="$name" ;;
fi esac
tmpdir="tmp_${platform}" tmpdir="tmp_${platform}"
mkdir -p "$tmpdir" mkdir -p "$tmpdir"
cp "$binary" "$tmpdir/" cp "$f" "$tmpdir/"
cp README.md "$tmpdir/README.md" cp README.md "$tmpdir/README.md"
cp .env.example "$tmpdir/.env.example" cp .env.example "$tmpdir/.env.example"
[ -f LICENSE ] && cp LICENSE "$tmpdir/LICENSE" || true [ -f LICENSE ] && cp LICENSE "$tmpdir/LICENSE" || true
@@ -103,14 +103,14 @@ jobs:
files: release/* files: release/*
generate_release_notes: true generate_release_notes: true
body: | body: |
## OpenAI 账号管理系统 v2 ## OpenAI 自动注册系统 v2
### 下载说明 ### 下载说明
| 平台 | 文件 | | 平台 | 文件 |
|------|------| |------|------|
| Windows x64 | `codex-register-v2-windows-x64.exe` | | Windows x64 | `codex-register-windows-x64.exe` |
| Linux x64 | `codex-register-v2-linux-x64` | | Linux x64 | `codex-register-linux-x64` |
| macOS ARM64 | `codex-register-v2-macos-arm64` | | macOS ARM64 | `codex-register-macos-arm64` |
### 使用方法 ### 使用方法
```bash ```bash

View File

@@ -64,6 +64,5 @@ jobs:
push: ${{ github.event_name != 'pull_request' }} push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.meta.outputs.tags }} tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }} labels: ${{ steps.meta.outputs.labels }}
platforms: linux/amd64,linux/arm64
cache-from: type=gha cache-from: type=gha
cache-to: type=gha,mode=max cache-to: type=gha,mode=max

View File

@@ -2,6 +2,8 @@
管理 OpenAI 账号的 Web UI 系统,支持多种邮箱服务、并发批量注册、代理管理和账号管理。 管理 OpenAI 账号的 Web UI 系统,支持多种邮箱服务、并发批量注册、代理管理和账号管理。
# 官方拉闸了,改变了授权流程,各位自行研究吧
> ⚠️ **免责声明**:本工具仅供学习和研究使用,使用本工具产生的一切后果由使用者自行承担。请遵守相关服务的使用条款,不要用于任何违法或不当用途。 如有侵权,请及时联系,会及时删除。 > ⚠️ **免责声明**:本工具仅供学习和研究使用,使用本工具产生的一切后果由使用者自行承担。请遵守相关服务的使用条款,不要用于任何违法或不当用途。 如有侵权,请及时联系,会及时删除。
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](LICENSE) [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](LICENSE)
@@ -144,9 +146,9 @@ docker-compose up -d
```bash ```bash
docker run -d \ docker run -d \
-p 15555:15555 \ -p 1455:1455 \
-e WEBUI_HOST=0.0.0.0 \ -e WEBUI_HOST=0.0.0.0 \
-e WEBUI_PORT=15555 \ -e WEBUI_PORT=1455 \
-e WEBUI_ACCESS_PASSWORD=your_secure_password \ -e WEBUI_ACCESS_PASSWORD=your_secure_password \
-v $(pwd)/data:/app/data \ -v $(pwd)/data:/app/data \
--name codex-register \ --name codex-register \
@@ -155,7 +157,7 @@ docker run -d \
环境变量说明: 环境变量说明:
- `WEBUI_HOST`: 监听的主机地址 (默认 `0.0.0.0`) - `WEBUI_HOST`: 监听的主机地址 (默认 `0.0.0.0`)
- `WEBUI_PORT`: 监听的端口 (默认 `15555`) - `WEBUI_PORT`: 监听的端口 (默认 `1455`)
- `WEBUI_ACCESS_PASSWORD`: 设置 Web UI 的访问密码 - `WEBUI_ACCESS_PASSWORD`: 设置 Web UI 的访问密码
- `DEBUG`: 设为 `1``true` 开启调试模式 - `DEBUG`: 设为 `1``true` 开启调试模式
- `LOG_LEVEL`: 日志级别,如 `info`, `debug` - `LOG_LEVEL`: 日志级别,如 `info`, `debug`
@@ -371,8 +373,7 @@ docker-compose build --no-cache
- CPA / Sub2API / Team Manager 上传始终直连,不走代理;其中 CPA 可选把账号记录的代理写入 auth file 的 `proxy_url` - CPA / Sub2API / Team Manager 上传始终直连,不走代理;其中 CPA 可选把账号记录的代理写入 auth file 的 `proxy_url`
- 注册时自动随机生成用户名和生日(年龄范围 18-45 岁) - 注册时自动随机生成用户名和生日(年龄范围 18-45 岁)
- 支付链接生成使用账号 access_token 鉴权,走全局代理配置 - 支付链接生成使用账号 access_token 鉴权,走全局代理配置
- 无痕浏览器优先使用 playwright注入 cookie 直达支付页);未安装时降级为系统 Chrome/Edge 无痕模式 - 无痕打开支付页默认调用系统 Chrome/Edge 的隐私模式
- 安装完整支付功能:`pip install ".[payment]" && playwright install chromium`(可选)
- 订阅状态自动检测调用 `chatgpt.com/backend-api/me`,走全局代理 - 订阅状态自动检测调用 `chatgpt.com/backend-api/me`,走全局代理
- 批量注册并发数上限为 50线程池大小已相应调整 - 批量注册并发数上限为 50线程池大小已相应调整

View File

@@ -1,5 +1,3 @@
version: '3.8'
services: services:
webui: webui:
build: . build: .

View File

@@ -23,9 +23,6 @@ dev = [
"pytest>=7.0.0", "pytest>=7.0.0",
"httpx>=0.24.0", "httpx>=0.24.0",
] ]
payment = [
"playwright>=1.40.0",
]
[project.scripts] [project.scripts]
codex-webui = "webui:main" codex-webui = "webui:main"

View File

@@ -11,5 +11,3 @@ python-multipart>=0.0.6
sqlalchemy>=2.0.0 sqlalchemy>=2.0.0
aiosqlite>=0.19.0 aiosqlite>=0.19.0
psycopg[binary]>=3.1.18 psycopg[binary]>=3.1.18
# 可选:无痕打开支付页需要 playwrightpip install playwright && playwright install chromium
# playwright>=1.40.0

View File

@@ -56,7 +56,7 @@ APP_DESCRIPTION = "自动注册 OpenAI/Codex CLI 账号的系统"
OAUTH_CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann" OAUTH_CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann"
OAUTH_AUTH_URL = "https://auth.openai.com/oauth/authorize" OAUTH_AUTH_URL = "https://auth.openai.com/oauth/authorize"
OAUTH_TOKEN_URL = "https://auth.openai.com/oauth/token" OAUTH_TOKEN_URL = "https://auth.openai.com/oauth/token"
OAUTH_REDIRECT_URI = "http://localhost:15555/auth/callback" OAUTH_REDIRECT_URI = "http://localhost:1455/auth/callback"
OAUTH_SCOPE = "openid email profile offline_access" OAUTH_SCOPE = "openid email profile offline_access"
# OpenAI API 端点 # OpenAI API 端点
@@ -65,20 +65,15 @@ OPENAI_API_ENDPOINTS = {
"signup": "https://auth.openai.com/api/accounts/authorize/continue", "signup": "https://auth.openai.com/api/accounts/authorize/continue",
"register": "https://auth.openai.com/api/accounts/user/register", "register": "https://auth.openai.com/api/accounts/user/register",
"send_otp": "https://auth.openai.com/api/accounts/email-otp/send", "send_otp": "https://auth.openai.com/api/accounts/email-otp/send",
"passwordless_send_otp": "https://auth.openai.com/api/accounts/passwordless/send-otp",
"validate_otp": "https://auth.openai.com/api/accounts/email-otp/validate", "validate_otp": "https://auth.openai.com/api/accounts/email-otp/validate",
"create_account": "https://auth.openai.com/api/accounts/create_account", "create_account": "https://auth.openai.com/api/accounts/create_account",
"add_phone": "https://auth.openai.com/add-phone",
"select_workspace": "https://auth.openai.com/api/accounts/workspace/select", "select_workspace": "https://auth.openai.com/api/accounts/workspace/select",
"send_passwordless_otp": "https://auth.openai.com/api/accounts/passwordless/send-otp",
"password_verify": "https://auth.openai.com/api/accounts/password/verify",
} }
# OpenAI 页面类型(用于判断账号状态) # OpenAI 页面类型(用于判断账号状态)
OPENAI_PAGE_TYPES = { OPENAI_PAGE_TYPES = {
"LOGIN_PASSWORD": "login_password",
"EMAIL_OTP_VERIFICATION": "email_otp_verification", # 已注册账号,需要 OTP 验证 "EMAIL_OTP_VERIFICATION": "email_otp_verification", # 已注册账号,需要 OTP 验证
"PASSWORD_REGISTRATION": "create_account_password", # 新账号,需要设置密码 "PASSWORD_REGISTRATION": "password", # 新账号,需要设置密码
} }
# ============================================================================ # ============================================================================
@@ -272,7 +267,7 @@ DEFAULT_SETTINGS = [
("registration.timeout", "120", "超时时间(秒)", "registration"), ("registration.timeout", "120", "超时时间(秒)", "registration"),
("registration.default_password_length", "12", "默认密码长度", "registration"), ("registration.default_password_length", "12", "默认密码长度", "registration"),
("webui.host", "0.0.0.0", "Web UI 监听主机", "webui"), ("webui.host", "0.0.0.0", "Web UI 监听主机", "webui"),
("webui.port", "8000", "Web UI 监听端口", "webui"), ("webui.port", "15555", "Web UI 监听端口", "webui"),
("webui.debug", "true", "调试模式", "webui"), ("webui.debug", "true", "调试模式", "webui"),
] ]
@@ -383,8 +378,20 @@ MICROSOFT_TOKEN_ENDPOINTS = {
} }
# IMAP 服务器配置 # IMAP 服务器配置
OUTLOOK_IMAP_SERVER = "outlook.live.com" OUTLOOK_IMAP_SERVERS = {
OUTLOOK_IMAP_PORT = 993 "OLD": "outlook.office365.com", # 旧版 IMAP
"NEW": "outlook.live.com", # 新版 IMAP
}
# Microsoft OAuth2 ScopeIMAP_NEW # Microsoft OAuth2 Scopes
OUTLOOK_IMAP_SCOPE = "https://outlook.office.com/IMAP.AccessAsUser.All offline_access" MICROSOFT_SCOPES = {
# 旧版 IMAP 不需要特定 scope
"IMAP_OLD": "",
# 新版 IMAP 需要的 scope
"IMAP_NEW": "https://outlook.office.com/IMAP.AccessAsUser.All offline_access",
# Graph API 需要的 scope
"GRAPH_API": "https://graph.microsoft.com/.default",
}
# Outlook 提供者默认优先级
OUTLOOK_PROVIDER_PRIORITY = ["imap_new", "imap_old", "graph_api"]

View File

@@ -76,7 +76,7 @@ SETTING_DEFINITIONS: Dict[str, SettingDefinition] = {
), ),
"webui_port": SettingDefinition( "webui_port": SettingDefinition(
db_key="webui.port", db_key="webui.port",
default_value=8000, default_value=15555,
category=SettingCategory.WEBUI, category=SettingCategory.WEBUI,
description="Web UI 监听端口" description="Web UI 监听端口"
), ),
@@ -136,7 +136,7 @@ SETTING_DEFINITIONS: Dict[str, SettingDefinition] = {
), ),
"openai_redirect_uri": SettingDefinition( "openai_redirect_uri": SettingDefinition(
db_key="openai.redirect_uri", db_key="openai.redirect_uri",
default_value="http://localhost:15555/auth/callback", default_value="http://localhost:1455/auth/callback",
category=SettingCategory.OPENAI, category=SettingCategory.OPENAI,
description="OpenAI OAuth 回调 URI" description="OpenAI OAuth 回调 URI"
), ),
@@ -358,6 +358,12 @@ SETTING_DEFINITIONS: Dict[str, SettingDefinition] = {
), ),
# Outlook 配置 # Outlook 配置
"outlook_provider_priority": SettingDefinition(
db_key="outlook.provider_priority",
default_value=["imap_old", "imap_new", "graph_api"],
category=SettingCategory.EMAIL,
description="Outlook 提供者优先级"
),
"outlook_health_failure_threshold": SettingDefinition( "outlook_health_failure_threshold": SettingDefinition(
db_key="outlook.health_failure_threshold", db_key="outlook.health_failure_threshold",
default_value=5, default_value=5,
@@ -376,12 +382,6 @@ SETTING_DEFINITIONS: Dict[str, SettingDefinition] = {
category=SettingCategory.EMAIL, category=SettingCategory.EMAIL,
description="Outlook OAuth 默认 Client ID" description="Outlook OAuth 默认 Client ID"
), ),
"outlook_use_idle": SettingDefinition(
db_key="outlook.use_idle",
default_value=True,
category=SettingCategory.EMAIL,
description="使用 IMAP IDLE 替代轮询获取验证码(降低延迟,默认开启)"
),
} }
# 属性名到数据库键名的映射(用于向后兼容) # 属性名到数据库键名的映射(用于向后兼容)
@@ -407,9 +407,9 @@ SETTING_TYPES: Dict[str, Type] = {
"cpa_enabled": bool, "cpa_enabled": bool,
"email_code_timeout": int, "email_code_timeout": int,
"email_code_poll_interval": int, "email_code_poll_interval": int,
"outlook_provider_priority": list,
"outlook_health_failure_threshold": int, "outlook_health_failure_threshold": int,
"outlook_health_disable_duration": int, "outlook_health_disable_duration": int,
"outlook_use_idle": bool,
} }
# 需要作为 SecretStr 处理的字段 # 需要作为 SecretStr 处理的字段
@@ -609,7 +609,7 @@ class Settings(BaseModel):
# Web UI 配置 # Web UI 配置
webui_host: str = "0.0.0.0" webui_host: str = "0.0.0.0"
webui_port: int = 8000 webui_port: int = 15555
webui_secret_key: SecretStr = SecretStr("your-secret-key-change-in-production") webui_secret_key: SecretStr = SecretStr("your-secret-key-change-in-production")
webui_access_password: SecretStr = SecretStr("admin123") webui_access_password: SecretStr = SecretStr("admin123")
@@ -622,7 +622,7 @@ class Settings(BaseModel):
openai_client_id: str = "app_EMoamEEZ73f0CkXaXp7hrann" openai_client_id: str = "app_EMoamEEZ73f0CkXaXp7hrann"
openai_auth_url: str = "https://auth.openai.com/oauth/authorize" openai_auth_url: str = "https://auth.openai.com/oauth/authorize"
openai_token_url: str = "https://auth.openai.com/oauth/token" openai_token_url: str = "https://auth.openai.com/oauth/token"
openai_redirect_uri: str = "http://localhost:15555/auth/callback" openai_redirect_uri: str = "http://localhost:1455/auth/callback"
openai_scope: str = "openid email profile offline_access" openai_scope: str = "openid email profile offline_access"
# 代理配置 # 代理配置
@@ -694,10 +694,10 @@ class Settings(BaseModel):
email_code_poll_interval: int = 3 email_code_poll_interval: int = 3
# Outlook 配置 # Outlook 配置
outlook_provider_priority: List[str] = ["imap_old", "imap_new", "graph_api"]
outlook_health_failure_threshold: int = 5 outlook_health_failure_threshold: int = 5
outlook_health_disable_duration: int = 60 outlook_health_disable_duration: int = 60
outlook_default_client_id: str = "24d9a0ed-8787-4584-883c-2fd79308940a" outlook_default_client_id: str = "24d9a0ed-8787-4584-883c-2fd79308940a"
outlook_use_idle: bool = True
# 全局配置实例 # 全局配置实例

View File

@@ -12,7 +12,6 @@ from .http_client import (
create_openai_client, create_openai_client,
) )
from .register import RegistrationEngine, RegistrationResult from .register import RegistrationEngine, RegistrationResult
from .login import LoginEngine
from .utils import setup_logging, get_data_dir from .utils import setup_logging, get_data_dir
__all__ = [ __all__ = [
@@ -28,7 +27,6 @@ __all__ = [
'create_openai_client', 'create_openai_client',
'RegistrationEngine', 'RegistrationEngine',
'RegistrationResult', 'RegistrationResult',
'LoginEngine',
'setup_logging', 'setup_logging',
'get_data_dir', 'get_data_dir',
] ]

View File

@@ -282,13 +282,13 @@ class OpenAIHTTPClient(HTTPClient):
loc = loc_match.group(1) if loc_match else None loc = loc_match.group(1) if loc_match else None
# 检查是否支持 # 检查是否支持
if loc in ["CN", "HK", "MO"]: if loc in ["CN", "HK", "MO", "TW"]:
return False, loc return False, loc
return True, loc return True, loc
except Exception as e: except Exception as e:
logger.error(f"检查 IP 地理位置失败: {e}") logger.error(f"检查 IP 地理位置失败: {e}")
return False, str(e) return False, None
def send_openai_request( def send_openai_request(
self, self,
@@ -417,4 +417,4 @@ def create_openai_client(
Returns: Returns:
OpenAIHTTPClient 实例 OpenAIHTTPClient 实例
""" """
return OpenAIHTTPClient(proxy_url, config) return OpenAIHTTPClient(proxy_url, config)

View File

@@ -81,7 +81,10 @@ class LoginEngine(RegistrationEngine):
} }
if sen_token: if sen_token:
sentinel = f'{{"p": "", "t": "", "c": "{sen_token}", "id": "{did}", "flow": "authorize_continue"}}' sentinel = (
f'{{"p": "", "t": "", "c": "{sen_token}", '
f'"id": "{did}", "flow": "authorize_continue"}}'
)
headers["openai-sentinel-token"] = sentinel headers["openai-sentinel-token"] = sentinel
response = self.session.post( response = self.session.post(
@@ -101,7 +104,6 @@ class LoginEngine(RegistrationEngine):
def _send_verification_code_passwordless(self) -> bool: def _send_verification_code_passwordless(self) -> bool:
"""发送验证码""" """发送验证码"""
try: try:
# 记录发送时间戳
self._otp_sent_at = time.time() self._otp_sent_at = time.time()
response = self.session.post( response = self.session.post(
OPENAI_API_ENDPOINTS["passwordless_send_otp"], OPENAI_API_ENDPOINTS["passwordless_send_otp"],
@@ -281,7 +283,6 @@ class LoginEngine(RegistrationEngine):
self._log("开始注册流程") self._log("开始注册流程")
self._log("=" * 60) self._log("=" * 60)
# 1. 检查 IP 地理位置
self._log("1. 检查 IP 地理位置...") self._log("1. 检查 IP 地理位置...")
ip_ok, location = self._check_ip_location() ip_ok, location = self._check_ip_location()
if not ip_ok: if not ip_ok:
@@ -291,7 +292,6 @@ class LoginEngine(RegistrationEngine):
self._log(f"IP 位置: {location}") self._log(f"IP 位置: {location}")
# 2. 创建邮箱
self._log("2. 创建邮箱...") self._log("2. 创建邮箱...")
if not self._create_email(): if not self._create_email():
result.error_message = "创建邮箱失败" result.error_message = "创建邮箱失败"
@@ -299,26 +299,22 @@ class LoginEngine(RegistrationEngine):
result.email = self.email result.email = self.email
# 3. 初始化会话
self._log("3. 初始化会话...") self._log("3. 初始化会话...")
if not self._init_session(): if not self._init_session():
result.error_message = "初始化会话失败" result.error_message = "初始化会话失败"
return result return result
# 4. 开始 OAuth 流程
self._log("4. 开始 OAuth 授权流程...") self._log("4. 开始 OAuth 授权流程...")
if not self._start_oauth(): if not self._start_oauth():
result.error_message = "开始 OAuth 流程失败" result.error_message = "开始 OAuth 流程失败"
return result return result
# 5. 获取 Device ID
self._log("5. 获取 Device ID...") self._log("5. 获取 Device ID...")
did = self._get_device_id() did = self._get_device_id()
if not did: if not did:
result.error_message = "获取 Device ID 失败" result.error_message = "获取 Device ID 失败"
return result return result
# 6. 检查 Sentinel 拦截
self._log("6. 检查 Sentinel 拦截...") self._log("6. 检查 Sentinel 拦截...")
sen_token = self._check_sentinel(did) sen_token = self._check_sentinel(did)
if sen_token: if sen_token:
@@ -326,32 +322,28 @@ class LoginEngine(RegistrationEngine):
else: else:
self._log("Sentinel 检查失败或未启用", "warning") self._log("Sentinel 检查失败或未启用", "warning")
# 7. 提交注册表单 + 解析响应判断账号状态
self._log("7. 提交注册表单...") self._log("7. 提交注册表单...")
signup_result = self._submit_signup_form(did, sen_token) signup_result = self._submit_signup_form(did, sen_token)
if not signup_result.success: if not signup_result.success:
result.error_message = f"提交注册表单失败: {signup_result.error_message}" result.error_message = f"提交注册表单失败: {signup_result.error_message}"
return result return result
# 8. 检测到已注册账号 → 直接终止任务
if self._is_existing_account: if self._is_existing_account:
self._log(f"8. 邮箱 {self.email} 在 OpenAI 已注册,跳过注册流程", "warning") self._log(f"8. 邮箱 {self.email} 在 OpenAI 已注册,跳过注册流程", "warning")
result.error_message = f"邮箱 {self.email} 已在 OpenAI 注册" result.error_message = f"邮箱 {self.email} 已在 OpenAI 注册"
return result return result
else:
self._log("8. 注册密码...")
password_ok, password = self._register_password()
if not password_ok:
result.error_message = "注册密码失败"
return result
# 9. 发送验证码 self._log("8. 注册密码...")
password_ok, password = self._register_password()
if not password_ok:
result.error_message = "注册密码失败"
return result
self._log("9. 发送验证码...") self._log("9. 发送验证码...")
if not self._send_verification_code(): if not self._send_verification_code():
result.error_message = "发送验证码失败" result.error_message = "发送验证码失败"
return result return result
# 10. 获取验证码(超时后重发一次)
self._log("10. 等待验证码...") self._log("10. 等待验证码...")
code = self._get_verification_code() code = self._get_verification_code()
if not code: if not code:
@@ -362,13 +354,11 @@ class LoginEngine(RegistrationEngine):
result.error_message = "获取验证码失败" result.error_message = "获取验证码失败"
return result return result
# 11. 验证验证码
self._log("11. 验证验证码...") self._log("11. 验证验证码...")
if not self._validate_verification_code(code): if not self._validate_verification_code(code):
result.error_message = "验证验证码失败" result.error_message = "验证验证码失败"
return result return result
# 12. 创建用户账户
self._log("12. 创建用户账户...") self._log("12. 创建用户账户...")
if not self._create_user_account(): if not self._create_user_account():
result.error_message = "创建用户账户失败" result.error_message = "创建用户账户失败"
@@ -404,7 +394,6 @@ class LoginEngine(RegistrationEngine):
result.error_message = "验证验证码失败" result.error_message = "验证验证码失败"
return result return result
# 13. 获取 Workspace ID
self._log("17. 获取 Workspace ID...") self._log("17. 获取 Workspace ID...")
workspace_id = self._get_workspace_id() workspace_id = self._get_workspace_id()
if not workspace_id: if not workspace_id:
@@ -413,45 +402,37 @@ class LoginEngine(RegistrationEngine):
result.workspace_id = workspace_id result.workspace_id = workspace_id
# 14. 选择 Workspace
self._log("18. 选择 Workspace...") self._log("18. 选择 Workspace...")
continue_url = self._select_workspace(workspace_id) continue_url = self._select_workspace(workspace_id)
if not continue_url: if not continue_url:
result.error_message = "选择 Workspace 失败" result.error_message = "选择 Workspace 失败"
return result return result
# 15. 跟随重定向链
self._log("19. 跟随重定向链...") self._log("19. 跟随重定向链...")
callback_url = self._follow_redirects(continue_url) callback_url = self._follow_redirects(continue_url)
if not callback_url: if not callback_url:
result.error_message = "跟随重定向链失败" result.error_message = "跟随重定向链失败"
return result return result
# 16. 处理 OAuth 回调
self._log("20. 处理 OAuth 回调...") self._log("20. 处理 OAuth 回调...")
token_info = self._handle_oauth_callback(callback_url) token_info = self._handle_oauth_callback(callback_url)
if not token_info: if not token_info:
result.error_message = "处理 OAuth 回调失败" result.error_message = "处理 OAuth 回调失败"
return result return result
# 提取账户信息
result.account_id = token_info.get("account_id", "") result.account_id = token_info.get("account_id", "")
result.access_token = token_info.get("access_token", "") result.access_token = token_info.get("access_token", "")
result.refresh_token = token_info.get("refresh_token", "") result.refresh_token = token_info.get("refresh_token", "")
result.id_token = token_info.get("id_token", "") result.id_token = token_info.get("id_token", "")
result.password = self.password or "" # 保存密码(已注册账号为空) result.password = self.password or ""
# 设置来源标记
result.source = "register" result.source = "register"
# 尝试获取 session_token 从 cookie
session_cookie = self.session.cookies.get("__Secure-next-auth.session-token") session_cookie = self.session.cookies.get("__Secure-next-auth.session-token")
if session_cookie: if session_cookie:
self.session_token = session_cookie self.session_token = session_cookie
result.session_token = session_cookie result.session_token = session_cookie
self._log(f"获取到 Session Token") self._log("获取到 Session Token")
# 17. 完成
self._log("=" * 60) self._log("=" * 60)
self._log("注册成功!") self._log("注册成功!")
self._log(f"邮箱: {result.email}") self._log(f"邮箱: {result.email}")

File diff suppressed because it is too large Load Diff

View File

@@ -15,7 +15,6 @@ import base64
import re import re
import uuid import uuid
from datetime import datetime, timedelta from datetime import datetime, timedelta
from html.parser import HTMLParser
from typing import Any, Dict, List, Optional, Union, Callable from typing import Any, Dict, List, Optional, Union, Callable
from pathlib import Path from pathlib import Path
@@ -569,49 +568,3 @@ class Timer:
if self.start_time is not None: if self.start_time is not None:
return time.time() - self.start_time return time.time() - self.start_time
return 0.0 return 0.0
class BootstrapExtractor(HTMLParser):
"""内部解析器,专门提取 id="client-bootstrap" 的 script 内容"""
def __init__(self):
super().__init__()
self._in_target = False
self.json_text = None
def handle_starttag(self, tag, attrs):
if tag == 'script':
attrs_dict = dict(attrs)
if attrs_dict.get('id') == 'client-bootstrap':
self._in_target = True
def handle_endtag(self, tag):
if tag == 'script' and self._in_target:
self._in_target = False
def handle_data(self, data):
if self._in_target and self.json_text is None:
self.json_text = data.strip()
def extract_client_bootstrap_json(html: str):
"""
从 HTML 字符串中提取 id="client-bootstrap" 的 script 标签内容并解析为 JSON。
返回 dict 或 None未找到或解析失败
"""
parser = BootstrapExtractor()
parser.feed(html)
if parser.json_text:
try:
return json.loads(parser.json_text)
except json.JSONDecodeError:
return None
return None
def base64_payload_decode(payload_b64):
import base64
import json as json_module
padding = 4 - (len(payload_b64) % 4)
if padding != 4:
payload_b64 += '=' * padding
# 解码Base64URL 使用 - 和 _ 替代 + 和 /
payload_bytes = base64.urlsafe_b64decode(payload_b64)
return json_module.loads(payload_bytes)

View File

@@ -36,6 +36,7 @@ def create_account(
access_token: Optional[str] = None, access_token: Optional[str] = None,
refresh_token: Optional[str] = None, refresh_token: Optional[str] = None,
id_token: Optional[str] = None, id_token: Optional[str] = None,
cookies: Optional[str] = None,
proxy_used: Optional[str] = None, proxy_used: Optional[str] = None,
expires_at: Optional['datetime'] = None, expires_at: Optional['datetime'] = None,
extra_data: Optional[Dict[str, Any]] = None, extra_data: Optional[Dict[str, Any]] = None,
@@ -62,6 +63,7 @@ def create_account(
access_token=access_token, access_token=access_token,
refresh_token=refresh_token, refresh_token=refresh_token,
id_token=id_token, id_token=id_token,
cookies=cookies,
proxy_used=proxy_used, proxy_used=proxy_used,
expires_at=expires_at, expires_at=expires_at,
extra_data=extra_data or {}, extra_data=extra_data or {},
@@ -134,7 +136,6 @@ def update_account(
} }
kwargs.setdefault("token_sync_status", _default_token_sync_status(persisted_token_values)) kwargs.setdefault("token_sync_status", _default_token_sync_status(persisted_token_values))
kwargs["token_sync_updated_at"] = datetime.utcnow() kwargs["token_sync_updated_at"] = datetime.utcnow()
for key, value in kwargs.items(): for key, value in kwargs.items():
if hasattr(db_account, key) and value is not None: if hasattr(db_account, key) and value is not None:
setattr(db_account, key, value) setattr(db_account, key, value)
@@ -353,6 +354,34 @@ def delete_registration_task(db: Session, task_uuid: str) -> bool:
return True return True
def fail_incomplete_registration_tasks(db: Session, error_message: str) -> List[str]:
"""将服务重启后遗留的未完成任务标记为失败"""
tasks = db.query(RegistrationTask).filter(
RegistrationTask.status.in_(("pending", "running"))
).all()
if not tasks:
return []
now = datetime.utcnow()
cleaned_task_ids: List[str] = []
cleanup_log = f"[系统] {error_message}"
for task in tasks:
task.status = "failed"
task.error_message = error_message
task.completed_at = now
if task.logs:
if cleanup_log not in task.logs:
task.logs = f"{task.logs}\n{cleanup_log}"
else:
task.logs = cleanup_log
cleaned_task_ids.append(task.task_uuid)
db.commit()
return cleaned_task_ids
# 为 API 路由添加别名 # 为 API 路由添加别名
get_account = get_account_by_id get_account = get_account_by_id
get_registration_task = get_registration_task_by_uuid get_registration_task = get_registration_task_by_uuid
@@ -503,13 +532,6 @@ def delete_proxy(db: Session, proxy_id: int) -> bool:
return True return True
def delete_disabled_proxies(db: Session) -> int:
"""删除所有已禁用代理"""
deleted = db.query(Proxy).filter(Proxy.enabled == False).delete(synchronize_session=False)
db.commit()
return deleted
def update_proxy_last_used(db: Session, proxy_id: int) -> bool: def update_proxy_last_used(db: Session, proxy_id: int) -> bool:
"""更新代理最后使用时间""" """更新代理最后使用时间"""
db_proxy = get_proxy_by_id(db, proxy_id) db_proxy = get_proxy_by_id(db, proxy_id)

View File

@@ -38,7 +38,9 @@ from .outlook.base import (
from .outlook.account import OutlookAccount from .outlook.account import OutlookAccount
from .outlook.providers import ( from .outlook.providers import (
OutlookProvider, OutlookProvider,
IMAPOldProvider,
IMAPNewProvider, IMAPNewProvider,
GraphAPIProvider,
) )
__all__ = [ __all__ = [
@@ -65,5 +67,7 @@ __all__ = [
'ProviderStatus', 'ProviderStatus',
'OutlookAccount', 'OutlookAccount',
'OutlookProvider', 'OutlookProvider',
'IMAPOldProvider',
'IMAPNewProvider', 'IMAPNewProvider',
'GraphAPIProvider',
] ]

View File

@@ -5,6 +5,8 @@
import abc import abc
import logging import logging
import time
from dataclasses import dataclass
from typing import Optional, Dict, Any, List from typing import Optional, Dict, Any, List
from enum import Enum from enum import Enum
@@ -13,12 +15,109 @@ from ..config.constants import EmailServiceType
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
EMAIL_PROVIDER_BACKOFF_BASE_SECONDS = 30
EMAIL_PROVIDER_BACKOFF_MAX_SECONDS = 3600
OTP_TIMEOUT_ERROR_PREFIX = "OTP_TIMEOUT"
@dataclass(frozen=True)
class EmailProviderBackoffState:
"""邮箱供应商退避状态"""
failures: int = 0
delay_seconds: int = 0
opened_until: float = 0.0
retry_after: Optional[int] = None
last_error: Optional[str] = None
def is_open(self, now: Optional[float] = None) -> bool:
now_ts = now if now is not None else time.time()
return self.opened_until > now_ts
def to_dict(self) -> Dict[str, Any]:
return {
"failures": self.failures,
"delay_seconds": self.delay_seconds,
"opened_until": self.opened_until,
"retry_after": self.retry_after,
"last_error": self.last_error,
}
def calculate_adaptive_backoff_delay(
failures: int,
base_delay: int = EMAIL_PROVIDER_BACKOFF_BASE_SECONDS,
max_delay: int = EMAIL_PROVIDER_BACKOFF_MAX_SECONDS,
is_timeout: bool = False,
) -> int:
"""根据连续失败次数计算指数退避时长"""
normalized_failures = max(0, failures)
if is_timeout and normalized_failures >= 3:
return max_delay
exponent = max(0, normalized_failures - 1)
return min(base_delay * (2 ** exponent), max_delay)
def is_otp_timeout_error(error: object) -> bool:
"""识别 OTP 超时类错误码。"""
if error is None:
return False
if isinstance(error, OTPTimeoutEmailServiceError):
return True
error_code = getattr(error, "error_code", "")
if isinstance(error_code, str) and error_code.startswith(OTP_TIMEOUT_ERROR_PREFIX):
return True
return False
def apply_adaptive_backoff(
current_state: Optional[EmailProviderBackoffState],
error: "EmailServiceError",
now: Optional[float] = None,
) -> EmailProviderBackoffState:
"""在限流场景下推进邮箱供应商退避状态"""
state = current_state or EmailProviderBackoffState()
now_ts = now if now is not None else time.time()
next_failures = state.failures + 1
delay_seconds = calculate_adaptive_backoff_delay(
next_failures,
is_timeout=is_otp_timeout_error(error),
)
return EmailProviderBackoffState(
failures=next_failures,
delay_seconds=delay_seconds,
opened_until=now_ts + delay_seconds,
retry_after=getattr(error, "retry_after", None),
last_error=str(error),
)
def reset_adaptive_backoff() -> EmailProviderBackoffState:
"""重置邮箱供应商退避状态"""
return EmailProviderBackoffState()
class EmailServiceError(Exception): class EmailServiceError(Exception):
"""邮箱服务异常""" """邮箱服务异常"""
pass pass
class RateLimitedEmailServiceError(EmailServiceError):
"""邮箱服务被限流"""
def __init__(self, message: str, retry_after: Optional[int] = None):
super().__init__(message)
self.retry_after = retry_after
class OTPTimeoutEmailServiceError(EmailServiceError):
"""OTP 验证码等待超时。"""
def __init__(self, message: str, error_code: str = OTP_TIMEOUT_ERROR_PREFIX):
super().__init__(message)
self.error_code = error_code
class EmailServiceStatus(Enum): class EmailServiceStatus(Enum):
"""邮箱服务状态""" """邮箱服务状态"""
HEALTHY = "healthy" HEALTHY = "healthy"
@@ -45,6 +144,7 @@ class BaseEmailService(abc.ABC):
self.name = name or f"{service_type.value}_service" self.name = name or f"{service_type.value}_service"
self._status = EmailServiceStatus.HEALTHY self._status = EmailServiceStatus.HEALTHY
self._last_error = None self._last_error = None
self._provider_backoff = reset_adaptive_backoff()
@property @property
def status(self) -> EmailServiceStatus: def status(self) -> EmailServiceStatus:
@@ -56,6 +156,15 @@ class BaseEmailService(abc.ABC):
"""获取最后一次错误信息""" """获取最后一次错误信息"""
return self._last_error return self._last_error
@property
def provider_backoff_state(self) -> EmailProviderBackoffState:
"""获取当前邮箱供应商退避状态"""
return self._provider_backoff
def apply_provider_backoff_state(self, state: Optional[EmailProviderBackoffState]) -> None:
"""注入外部持久化的邮箱供应商退避状态"""
self._provider_backoff = state or reset_adaptive_backoff()
@abc.abstractmethod @abc.abstractmethod
def create_email(self, config: Dict[str, Any] = None) -> Dict[str, Any]: def create_email(self, config: Dict[str, Any] = None) -> Dict[str, Any]:
""" """
@@ -92,7 +201,7 @@ class BaseEmailService(abc.ABC):
email_id: 邮箱服务中的 ID如果需要 email_id: 邮箱服务中的 ID如果需要
timeout: 超时时间(秒) timeout: 超时时间(秒)
pattern: 验证码正则表达式 pattern: 验证码正则表达式
otp_sent_at: OTP 发送时间戳,用于过滤旧邮件 otp_sent_at: OTP 发送时间戳,只允许使用严格晚于该锚点的邮件
Returns: Returns:
验证码字符串,如果超时或未找到返回 None 验证码字符串,如果超时或未找到返回 None
@@ -282,8 +391,16 @@ class BaseEmailService(abc.ABC):
if success: if success:
self._status = EmailServiceStatus.HEALTHY self._status = EmailServiceStatus.HEALTHY
self._last_error = None self._last_error = None
self._provider_backoff = reset_adaptive_backoff()
else: else:
self._status = EmailServiceStatus.DEGRADED if isinstance(error, RateLimitedEmailServiceError) or is_otp_timeout_error(error):
self._status = EmailServiceStatus.UNAVAILABLE
self._provider_backoff = apply_adaptive_backoff(
self._provider_backoff,
error,
)
else:
self._status = EmailServiceStatus.DEGRADED
if error: if error:
self._last_error = str(error) self._last_error = str(error)
@@ -383,4 +500,4 @@ def create_email_service(
Returns: Returns:
邮箱服务实例 邮箱服务实例
""" """
return EmailServiceFactory.create(service_type, config, name) return EmailServiceFactory.create(service_type, config, name)

View File

@@ -12,7 +12,7 @@ from datetime import datetime, timezone
from html import unescape from html import unescape
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from .base import BaseEmailService, EmailServiceError, EmailServiceType from .base import BaseEmailService, EmailServiceError, EmailServiceType, RateLimitedEmailServiceError
from ..config.constants import OTP_CODE_PATTERN from ..config.constants import OTP_CODE_PATTERN
from ..core.http_client import HTTPClient, RequestConfig from ..core.http_client import HTTPClient, RequestConfig
@@ -102,7 +102,19 @@ class DuckMailService(BaseEmailService):
error_message = f"{error_message} - {error_payload}" error_message = f"{error_message} - {error_payload}"
except Exception: except Exception:
error_message = f"{error_message} - {response.text[:200]}" error_message = f"{error_message} - {response.text[:200]}"
raise EmailServiceError(error_message) retry_after = None
if response.status_code == 429:
retry_after_header = response.headers.get("Retry-After")
if retry_after_header:
try:
retry_after = max(1, int(retry_after_header))
except ValueError:
retry_after = None
error = RateLimitedEmailServiceError(error_message, retry_after=retry_after)
else:
error = EmailServiceError(error_message)
self.update_status(False, error)
raise error
try: try:
return response.json() return response.json()

View File

@@ -10,7 +10,7 @@ import random
import string import string
from typing import Optional, Dict, Any, List from typing import Optional, Dict, Any, List
from .base import BaseEmailService, EmailServiceError, EmailServiceType from .base import BaseEmailService, EmailServiceError, EmailServiceType, RateLimitedEmailServiceError
from ..core.http_client import HTTPClient, RequestConfig from ..core.http_client import HTTPClient, RequestConfig
from ..config.constants import OTP_CODE_PATTERN from ..config.constants import OTP_CODE_PATTERN
@@ -96,8 +96,19 @@ class FreemailService(BaseEmailService):
error_msg = f"{error_msg} - {error_data}" error_msg = f"{error_msg} - {error_data}"
except Exception: except Exception:
error_msg = f"{error_msg} - {response.text[:200]}" error_msg = f"{error_msg} - {response.text[:200]}"
self.update_status(False, EmailServiceError(error_msg)) retry_after = None
raise EmailServiceError(error_msg) if response.status_code == 429:
retry_after_header = response.headers.get("Retry-After")
if retry_after_header:
try:
retry_after = max(1, int(retry_after_header))
except ValueError:
retry_after = None
error = RateLimitedEmailServiceError(error_msg, retry_after=retry_after)
else:
error = EmailServiceError(error_msg)
self.update_status(False, error)
raise error
try: try:
return response.json() return response.json()

View File

@@ -10,7 +10,7 @@ import logging
from typing import Optional, Dict, Any, List from typing import Optional, Dict, Any, List
from urllib.parse import urljoin from urllib.parse import urljoin
from .base import BaseEmailService, EmailServiceError, EmailServiceType from .base import BaseEmailService, EmailServiceError, EmailServiceType, RateLimitedEmailServiceError
from ..core.http_client import HTTPClient, RequestConfig from ..core.http_client import HTTPClient, RequestConfig
from ..config.constants import OTP_CODE_PATTERN from ..config.constants import OTP_CODE_PATTERN
@@ -148,8 +148,20 @@ class MeoMailEmailService(BaseEmailService):
except: except:
error_msg = f"{error_msg} - {response.text[:200]}" error_msg = f"{error_msg} - {response.text[:200]}"
self.update_status(False, EmailServiceError(error_msg)) retry_after = None
raise EmailServiceError(error_msg) if response.status_code == 429:
retry_after_header = response.headers.get("Retry-After")
if retry_after_header:
try:
retry_after = max(1, int(retry_after_header))
except ValueError:
retry_after = None
error = RateLimitedEmailServiceError(error_msg, retry_after=retry_after)
else:
error = EmailServiceError(error_msg)
self.update_status(False, error)
raise error
# 解析响应 # 解析响应
try: try:
@@ -553,4 +565,4 @@ class MeoMailEmailService(BaseEmailService):
"system_config": config, "system_config": config,
"cached_emails_count": len(self._emails_cache), "cached_emails_count": len(self._emails_cache),
"status": self.status.value, "status": self.status.value,
} }

View File

@@ -3,7 +3,7 @@ Outlook 账户数据类
""" """
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Any from typing import Dict, Any, Optional
@dataclass @dataclass
@@ -16,6 +16,7 @@ class OutlookAccount:
@classmethod @classmethod
def from_config(cls, config: Dict[str, Any]) -> "OutlookAccount": def from_config(cls, config: Dict[str, Any]) -> "OutlookAccount":
"""从配置创建账户"""
return cls( return cls(
email=config.get("email", ""), email=config.get("email", ""),
password=config.get("password", ""), password=config.get("password", ""),
@@ -24,12 +25,15 @@ class OutlookAccount:
) )
def has_oauth(self) -> bool: def has_oauth(self) -> bool:
"""是否支持 OAuth2"""
return bool(self.client_id and self.refresh_token) return bool(self.client_id and self.refresh_token)
def validate(self) -> bool: def validate(self) -> bool:
"""验证账户信息是否有效"""
return bool(self.email and self.password) or self.has_oauth() return bool(self.email and self.password) or self.has_oauth()
def to_dict(self, include_sensitive: bool = False) -> Dict[str, Any]: def to_dict(self, include_sensitive: bool = False) -> Dict[str, Any]:
"""转换为字典"""
result = { result = {
"email": self.email, "email": self.email,
"has_oauth": self.has_oauth(), "has_oauth": self.has_oauth(),
@@ -43,4 +47,5 @@ class OutlookAccount:
return result return result
def __str__(self) -> str: def __str__(self) -> str:
"""字符串表示"""
return f"OutlookAccount({self.email})" return f"OutlookAccount({self.email})"

View File

@@ -1,5 +1,6 @@
""" """
Outlook 邮箱服务基础定义 Outlook 服务基础定义
包含枚举类型和数据类
""" """
from dataclasses import dataclass, field from dataclasses import dataclass, field
@@ -9,38 +10,49 @@ from typing import Optional, Dict, Any, List
class ProviderType(str, Enum): class ProviderType(str, Enum):
"""Outlook 提供者类型(仅 IMAP_NEW""" """Outlook 提供者类型"""
IMAP_NEW = "imap_new" IMAP_OLD = "imap_old" # 旧版 IMAP (outlook.office365.com)
IMAP_NEW = "imap_new" # 新版 IMAP (outlook.live.com)
GRAPH_API = "graph_api" # Microsoft Graph API
class TokenEndpoint(str, Enum): class TokenEndpoint(str, Enum):
"""Token 端点""" """Token 端点"""
LIVE = "https://login.live.com/oauth20_token.srf"
CONSUMERS = "https://login.microsoftonline.com/consumers/oauth2/v2.0/token" CONSUMERS = "https://login.microsoftonline.com/consumers/oauth2/v2.0/token"
COMMON = "https://login.microsoftonline.com/common/oauth2/v2.0/token"
class IMAPServer(str, Enum):
"""IMAP 服务器"""
OLD = "outlook.office365.com"
NEW = "outlook.live.com"
class ProviderStatus(str, Enum): class ProviderStatus(str, Enum):
"""提供者状态""" """提供者状态"""
HEALTHY = "healthy" HEALTHY = "healthy" # 健康
DEGRADED = "degraded" DEGRADED = "degraded" # 降级
DISABLED = "disabled" DISABLED = "disabled" # 禁用
@dataclass @dataclass
class EmailMessage: class EmailMessage:
"""邮件消息数据类""" """邮件消息数据类"""
id: str id: str # 消息 ID
subject: str subject: str # 主题
sender: str sender: str # 发件人
recipients: List[str] = field(default_factory=list) recipients: List[str] = field(default_factory=list) # 收件人列表
body: str = "" body: str = "" # 正文内容
body_preview: str = "" body_preview: str = "" # 正文预览
received_at: Optional[datetime] = None received_at: Optional[datetime] = None # 接收时间
received_timestamp: int = 0 received_timestamp: int = 0 # 接收时间戳
is_read: bool = False is_read: bool = False # 是否已读
has_attachments: bool = False has_attachments: bool = False # 是否有附件
raw_data: Optional[bytes] = None raw_data: Optional[bytes] = None # 原始数据(用于调试)
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
return { return {
"id": self.id, "id": self.id,
"subject": self.subject, "subject": self.subject,
@@ -59,17 +71,19 @@ class EmailMessage:
class TokenInfo: class TokenInfo:
"""Token 信息数据类""" """Token 信息数据类"""
access_token: str access_token: str
expires_at: float expires_at: float # 过期时间戳
token_type: str = "Bearer" token_type: str = "Bearer"
scope: str = "" scope: str = ""
refresh_token: Optional[str] = None refresh_token: Optional[str] = None
def is_expired(self, buffer_seconds: int = 120) -> bool: def is_expired(self, buffer_seconds: int = 120) -> bool:
"""检查 Token 是否已过期"""
import time import time
return time.time() >= (self.expires_at - buffer_seconds) return time.time() >= (self.expires_at - buffer_seconds)
@classmethod @classmethod
def from_response(cls, data: Dict[str, Any], scope: str = "") -> "TokenInfo": def from_response(cls, data: Dict[str, Any], scope: str = "") -> "TokenInfo":
"""从 API 响应创建"""
import time import time
return cls( return cls(
access_token=data.get("access_token", ""), access_token=data.get("access_token", ""),
@@ -85,42 +99,49 @@ class ProviderHealth:
"""提供者健康状态""" """提供者健康状态"""
provider_type: ProviderType provider_type: ProviderType
status: ProviderStatus = ProviderStatus.HEALTHY status: ProviderStatus = ProviderStatus.HEALTHY
failure_count: int = 0 failure_count: int = 0 # 连续失败次数
last_success: Optional[datetime] = None last_success: Optional[datetime] = None # 最后成功时间
last_failure: Optional[datetime] = None last_failure: Optional[datetime] = None # 最后失败时间
last_error: str = "" last_error: str = "" # 最后错误信息
disabled_until: Optional[datetime] = None disabled_until: Optional[datetime] = None # 禁用截止时间
def record_success(self): def record_success(self):
"""记录成功"""
self.status = ProviderStatus.HEALTHY self.status = ProviderStatus.HEALTHY
self.failure_count = 0 self.failure_count = 0
self.last_success = datetime.now() self.last_success = datetime.now()
self.disabled_until = None self.disabled_until = None
def record_failure(self, error: str): def record_failure(self, error: str):
"""记录失败"""
self.failure_count += 1 self.failure_count += 1
self.last_failure = datetime.now() self.last_failure = datetime.now()
self.last_error = error self.last_error = error
def should_disable(self, threshold: int = 3) -> bool: def should_disable(self, threshold: int = 3) -> bool:
"""判断是否应该禁用"""
return self.failure_count >= threshold return self.failure_count >= threshold
def is_disabled(self) -> bool: def is_disabled(self) -> bool:
"""检查是否被禁用"""
if self.disabled_until and datetime.now() < self.disabled_until: if self.disabled_until and datetime.now() < self.disabled_until:
return True return True
return False return False
def disable(self, duration_seconds: int = 300): def disable(self, duration_seconds: int = 300):
"""禁用提供者"""
from datetime import timedelta from datetime import timedelta
self.status = ProviderStatus.DISABLED self.status = ProviderStatus.DISABLED
self.disabled_until = datetime.now() + timedelta(seconds=duration_seconds) self.disabled_until = datetime.now() + timedelta(seconds=duration_seconds)
def enable(self): def enable(self):
"""启用提供者"""
self.status = ProviderStatus.HEALTHY self.status = ProviderStatus.HEALTHY
self.disabled_until = None self.disabled_until = None
self.failure_count = 0 self.failure_count = 0
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
return { return {
"provider_type": self.provider_type.value, "provider_type": self.provider_type.value,
"status": self.status.value, "status": self.status.value,

View File

@@ -1,13 +1,15 @@
""" """
健康检查管理(简化版,单 Provider 健康检查和故障切换管理
""" """
import logging import logging
import threading import threading
import time
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Any, Dict, Optional from typing import Dict, List, Optional, Any
from .base import ProviderHealth, ProviderStatus, ProviderType from .base import ProviderType, ProviderHealth, ProviderStatus
from .providers.base import OutlookProvider
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -15,49 +17,296 @@ logger = logging.getLogger(__name__)
class HealthChecker: class HealthChecker:
""" """
单 Provider 健康检查器 健康检查管理
跟踪 IMAP_NEW 的健康状态 跟踪各提供者的健康状态,管理故障切换
""" """
def __init__( def __init__(
self, self,
failure_threshold: int = 3, failure_threshold: int = 3,
disable_duration: int = 300, disable_duration: int = 300,
recovery_check_interval: int = 60,
): ):
"""
初始化健康检查器
Args:
failure_threshold: 连续失败次数阈值,超过后禁用
disable_duration: 禁用时长(秒)
recovery_check_interval: 恢复检查间隔(秒)
"""
self.failure_threshold = failure_threshold self.failure_threshold = failure_threshold
self.disable_duration = disable_duration self.disable_duration = disable_duration
self._health = ProviderHealth(provider_type=ProviderType.IMAP_NEW) self.recovery_check_interval = recovery_check_interval
# 提供者健康状态: ProviderType -> ProviderHealth
self._health_status: Dict[ProviderType, ProviderHealth] = {}
self._lock = threading.Lock() self._lock = threading.Lock()
def record_success(self): # 初始化所有提供者的健康状态
with self._lock: for provider_type in ProviderType:
self._health.record_success() self._health_status[provider_type] = ProviderHealth(
provider_type=provider_type
)
def record_failure(self, error: str): def get_health(self, provider_type: ProviderType) -> ProviderHealth:
"""获取提供者的健康状态"""
with self._lock: with self._lock:
self._health.record_failure(error) return self._health_status.get(provider_type, ProviderHealth(provider_type=provider_type))
if self._health.should_disable(self.failure_threshold):
self._health.disable(self.disable_duration) def record_success(self, provider_type: ProviderType):
logger.warning( """记录成功操作"""
f"IMAP_NEW 已禁用 {self.disable_duration}s原因: {error}" with self._lock:
health = self._health_status.get(provider_type)
if health:
health.record_success()
logger.debug(f"{provider_type.value} 记录成功")
def record_failure(self, provider_type: ProviderType, error: str):
"""记录失败操作"""
with self._lock:
health = self._health_status.get(provider_type)
if health:
health.record_failure(error)
# 检查是否需要禁用
if health.should_disable(self.failure_threshold):
health.disable(self.disable_duration)
logger.warning(
f"{provider_type.value} 已禁用 {self.disable_duration} 秒,"
f"原因: {error}"
)
def is_available(self, provider_type: ProviderType) -> bool:
"""
检查提供者是否可用
Args:
provider_type: 提供者类型
Returns:
是否可用
"""
health = self.get_health(provider_type)
# 检查是否被禁用
if health.is_disabled():
remaining = (health.disabled_until - datetime.now()).total_seconds()
logger.debug(
f"{provider_type.value} 已被禁用,剩余 {int(remaining)}"
)
return False
return health.status != ProviderStatus.DISABLED
def get_available_providers(
self,
priority_order: Optional[List[ProviderType]] = None,
) -> List[ProviderType]:
"""
获取可用的提供者列表
Args:
priority_order: 优先级顺序,默认为 [IMAP_NEW, IMAP_OLD, GRAPH_API]
Returns:
可用的提供者列表
"""
if priority_order is None:
priority_order = [
ProviderType.IMAP_NEW,
ProviderType.IMAP_OLD,
ProviderType.GRAPH_API,
]
available = []
for provider_type in priority_order:
if self.is_available(provider_type):
available.append(provider_type)
return available
def get_next_available_provider(
self,
priority_order: Optional[List[ProviderType]] = None,
) -> Optional[ProviderType]:
"""
获取下一个可用的提供者
Args:
priority_order: 优先级顺序
Returns:
可用的提供者类型,如果没有返回 None
"""
available = self.get_available_providers(priority_order)
return available[0] if available else None
def force_disable(self, provider_type: ProviderType, duration: Optional[int] = None):
"""
强制禁用提供者
Args:
provider_type: 提供者类型
duration: 禁用时长(秒),默认使用配置值
"""
with self._lock:
health = self._health_status.get(provider_type)
if health:
health.disable(duration or self.disable_duration)
logger.warning(f"{provider_type.value} 已强制禁用")
def force_enable(self, provider_type: ProviderType):
"""
强制启用提供者
Args:
provider_type: 提供者类型
"""
with self._lock:
health = self._health_status.get(provider_type)
if health:
health.enable()
logger.info(f"{provider_type.value} 已启用")
def get_all_health_status(self) -> Dict[str, Any]:
"""
获取所有提供者的健康状态
Returns:
健康状态字典
"""
with self._lock:
return {
provider_type.value: health.to_dict()
for provider_type, health in self._health_status.items()
}
def check_and_recover(self):
"""
检查并恢复被禁用的提供者
如果禁用时间已过,自动恢复提供者
"""
with self._lock:
for provider_type, health in self._health_status.items():
if health.is_disabled():
# 检查是否可以恢复
if health.disabled_until and datetime.now() >= health.disabled_until:
health.enable()
logger.info(f"{provider_type.value} 已自动恢复")
def reset_all(self):
"""重置所有提供者的健康状态"""
with self._lock:
for provider_type in ProviderType:
self._health_status[provider_type] = ProviderHealth(
provider_type=provider_type
) )
logger.info("已重置所有提供者的健康状态")
def is_available(self) -> bool:
with self._lock:
if self._health.is_disabled():
remaining = (
(self._health.disabled_until - datetime.now()).total_seconds()
if self._health.disabled_until
else 0
)
logger.debug(f"IMAP_NEW 已被禁用,剩余 {int(remaining)}s")
return False
return self._health.status != ProviderStatus.DISABLED
def reset(self): class FailoverManager:
"""
故障切换管理器
管理提供者之间的自动切换
"""
def __init__(
self,
health_checker: HealthChecker,
priority_order: Optional[List[ProviderType]] = None,
):
"""
初始化故障切换管理器
Args:
health_checker: 健康检查器
priority_order: 提供者优先级顺序
"""
self.health_checker = health_checker
self.priority_order = priority_order or [
ProviderType.IMAP_NEW,
ProviderType.IMAP_OLD,
ProviderType.GRAPH_API,
]
# 当前使用的提供者索引
self._current_index = 0
self._lock = threading.Lock()
def get_current_provider(self) -> Optional[ProviderType]:
"""
获取当前提供者
Returns:
当前提供者类型,如果没有可用的返回 None
"""
available = self.health_checker.get_available_providers(self.priority_order)
if not available:
return None
with self._lock: with self._lock:
self._health = ProviderHealth(provider_type=ProviderType.IMAP_NEW) # 尝试使用当前索引
if self._current_index < len(available):
return available[self._current_index]
return available[0]
def switch_to_next(self) -> Optional[ProviderType]:
"""
切换到下一个提供者
Returns:
下一个提供者类型,如果没有可用的返回 None
"""
available = self.health_checker.get_available_providers(self.priority_order)
if not available:
return None
with self._lock:
self._current_index = (self._current_index + 1) % len(available)
next_provider = available[self._current_index]
logger.info(f"切换到提供者: {next_provider.value}")
return next_provider
def on_provider_success(self, provider_type: ProviderType):
"""
提供者成功时调用
Args:
provider_type: 提供者类型
"""
self.health_checker.record_success(provider_type)
# 重置索引到成功的提供者
with self._lock:
available = self.health_checker.get_available_providers(self.priority_order)
if provider_type in available:
self._current_index = available.index(provider_type)
def on_provider_failure(self, provider_type: ProviderType, error: str):
"""
提供者失败时调用
Args:
provider_type: 提供者类型
error: 错误信息
"""
self.health_checker.record_failure(provider_type, error)
def get_status(self) -> Dict[str, Any]: def get_status(self) -> Dict[str, Any]:
with self._lock: """
return self._health.to_dict() 获取故障切换状态
Returns:
状态字典
"""
current = self.get_current_provider()
return {
"current_provider": current.value if current else None,
"priority_order": [p.value for p in self.priority_order],
"available_providers": [
p.value for p in self.health_checker.get_available_providers(self.priority_order)
],
"health_status": self.health_checker.get_all_health_status(),
}

View File

@@ -3,18 +3,27 @@ Outlook 提供者模块
""" """
from .base import OutlookProvider, ProviderConfig from .base import OutlookProvider, ProviderConfig
from .imap_old import IMAPOldProvider
from .imap_new import IMAPNewProvider from .imap_new import IMAPNewProvider
from .graph_api import GraphAPIProvider
__all__ = [ __all__ = [
'OutlookProvider', 'OutlookProvider',
'ProviderConfig', 'ProviderConfig',
'IMAPOldProvider',
'IMAPNewProvider', 'IMAPNewProvider',
'GraphAPIProvider',
] ]
# 提供者注册表
PROVIDER_REGISTRY = { PROVIDER_REGISTRY = {
'imap_old': IMAPOldProvider,
'imap_new': IMAPNewProvider, 'imap_new': IMAPNewProvider,
'graph_api': GraphAPIProvider,
} }
def get_provider_class(provider_type: str): def get_provider_class(provider_type: str):
"""获取提供者类"""
return PROVIDER_REGISTRY.get(provider_type) return PROVIDER_REGISTRY.get(provider_type)

View File

@@ -5,7 +5,7 @@ Outlook 提供者抽象基类
import abc import abc
import logging import logging
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional from typing import Dict, Any, List, Optional
from ..base import ProviderType, EmailMessage, ProviderHealth, ProviderStatus from ..base import ProviderType, EmailMessage, ProviderHealth, ProviderStatus
from ..account import OutlookAccount from ..account import OutlookAccount
@@ -18,36 +18,56 @@ logger = logging.getLogger(__name__)
class ProviderConfig: class ProviderConfig:
"""提供者配置""" """提供者配置"""
timeout: int = 30 timeout: int = 30
max_retries: int = 3
proxy_url: Optional[str] = None proxy_url: Optional[str] = None
service_id: Optional[int] = None
# 健康检查配置
health_failure_threshold: int = 3 health_failure_threshold: int = 3
health_disable_duration: int = 300 health_disable_duration: int = 300 # 秒
class OutlookProvider(abc.ABC): class OutlookProvider(abc.ABC):
"""Outlook 提供者抽象基类""" """
Outlook 提供者抽象基类
定义所有提供者必须实现的接口
"""
def __init__( def __init__(
self, self,
account: OutlookAccount, account: OutlookAccount,
config: Optional[ProviderConfig] = None, config: Optional[ProviderConfig] = None,
): ):
"""
初始化提供者
Args:
account: Outlook 账户
config: 提供者配置
"""
self.account = account self.account = account
self.config = config or ProviderConfig() self.config = config or ProviderConfig()
self._health = ProviderHealth(provider_type=ProviderType.IMAP_NEW)
# 健康状态
self._health = ProviderHealth(provider_type=self.provider_type)
# 连接状态
self._connected = False self._connected = False
self._last_error: Optional[str] = None self._last_error: Optional[str] = None
@property @property
@abc.abstractmethod
def provider_type(self) -> ProviderType: def provider_type(self) -> ProviderType:
return ProviderType.IMAP_NEW """获取提供者类型"""
pass
@property @property
def health(self) -> ProviderHealth: def health(self) -> ProviderHealth:
"""获取健康状态"""
return self._health return self._health
@property @property
def is_healthy(self) -> bool: def is_healthy(self) -> bool:
"""检查是否健康"""
return ( return (
self._health.status == ProviderStatus.HEALTHY self._health.status == ProviderStatus.HEALTHY
and not self._health.is_disabled() and not self._health.is_disabled()
@@ -55,14 +75,22 @@ class OutlookProvider(abc.ABC):
@property @property
def is_connected(self) -> bool: def is_connected(self) -> bool:
"""检查是否已连接"""
return self._connected return self._connected
@abc.abstractmethod @abc.abstractmethod
def connect(self) -> bool: def connect(self) -> bool:
"""
连接到服务
Returns:
是否连接成功
"""
pass pass
@abc.abstractmethod @abc.abstractmethod
def disconnect(self): def disconnect(self):
"""断开连接"""
pass pass
@abc.abstractmethod @abc.abstractmethod
@@ -71,44 +99,81 @@ class OutlookProvider(abc.ABC):
count: int = 20, count: int = 20,
only_unseen: bool = True, only_unseen: bool = True,
) -> List[EmailMessage]: ) -> List[EmailMessage]:
"""
获取最近的邮件
Args:
count: 获取数量
only_unseen: 是否只获取未读
Returns:
邮件列表
"""
pass pass
@abc.abstractmethod @abc.abstractmethod
def test_connection(self) -> bool: def test_connection(self) -> bool:
"""
测试连接是否正常
Returns:
连接是否正常
"""
pass pass
def wait_for_new_email_idle(self, timeout: int = 25) -> bool:
"""IMAP IDLE默认不支持子类可覆盖"""
return False
def record_success(self): def record_success(self):
"""记录成功操作"""
self._health.record_success() self._health.record_success()
self._last_error = None self._last_error = None
logger.debug(f"[{self.account.email}] {self.provider_type.value} 操作成功")
def record_failure(self, error: str): def record_failure(self, error: str):
"""记录失败操作"""
self._health.record_failure(error) self._health.record_failure(error)
self._last_error = error self._last_error = error
# 检查是否需要禁用
if self._health.should_disable(self.config.health_failure_threshold): if self._health.should_disable(self.config.health_failure_threshold):
self._health.disable(self.config.health_disable_duration) self._health.disable(self.config.health_disable_duration)
logger.warning( logger.warning(
f"[{self.account.email}] IMAP_NEW 已禁用 " f"[{self.account.email}] {self.provider_type.value} 已禁用 "
f"{self.config.health_disable_duration}s,原因: {error}" f"{self.config.health_disable_duration},原因: {error}"
)
else:
logger.warning(
f"[{self.account.email}] {self.provider_type.value} 操作失败 "
f"({self._health.failure_count}/{self.config.health_failure_threshold}): {error}"
) )
def check_health(self) -> bool: def check_health(self) -> bool:
"""
检查健康状态
Returns:
是否健康可用
"""
# 检查是否被禁用
if self._health.is_disabled(): if self._health.is_disabled():
logger.debug(
f"[{self.account.email}] {self.provider_type.value} 已被禁用,"
f"将在 {self._health.disabled_until} 后恢复"
)
return False return False
return self._health.status in (ProviderStatus.HEALTHY, ProviderStatus.DEGRADED) return self._health.status in (ProviderStatus.HEALTHY, ProviderStatus.DEGRADED)
def __enter__(self): def __enter__(self):
"""上下文管理器入口"""
self.connect() self.connect()
return self return self
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
"""上下文管理器出口"""
self.disconnect() self.disconnect()
return False return False
def __str__(self) -> str: def __str__(self) -> str:
"""字符串表示"""
return f"{self.__class__.__name__}({self.account.email})" return f"{self.__class__.__name__}({self.account.email})"
def __repr__(self) -> str: def __repr__(self) -> str:

View File

@@ -0,0 +1,250 @@
"""
Graph API 提供者
使用 Microsoft Graph REST API
"""
import json
import logging
from typing import List, Optional
from datetime import datetime
from curl_cffi import requests as _requests
from ..base import ProviderType, EmailMessage
from ..account import OutlookAccount
from ..token_manager import TokenManager
from .base import OutlookProvider, ProviderConfig
logger = logging.getLogger(__name__)
class GraphAPIProvider(OutlookProvider):
"""
Graph API 提供者
使用 Microsoft Graph REST API 获取邮件
需要 graph.microsoft.com/.default scope
"""
# Graph API 端点
GRAPH_API_BASE = "https://graph.microsoft.com/v1.0"
MESSAGES_ENDPOINT = "/me/mailFolders/inbox/messages"
@property
def provider_type(self) -> ProviderType:
return ProviderType.GRAPH_API
def __init__(
self,
account: OutlookAccount,
config: Optional[ProviderConfig] = None,
):
super().__init__(account, config)
# Token 管理器
self._token_manager: Optional[TokenManager] = None
# 注意Graph API 必须使用 OAuth2
if not account.has_oauth():
logger.warning(
f"[{self.account.email}] Graph API 提供者需要 OAuth2 配置 "
f"(client_id + refresh_token)"
)
def connect(self) -> bool:
"""
验证连接(获取 Token
Returns:
是否连接成功
"""
if not self.account.has_oauth():
error = "Graph API 需要 OAuth2 配置"
self.record_failure(error)
logger.error(f"[{self.account.email}] {error}")
return False
if not self._token_manager:
self._token_manager = TokenManager(
self.account,
ProviderType.GRAPH_API,
self.config.proxy_url,
self.config.timeout,
)
# 尝试获取 Token
token = self._token_manager.get_access_token()
if token:
self._connected = True
self.record_success()
logger.info(f"[{self.account.email}] Graph API 连接成功")
return True
return False
def disconnect(self):
"""断开连接(清除状态)"""
self._connected = False
def get_recent_emails(
self,
count: int = 20,
only_unseen: bool = True,
) -> List[EmailMessage]:
"""
获取最近的邮件
Args:
count: 获取数量
only_unseen: 是否只获取未读
Returns:
邮件列表
"""
if not self._connected:
if not self.connect():
return []
try:
# 获取 Access Token
token = self._token_manager.get_access_token()
if not token:
self.record_failure("无法获取 Access Token")
return []
# 构建 API 请求
url = f"{self.GRAPH_API_BASE}{self.MESSAGES_ENDPOINT}"
params = {
"$top": count,
"$select": "id,subject,from,toRecipients,receivedDateTime,isRead,hasAttachments,bodyPreview,body",
"$orderby": "receivedDateTime desc",
}
# 只获取未读邮件
if only_unseen:
params["$filter"] = "isRead eq false"
# 构建代理配置
proxies = None
if self.config.proxy_url:
proxies = {"http": self.config.proxy_url, "https": self.config.proxy_url}
# 发送请求curl_cffi 自动对 params 进行 URL 编码)
resp = _requests.get(
url,
params=params,
headers={
"Authorization": f"Bearer {token}",
"Accept": "application/json",
"Prefer": "outlook.body-content-type='text'",
},
proxies=proxies,
timeout=self.config.timeout,
impersonate="chrome110",
)
if resp.status_code == 401:
# Token 无 Graph 权限client_id 未授权),清除缓存但不记录健康失败
# 避免因权限不足导致健康检查器禁用该提供者,影响其他账户
if self._token_manager:
self._token_manager.clear_cache()
self._connected = False
logger.warning(f"[{self.account.email}] Graph API 返回 401client_id 可能无 Graph 权限,跳过")
return []
if resp.status_code != 200:
error_body = resp.text[:200]
self.record_failure(f"HTTP {resp.status_code}: {error_body}")
logger.error(f"[{self.account.email}] Graph API 请求失败: HTTP {resp.status_code}")
return []
data = resp.json()
# 解析邮件
messages = data.get("value", [])
emails = []
for msg in messages:
try:
email_msg = self._parse_graph_message(msg)
if email_msg:
emails.append(email_msg)
except Exception as e:
logger.warning(f"[{self.account.email}] 解析 Graph API 邮件失败: {e}")
self.record_success()
return emails
except Exception as e:
self.record_failure(str(e))
logger.error(f"[{self.account.email}] Graph API 获取邮件失败: {e}")
return []
def _parse_graph_message(self, msg: dict) -> Optional[EmailMessage]:
"""
解析 Graph API 消息
Args:
msg: Graph API 消息对象
Returns:
EmailMessage 对象
"""
# 解析发件人
from_info = msg.get("from", {})
sender_info = from_info.get("emailAddress", {})
sender = sender_info.get("address", "")
# 解析收件人
recipients = []
for recipient in msg.get("toRecipients", []):
addr_info = recipient.get("emailAddress", {})
addr = addr_info.get("address", "")
if addr:
recipients.append(addr)
# 解析日期
received_at = None
received_timestamp = 0
try:
date_str = msg.get("receivedDateTime", "")
if date_str:
# ISO 8601 格式
received_at = datetime.fromisoformat(date_str.replace("Z", "+00:00"))
received_timestamp = int(received_at.timestamp())
except Exception:
pass
# 获取正文
body_info = msg.get("body", {})
body = body_info.get("content", "")
body_preview = msg.get("bodyPreview", "")
return EmailMessage(
id=msg.get("id", ""),
subject=msg.get("subject", ""),
sender=sender,
recipients=recipients,
body=body,
body_preview=body_preview,
received_at=received_at,
received_timestamp=received_timestamp,
is_read=msg.get("isRead", False),
has_attachments=msg.get("hasAttachments", False),
)
def test_connection(self) -> bool:
"""
测试 Graph API 连接
Returns:
连接是否正常
"""
try:
# 尝试获取一封邮件来测试连接
emails = self.get_recent_emails(count=1, only_unseen=False)
return True
except Exception as e:
logger.warning(f"[{self.account.email}] Graph API 连接测试失败: {e}")
return False

View File

@@ -1,261 +1,200 @@
""" """
新版 IMAP 提供者 新版 IMAP 提供者
使用 outlook.live.com:993 + consumers Token 端点 使用 outlook.live.com 服务器和 login.microsoftonline.com/consumers Token 端点
引入进程级 IMAPConnectionPool 连接复用和 IMAP IDLE
""" """
import email import email
import imaplib import imaplib
import logging import logging
import select
import time
import threading
from datetime import datetime, timedelta, timezone
from email.header import decode_header from email.header import decode_header
from email.utils import parsedate_to_datetime from email.utils import parsedate_to_datetime
from typing import Dict, List, Optional from typing import List, Optional
from ..base import EmailMessage from ..base import ProviderType, EmailMessage
from ..account import OutlookAccount from ..account import OutlookAccount
from ..token_manager import TokenManager from ..token_manager import TokenManager
from .base import OutlookProvider, ProviderConfig from .base import OutlookProvider, ProviderConfig
from .imap_old import IMAPOldProvider
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class IMAPConnectionPool:
"""进程级 IMAP 连接池,按 email 复用 IMAP4_SSL 连接"""
IMAP_HOST = "outlook.live.com"
IMAP_PORT = 993
def __init__(self):
self._connections: Dict[str, imaplib.IMAP4_SSL] = {}
self._lock = threading.Lock()
def get_connection(
self,
email_addr: str,
token: str,
timeout: int = 30,
) -> imaplib.IMAP4_SSL:
"""获取或新建 IMAP 连接"""
# 先在锁内检查现有连接
with self._lock:
conn = self._connections.get(email_addr)
if conn:
try:
conn.noop()
return conn
except Exception:
self._close_one(email_addr)
# 标记为「建连中」,防止重复建连
self._connections[email_addr] = None
# 锁外建立新连接(耗时操作不持锁)
try:
new_conn = imaplib.IMAP4_SSL(self.IMAP_HOST, self.IMAP_PORT, timeout=timeout)
auth_str = f"user={email_addr}\x01auth=Bearer {token}\x01\x01"
new_conn.authenticate("XOAUTH2", lambda _: auth_str.encode("utf-8"))
except Exception:
with self._lock:
# 建连失败,清除占位
if self._connections.get(email_addr) is None:
del self._connections[email_addr]
raise
with self._lock:
self._connections[email_addr] = new_conn
logger.debug(f"[{email_addr}] IMAP 新连接已建立")
return new_conn
def invalidate(self, email_addr: str):
"""废弃连接(认证失败或连接异常时调用)"""
with self._lock:
self._close_one(email_addr)
def _close_one(self, email_addr: str):
conn = self._connections.pop(email_addr, None)
if conn:
try:
conn.logout()
except Exception:
pass
# 模块级单例连接池
_imap_pool = IMAPConnectionPool()
class IMAPNewProvider(OutlookProvider): class IMAPNewProvider(OutlookProvider):
""" """
新版 IMAP 提供者 新版 IMAP 提供者
通过连接池复用连接,支持 IMAP IDLE 使用 outlook.live.com:993 和 login.microsoftonline.com/consumers Token 端点
需要 IMAP.AccessAsUser.All scope
""" """
# IMAP 服务器配置
IMAP_HOST = "outlook.live.com" IMAP_HOST = "outlook.live.com"
IMAP_PORT = 993 IMAP_PORT = 993
@property
def provider_type(self) -> ProviderType:
return ProviderType.IMAP_NEW
def __init__( def __init__(
self, self,
account: OutlookAccount, account: OutlookAccount,
config: Optional[ProviderConfig] = None, config: Optional[ProviderConfig] = None,
): ):
super().__init__(account, config) super().__init__(account, config)
self._conn: Optional[imaplib.IMAP4_SSL] = None
self._token_manager: Optional[TokenManager] = None
self._idle_tag_counter = 0
self._idle_tag_lock = threading.Lock()
# IMAP 连接
self._conn: Optional[imaplib.IMAP4_SSL] = None
# Token 管理器
self._token_manager: Optional[TokenManager] = None
# 注意:新版 IMAP 必须使用 OAuth2
if not account.has_oauth(): if not account.has_oauth():
logger.warning( logger.warning(
f"[{self.account.email}] IMAP_NEW 需要 OAuth2 配置 (client_id + refresh_token)" f"[{self.account.email}] 新版 IMAP 提供者需要 OAuth2 配置 "
f"(client_id + refresh_token)"
) )
def _get_token_manager(self) -> TokenManager:
if not self._token_manager:
self._token_manager = TokenManager(
self.account,
proxy_url=self.config.proxy_url,
timeout=self.config.timeout,
service_id=self.config.service_id,
)
return self._token_manager
def _next_idle_tag(self) -> str:
"""生成唯一 IDLE tag避免使用私有 _new_tag"""
with self._idle_tag_lock:
self._idle_tag_counter += 1
return f"IDLE{self._idle_tag_counter:04d}"
def connect(self) -> bool: def connect(self) -> bool:
"""从连接池获取连接""" """
连接到 IMAP 服务器
Returns:
是否连接成功
"""
if self._connected and self._conn:
try:
self._conn.noop()
return True
except Exception:
self.disconnect()
# 新版 IMAP 必须使用 OAuth2无 OAuth 时静默跳过,不记录健康失败
if not self.account.has_oauth(): if not self.account.has_oauth():
logger.debug(f"[{self.account.email}] 跳过 IMAP_NEW无 OAuth") logger.debug(f"[{self.account.email}] 跳过 IMAP_NEW无 OAuth")
return False return False
try: try:
tm = self._get_token_manager() logger.debug(f"[{self.account.email}] 正在连接 IMAP ({self.IMAP_HOST})...")
token = tm.get_access_token()
if not token:
logger.error(f"[{self.account.email}] 获取 IMAP Token 失败")
return False
self._conn = _imap_pool.get_connection( # 创建连接
self.account.email, token, self.config.timeout self._conn = imaplib.IMAP4_SSL(
self.IMAP_HOST,
self.IMAP_PORT,
timeout=self.config.timeout,
) )
self._connected = True
self.record_success()
logger.debug(f"[{self.account.email}] IMAP 连接就绪(连接池)")
return True
except imaplib.IMAP4.error as e: # XOAUTH2 认证
err = str(e) if self._authenticate_xoauth2():
# Token 失效时强制刷新并重试一次 self._connected = True
if "AUTHENTICATE" in err or "invalid" in err.lower(): self.record_success()
logger.warning(f"[{self.account.email}] XOAUTH2 认证失败,尝试刷新 Token") logger.info(f"[{self.account.email}] 新版 IMAP 连接成功 (XOAUTH2)")
_imap_pool.invalidate(self.account.email) return True
try:
tm = self._get_token_manager()
token = tm.get_access_token(force_refresh=True)
if token:
self._conn = _imap_pool.get_connection(
self.account.email, token, self.config.timeout
)
self._connected = True
self.record_success()
return True
except Exception as retry_e:
self.record_failure(str(retry_e))
logger.error(f"[{self.account.email}] Token 刷新后重连失败: {retry_e}")
else:
self.record_failure(err)
logger.error(f"[{self.account.email}] IMAP 连接失败: {e}")
self._connected = False
self._conn = None
return False return False
except Exception as e: except Exception as e:
self.disconnect()
self.record_failure(str(e)) self.record_failure(str(e))
logger.error(f"[{self.account.email}] IMAP 连接失败: {e}") logger.error(f"[{self.account.email}] 新版 IMAP 连接失败: {e}")
self._connected = False return False
self._conn = None
def _authenticate_xoauth2(self) -> bool:
"""
使用 XOAUTH2 认证
Returns:
是否认证成功
"""
if not self._token_manager:
self._token_manager = TokenManager(
self.account,
ProviderType.IMAP_NEW,
self.config.proxy_url,
self.config.timeout,
)
# 获取 Access Token
token = self._token_manager.get_access_token()
if not token:
logger.error(f"[{self.account.email}] 获取 IMAP Token 失败")
return False
try:
# 构建 XOAUTH2 认证字符串
auth_string = f"user={self.account.email}\x01auth=Bearer {token}\x01\x01"
self._conn.authenticate("XOAUTH2", lambda _: auth_string.encode("utf-8"))
return True
except Exception as e:
logger.error(f"[{self.account.email}] XOAUTH2 认证异常: {e}")
# 清除缓存的 Token
self._token_manager.clear_cache()
return False return False
def disconnect(self): def disconnect(self):
"""归还连接池(不 logout保持复用""" """断开 IMAP 连接"""
if self._conn:
try:
self._conn.close()
except Exception:
pass
try:
self._conn.logout()
except Exception:
pass
self._conn = None
self._connected = False self._connected = False
self._conn = None
def get_recent_emails( def get_recent_emails(
self, self,
count: int = 20, count: int = 20,
only_unseen: bool = True, only_unseen: bool = True,
since_minutes: Optional[int] = None,
folders: Optional[List[str]] = None,
) -> List[EmailMessage]: ) -> List[EmailMessage]:
""" """
获取最近的邮件,支持多文件夹搜索(合并去重)。 获取最近的邮件
搜索策略: Args:
- since_minutes 指定时:用 SINCE 日期 + ALL 搜索最近N分钟内的邮件不受已读/未读限制) count: 获取数量
- only_unseen=True 且未指定 since_minutes搜索 UNSEEN only_unseen: 是否只获取未读
- only_unseen=False 且未指定 since_minutes搜索全部取最近 count 封)
- folders 默认为 ["INBOX"],可传入多个文件夹(如 ["INBOX", "Junk Email"] Returns:
邮件列表
""" """
if not self._connected: if not self._connected:
if not self.connect(): if not self.connect():
return [] return []
if folders is None: try:
folders = ["INBOX"] # 选择收件箱
self._conn.select("INBOX", readonly=True)
all_emails: List[EmailMessage] = [] # 搜索邮件
seen_ids: set = set() flag = "UNSEEN" if only_unseen else "ALL"
status, data = self._conn.search(None, flag)
for folder in folders: if status != "OK" or not data or not data[0]:
try: return []
status, _ = self._conn.select(folder, readonly=True)
if status != "OK":
logger.debug(f"[{self.account.email}] 文件夹 {folder} 不存在或无法访问,跳过")
continue
if since_minutes is not None: # 获取最新的邮件 ID
since_dt = datetime.now(timezone.utc) - timedelta(minutes=since_minutes) ids = data[0].split()
since_str = since_dt.strftime("%d-%b-%Y") recent_ids = ids[-count:][::-1]
status, data = self._conn.search(None, f"SINCE {since_str}")
elif only_unseen:
status, data = self._conn.search(None, "UNSEEN")
else:
status, data = self._conn.search(None, "ALL")
if status != "OK" or not data or not data[0]: emails = []
continue for msg_id in recent_ids:
try:
email_msg = self._fetch_email(msg_id)
if email_msg:
emails.append(email_msg)
except Exception as e:
logger.warning(f"[{self.account.email}] 解析邮件失败 (ID: {msg_id}): {e}")
ids = data[0].split() return emails
recent_ids = ids[-count:][::-1] # 取最新的 count 封,倒序(最新在前)
for msg_id in recent_ids: except Exception as e:
try: self.record_failure(str(e))
msg = self._fetch_email(msg_id) logger.error(f"[{self.account.email}] 获取邮件失败: {e}")
if msg and msg.id not in seen_ids: return []
seen_ids.add(msg.id)
all_emails.append(msg)
except Exception as e:
logger.warning(f"[{self.account.email}] 解析邮件失败 (ID: {msg_id}, folder: {folder}): {e}")
except Exception as e:
self.record_failure(str(e))
logger.warning(f"[{self.account.email}] 搜索文件夹 {folder} 失败: {e}")
_imap_pool.invalidate(self.account.email)
self._connected = False
self._conn = None
break
# 按收信时间降序排列,截取 count 封
all_emails.sort(key=lambda m: m.received_timestamp, reverse=True)
return all_emails[:count]
def _fetch_email(self, msg_id: bytes) -> Optional[EmailMessage]: def _fetch_email(self, msg_id: bytes) -> Optional[EmailMessage]:
"""获取并解析单封邮件""" """获取并解析单封邮件"""
@@ -272,193 +211,21 @@ class IMAPNewProvider(OutlookProvider):
if not raw: if not raw:
return None return None
return _parse_email(raw) return self._parse_email(raw)
def wait_for_new_email_idle(self, timeout: int = 25) -> bool: @staticmethod
""" def _parse_email(raw: bytes) -> EmailMessage:
RFC 2177 IMAP IDLE 实现。 """解析原始邮件"""
发送 IDLE 命令,等待服务器推送 EXISTS/RECENT然后发送 DONE。 # 使用旧版提供者的解析方法
Returns True 表示有新邮件推送False 表示超时或异常(调用方降级轮询)。 return IMAPOldProvider._parse_email(raw)
"""
if not self._connected:
if not self.connect():
return False
try:
self._conn.select("INBOX", readonly=True)
except Exception as e:
logger.warning(f"[{self.account.email}] IDLE 前 SELECT 失败: {e}")
return False
tag = self._next_idle_tag()
sock = self._conn.socket()
logger.info(f"[{self.account.email}] 进入 IMAP IDLE 等待模式(超时 {timeout}stag={tag}")
try:
# 发送 IDLE 命令
self._conn.send(f"{tag} IDLE\r\n".encode())
# 等待 "+" 延续响应(服务端确认进入 IDLE
deadline = time.time() + min(10.0, timeout)
buf = b""
got_continuation = False
while time.time() < deadline:
ready = select.select([sock], [], [], min(2.0, deadline - time.time()))
if ready[0]:
chunk = sock.recv(4096)
if not chunk:
break
buf += chunk
if b"+ " in buf or b"+\r\n" in buf:
got_continuation = True
break
if not got_continuation:
logger.warning(f"[{self.account.email}] 未收到 IDLE 延续响应,放弃")
return False
# 等待 EXISTS / RECENT 推送
got_new = False
buf = b""
deadline = time.time() + timeout
while time.time() < deadline:
remaining = deadline - time.time()
if remaining <= 0:
break
ready = select.select([sock], [], [], min(2.0, remaining))
if ready[0]:
chunk = sock.recv(4096)
if not chunk:
break
buf += chunk
if b"EXISTS" in buf or b"RECENT" in buf:
logger.debug(f"[{self.account.email}] IDLE 收到新邮件推送")
got_new = True
break
return got_new
except Exception as e:
logger.warning(f"[{self.account.email}] IMAP IDLE 异常: {e}")
return False
finally:
# 发送 DONE 结束 IDLE并排空服务端响应
try:
self._conn.send(b"DONE\r\n")
drain_deadline = time.time() + 5
drain_buf = b""
tag_end = f"{tag} OK".encode()
tag_no = f"{tag} NO".encode()
tag_bad = f"{tag} BAD".encode()
while time.time() < drain_deadline:
ready = select.select([sock], [], [], 1.0)
if not ready[0]:
break
chunk = sock.recv(4096)
if not chunk:
break
drain_buf += chunk
if any(t in drain_buf for t in (tag_end, tag_no, tag_bad)):
break
except Exception:
_imap_pool.invalidate(self.account.email)
self._connected = False
self._conn = None
def test_connection(self) -> bool: def test_connection(self) -> bool:
"""测试 IMAP 连接""" """测试 IMAP 连接"""
try: try:
with self: with self:
self._conn.select("INBOX", readonly=True) self._conn.select("INBOX", readonly=True)
self._conn.search(None, "ALL")
return True return True
except Exception as e: except Exception as e:
logger.warning(f"[{self.account.email}] IMAP 连接测试失败: {e}") logger.warning(f"[{self.account.email}] 新版 IMAP 连接测试失败: {e}")
return False return False
def _parse_email(raw: bytes) -> EmailMessage:
"""解析原始邮件为 EmailMessage优先 text/plain次选 text/html"""
msg = email.message_from_bytes(raw)
def _decode(val):
if not val:
return ""
parts = decode_header(str(val))
result = ""
for part, charset in parts:
if isinstance(part, bytes):
try:
result += part.decode(charset or "utf-8", errors="replace")
except (LookupError, UnicodeDecodeError):
result += part.decode("utf-8", errors="replace")
else:
result += str(part)
return result
subject = _decode(msg.get("Subject", ""))
sender = _decode(msg.get("From", ""))
recipients = [_decode(msg.get("To", ""))]
received_at = None
received_ts = 0
date_str = msg.get("Date", "")
if date_str:
try:
received_at = parsedate_to_datetime(date_str)
received_ts = int(received_at.timestamp())
except Exception:
pass
# 提取正文:优先 text/plain次选 text/html
plain_body = ""
html_body = ""
if msg.is_multipart():
for part in msg.walk():
ct = part.get_content_type()
cd = str(part.get("Content-Disposition", ""))
if "attachment" in cd.lower():
continue
try:
charset = part.get_content_charset() or "utf-8"
payload = part.get_payload(decode=True)
if not payload:
continue
decoded = payload.decode(charset, errors="replace")
if ct == "text/plain" and not plain_body:
plain_body = decoded
elif ct == "text/html" and not html_body:
html_body = decoded
except Exception:
pass
else:
try:
charset = msg.get_content_charset() or "utf-8"
payload = msg.get_payload(decode=True)
if payload:
ct = msg.get_content_type()
decoded = payload.decode(charset, errors="replace")
if ct == "text/plain":
plain_body = decoded
else:
html_body = decoded
except Exception:
pass
body = plain_body or html_body
body_preview = body[:200].strip()
msg_id = msg.get("Message-ID", "").strip("<>")
if not msg_id:
msg_id = f"{sender}_{received_ts}"
return EmailMessage(
id=msg_id,
subject=subject,
sender=sender,
recipients=recipients,
body=body,
body_preview=body_preview,
received_at=received_at,
received_timestamp=received_ts,
)

View File

@@ -0,0 +1,345 @@
"""
旧版 IMAP 提供者
使用 outlook.office365.com 服务器和 login.live.com Token 端点
"""
import email
import imaplib
import logging
from email.header import decode_header
from email.utils import parsedate_to_datetime
from typing import List, Optional
from ..base import ProviderType, EmailMessage
from ..account import OutlookAccount
from ..token_manager import TokenManager
from .base import OutlookProvider, ProviderConfig
logger = logging.getLogger(__name__)
class IMAPOldProvider(OutlookProvider):
"""
旧版 IMAP 提供者
使用 outlook.office365.com:993 和 login.live.com Token 端点
"""
# IMAP 服务器配置
IMAP_HOST = "outlook.office365.com"
IMAP_PORT = 993
@property
def provider_type(self) -> ProviderType:
return ProviderType.IMAP_OLD
def __init__(
self,
account: OutlookAccount,
config: Optional[ProviderConfig] = None,
):
super().__init__(account, config)
# IMAP 连接
self._conn: Optional[imaplib.IMAP4_SSL] = None
# Token 管理器
self._token_manager: Optional[TokenManager] = None
def connect(self) -> bool:
"""
连接到 IMAP 服务器
Returns:
是否连接成功
"""
if self._connected and self._conn:
# 检查现有连接
try:
self._conn.noop()
return True
except Exception:
self.disconnect()
try:
logger.debug(f"[{self.account.email}] 正在连接 IMAP ({self.IMAP_HOST})...")
# 创建连接
self._conn = imaplib.IMAP4_SSL(
self.IMAP_HOST,
self.IMAP_PORT,
timeout=self.config.timeout,
)
# 尝试 XOAUTH2 认证
if self.account.has_oauth():
if self._authenticate_xoauth2():
self._connected = True
self.record_success()
logger.info(f"[{self.account.email}] IMAP 连接成功 (XOAUTH2)")
return True
else:
logger.warning(f"[{self.account.email}] XOAUTH2 认证失败,尝试密码认证")
# 密码认证
if self.account.password:
self._conn.login(self.account.email, self.account.password)
self._connected = True
self.record_success()
logger.info(f"[{self.account.email}] IMAP 连接成功 (密码认证)")
return True
raise ValueError("没有可用的认证方式")
except Exception as e:
self.disconnect()
self.record_failure(str(e))
logger.error(f"[{self.account.email}] IMAP 连接失败: {e}")
return False
def _authenticate_xoauth2(self) -> bool:
"""
使用 XOAUTH2 认证
Returns:
是否认证成功
"""
if not self._token_manager:
self._token_manager = TokenManager(
self.account,
ProviderType.IMAP_OLD,
self.config.proxy_url,
self.config.timeout,
)
# 获取 Access Token
token = self._token_manager.get_access_token()
if not token:
return False
try:
# 构建 XOAUTH2 认证字符串
auth_string = f"user={self.account.email}\x01auth=Bearer {token}\x01\x01"
self._conn.authenticate("XOAUTH2", lambda _: auth_string.encode("utf-8"))
return True
except Exception as e:
logger.debug(f"[{self.account.email}] XOAUTH2 认证异常: {e}")
# 清除缓存的 Token
self._token_manager.clear_cache()
return False
def disconnect(self):
"""断开 IMAP 连接"""
if self._conn:
try:
self._conn.close()
except Exception:
pass
try:
self._conn.logout()
except Exception:
pass
self._conn = None
self._connected = False
def get_recent_emails(
self,
count: int = 20,
only_unseen: bool = True,
) -> List[EmailMessage]:
"""
获取最近的邮件
Args:
count: 获取数量
only_unseen: 是否只获取未读
Returns:
邮件列表
"""
if not self._connected:
if not self.connect():
return []
try:
# 选择收件箱
self._conn.select("INBOX", readonly=True)
# 搜索邮件
flag = "UNSEEN" if only_unseen else "ALL"
status, data = self._conn.search(None, flag)
if status != "OK" or not data or not data[0]:
return []
# 获取最新的邮件 ID
ids = data[0].split()
recent_ids = ids[-count:][::-1] # 倒序,最新的在前
emails = []
for msg_id in recent_ids:
try:
email_msg = self._fetch_email(msg_id)
if email_msg:
emails.append(email_msg)
except Exception as e:
logger.warning(f"[{self.account.email}] 解析邮件失败 (ID: {msg_id}): {e}")
return emails
except Exception as e:
self.record_failure(str(e))
logger.error(f"[{self.account.email}] 获取邮件失败: {e}")
return []
def _fetch_email(self, msg_id: bytes) -> Optional[EmailMessage]:
"""
获取并解析单封邮件
Args:
msg_id: 邮件 ID
Returns:
EmailMessage 对象,失败返回 None
"""
status, data = self._conn.fetch(msg_id, "(RFC822)")
if status != "OK" or not data or not data[0]:
return None
# 获取原始邮件内容
raw = b""
for part in data:
if isinstance(part, tuple) and len(part) > 1:
raw = part[1]
break
if not raw:
return None
return self._parse_email(raw)
@staticmethod
def _parse_email(raw: bytes) -> EmailMessage:
"""
解析原始邮件
Args:
raw: 原始邮件数据
Returns:
EmailMessage 对象
"""
# 移除 BOM
if raw.startswith(b"\xef\xbb\xbf"):
raw = raw[3:]
msg = email.message_from_bytes(raw)
# 解析邮件头
subject = IMAPOldProvider._decode_header(msg.get("Subject", ""))
sender = IMAPOldProvider._decode_header(msg.get("From", ""))
to = IMAPOldProvider._decode_header(msg.get("To", ""))
delivered_to = IMAPOldProvider._decode_header(msg.get("Delivered-To", ""))
x_original_to = IMAPOldProvider._decode_header(msg.get("X-Original-To", ""))
date_str = IMAPOldProvider._decode_header(msg.get("Date", ""))
# 提取正文
body = IMAPOldProvider._extract_body(msg)
# 解析日期
received_timestamp = 0
received_at = None
try:
if date_str:
received_at = parsedate_to_datetime(date_str)
received_timestamp = int(received_at.timestamp())
except Exception:
pass
# 构建收件人列表
recipients = [r for r in [to, delivered_to, x_original_to] if r]
return EmailMessage(
id=msg.get("Message-ID", ""),
subject=subject,
sender=sender,
recipients=recipients,
body=body,
received_at=received_at,
received_timestamp=received_timestamp,
is_read=False, # 搜索的是未读邮件
raw_data=raw[:500] if len(raw) > 500 else raw,
)
@staticmethod
def _decode_header(header: str) -> str:
"""解码邮件头"""
if not header:
return ""
parts = []
for chunk, encoding in decode_header(header):
if isinstance(chunk, bytes):
try:
decoded = chunk.decode(encoding or "utf-8", errors="replace")
parts.append(decoded)
except Exception:
parts.append(chunk.decode("utf-8", errors="replace"))
else:
parts.append(str(chunk))
return "".join(parts).strip()
@staticmethod
def _extract_body(msg) -> str:
"""提取邮件正文"""
import html as html_module
import re
texts = []
parts = msg.walk() if msg.is_multipart() else [msg]
for part in parts:
content_type = part.get_content_type()
if content_type not in ("text/plain", "text/html"):
continue
payload = part.get_payload(decode=True)
if not payload:
continue
charset = part.get_content_charset() or "utf-8"
try:
text = payload.decode(charset, errors="replace")
except LookupError:
text = payload.decode("utf-8", errors="replace")
# 如果是 HTML移除标签
if "<html" in text.lower():
text = re.sub(r"<[^>]+>", " ", text)
texts.append(text)
# 合并并清理文本
combined = " ".join(texts)
combined = html_module.unescape(combined)
combined = re.sub(r"\s+", " ", combined).strip()
return combined
def test_connection(self) -> bool:
"""
测试 IMAP 连接
Returns:
连接是否正常
"""
try:
with self:
self._conn.select("INBOX", readonly=True)
self._conn.search(None, "ALL")
return True
except Exception as e:
logger.warning(f"[{self.account.email}] IMAP 连接测试失败: {e}")
return False

View File

@@ -1,6 +1,6 @@
""" """
Outlook 邮箱服务主类(简化版) Outlook 邮箱服务主类
单一 IMAP_NEW Provider + 邮件缓存 + IMAP IDLE 支持 支持多种 IMAP/API 连接方式,自动故障切换
""" """
import logging import logging
@@ -8,24 +8,34 @@ import threading
import time import time
from typing import Optional, Dict, Any, List from typing import Optional, Dict, Any, List
from ..base import BaseEmailService, EmailServiceError, EmailServiceType from ..base import BaseEmailService, EmailServiceError, EmailServiceStatus, EmailServiceType
from ...config.constants import EmailServiceType as ServiceType from ...config.constants import EmailServiceType as ServiceType
from ...config.settings import get_settings from ...config.settings import get_settings
from .account import OutlookAccount from .account import OutlookAccount
from .base import EmailMessage from .base import ProviderType, EmailMessage
from .email_parser import get_email_parser from .email_parser import EmailParser, get_email_parser
from .health_checker import HealthChecker from .health_checker import HealthChecker, FailoverManager
from .providers.base import ProviderConfig from .providers.base import OutlookProvider, ProviderConfig
from .providers.imap_old import IMAPOldProvider
from .providers.imap_new import IMAPNewProvider from .providers.imap_new import IMAPNewProvider
from .providers.graph_api import GraphAPIProvider
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# 验证码搜索的文件夹列表(同时搜索收件箱和垃圾箱)
_OUTLOOK_SEARCH_FOLDERS = ["INBOX", "Junk Email"] # 默认提供者优先级
# IMAP_OLD 最兼容(只需 login.live.com tokenIMAP_NEW 次之Graph API 最后
# 原因:部分 client_id 没有 Graph API 权限,但有 IMAP 权限
DEFAULT_PROVIDER_PRIORITY = [
ProviderType.IMAP_OLD,
ProviderType.IMAP_NEW,
ProviderType.GRAPH_API,
]
def _get_code_settings() -> dict: def get_email_code_settings() -> dict:
"""获取验证码等待配置"""
settings = get_settings() settings = get_settings()
return { return {
"timeout": settings.email_code_timeout, "timeout": settings.email_code_timeout,
@@ -33,58 +43,56 @@ def _get_code_settings() -> dict:
} }
class _EmailCache:
"""轻量级邮件内存缓存TTL=60s减少重复 IMAP 请求)"""
TTL = 60
def __init__(self):
self._cache: Dict[str, tuple] = {} # email -> (timestamp, List[EmailMessage])
self._lock = threading.Lock()
def get(self, email: str) -> Optional[List[EmailMessage]]:
with self._lock:
entry = self._cache.get(email)
if entry and time.time() - entry[0] < self.TTL:
return entry[1]
return None
def set(self, email: str, messages: List[EmailMessage]):
with self._lock:
self._cache[email] = (time.time(), messages)
def invalidate(self, email: str):
with self._lock:
self._cache.pop(email, None)
class OutlookService(BaseEmailService): class OutlookService(BaseEmailService):
""" """
Outlook 邮箱服务 Outlook 邮箱服务
使用单一 IMAP_NEW Provider支持连接池复用和 IMAP IDLE 支持多种 IMAP/API 连接方式,自动故障切换
""" """
def __init__(self, config: Dict[str, Any] = None, name: str = None): def __init__(self, config: Dict[str, Any] = None, name: str = None):
"""
初始化 Outlook 服务
Args:
config: 配置字典,支持以下键:
- accounts: Outlook 账户列表
- provider_priority: 提供者优先级列表
- health_failure_threshold: 连续失败次数阈值
- health_disable_duration: 禁用时长(秒)
- timeout: 请求超时时间
- proxy_url: 代理 URL
name: 服务名称
"""
super().__init__(ServiceType.OUTLOOK, name) super().__init__(ServiceType.OUTLOOK, name)
# 默认配置
default_config = { default_config = {
"accounts": [], "accounts": [],
"provider_priority": [p.value for p in DEFAULT_PROVIDER_PRIORITY],
"health_failure_threshold": 5, "health_failure_threshold": 5,
"health_disable_duration": 60, "health_disable_duration": 60,
"timeout": 30, "timeout": 30,
"proxy_url": None, "proxy_url": None,
} }
self.config = {**default_config, **(config or {})} self.config = {**default_config, **(config or {})}
# 解析提供者优先级
self.provider_priority = [
ProviderType(p) for p in self.config.get("provider_priority", [])
]
if not self.provider_priority:
self.provider_priority = DEFAULT_PROVIDER_PRIORITY
# 提供者配置
self.provider_config = ProviderConfig( self.provider_config = ProviderConfig(
timeout=self.config.get("timeout", 30), timeout=self.config.get("timeout", 30),
proxy_url=self.config.get("proxy_url"), proxy_url=self.config.get("proxy_url"),
service_id=self.config.get("service_id"),
health_failure_threshold=self.config.get("health_failure_threshold", 3), health_failure_threshold=self.config.get("health_failure_threshold", 3),
health_disable_duration=self.config.get("health_disable_duration", 300), health_disable_duration=self.config.get("health_disable_duration", 300),
) )
# 获取默认 client_id # 获取默认 client_id(供无 client_id 的账户使用)
try: try:
_default_client_id = get_settings().outlook_default_client_id _default_client_id = get_settings().outlook_default_client_id
except Exception: except Exception:
@@ -95,121 +103,194 @@ class OutlookService(BaseEmailService):
self._current_account_index = 0 self._current_account_index = 0
self._account_lock = threading.Lock() self._account_lock = threading.Lock()
# 支持两种配置格式
if "email" in self.config and "password" in self.config: if "email" in self.config and "password" in self.config:
account = OutlookAccount.from_config(self.config) account = OutlookAccount.from_config(self.config)
if not account.client_id and _default_client_id: if not account.client_id and _default_client_id:
account.client_id = _default_client_id account.client_id = _default_client_id
if account.validate(): if account.validate():
if not account.has_oauth(): self.accounts.append(account)
logger.warning(
f"[{account.email}] 跳过IMAP_NEW 仅支持 OAuth2"
f"请配置 client_id 和 refresh_token"
)
else:
self.accounts.append(account)
else: else:
for ac in self.config.get("accounts", []): for account_config in self.config.get("accounts", []):
account = OutlookAccount.from_config(ac) account = OutlookAccount.from_config(account_config)
if not account.client_id and _default_client_id: if not account.client_id and _default_client_id:
account.client_id = _default_client_id account.client_id = _default_client_id
if account.validate(): if account.validate():
if not account.has_oauth(): self.accounts.append(account)
logger.warning(
f"[{account.email}] 跳过IMAP_NEW 仅支持 OAuth2"
f"请配置 client_id 和 refresh_token"
)
else:
self.accounts.append(account)
if not self.accounts: if not self.accounts:
logger.warning("未配置有效的 Outlook 账户(需要 client_id + refresh_token") logger.warning("未配置有效的 Outlook 账户")
# 健康检查器 # 健康检查器和故障切换管理器
self.health_checker = HealthChecker( self.health_checker = HealthChecker(
failure_threshold=self.provider_config.health_failure_threshold, failure_threshold=self.provider_config.health_failure_threshold,
disable_duration=self.provider_config.health_disable_duration, disable_duration=self.provider_config.health_disable_duration,
) )
self.failover_manager = FailoverManager(
health_checker=self.health_checker,
priority_order=self.provider_priority,
)
# 邮件解析器 # 邮件解析器
self.email_parser = get_email_parser() self.email_parser = get_email_parser()
# Provider 实例缓存: email -> IMAPNewProvider # 提供者实例缓存: (email, provider_type) -> OutlookProvider
self._providers: Dict[str, IMAPNewProvider] = {} self._providers: Dict[tuple, OutlookProvider] = {}
self._provider_lock = threading.Lock() self._provider_lock = threading.Lock()
# IMAP 并发限制(最多 5 个并发 # IMAP 连接限制(防止限流
self._imap_semaphore = threading.Semaphore(5) self._imap_semaphore = threading.Semaphore(5)
# 邮件缓存 # 验证码去重机制
self._email_cache = _EmailCache()
# 验证码去重
self._used_codes: Dict[str, set] = {} self._used_codes: Dict[str, set] = {}
def _get_provider(self, account: OutlookAccount) -> IMAPNewProvider: def _get_provider(
key = account.email.lower()
with self._provider_lock:
if key not in self._providers:
self._providers[key] = IMAPNewProvider(account, self.provider_config)
return self._providers[key]
def _fetch_emails(
self, self,
account: OutlookAccount, account: OutlookAccount,
count: int = 15, provider_type: ProviderType,
) -> OutlookProvider:
"""
获取或创建提供者实例
Args:
account: Outlook 账户
provider_type: 提供者类型
Returns:
提供者实例
"""
cache_key = (account.email.lower(), provider_type)
with self._provider_lock:
if cache_key not in self._providers:
provider = self._create_provider(account, provider_type)
self._providers[cache_key] = provider
return self._providers[cache_key]
def _create_provider(
self,
account: OutlookAccount,
provider_type: ProviderType,
) -> OutlookProvider:
"""
创建提供者实例
Args:
account: Outlook 账户
provider_type: 提供者类型
Returns:
提供者实例
"""
if provider_type == ProviderType.IMAP_OLD:
return IMAPOldProvider(account, self.provider_config)
elif provider_type == ProviderType.IMAP_NEW:
return IMAPNewProvider(account, self.provider_config)
elif provider_type == ProviderType.GRAPH_API:
return GraphAPIProvider(account, self.provider_config)
else:
raise ValueError(f"未知的提供者类型: {provider_type}")
def _get_provider_priority_for_account(self, account: OutlookAccount) -> List[ProviderType]:
"""根据账户是否有 OAuth返回适合的提供者优先级列表"""
if account.has_oauth():
return self.provider_priority
else:
# 无 OAuth直接走旧版 IMAP密码认证跳过需要 OAuth 的提供者
return [ProviderType.IMAP_OLD]
def _try_providers_for_emails(
self,
account: OutlookAccount,
count: int = 20,
only_unseen: bool = True, only_unseen: bool = True,
since_minutes: Optional[int] = None,
use_cache: bool = False,
folders: Optional[List[str]] = None,
) -> List[EmailMessage]: ) -> List[EmailMessage]:
"""通过 IMAP_NEW Provider 获取邮件,可选使用内存缓存""" """
if use_cache: 尝试多个提供者获取邮件
cached = self._email_cache.get(account.email)
if cached is not None:
return cached
if not self.health_checker.is_available(): Args:
logger.debug(f"[{account.email}] IMAP_NEW 不可用,跳过") account: Outlook 账户
return [] count: 获取数量
only_unseen: 是否只获取未读
try: Returns:
provider = self._get_provider(account) 邮件列表
with self._imap_semaphore: """
with provider: errors = []
emails = provider.get_recent_emails(
count, only_unseen, since_minutes=since_minutes, folders=folders
)
if emails: # 根据账户类型选择合适的提供者优先级
self.health_checker.record_success() priority = self._get_provider_priority_for_account(account)
if use_cache:
self._email_cache.set(account.email, emails)
return emails
except Exception as e: # 按优先级尝试各提供者
err = str(e) for provider_type in priority:
self.health_checker.record_failure(err) # 检查提供者是否可用
logger.warning(f"[{account.email}] 获取邮件失败: {e}") if not self.health_checker.is_available(provider_type):
return [] logger.debug(
f"[{account.email}] {provider_type.value} 不可用,跳过"
)
continue
try:
provider = self._get_provider(account, provider_type)
with self._imap_semaphore:
with provider:
emails = provider.get_recent_emails(count, only_unseen)
if emails:
# 成功获取邮件
self.health_checker.record_success(provider_type)
logger.debug(
f"[{account.email}] {provider_type.value} 获取到 {len(emails)} 封邮件"
)
return emails
except Exception as e:
error_msg = str(e)
errors.append(f"{provider_type.value}: {error_msg}")
self.health_checker.record_failure(provider_type, error_msg)
logger.warning(
f"[{account.email}] {provider_type.value} 获取邮件失败: {e}"
)
logger.error(
f"[{account.email}] 所有提供者都失败: {'; '.join(errors)}"
)
return []
def create_email(self, config: Dict[str, Any] = None) -> Dict[str, Any]: def create_email(self, config: Dict[str, Any] = None) -> Dict[str, Any]:
"""轮询选择可用的 Outlook 账户""" """
选择可用的 Outlook 账户
Args:
config: 配置参数(未使用)
Returns:
包含邮箱信息的字典
"""
if not self.accounts: if not self.accounts:
self.update_status(False, EmailServiceError("没有可用的 Outlook 账户")) self.update_status(False, EmailServiceError("没有可用的 Outlook 账户"))
raise EmailServiceError("没有可用的 Outlook 账户") raise EmailServiceError("没有可用的 Outlook 账户")
# 轮询选择账户
with self._account_lock: with self._account_lock:
account = self.accounts[self._current_account_index] account = self.accounts[self._current_account_index]
self._current_account_index = (self._current_account_index + 1) % len(self.accounts) self._current_account_index = (self._current_account_index + 1) % len(self.accounts)
logger.info(f"选择 Outlook 账户: {account.email}") email_info = {
self.update_status(True)
return {
"email": account.email, "email": account.email,
"service_id": account.email, "service_id": account.email,
"account": {"email": account.email, "has_oauth": account.has_oauth()}, "account": {
"email": account.email,
"has_oauth": account.has_oauth()
}
} }
logger.info(f"选择 Outlook 账户: {account.email}")
self.update_status(True)
return email_info
def get_verification_code( def get_verification_code(
self, self,
email: str, email: str,
@@ -218,185 +299,114 @@ class OutlookService(BaseEmailService):
pattern: str = None, pattern: str = None,
otp_sent_at: Optional[float] = None, otp_sent_at: Optional[float] = None,
) -> Optional[str]: ) -> Optional[str]:
"""从 Outlook 邮箱获取验证码""" """
account = next( 从 Outlook 邮箱获取验证码
(a for a in self.accounts if a.email.lower() == email.lower()), None
) Args:
email: 邮箱地址
email_id: 未使用
timeout: 超时时间(秒)
pattern: 验证码正则表达式(未使用)
otp_sent_at: OTP 发送时间戳
Returns:
验证码字符串
"""
# 查找对应的账户
account = None
for acc in self.accounts:
if acc.email.lower() == email.lower():
account = acc
break
if not account: if not account:
self.update_status(False, EmailServiceError(f"未找到邮箱账户: {email}")) self.update_status(False, EmailServiceError(f"未找到邮箱对应的账户: {email}"))
return None return None
code_settings = _get_code_settings() # 获取验证码等待配置
code_settings = get_email_code_settings()
actual_timeout = timeout or code_settings["timeout"] actual_timeout = timeout or code_settings["timeout"]
poll_interval = code_settings["poll_interval"] poll_interval = code_settings["poll_interval"]
logger.info(f"[{email}] 开始获取验证码,超时 {actual_timeout}s") logger.info(
f"[{email}] 开始获取验证码,超时 {actual_timeout}s"
f"提供者优先级: {[p.value for p in self.provider_priority]}"
)
# 初始化验证码去重集合
if email not in self._used_codes: if email not in self._used_codes:
self._used_codes[email] = set() self._used_codes[email] = set()
used_codes = self._used_codes[email] used_codes = self._used_codes[email]
# 计算最小时间戳(留出 60 秒时钟偏差)
min_timestamp = (otp_sent_at - 60) if otp_sent_at else 0 min_timestamp = (otp_sent_at - 60) if otp_sent_at else 0
use_idle = True
try:
use_idle = get_settings().outlook_use_idle
except Exception:
pass
if use_idle:
code = self._wait_with_idle(
account, email, actual_timeout, min_timestamp, used_codes, otp_sent_at
)
else:
code = self._wait_with_poll(
account, email, actual_timeout, poll_interval, min_timestamp, used_codes, otp_sent_at
)
if code:
used_codes.add(code)
self.update_status(True)
return code
return None
def _wait_with_poll(
self,
account: OutlookAccount,
email: str,
timeout: int,
poll_interval: int,
min_timestamp: float,
used_codes: set,
otp_sent_at: Optional[float] = None,
) -> Optional[str]:
"""轮询方式等待验证码"""
start_time = time.time() start_time = time.time()
poll_count = 0 poll_count = 0
while time.time() - start_time < timeout: while time.time() - start_time < actual_timeout:
poll_count += 1 poll_count += 1
# 每次动态计算 since_minutes确保时间窗口随轮询推进而更新
if otp_sent_at: # 渐进式邮件检查:前 3 次只检查未读
elapsed_since_send = int((time.time() - otp_sent_at) / 60) + 2 only_unseen = poll_count <= 3
since_minutes: Optional[int] = min(elapsed_since_send, 180)
only_unseen = False
else:
since_minutes = None
only_unseen = poll_count <= 3
try: try:
emails = self._fetch_emails( # 尝试多个提供者获取邮件
account, count=15, only_unseen=only_unseen, emails = self._try_providers_for_emails(
since_minutes=since_minutes, account,
folders=_OUTLOOK_SEARCH_FOLDERS, count=15,
only_unseen=only_unseen,
) )
if emails: if emails:
logger.debug(
f"[{email}] 第 {poll_count} 次轮询获取到 {len(emails)} 封邮件"
)
# 从邮件中查找验证码
code = self.email_parser.find_verification_code_in_emails( code = self.email_parser.find_verification_code_in_emails(
emails, emails,
target_email=email, target_email=email,
min_timestamp=min_timestamp, min_timestamp=min_timestamp,
used_codes=used_codes, used_codes=used_codes,
) )
if code: if code:
used_codes.add(code)
elapsed = int(time.time() - start_time) elapsed = int(time.time() - start_time)
logger.info( logger.info(
f"[{email}] 找到验证码: {code}耗时 {elapsed}s轮询 {poll_count}" f"[{email}] 找到验证码: {code}"
f"总耗时 {elapsed}s轮询 {poll_count}"
) )
self.update_status(True)
return code return code
except Exception as e:
logger.warning(f"[{email}] 轮询出错: {e}")
except Exception as e:
logger.warning(f"[{email}] 检查出错: {e}")
# 等待下次轮询
time.sleep(poll_interval) time.sleep(poll_interval)
logger.warning(f"[{email}] 验证码超时 ({timeout}s),共轮询 {poll_count}") elapsed = int(time.time() - start_time)
logger.warning(f"[{email}] 验证码超时 ({actual_timeout}s),共轮询 {poll_count}")
return None return None
def _wait_with_idle( def list_emails(self, **kwargs) -> List[Dict[str, Any]]:
self, """列出所有可用的 Outlook 账户"""
account: OutlookAccount, return [
email: str, {
timeout: int, "email": account.email,
min_timestamp: float, "id": account.email,
used_codes: set, "has_oauth": account.has_oauth(),
otp_sent_at: Optional[float] = None, "type": "outlook"
) -> Optional[str]: }
"""IMAP IDLE 方式等待验证码,失败时自动降级为轮询""" for account in self.accounts
if not self.health_checker.is_available(): ]
logger.warning(f"[{email}] IMAP_NEW 不可用,降级为轮询")
return self._wait_with_poll(
account, email, timeout, 3, min_timestamp, used_codes, otp_sent_at
)
# 计算 since_minutes从发送时间前2分钟开始最多180分钟 def delete_email(self, email_id: str) -> bool:
since_minutes: Optional[int] = None """删除邮箱Outlook 不支持删除账户)"""
if otp_sent_at: logger.warning(f"Outlook 服务不支持删除账户: {email_id}")
elapsed_since_send = int((time.time() - otp_sent_at) / 60) + 2 return False
since_minutes = min(elapsed_since_send, 180)
start_time = time.time()
try:
provider = self._get_provider(account)
with self._imap_semaphore:
with provider:
# 先做一次即时检查
emails = provider.get_recent_emails(
15, only_unseen=(since_minutes is None), since_minutes=since_minutes,
folders=_OUTLOOK_SEARCH_FOLDERS,
)
code = self.email_parser.find_verification_code_in_emails(
emails,
target_email=email,
min_timestamp=min_timestamp,
used_codes=used_codes,
)
if code:
elapsed = int(time.time() - start_time)
logger.info(f"[{email}] 找到验证码: {code},耗时 {elapsed}s即时检查")
return code
# IDLE 等待循环
while time.time() - start_time < timeout:
remaining = int(timeout - (time.time() - start_time))
if remaining <= 0:
break
arrived = provider.wait_for_new_email_idle(timeout=min(remaining, 25))
# 无效化缓存,强制重新拉取
self._email_cache.invalidate(email)
# IDLE 触发后用 since_minutes 搜索,覆盖已读邮件
fetch_since = since_minutes
if fetch_since is None:
# 没有 otp_sent_at 时用距当前时间2分钟内的邮件
fetch_since = 2
emails = provider.get_recent_emails(
15, only_unseen=False, since_minutes=fetch_since,
folders=_OUTLOOK_SEARCH_FOLDERS,
)
code = self.email_parser.find_verification_code_in_emails(
emails,
target_email=email,
min_timestamp=min_timestamp,
used_codes=used_codes,
)
if code:
elapsed = int(time.time() - start_time)
logger.info(
f"[{email}] 找到验证码: {code},耗时 {elapsed}s"
f"IDLE {'推送' if arrived else '超时检查'}"
)
return code
except Exception as e:
logger.warning(f"[{email}] IDLE 失败,降级为轮询: {e}")
elapsed = int(time.time() - start_time)
remaining = max(0, timeout - elapsed)
if remaining > 0:
code_settings = _get_code_settings()
return self._wait_with_poll(
account, email, remaining,
code_settings["poll_interval"], min_timestamp, used_codes, otp_sent_at
)
logger.warning(f"[{email}] IDLE 等待验证码超时 ({timeout}s)")
return None
def check_health(self) -> bool: def check_health(self) -> bool:
"""检查 Outlook 服务是否可用""" """检查 Outlook 服务是否可用"""
@@ -404,48 +414,48 @@ class OutlookService(BaseEmailService):
self.update_status(False, EmailServiceError("没有配置的账户")) self.update_status(False, EmailServiceError("没有配置的账户"))
return False return False
try: # 测试第一个账户的连接
provider = self._get_provider(self.accounts[0]) test_account = self.accounts[0]
if provider.test_connection():
self.update_status(True) # 尝试任一提供者连接
return True for provider_type in self.provider_priority:
except Exception as e: try:
logger.warning(f"Outlook 健康检查失败: {e}") provider = self._get_provider(test_account, provider_type)
if provider.test_connection():
self.update_status(True)
return True
except Exception as e:
logger.warning(
f"Outlook 健康检查失败 ({test_account.email}, {provider_type.value}): {e}"
)
self.update_status(False, EmailServiceError("健康检查失败")) self.update_status(False, EmailServiceError("健康检查失败"))
return False return False
def list_emails(self, **kwargs) -> List[Dict[str, Any]]: def get_provider_status(self) -> Dict[str, Any]:
return [ """获取提供者状态"""
{ return self.failover_manager.get_status()
"email": a.email,
"id": a.email,
"has_oauth": a.has_oauth(),
"type": "outlook",
}
for a in self.accounts
]
def delete_email(self, email_id: str) -> bool:
logger.warning(f"Outlook 服务不支持删除账户: {email_id}")
return False
def get_account_stats(self) -> Dict[str, Any]: def get_account_stats(self) -> Dict[str, Any]:
"""获取账户统计信息"""
total = len(self.accounts) total = len(self.accounts)
oauth_count = sum(1 for a in self.accounts if a.has_oauth()) oauth_count = sum(1 for acc in self.accounts if acc.has_oauth())
return { return {
"total_accounts": total, "total_accounts": total,
"oauth_accounts": oauth_count, "oauth_accounts": oauth_count,
"password_accounts": total - oauth_count, "password_accounts": total - oauth_count,
"accounts": [a.to_dict() for a in self.accounts], "accounts": [acc.to_dict() for acc in self.accounts],
"health_status": self.health_checker.get_status(), "provider_status": self.get_provider_status(),
} }
def add_account(self, account_config: Dict[str, Any]) -> bool: def add_account(self, account_config: Dict[str, Any]) -> bool:
"""添加新的 Outlook 账户"""
try: try:
account = OutlookAccount.from_config(account_config) account = OutlookAccount.from_config(account_config)
if not account.validate(): if not account.validate():
return False return False
self.accounts.append(account) self.accounts.append(account)
logger.info(f"添加 Outlook 账户: {account.email}") logger.info(f"添加 Outlook 账户: {account.email}")
return True return True
@@ -454,13 +464,24 @@ class OutlookService(BaseEmailService):
return False return False
def remove_account(self, email: str) -> bool: def remove_account(self, email: str) -> bool:
for i, a in enumerate(self.accounts): """移除 Outlook 账户"""
if a.email.lower() == email.lower(): for i, acc in enumerate(self.accounts):
if acc.email.lower() == email.lower():
self.accounts.pop(i) self.accounts.pop(i)
logger.info(f"移除 Outlook 账户: {email}") logger.info(f"移除 Outlook 账户: {email}")
return True return True
return False return False
def reset_health(self): def reset_provider_health(self):
self.health_checker.reset() """重置所有提供者的健康状态"""
logger.info("已重置 IMAP_NEW 健康状态") self.health_checker.reset_all()
logger.info("已重置所有提供者的健康状态")
def force_provider(self, provider_type: ProviderType):
"""强制使用指定的提供者"""
self.health_checker.force_enable(provider_type)
# 禁用其他提供者
for pt in ProviderType:
if pt != provider_type:
self.health_checker.force_disable(pt, 60)
logger.info(f"已强制使用提供者: {provider_type.value}")

View File

@@ -1,6 +1,6 @@
""" """
Token 管理器(简化版) Token 管理器
固定使用 consumers 端点 + IMAP scope 支持多个 Microsoft Token 端点,自动选择合适的端点
""" """
import json import json
@@ -11,98 +11,153 @@ from typing import Dict, Optional, Any
from curl_cffi import requests as _requests from curl_cffi import requests as _requests
from .base import TokenInfo from .base import ProviderType, TokenEndpoint, TokenInfo
from .account import OutlookAccount from .account import OutlookAccount
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
TOKEN_URL = "https://login.microsoftonline.com/consumers/oauth2/v2.0/token"
IMAP_SCOPE = "https://outlook.office.com/IMAP.AccessAsUser.All offline_access" # 各提供者的 Scope 配置
PROVIDER_SCOPES = {
ProviderType.IMAP_OLD: "", # 旧版 IMAP 不需要特定 scope
ProviderType.IMAP_NEW: "https://outlook.office.com/IMAP.AccessAsUser.All offline_access",
ProviderType.GRAPH_API: "https://graph.microsoft.com/.default",
}
# 各提供者的 Token 端点
PROVIDER_TOKEN_URLS = {
ProviderType.IMAP_OLD: TokenEndpoint.LIVE.value,
ProviderType.IMAP_NEW: TokenEndpoint.CONSUMERS.value,
ProviderType.GRAPH_API: TokenEndpoint.COMMON.value,
}
class TokenManager: class TokenManager:
""" """
Token 管理器 Token 管理器
固定 consumers 端点,缓存 key = email 支持多端点 Token 获取和缓存
""" """
_token_cache: Dict[str, TokenInfo] = {} # Token 缓存: key = (email, provider_type) -> TokenInfo
_token_cache: Dict[tuple, TokenInfo] = {}
_cache_lock = threading.Lock() _cache_lock = threading.Lock()
# 默认超时时间
DEFAULT_TIMEOUT = 30 DEFAULT_TIMEOUT = 30
# Token 刷新提前时间(秒)
REFRESH_BUFFER = 120 REFRESH_BUFFER = 120
def __init__( def __init__(
self, self,
account: OutlookAccount, account: OutlookAccount,
provider_type: ProviderType,
proxy_url: Optional[str] = None, proxy_url: Optional[str] = None,
timeout: int = DEFAULT_TIMEOUT, timeout: int = DEFAULT_TIMEOUT,
service_id: Optional[int] = None,
): ):
"""
初始化 Token 管理器
Args:
account: Outlook 账户
provider_type: 提供者类型
proxy_url: 代理 URL可选
timeout: 请求超时时间
"""
self.account = account self.account = account
self.provider_type = provider_type
self.proxy_url = proxy_url self.proxy_url = proxy_url
self.timeout = timeout self.timeout = timeout
self.service_id = service_id
def _cache_key(self) -> str: # 获取端点和 Scope
return self.account.email.lower() self.token_url = PROVIDER_TOKEN_URLS.get(provider_type, TokenEndpoint.LIVE.value)
self.scope = PROVIDER_SCOPES.get(provider_type, "")
def get_cached_token(self) -> Optional[TokenInfo]: def get_cached_token(self) -> Optional[TokenInfo]:
"""获取缓存的 Token"""
cache_key = (self.account.email.lower(), self.provider_type)
with self._cache_lock: with self._cache_lock:
token = self._token_cache.get(self._cache_key()) token = self._token_cache.get(cache_key)
if token and not token.is_expired(self.REFRESH_BUFFER): if token and not token.is_expired(self.REFRESH_BUFFER):
return token return token
return None return None
def set_cached_token(self, token: TokenInfo): def set_cached_token(self, token: TokenInfo):
"""缓存 Token"""
cache_key = (self.account.email.lower(), self.provider_type)
with self._cache_lock: with self._cache_lock:
self._token_cache[self._cache_key()] = token self._token_cache[cache_key] = token
def clear_cache(self): def clear_cache(self):
"""清除缓存"""
cache_key = (self.account.email.lower(), self.provider_type)
with self._cache_lock: with self._cache_lock:
self._token_cache.pop(self._cache_key(), None) self._token_cache.pop(cache_key, None)
def get_access_token(self, force_refresh: bool = False) -> Optional[str]: def get_access_token(self, force_refresh: bool = False) -> Optional[str]:
"""
获取 Access Token
Args:
force_refresh: 是否强制刷新
Returns:
Access Token 字符串,失败返回 None
"""
# 检查缓存
if not force_refresh: if not force_refresh:
cached = self.get_cached_token() cached = self.get_cached_token()
if cached: if cached:
logger.debug(f"[{self.account.email}] 使用缓存 Token") logger.debug(f"[{self.account.email}] 使用缓存 Token ({self.provider_type.value})")
return cached.access_token return cached.access_token
# 刷新 Token
try: try:
token = self._refresh_token() token = self._refresh_token()
if token: if token:
self.set_cached_token(token) self.set_cached_token(token)
return token.access_token return token.access_token
except Exception as e: except Exception as e:
logger.error(f"[{self.account.email}] 获取 Token 失败: {e}") logger.error(f"[{self.account.email}] 获取 Token 失败 ({self.provider_type.value}): {e}")
return None return None
def _refresh_token(self) -> Optional[TokenInfo]: def _refresh_token(self) -> Optional[TokenInfo]:
"""
刷新 Token
Returns:
TokenInfo 对象,失败返回 None
"""
if not self.account.client_id or not self.account.refresh_token: if not self.account.client_id or not self.account.refresh_token:
raise ValueError("缺少 client_id 或 refresh_token") raise ValueError("缺少 client_id 或 refresh_token")
logger.debug(f"[{self.account.email}] 正在刷新 Token...") logger.debug(f"[{self.account.email}] 正在刷新 Token ({self.provider_type.value})...")
logger.debug(f"[{self.account.email}] Token URL: {self.token_url}")
# 构建请求体
data = { data = {
"client_id": self.account.client_id, "client_id": self.account.client_id,
"refresh_token": self.account.refresh_token, "refresh_token": self.account.refresh_token,
"grant_type": "refresh_token", "grant_type": "refresh_token",
"scope": IMAP_SCOPE,
} }
# 添加 Scope如果需要
if self.scope:
data["scope"] = self.scope
headers = { headers = {
"Content-Type": "application/x-www-form-urlencoded", "Content-Type": "application/x-www-form-urlencoded",
"Accept": "application/json", "Accept": "application/json",
} }
proxies = None proxies = None
if self.proxy_url: if self.proxy_url:
proxies = {"http": self.proxy_url, "https": self.proxy_url} proxies = {"http": self.proxy_url, "https": self.proxy_url}
try: try:
resp = _requests.post( resp = _requests.post(
TOKEN_URL, self.token_url,
data=data, data=data,
headers=headers, headers=headers,
proxies=proxies, proxies=proxies,
@@ -111,56 +166,74 @@ class TokenManager:
) )
if resp.status_code != 200: if resp.status_code != 200:
body = resp.text error_body = resp.text
logger.error(f"[{self.account.email}] Token 刷新失败: HTTP {resp.status_code}") logger.error(f"[{self.account.email}] Token 刷新失败: HTTP {resp.status_code}")
if "service abuse" in body.lower(): logger.debug(f"[{self.account.email}] 错误响应: {error_body[:500]}")
if "service abuse" in error_body.lower():
logger.warning(f"[{self.account.email}] 账号可能被封禁") logger.warning(f"[{self.account.email}] 账号可能被封禁")
elif "invalid_grant" in body.lower(): elif "invalid_grant" in error_body.lower():
logger.warning(f"[{self.account.email}] Refresh Token 已失效") logger.warning(f"[{self.account.email}] Refresh Token 已失效")
return None return None
response_data = resp.json() response_data = resp.json()
token = TokenInfo.from_response(response_data, IMAP_SCOPE)
# 解析响应
token = TokenInfo.from_response(response_data, self.scope)
logger.info( logger.info(
f"[{self.account.email}] Token 刷新成功" f"[{self.account.email}] Token 刷新成功 ({self.provider_type.value}), "
f"有效期 {int(token.expires_at - time.time())}" f"有效期 {int(token.expires_at - time.time())}"
) )
# 若响应含新 refresh_token → 写回内存 + 持久化数据库
new_rt = response_data.get("refresh_token", "")
if new_rt and new_rt != self.account.refresh_token:
self.account.refresh_token = new_rt
if self.service_id:
try:
from ...database.session import get_session_manager
from ...database.crud import update_outlook_refresh_token
with get_session_manager().session_scope() as db:
update_outlook_refresh_token(
db, self.service_id, self.account.email, new_rt
)
logger.info(f"[{self.account.email}] refresh_token 已写回数据库")
except Exception as e:
logger.warning(f"[{self.account.email}] 写回 refresh_token 失败: {e}")
return token return token
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
logger.error(f"[{self.account.email}] JSON 解析错误: {e}") logger.error(f"[{self.account.email}] JSON 解析错误: {e}")
return None return None
except Exception as e: except Exception as e:
logger.error(f"[{self.account.email}] 未知错误: {e}") logger.error(f"[{self.account.email}] 未知错误: {e}")
return None return None
@classmethod @classmethod
def clear_all_cache(cls): def clear_all_cache(cls):
"""清除所有 Token 缓存"""
with cls._cache_lock: with cls._cache_lock:
cls._token_cache.clear() cls._token_cache.clear()
logger.info("已清除所有 Token 缓存") logger.info("已清除所有 Token 缓存")
@classmethod @classmethod
def get_cache_stats(cls) -> Dict[str, Any]: def get_cache_stats(cls) -> Dict[str, Any]:
"""获取缓存统计"""
with cls._cache_lock: with cls._cache_lock:
return { return {
"cache_size": len(cls._token_cache), "cache_size": len(cls._token_cache),
"entries": list(cls._token_cache.keys()), "entries": [
{
"email": key[0],
"provider": key[1].value,
}
for key in cls._token_cache.keys()
],
} }
def create_token_manager(
account: OutlookAccount,
provider_type: ProviderType,
proxy_url: Optional[str] = None,
timeout: int = TokenManager.DEFAULT_TIMEOUT,
) -> TokenManager:
"""
创建 Token 管理器的工厂函数
Args:
account: Outlook 账户
provider_type: 提供者类型
proxy_url: 代理 URL
timeout: 超时时间
Returns:
TokenManager 实例
"""
return TokenManager(account, provider_type, proxy_url, timeout)

View File

@@ -15,7 +15,7 @@ from email.policy import default as email_policy
from html import unescape from html import unescape
from typing import Optional, Dict, Any, List from typing import Optional, Dict, Any, List
from .base import BaseEmailService, EmailServiceError, EmailServiceType from .base import BaseEmailService, EmailServiceError, EmailServiceType, RateLimitedEmailServiceError
from ..core.http_client import HTTPClient, RequestConfig from ..core.http_client import HTTPClient, RequestConfig
from ..config.constants import OTP_CODE_PATTERN from ..config.constants import OTP_CODE_PATTERN
@@ -200,8 +200,19 @@ class TempMailService(BaseEmailService):
error_msg = f"{error_msg} - {error_data}" error_msg = f"{error_msg} - {error_data}"
except Exception: except Exception:
error_msg = f"{error_msg} - {response.text[:200]}" error_msg = f"{error_msg} - {response.text[:200]}"
self.update_status(False, EmailServiceError(error_msg)) retry_after = None
raise EmailServiceError(error_msg) if response.status_code == 429:
retry_after_header = response.headers.get("Retry-After")
if retry_after_header:
try:
retry_after = max(1, int(retry_after_header))
except ValueError:
retry_after = None
error = RateLimitedEmailServiceError(error_msg, retry_after=retry_after)
else:
error = EmailServiceError(error_msg)
self.update_status(False, error)
raise error
try: try:
return response.json() return response.json()

View File

@@ -6,6 +6,7 @@ import re
import time import time
import logging import logging
from typing import Optional, Dict, Any, List from typing import Optional, Dict, Any, List
from datetime import datetime, timezone
from .base import BaseEmailService, EmailServiceError, EmailServiceType from .base import BaseEmailService, EmailServiceError, EmailServiceType
from ..core.http_client import HTTPClient, RequestConfig from ..core.http_client import HTTPClient, RequestConfig
@@ -59,6 +60,35 @@ class TempmailService(BaseEmailService):
self._email_cache: Dict[str, Dict[str, Any]] = {} self._email_cache: Dict[str, Dict[str, Any]] = {}
self._last_check_time: float = 0 self._last_check_time: float = 0
def _parse_message_time(self, value: Any) -> Optional[float]:
"""解析 Tempmail 邮件时间,兼容 Unix 时间戳与 ISO 8601。"""
if value is None or value == "":
return None
if isinstance(value, (int, float)):
timestamp = float(value)
else:
text = str(value).strip()
if not text:
return None
try:
timestamp = float(text)
except ValueError:
try:
normalized = text.replace("Z", "+00:00")
timestamp = datetime.fromisoformat(normalized).astimezone(timezone.utc).timestamp()
except Exception:
return None
while timestamp > 1e11:
timestamp /= 1000.0
return timestamp if timestamp > 0 else None
def _get_received_timestamp(self, message: Dict[str, Any]) -> Optional[float]:
"""返回 Tempmail 邮件的接收时间戳。"""
return self._parse_message_time(message.get("received_at"))
def _save_token_to_db(self, email: str, token: str) -> None: def _save_token_to_db(self, email: str, token: str) -> None:
"""将邮箱 token 持久化到 Setting 表key=tempmail_token:{email}""" """将邮箱 token 持久化到 Setting 表key=tempmail_token:{email}"""
try: try:
@@ -154,7 +184,7 @@ class TempmailService(BaseEmailService):
email_id: 邮箱 token如果不提供从缓存中查找 email_id: 邮箱 token如果不提供从缓存中查找
timeout: 超时时间(秒) timeout: 超时时间(秒)
pattern: 验证码正则表达式 pattern: 验证码正则表达式
otp_sent_at: OTP 发送时间戳Tempmail 服务暂不使用此参数) otp_sent_at: OTP 发送时间戳,只允许使用严格晚于该锚点的邮件
Returns: Returns:
验证码字符串,如果超时或未找到返回 None 验证码字符串,如果超时或未找到返回 None
@@ -209,11 +239,20 @@ class TempmailService(BaseEmailService):
if not isinstance(msg, dict): if not isinstance(msg, dict):
continue continue
# 使用 date 作为唯一标识 msg_timestamp = self._get_received_timestamp(msg)
msg_date = msg.get("date", 0) if otp_sent_at is not None:
if not msg_date or msg_date in seen_ids: if msg_timestamp is None or msg_timestamp <= otp_sent_at:
continue
message_id = str(
msg.get("id")
or msg.get("date")
or msg.get("createdAt")
or f"{msg.get('from', '')}:{msg.get('subject', '')}:{msg_timestamp}"
).strip()
if not message_id or message_id in seen_ids:
continue continue
seen_ids.add(msg_date) seen_ids.add(message_id)
sender = str(msg.get("from", "")).lower() sender = str(msg.get("from", "")).lower()
subject = str(msg.get("subject", "")) subject = str(msg.get("subject", ""))
@@ -419,4 +458,4 @@ class TempmailService(BaseEmailService):
"email": email, "email": email,
"message": "等待验证码超时" "message": "等待验证码超时"
}) })
return None return None

View File

@@ -15,9 +15,11 @@ from fastapi import FastAPI, Request, Form
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates from fastapi.templating import Jinja2Templates
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse, RedirectResponse from fastapi.responses import FileResponse, HTMLResponse, RedirectResponse
from ..config.settings import get_settings from ..config.settings import get_settings
from ..database import crud
from ..database.session import get_db
from .routes import api_router from .routes import api_router
from .routes.websocket import router as ws_router from .routes.websocket import router as ws_router
from .task_manager import task_manager from .task_manager import task_manager
@@ -108,9 +110,9 @@ def create_app() -> FastAPI:
async def login_page(request: Request, next: Optional[str] = "/"): async def login_page(request: Request, next: Optional[str] = "/"):
"""登录页面""" """登录页面"""
return templates.TemplateResponse( return templates.TemplateResponse(
request, request=request,
"login.html", name="login.html",
{"error": "", "next": next or "/"} context={"request": request, "error": "", "next": next or "/"}
) )
@app.post("/login") @app.post("/login")
@@ -119,9 +121,9 @@ def create_app() -> FastAPI:
expected = get_settings().webui_access_password.get_secret_value() expected = get_settings().webui_access_password.get_secret_value()
if not secrets.compare_digest(password, expected): if not secrets.compare_digest(password, expected):
return templates.TemplateResponse( return templates.TemplateResponse(
request, request=request,
"login.html", name="login.html",
{"error": "密码错误", "next": next or "/"}, context={"request": request, "error": "密码错误", "next": next or "/"},
status_code=401 status_code=401
) )
@@ -136,38 +138,48 @@ def create_app() -> FastAPI:
response.delete_cookie("webui_auth") response.delete_cookie("webui_auth")
return response return response
@app.get("/favicon.ico", include_in_schema=False)
async def favicon_ico():
"""兼容浏览器对根路径 favicon 的默认请求。"""
return FileResponse(STATIC_DIR / "favicon.svg", media_type="image/svg+xml")
@app.get("/favicon.svg", include_in_schema=False)
async def favicon_svg():
"""提供统一的站点图标资源。"""
return FileResponse(STATIC_DIR / "favicon.svg", media_type="image/svg+xml")
@app.get("/", response_class=HTMLResponse) @app.get("/", response_class=HTMLResponse)
async def index(request: Request): async def index(request: Request):
"""首页 - 注册页面""" """首页 - 注册页面"""
if not _is_authenticated(request): if not _is_authenticated(request):
return _redirect_to_login(request) return _redirect_to_login(request)
return templates.TemplateResponse(request, "index.html") return templates.TemplateResponse(request=request, name="index.html", context={"request": request})
@app.get("/accounts", response_class=HTMLResponse) @app.get("/accounts", response_class=HTMLResponse)
async def accounts_page(request: Request): async def accounts_page(request: Request):
"""账号管理页面""" """账号管理页面"""
if not _is_authenticated(request): if not _is_authenticated(request):
return _redirect_to_login(request) return _redirect_to_login(request)
return templates.TemplateResponse(request, "accounts.html") return templates.TemplateResponse(request=request, name="accounts.html", context={"request": request})
@app.get("/email-services", response_class=HTMLResponse) @app.get("/email-services", response_class=HTMLResponse)
async def email_services_page(request: Request): async def email_services_page(request: Request):
"""邮箱服务管理页面""" """邮箱服务管理页面"""
if not _is_authenticated(request): if not _is_authenticated(request):
return _redirect_to_login(request) return _redirect_to_login(request)
return templates.TemplateResponse(request, "email_services.html") return templates.TemplateResponse(request=request, name="email_services.html", context={"request": request})
@app.get("/settings", response_class=HTMLResponse) @app.get("/settings", response_class=HTMLResponse)
async def settings_page(request: Request): async def settings_page(request: Request):
"""设置页面""" """设置页面"""
if not _is_authenticated(request): if not _is_authenticated(request):
return _redirect_to_login(request) return _redirect_to_login(request)
return templates.TemplateResponse(request, "settings.html") return templates.TemplateResponse(request=request, name="settings.html", context={"request": request})
@app.get("/payment", response_class=HTMLResponse) @app.get("/payment", response_class=HTMLResponse)
async def payment_page(request: Request): async def payment_page(request: Request):
"""支付页面""" """支付页面"""
return templates.TemplateResponse(request, "payment.html") return templates.TemplateResponse(request=request, name="payment.html", context={"request": request})
@app.on_event("startup") @app.on_event("startup")
async def startup_event(): async def startup_event():
@@ -185,6 +197,12 @@ def create_app() -> FastAPI:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
task_manager.set_loop(loop) task_manager.set_loop(loop)
stale_error = "服务启动时检测到未完成的历史任务,已标记失败,请重新发起。"
with get_db() as db:
stale_tasks = crud.fail_incomplete_registration_tasks(db, stale_error)
if stale_tasks:
logger.warning("已收敛 %s 个僵尸任务: %s", len(stale_tasks), ", ".join(task[:8] for task in stale_tasks))
logger.info("=" * 50) logger.info("=" * 50)
logger.info(f"{settings.app_name} v{settings.app_version} 启动中...") logger.info(f"{settings.app_name} v{settings.app_version} 启动中...")
logger.info(f"调试模式: {settings.debug}") logger.info(f"调试模式: {settings.debug}")

View File

@@ -1028,11 +1028,9 @@ def _build_inbox_config(db, service_type, email: str) -> dict:
EmailServiceModel.enabled == True EmailServiceModel.enabled == True
) )
if service_type == EST.OUTLOOK: if service_type == EST.OUTLOOK:
# 按 config.email 精确匹配,不受 enabled 限制(收件箱是账号自己的邮箱) # 按 config.email 匹配账号 email
all_outlook = db.query(EmailServiceModel).filter( services = query.all()
EmailServiceModel.service_type == db_type svc = next((s for s in services if (s.config or {}).get("email") == email), None)
).all()
svc = next((s for s in all_outlook if (s.config or {}).get("email", "").lower() == email.lower()), None)
else: else:
svc = query.order_by(EmailServiceModel.priority.asc()).first() svc = query.order_by(EmailServiceModel.priority.asc()).first()

View File

@@ -67,7 +67,7 @@ class ServiceTestResult(BaseModel):
class OutlookBatchImportRequest(BaseModel): class OutlookBatchImportRequest(BaseModel):
"""Outlook 批量导入请求""" """Outlook 批量导入请求"""
data: str # 多行数据,每行格式: 邮箱----密码----client_id----refresh_token data: str # 多行数据,每行格式: 邮箱----密码 或 邮箱----密码----client_id----refresh_token
enabled: bool = True enabled: bool = True
priority: int = 0 priority: int = 0
@@ -461,8 +461,11 @@ async def batch_import_outlook(request: OutlookBatchImportRequest):
""" """
批量导入 Outlook 邮箱账户 批量导入 Outlook 邮箱账户
格式(每行):邮箱----密码----client_id----refresh_token 支持两种格式:
使用四个连字符(----)分隔字段 - 格式一(密码认证):邮箱----密码
- 格式二XOAUTH2 认证):邮箱----密码----client_id----refresh_token
每行一个账户,使用四个连字符(----)分隔字段
""" """
lines = request.data.strip().split("\n") lines = request.data.strip().split("\n")
total = len(lines) total = len(lines)
@@ -481,18 +484,14 @@ async def batch_import_outlook(request: OutlookBatchImportRequest):
parts = line.split("----") parts = line.split("----")
# 必须是四字段格式 # 验证格式
if len(parts) < 4: if len(parts) < 2:
failed += 1 failed += 1
errors.append( errors.append(f"{i+1}: 格式错误,至少需要邮箱和密码")
f"{i+1}: 格式错误,必须为 邮箱----密码----client_id----refresh_token"
)
continue continue
email = parts[0].strip() email = parts[0].strip()
password = parts[1].strip() password = parts[1].strip()
client_id = parts[2].strip()
refresh_token = parts[3].strip()
# 验证邮箱格式 # 验证邮箱格式
if "@" not in email: if "@" not in email:
@@ -500,12 +499,6 @@ async def batch_import_outlook(request: OutlookBatchImportRequest):
errors.append(f"{i+1}: 无效的邮箱地址: {email}") errors.append(f"{i+1}: 无效的邮箱地址: {email}")
continue continue
# 验证 OAuth 字段非空
if not client_id or not refresh_token:
failed += 1
errors.append(f"{i+1}: [{email}] client_id 或 refresh_token 不能为空")
continue
# 检查是否已存在 # 检查是否已存在
existing = db.query(EmailServiceModel).filter( existing = db.query(EmailServiceModel).filter(
EmailServiceModel.service_type == "outlook", EmailServiceModel.service_type == "outlook",
@@ -520,11 +513,17 @@ async def batch_import_outlook(request: OutlookBatchImportRequest):
# 构建配置 # 构建配置
config = { config = {
"email": email, "email": email,
"password": password, "password": password
"client_id": client_id,
"refresh_token": refresh_token,
} }
# 检查是否有 OAuth 信息(格式二)
if len(parts) >= 4:
client_id = parts[2].strip()
refresh_token = parts[3].strip()
if client_id and refresh_token:
config["client_id"] = client_id
config["refresh_token"] = refresh_token
# 创建服务记录 # 创建服务记录
try: try:
service = EmailServiceModel( service = EmailServiceModel(
@@ -609,86 +608,3 @@ async def test_tempmail_service(request: TempmailTestRequest):
except Exception as e: except Exception as e:
logger.error(f"测试临时邮箱失败: {e}") logger.error(f"测试临时邮箱失败: {e}")
return {"success": False, "message": f"测试失败: {str(e)}"} return {"success": False, "message": f"测试失败: {str(e)}"}
# ============== 收件箱 ==============
@router.get("/{service_id}/inbox")
async def get_outlook_inbox(
service_id: int,
count: int = Query(30, ge=1, le=100),
only_unseen: bool = Query(False),
):
"""获取 Outlook 收件箱邮件列表"""
with get_db() as db:
service = db.query(EmailServiceModel).filter(EmailServiceModel.id == service_id).first()
if not service:
raise HTTPException(status_code=404, detail="服务不存在")
if service.service_type != "outlook":
raise HTTPException(status_code=400, detail="仅支持 Outlook 类型服务")
config = service.config or {}
email_addr = config.get("email", "")
client_id = config.get("client_id", "")
refresh_token = config.get("refresh_token", "")
# client_id 为空时尝试使用全局默认值
if not client_id:
from ...config.settings import get_settings
client_id = get_settings().outlook_default_client_id or ""
if not client_id or not refresh_token:
raise HTTPException(status_code=400, detail="该账户缺少 OAuth 配置client_id / refresh_token无法读取收件箱")
try:
from ...services.outlook.account import OutlookAccount
from ...services.outlook.token_manager import TokenManager
from ...services.outlook.providers.imap_new import IMAPNewProvider
from ...services.outlook.providers.base import ProviderConfig
account = OutlookAccount(
email=email_addr,
password=config.get("password", ""),
client_id=client_id,
refresh_token=refresh_token,
)
provider_config = ProviderConfig(
proxy_url=None,
timeout=30,
service_id=service_id,
)
provider = IMAPNewProvider(account, provider_config)
connected = provider.connect()
if not connected:
raise HTTPException(status_code=502, detail="IMAP 连接失败,请检查 OAuth 配置")
try:
messages = provider.get_recent_emails(count=count, only_unseen=only_unseen)
finally:
provider.disconnect()
emails = []
for m in messages:
received_str = m.received_at.isoformat() if m.received_at else None
emails.append({
"id": m.id or "",
"subject": m.subject or "",
"sender": m.sender or "",
"received_at": received_str,
"body_preview": m.body_preview or (m.body or "")[:200],
"body": m.body or "",
"is_read": m.is_read,
})
return {
"email": email_addr,
"total": len(emails),
"emails": emails,
}
except HTTPException:
raise
except Exception as e:
logger.error(f"获取收件箱失败 service_id={service_id}: {e}")
raise HTTPException(status_code=500, detail=f"获取收件箱失败: {str(e)}")

View File

@@ -7,8 +7,9 @@ import logging
import uuid import uuid
import random import random
import re import re
import time
from datetime import datetime from datetime import datetime
from typing import List, Optional, Dict, Tuple, Any from typing import List, Optional, Dict, Tuple
from fastapi import APIRouter, HTTPException, Query, BackgroundTasks from fastapi import APIRouter, HTTPException, Query, BackgroundTasks
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@@ -16,9 +17,13 @@ from pydantic import BaseModel, Field
from ...database import crud from ...database import crud
from ...database.session import get_db from ...database.session import get_db
from ...database.models import RegistrationTask, Proxy from ...database.models import RegistrationTask, Proxy
from ...core.login import LoginEngine from ...core.register import (
from ...core.register import RegistrationResult ERROR_OTP_TIMEOUT_SECONDARY,
RegistrationEngine,
RegistrationResult,
)
from ...services import EmailServiceFactory, EmailServiceType from ...services import EmailServiceFactory, EmailServiceType
from ...services.base import EmailProviderBackoffState, OTPTimeoutEmailServiceError
from ...config.settings import get_settings from ...config.settings import get_settings
from ..task_manager import task_manager from ..task_manager import task_manager
@@ -29,6 +34,7 @@ router = APIRouter()
running_tasks: dict = {} running_tasks: dict = {}
# 批量任务存储 # 批量任务存储
batch_tasks: Dict[str, dict] = {} batch_tasks: Dict[str, dict] = {}
email_service_circuit_breakers: Dict[int, EmailProviderBackoffState] = {}
# ============== Proxy Helper Functions ============== # ============== Proxy Helper Functions ==============
@@ -253,6 +259,176 @@ def _normalize_email_service_config(
return normalized return normalized
def _get_email_service_backoff_state(service_id: Optional[int]) -> EmailProviderBackoffState:
if service_id is None:
return EmailProviderBackoffState()
return email_service_circuit_breakers.get(service_id, EmailProviderBackoffState())
def _store_email_service_backoff_state(
service_id: Optional[int],
backoff_state: Optional[EmailProviderBackoffState],
) -> Optional[EmailProviderBackoffState]:
if service_id is None or backoff_state is None:
return None
if backoff_state.failures == 0 and backoff_state.delay_seconds == 0:
email_service_circuit_breakers.pop(service_id, None)
return backoff_state
email_service_circuit_breakers[service_id] = backoff_state
return backoff_state
def _get_phase_result(phase_history, phase_name: str):
for phase_result in phase_history or []:
if getattr(phase_result, "phase", None) == phase_name:
return phase_result
return None
def _is_email_service_circuit_open(service_id: Optional[int], now: Optional[float] = None) -> bool:
if service_id is None:
return False
return _get_email_service_backoff_state(service_id).is_open(now)
def _trip_email_service_circuit(
service_id: Optional[int],
backoff_state: Optional[EmailProviderBackoffState],
) -> int:
if service_id is None or backoff_state is None:
return 0
_store_email_service_backoff_state(service_id, backoff_state)
return backoff_state.delay_seconds
def _record_email_service_timeout_backoff(
service_id: Optional[int],
email_service,
previous_backoff_state: EmailProviderBackoffState,
error_code: str,
error_message: str,
) -> Optional[EmailProviderBackoffState]:
if service_id is None:
return None
timeout_error = OTPTimeoutEmailServiceError(
error_message or "等待验证码超时",
error_code=error_code,
)
if hasattr(email_service, "apply_provider_backoff_state"):
email_service.apply_provider_backoff_state(previous_backoff_state)
if hasattr(email_service, "update_status"):
email_service.update_status(False, timeout_error)
backoff_state = getattr(email_service, "provider_backoff_state", None)
return _store_email_service_backoff_state(service_id, backoff_state)
def _build_email_service_candidates(
db,
service_type: EmailServiceType,
actual_proxy_url: Optional[str],
email_service_id: Optional[int],
email_service_config: Optional[dict],
) -> List[Dict[str, object]]:
from ...database.models import EmailService as EmailServiceModel, Account
settings = get_settings()
candidates: List[Dict[str, object]] = []
def append_candidate(candidate_type: EmailServiceType, config: dict, db_service=None) -> None:
candidates.append({
"service_type": candidate_type,
"config": config,
"db_service": db_service,
})
def append_database_candidates(db_service_type: str) -> None:
services = db.query(EmailServiceModel).filter(
EmailServiceModel.service_type == db_service_type,
EmailServiceModel.enabled == True
).order_by(EmailServiceModel.priority.asc(), EmailServiceModel.id.asc()).all()
for db_service in services:
if _is_email_service_circuit_open(db_service.id):
continue
candidate_type = EmailServiceType(db_service.service_type)
config = _normalize_email_service_config(candidate_type, db_service.config, actual_proxy_url)
append_candidate(candidate_type, config, db_service=db_service)
if email_service_id:
db_service = db.query(EmailServiceModel).filter(
EmailServiceModel.id == email_service_id,
EmailServiceModel.enabled == True
).first()
if not db_service:
raise ValueError(f"邮箱服务不存在或已禁用: {email_service_id}")
if _is_email_service_circuit_open(db_service.id):
raise ValueError(f"邮箱服务处于熔断状态: {db_service.name}")
candidate_type = EmailServiceType(db_service.service_type)
config = _normalize_email_service_config(candidate_type, db_service.config, actual_proxy_url)
append_candidate(candidate_type, config, db_service=db_service)
return candidates
if service_type == EmailServiceType.TEMPMAIL:
append_candidate(service_type, {
"base_url": settings.tempmail_base_url,
"timeout": settings.tempmail_timeout,
"max_retries": settings.tempmail_max_retries,
"proxy_url": actual_proxy_url,
})
elif service_type == EmailServiceType.MOE_MAIL:
append_database_candidates("moe_mail")
if not candidates:
if settings.custom_domain_base_url and settings.custom_domain_api_key:
append_candidate(service_type, {
"base_url": settings.custom_domain_base_url,
"api_key": settings.custom_domain_api_key.get_secret_value() if settings.custom_domain_api_key else "",
"proxy_url": actual_proxy_url,
})
else:
raise ValueError("没有可用的自定义域名邮箱服务,请先在设置中配置")
elif service_type == EmailServiceType.OUTLOOK:
services = db.query(EmailServiceModel).filter(
EmailServiceModel.service_type == "outlook",
EmailServiceModel.enabled == True
).order_by(EmailServiceModel.priority.asc(), EmailServiceModel.id.asc()).all()
if not services:
raise ValueError("没有可用的 Outlook 账户,请先在设置中导入账户")
for db_service in services:
if _is_email_service_circuit_open(db_service.id):
continue
email = db_service.config.get("email") if db_service.config else None
if not email:
continue
existing = db.query(Account).filter(Account.email == email).first()
if existing:
logger.info(f"跳过已注册的 Outlook 账户: {email}")
continue
config = _normalize_email_service_config(service_type, db_service.config, actual_proxy_url)
append_candidate(service_type, config, db_service=db_service)
if not candidates:
raise ValueError("所有 Outlook 账户都已注册过,或当前均处于熔断状态")
elif service_type == EmailServiceType.DUCK_MAIL:
append_database_candidates("duck_mail")
if not candidates:
raise ValueError("没有可用的 DuckMail 邮箱服务,请先在邮箱服务页面添加服务")
elif service_type == EmailServiceType.FREEMAIL:
append_database_candidates("freemail")
if not candidates:
raise ValueError("没有可用的 Freemail 邮箱服务,请先在邮箱服务页面添加服务")
elif service_type == EmailServiceType.IMAP_MAIL:
append_database_candidates("imap_mail")
if not candidates:
raise ValueError("没有可用的 IMAP 邮箱服务,请先在邮箱服务中添加")
else:
append_candidate(service_type, email_service_config or {})
return candidates
def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy: Optional[str], email_service_config: Optional[dict], email_service_id: Optional[int] = None, log_prefix: str = "", batch_id: str = "", auto_upload_cpa: bool = False, cpa_service_ids: List[int] = None, auto_upload_sub2api: bool = False, sub2api_service_ids: List[int] = None, auto_upload_tm: bool = False, tm_service_ids: List[int] = None): def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy: Optional[str], email_service_config: Optional[dict], email_service_id: Optional[int] = None, log_prefix: str = "", batch_id: str = "", auto_upload_cpa: bool = False, cpa_service_ids: List[int] = None, auto_upload_sub2api: bool = False, sub2api_service_ids: List[int] = None, auto_upload_tm: bool = False, tm_service_ids: List[int] = None):
""" """
在线程池中执行的同步注册任务 在线程池中执行的同步注册任务
@@ -261,165 +437,31 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy:
""" """
with get_db() as db: with get_db() as db:
try: try:
# 检查是否已取消
if task_manager.is_cancelled(task_uuid): if task_manager.is_cancelled(task_uuid):
logger.info(f"任务 {task_uuid} 已取消,跳过执行") logger.info(f"任务 {task_uuid} 已取消,跳过执行")
return return
# 更新任务状态为运行中
task = crud.update_registration_task( task = crud.update_registration_task(
db, task_uuid, db, task_uuid,
status="running", status="running",
started_at=datetime.utcnow() started_at=datetime.utcnow()
) )
if not task: if not task:
logger.error(f"任务不存在: {task_uuid}") logger.error(f"任务不存在: {task_uuid}")
return return
resolved_email_service_id = email_service_id or task.email_service_id
# 更新 TaskManager 状态
task_manager.update_status(task_uuid, "running") task_manager.update_status(task_uuid, "running")
settings = get_settings()
log_callback = task_manager.create_log_callback(task_uuid, prefix=log_prefix, batch_id=batch_id) log_callback = task_manager.create_log_callback(task_uuid, prefix=log_prefix, batch_id=batch_id)
requested_service_type = EmailServiceType(email_service_type)
def build_email_service(active_proxy_url: Optional[str]):
requested_service_type = EmailServiceType(email_service_type)
if resolved_email_service_id:
from ...database.models import EmailService as EmailServiceModel
db_service = db.query(EmailServiceModel).filter(
EmailServiceModel.id == resolved_email_service_id,
EmailServiceModel.enabled == True
).first()
if db_service:
selected_service_type = EmailServiceType(db_service.service_type)
config = _normalize_email_service_config(selected_service_type, db_service.config, active_proxy_url)
crud.update_registration_task(db, task_uuid, email_service_id=db_service.id)
logger.info(
f"使用数据库邮箱服务: {db_service.name} "
f"(ID: {db_service.id}, 类型: {selected_service_type.value})"
)
email_service = EmailServiceFactory.create(selected_service_type, config)
return email_service, selected_service_type
raise ValueError(f"邮箱服务不存在或已禁用: {resolved_email_service_id}")
service_type = requested_service_type
if service_type == EmailServiceType.TEMPMAIL:
config = {
"base_url": settings.tempmail_base_url,
"timeout": settings.tempmail_timeout,
"max_retries": settings.tempmail_max_retries,
"proxy_url": active_proxy_url,
}
elif service_type == EmailServiceType.MOE_MAIL:
from ...database.models import EmailService as EmailServiceModel
db_service = db.query(EmailServiceModel).filter(
EmailServiceModel.service_type == "moe_mail",
EmailServiceModel.enabled == True
).order_by(EmailServiceModel.priority.asc()).first()
if db_service and db_service.config:
config = _normalize_email_service_config(service_type, db_service.config, active_proxy_url)
crud.update_registration_task(db, task_uuid, email_service_id=db_service.id)
logger.info(f"使用数据库自定义域名服务: {db_service.name}")
elif settings.custom_domain_base_url and settings.custom_domain_api_key:
config = {
"base_url": settings.custom_domain_base_url,
"api_key": settings.custom_domain_api_key.get_secret_value() if settings.custom_domain_api_key else "",
"proxy_url": active_proxy_url,
}
else:
raise ValueError("没有可用的自定义域名邮箱服务,请先在设置中配置")
elif service_type == EmailServiceType.OUTLOOK:
from ...database.models import EmailService as EmailServiceModel, Account
outlook_services = db.query(EmailServiceModel).filter(
EmailServiceModel.service_type == "outlook",
EmailServiceModel.enabled == True
).order_by(EmailServiceModel.priority.asc()).all()
if not outlook_services:
raise ValueError("没有可用的 Outlook 账户,请先在设置中导入账户")
# 找到一个未注册的 Outlook 账户
selected_service = None
for svc in outlook_services:
email = svc.config.get("email") if svc.config else None
if not email:
continue
existing = db.query(Account).filter(Account.email == email).first()
if not existing:
selected_service = svc
logger.info(f"选择未注册的 Outlook 账户: {email}")
break
logger.info(f"跳过已注册的 Outlook 账户: {email}")
if selected_service and selected_service.config:
config = selected_service.config.copy()
config['service_id'] = selected_service.id
crud.update_registration_task(db, task_uuid, email_service_id=selected_service.id)
logger.info(f"使用数据库 Outlook 账户: {selected_service.name}")
else:
raise ValueError("所有 Outlook 账户都已注册过 OpenAI 账号,请添加新的 Outlook 账户")
elif service_type == EmailServiceType.DUCK_MAIL:
from ...database.models import EmailService as EmailServiceModel
db_service = db.query(EmailServiceModel).filter(
EmailServiceModel.service_type == "duck_mail",
EmailServiceModel.enabled == True
).order_by(EmailServiceModel.priority.asc()).first()
if db_service and db_service.config:
config = _normalize_email_service_config(service_type, db_service.config, active_proxy_url)
crud.update_registration_task(db, task_uuid, email_service_id=db_service.id)
logger.info(f"使用数据库 DuckMail 服务: {db_service.name}")
else:
raise ValueError("没有可用的 DuckMail 邮箱服务,请先在邮箱服务页面添加服务")
elif service_type == EmailServiceType.FREEMAIL:
from ...database.models import EmailService as EmailServiceModel
db_service = db.query(EmailServiceModel).filter(
EmailServiceModel.service_type == "freemail",
EmailServiceModel.enabled == True
).order_by(EmailServiceModel.priority.asc()).first()
if db_service and db_service.config:
config = _normalize_email_service_config(service_type, db_service.config, active_proxy_url)
crud.update_registration_task(db, task_uuid, email_service_id=db_service.id)
logger.info(f"使用数据库 Freemail 服务: {db_service.name}")
else:
raise ValueError("没有可用的 Freemail 邮箱服务,请先在邮箱服务页面添加服务")
elif service_type == EmailServiceType.IMAP_MAIL:
from ...database.models import EmailService as EmailServiceModel
db_service = db.query(EmailServiceModel).filter(
EmailServiceModel.service_type == "imap_mail",
EmailServiceModel.enabled == True
).order_by(EmailServiceModel.priority.asc()).first()
if db_service and db_service.config:
config = _normalize_email_service_config(service_type, db_service.config, active_proxy_url)
crud.update_registration_task(db, task_uuid, email_service_id=db_service.id)
logger.info(f"使用数据库 IMAP 邮箱服务: {db_service.name}")
else:
raise ValueError("没有可用的 IMAP 邮箱服务,请先在邮箱服务中添加")
else:
config = email_service_config or {}
email_service = EmailServiceFactory.create(service_type, config)
return email_service, service_type
requested_proxy = proxy requested_proxy = proxy
exhausted_proxy_ids = set() exhausted_proxy_ids = set()
result = None result = RegistrationResult(success=False, logs=[])
active_service_type = EmailServiceType(email_service_type) active_service_type = requested_service_type
proxy_id = None
while True: while True:
actual_proxy_url = requested_proxy actual_proxy_url = requested_proxy
proxy_id = None proxy_id = None
if not actual_proxy_url: if not actual_proxy_url:
actual_proxy_url, proxy_id = get_proxy_for_registration( actual_proxy_url, proxy_id = get_proxy_for_registration(
db, db,
@@ -429,20 +471,126 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy:
logger.info(f"任务 {task_uuid} 使用代理: {actual_proxy_url[:50]}...") logger.info(f"任务 {task_uuid} 使用代理: {actual_proxy_url[:50]}...")
crud.update_registration_task(db, task_uuid, proxy=actual_proxy_url) crud.update_registration_task(db, task_uuid, proxy=actual_proxy_url)
email_service, active_service_type = build_email_service(actual_proxy_url) service_candidates = _build_email_service_candidates(
task_manager.update_status(task_uuid, "running", email_service=active_service_type.value) db,
engine = LoginEngine( requested_service_type,
email_service=email_service, actual_proxy_url,
proxy_url=actual_proxy_url, email_service_id,
callback_logger=log_callback, email_service_config,
task_uuid=task_uuid
) )
result = engine.run() should_retry_with_new_proxy = False
for attempt_index, candidate in enumerate(service_candidates, start=1):
selected_service_type = candidate["service_type"]
candidate_config = candidate["config"]
db_service = candidate.get("db_service")
active_service_type = selected_service_type
if db_service is not None:
crud.update_registration_task(db, task_uuid, email_service_id=db_service.id)
logger.info(
f"任务 {task_uuid} 使用数据库邮箱服务: {db_service.name} "
f"(ID: {db_service.id}, 类型: {selected_service_type.value}, 尝试: {attempt_index}/{len(service_candidates)})"
)
log_callback(
f"[系统] 使用邮箱服务: {db_service.name} "
f"({selected_service_type.value}, 尝试 {attempt_index}/{len(service_candidates)})"
)
else:
crud.update_registration_task(db, task_uuid, email_service_id=None)
task_manager.update_status(task_uuid, "running", email_service=active_service_type.value)
email_service = EmailServiceFactory.create(
selected_service_type,
candidate_config,
name=db_service.name if db_service is not None else None,
)
provider_backoff_before_run = EmailProviderBackoffState()
if db_service is not None:
provider_backoff_before_run = _get_email_service_backoff_state(db_service.id)
if db_service is not None and hasattr(email_service, "apply_provider_backoff_state"):
email_service.apply_provider_backoff_state(provider_backoff_before_run)
engine = RegistrationEngine(
email_service=email_service,
proxy_url=actual_proxy_url,
callback_logger=log_callback,
task_uuid=task_uuid
)
try:
result = engine.run()
finally:
close_engine = getattr(engine, "close", None)
if callable(close_engine):
close_engine()
email_prepare_phase = _get_phase_result(
getattr(engine, "phase_history", []),
"email_prepare",
)
if db_service is not None and email_prepare_phase is not None:
_store_email_service_backoff_state(
db_service.id,
getattr(email_prepare_phase, "provider_backoff", None),
)
if result.success:
break
if is_retryable_proxy_error(result.error_message):
should_retry_with_new_proxy = True
break
can_failover = (
db_service is not None
and attempt_index < len(service_candidates)
and email_prepare_phase is not None
and not email_prepare_phase.success
and email_prepare_phase.error_code == "EMAIL_PROVIDER_RATE_LIMITED"
and email_prepare_phase.provider_backoff is not None
)
if not can_failover:
if (
db_service is not None
and result.error_code == ERROR_OTP_TIMEOUT_SECONDARY
):
timeout_backoff = _record_email_service_timeout_backoff(
db_service.id,
email_service,
provider_backoff_before_run,
result.error_code,
result.error_message,
)
if timeout_backoff is not None:
logger.warning(
f"邮箱服务 OTP 超时,已退避 {db_service.name} "
f"{timeout_backoff.delay_seconds} 秒,连续失败 "
f"{timeout_backoff.failures}"
)
log_callback(
f"[系统] 邮箱服务 OTP 超时,退避 "
f"{timeout_backoff.delay_seconds} 秒: {db_service.name} "
f"(连续失败 {timeout_backoff.failures} 次)"
)
break
backoff_state = email_prepare_phase.provider_backoff
cooldown = _trip_email_service_circuit(db_service.id, backoff_state)
logger.warning(
f"邮箱服务限流,已退避 {db_service.name} {cooldown} 秒,"
f"连续失败 {backoff_state.failures} 次,"
f"任务 {task_uuid} 将切换到下一个服务"
)
log_callback(
f"[系统] 邮箱服务限流,退避 {cooldown} 秒并切换: "
f"{db_service.name} (连续失败 {backoff_state.failures} 次)"
)
if result.success: if result.success:
break break
if is_retryable_proxy_error(result.error_message): if should_retry_with_new_proxy:
log_callback(f"[代理] 检测到可重试网络错误: {result.error_message}") log_callback(f"[代理] 检测到可重试网络错误: {result.error_message}")
if proxy_id and disable_proxy_for_network_error(db, proxy_id, result.error_message): if proxy_id and disable_proxy_for_network_error(db, proxy_id, result.error_message):
exhausted_proxy_ids.add(proxy_id) exhausted_proxy_ids.add(proxy_id)
@@ -652,53 +800,34 @@ async def run_registration_task(task_uuid: str, email_service_type: str, proxy:
def _init_batch_state(batch_id: str, task_uuids: List[str]): def _init_batch_state(batch_id: str, task_uuids: List[str]):
"""初始化批量任务内存状态""" """初始化批量任务内存状态"""
task_manager.init_batch(batch_id, len(task_uuids)) task_manager.init_batch(batch_id, len(task_uuids))
metadata = batch_tasks.get(batch_id, {}).copy() batch_tasks[batch_id] = {
metadata["task_uuids"] = task_uuids "total": len(task_uuids),
batch_tasks[batch_id] = metadata "completed": 0,
"success": 0,
"failed": 0,
"cancelled": False,
"task_uuids": task_uuids,
"current_index": 0,
"logs": [],
"finished": False
}
def _make_batch_helpers(batch_id: str): def _make_batch_helpers(batch_id: str):
"""返回 add_batch_log 和 update_batch_status 辅助函数""" """返回 add_batch_log 和 update_batch_status 辅助函数"""
def add_batch_log(msg: str): def add_batch_log(msg: str):
batch_tasks[batch_id]["logs"].append(msg)
task_manager.add_batch_log(batch_id, msg) task_manager.add_batch_log(batch_id, msg)
def update_batch_status(**kwargs): def update_batch_status(**kwargs):
for key, value in kwargs.items():
if key in batch_tasks[batch_id]:
batch_tasks[batch_id][key] = value
task_manager.update_batch_status(batch_id, **kwargs) task_manager.update_batch_status(batch_id, **kwargs)
return add_batch_log, update_batch_status return add_batch_log, update_batch_status
def _get_batch_snapshot(batch_id: str) -> Optional[Dict[str, Any]]:
"""聚合批量任务元数据与实时状态。"""
metadata = batch_tasks.get(batch_id)
if metadata is None:
return None
status = task_manager.get_batch_status(batch_id) or {}
initial_skipped = int(metadata.get("initial_skipped", 0) or 0)
total = status.get("total", metadata.get("total", 0))
completed = status.get("completed", 0)
success = status.get("success", 0)
failed = status.get("failed", 0)
skipped = status.get("skipped", 0) + initial_skipped
return {
"batch_id": batch_id,
"total": total,
"completed": completed,
"success": success,
"failed": failed,
"skipped": skipped,
"current_index": status.get("current_index", 0),
"cancelled": status.get("cancelled", False),
"finished": status.get("finished", False),
"status": status.get("status", metadata.get("status", "pending")),
"logs": task_manager.get_batch_logs(batch_id),
"task_uuids": metadata.get("task_uuids", []),
"service_ids": metadata.get("service_ids", []),
}
async def run_batch_parallel( async def run_batch_parallel(
batch_id: str, batch_id: str,
task_uuids: List[str], task_uuids: List[str],
@@ -737,19 +866,21 @@ async def run_batch_parallel(
t = crud.get_registration_task(db, uuid) t = crud.get_registration_task(db, uuid)
if t: if t:
async with counter_lock: async with counter_lock:
new_completed = batch_tasks[batch_id]["completed"] + 1
new_success = batch_tasks[batch_id]["success"]
new_failed = batch_tasks[batch_id]["failed"]
if t.status == "completed": if t.status == "completed":
new_success += 1
add_batch_log(f"{prefix} [成功] 注册成功") add_batch_log(f"{prefix} [成功] 注册成功")
elif t.status == "failed": elif t.status == "failed":
new_failed += 1
add_batch_log(f"{prefix} [失败] 注册失败: {t.error_message}") add_batch_log(f"{prefix} [失败] 注册失败: {t.error_message}")
task_manager.record_batch_task_result(batch_id, t.status) update_batch_status(completed=new_completed, success=new_success, failed=new_failed)
try: try:
await asyncio.gather(*[_run_one(i, u) for i, u in enumerate(task_uuids)], return_exceptions=True) await asyncio.gather(*[_run_one(i, u) for i, u in enumerate(task_uuids)], return_exceptions=True)
if not task_manager.is_batch_cancelled(batch_id): if not task_manager.is_batch_cancelled(batch_id):
snapshot = task_manager.get_batch_status(batch_id) or {} add_batch_log(f"[完成] 批量任务完成!成功: {batch_tasks[batch_id]['success']}, 失败: {batch_tasks[batch_id]['failed']}")
add_batch_log(
f"[完成] 批量任务完成!成功: {snapshot.get('success', 0)}, 失败: {snapshot.get('failed', 0)}"
)
update_batch_status(finished=True, status="completed") update_batch_status(finished=True, status="completed")
else: else:
update_batch_status(finished=True, status="cancelled") update_batch_status(finished=True, status="cancelled")
@@ -757,6 +888,8 @@ async def run_batch_parallel(
logger.error(f"批量任务 {batch_id} 异常: {e}") logger.error(f"批量任务 {batch_id} 异常: {e}")
add_batch_log(f"[错误] 批量任务异常: {str(e)}") add_batch_log(f"[错误] 批量任务异常: {str(e)}")
update_batch_status(finished=True, status="failed") update_batch_status(finished=True, status="failed")
finally:
batch_tasks[batch_id]["finished"] = True
async def run_batch_pipeline( async def run_batch_pipeline(
@@ -799,17 +932,22 @@ async def run_batch_pipeline(
t = crud.get_registration_task(db, uuid) t = crud.get_registration_task(db, uuid)
if t: if t:
async with counter_lock: async with counter_lock:
new_completed = batch_tasks[batch_id]["completed"] + 1
new_success = batch_tasks[batch_id]["success"]
new_failed = batch_tasks[batch_id]["failed"]
if t.status == "completed": if t.status == "completed":
new_success += 1
add_batch_log(f"{pfx} [成功] 注册成功") add_batch_log(f"{pfx} [成功] 注册成功")
elif t.status == "failed": elif t.status == "failed":
new_failed += 1
add_batch_log(f"{pfx} [失败] 注册失败: {t.error_message}") add_batch_log(f"{pfx} [失败] 注册失败: {t.error_message}")
task_manager.record_batch_task_result(batch_id, t.status) update_batch_status(completed=new_completed, success=new_success, failed=new_failed)
finally: finally:
semaphore.release() semaphore.release()
try: try:
for i, task_uuid in enumerate(task_uuids): for i, task_uuid in enumerate(task_uuids):
if task_manager.is_batch_cancelled(batch_id): if task_manager.is_batch_cancelled(batch_id) or batch_tasks[batch_id]["cancelled"]:
with get_db() as db: with get_db() as db:
for remaining_uuid in task_uuids[i:]: for remaining_uuid in task_uuids[i:]:
crud.update_registration_task(db, remaining_uuid, status="cancelled") crud.update_registration_task(db, remaining_uuid, status="cancelled")
@@ -833,15 +971,14 @@ async def run_batch_pipeline(
await asyncio.gather(*running_tasks_list, return_exceptions=True) await asyncio.gather(*running_tasks_list, return_exceptions=True)
if not task_manager.is_batch_cancelled(batch_id): if not task_manager.is_batch_cancelled(batch_id):
snapshot = task_manager.get_batch_status(batch_id) or {} add_batch_log(f"[完成] 批量任务完成!成功: {batch_tasks[batch_id]['success']}, 失败: {batch_tasks[batch_id]['failed']}")
add_batch_log(
f"[完成] 批量任务完成!成功: {snapshot.get('success', 0)}, 失败: {snapshot.get('failed', 0)}"
)
update_batch_status(finished=True, status="completed") update_batch_status(finished=True, status="completed")
except Exception as e: except Exception as e:
logger.error(f"批量任务 {batch_id} 异常: {e}") logger.error(f"批量任务 {batch_id} 异常: {e}")
add_batch_log(f"[错误] 批量任务异常: {str(e)}") add_batch_log(f"[错误] 批量任务异常: {str(e)}")
update_batch_status(finished=True, status="failed") update_batch_status(finished=True, status="failed")
finally:
batch_tasks[batch_id]["finished"] = True
async def run_batch_registration( async def run_batch_registration(
@@ -974,7 +1111,6 @@ async def start_batch_registration(
# 创建批量任务 # 创建批量任务
batch_id = str(uuid.uuid4()) batch_id = str(uuid.uuid4())
task_uuids = [] task_uuids = []
batch_tasks[batch_id] = {"total": request.count}
with get_db() as db: with get_db() as db:
for _ in range(request.count): for _ in range(request.count):
@@ -1021,33 +1157,34 @@ async def start_batch_registration(
@router.get("/batch/{batch_id}") @router.get("/batch/{batch_id}")
async def get_batch_status(batch_id: str): async def get_batch_status(batch_id: str):
"""获取批量任务状态""" """获取批量任务状态"""
snapshot = _get_batch_snapshot(batch_id) if batch_id not in batch_tasks:
if snapshot is None:
raise HTTPException(status_code=404, detail="批量任务不存在") raise HTTPException(status_code=404, detail="批量任务不存在")
batch = batch_tasks[batch_id]
return { return {
"batch_id": batch_id, "batch_id": batch_id,
"total": snapshot["total"], "total": batch["total"],
"completed": snapshot["completed"], "completed": batch["completed"],
"success": snapshot["success"], "success": batch["success"],
"failed": snapshot["failed"], "failed": batch["failed"],
"current_index": snapshot["current_index"], "current_index": batch["current_index"],
"cancelled": snapshot["cancelled"], "cancelled": batch["cancelled"],
"finished": snapshot["finished"], "finished": batch.get("finished", False),
"progress": f"{snapshot['completed']}/{snapshot['total']}" "progress": f"{batch['completed']}/{batch['total']}"
} }
@router.post("/batch/{batch_id}/cancel") @router.post("/batch/{batch_id}/cancel")
async def cancel_batch(batch_id: str): async def cancel_batch(batch_id: str):
"""取消批量任务""" """取消批量任务"""
snapshot = _get_batch_snapshot(batch_id) if batch_id not in batch_tasks:
if snapshot is None:
raise HTTPException(status_code=404, detail="批量任务不存在") raise HTTPException(status_code=404, detail="批量任务不存在")
if snapshot.get("finished"): batch = batch_tasks[batch_id]
if batch.get("finished"):
raise HTTPException(status_code=400, detail="批量任务已完成") raise HTTPException(status_code=400, detail="批量任务已完成")
batch["cancelled"] = True
task_manager.cancel_batch(batch_id) task_manager.cancel_batch(batch_id)
return {"success": True, "message": "批量任务取消请求已提交"} return {"success": True, "message": "批量任务取消请求已提交"}
@@ -1528,11 +1665,18 @@ async def start_outlook_batch_registration(
# 创建批量任务 # 创建批量任务
batch_id = str(uuid.uuid4()) batch_id = str(uuid.uuid4())
# 记录额外元数据,由 task_manager 维护实时状态 # 初始化批量任务状态
batch_tasks[batch_id] = { batch_tasks[batch_id] = {
"total": len(actual_service_ids), "total": len(actual_service_ids),
"initial_skipped": skipped_count, "completed": 0,
"success": 0,
"failed": 0,
"skipped": 0,
"cancelled": False,
"service_ids": actual_service_ids, "service_ids": actual_service_ids,
"current_index": 0,
"logs": [],
"finished": False
} }
# 在后台运行批量注册 # 在后台运行批量注册
@@ -1566,35 +1710,37 @@ async def start_outlook_batch_registration(
@router.get("/outlook-batch/{batch_id}") @router.get("/outlook-batch/{batch_id}")
async def get_outlook_batch_status(batch_id: str): async def get_outlook_batch_status(batch_id: str):
"""获取 Outlook 批量任务状态""" """获取 Outlook 批量任务状态"""
snapshot = _get_batch_snapshot(batch_id) if batch_id not in batch_tasks:
if snapshot is None:
raise HTTPException(status_code=404, detail="批量任务不存在") raise HTTPException(status_code=404, detail="批量任务不存在")
batch = batch_tasks[batch_id]
return { return {
"batch_id": batch_id, "batch_id": batch_id,
"total": snapshot["total"], "total": batch["total"],
"completed": snapshot["completed"], "completed": batch["completed"],
"success": snapshot["success"], "success": batch["success"],
"failed": snapshot["failed"], "failed": batch["failed"],
"skipped": snapshot["skipped"], "skipped": batch.get("skipped", 0),
"current_index": snapshot["current_index"], "current_index": batch["current_index"],
"cancelled": snapshot["cancelled"], "cancelled": batch["cancelled"],
"finished": snapshot["finished"], "finished": batch.get("finished", False),
"logs": snapshot["logs"], "logs": batch.get("logs", []),
"progress": f"{snapshot['completed']}/{snapshot['total']}" "progress": f"{batch['completed']}/{batch['total']}"
} }
@router.post("/outlook-batch/{batch_id}/cancel") @router.post("/outlook-batch/{batch_id}/cancel")
async def cancel_outlook_batch(batch_id: str): async def cancel_outlook_batch(batch_id: str):
"""取消 Outlook 批量任务""" """取消 Outlook 批量任务"""
snapshot = _get_batch_snapshot(batch_id) if batch_id not in batch_tasks:
if snapshot is None:
raise HTTPException(status_code=404, detail="批量任务不存在") raise HTTPException(status_code=404, detail="批量任务不存在")
if snapshot.get("finished"): batch = batch_tasks[batch_id]
if batch.get("finished"):
raise HTTPException(status_code=400, detail="批量任务已完成") raise HTTPException(status_code=400, detail="批量任务已完成")
# 同时更新两个系统的取消状态
batch["cancelled"] = True
task_manager.cancel_batch(batch_id) task_manager.cancel_batch(batch_id)
return {"success": True, "message": "批量任务取消请求已提交"} return {"success": True, "message": "批量任务取消请求已提交"}

View File

@@ -469,60 +469,6 @@ class ProxyUpdateRequest(BaseModel):
priority: Optional[int] = None priority: Optional[int] = None
def _test_proxy_connectivity(proxy_url: str) -> dict:
"""测试代理连通性并返回统一结果。"""
import time
from curl_cffi import requests as cffi_requests
test_url = "https://api.ipify.org?format=json"
start_time = time.time()
proxies_dict = {
"http": proxy_url,
"https": proxy_url
}
response = cffi_requests.get(
test_url,
proxies=proxies_dict,
timeout=3,
impersonate="chrome110"
)
elapsed_time = time.time() - start_time
if response.status_code == 200:
ip_info = response.json()
return {
"success": True,
"ip": ip_info.get("ip", ""),
"response_time": round(elapsed_time * 1000),
"message": f"代理连接成功,出口 IP: {ip_info.get('ip', 'unknown')}"
}
return {
"success": False,
"message": f"状态码: {response.status_code}"
}
def _auto_disable_proxy_on_failure(db, proxy, message: str) -> dict:
"""代理测试失败时自动禁用,并返回统一提示。"""
auto_disabled = False
if proxy.enabled:
crud.update_proxy(db, proxy.id, enabled=False)
auto_disabled = True
final_message = message
if auto_disabled:
final_message = f"{message},已自动禁用"
return {
"success": False,
"auto_disabled": auto_disabled,
"message": final_message,
}
@router.get("/proxies") @router.get("/proxies")
async def get_proxies_list(enabled: Optional[bool] = None): async def get_proxies_list(enabled: Optional[bool] = None):
"""获取代理列表""" """获取代理列表"""
@@ -613,59 +559,107 @@ async def set_proxy_default(proxy_id: int):
@router.post("/proxies/{proxy_id}/test") @router.post("/proxies/{proxy_id}/test")
async def test_proxy_item(proxy_id: int): async def test_proxy_item(proxy_id: int):
"""测试单个代理""" """测试单个代理"""
import time
from curl_cffi import requests as cffi_requests
with get_db() as db: with get_db() as db:
proxy = crud.get_proxy_by_id(db, proxy_id) proxy = crud.get_proxy_by_id(db, proxy_id)
if not proxy: if not proxy:
raise HTTPException(status_code=404, detail="代理不存在") raise HTTPException(status_code=404, detail="代理不存在")
proxy_url = proxy.proxy_url
test_url = "https://api.ipify.org?format=json"
start_time = time.time()
try: try:
result = _test_proxy_connectivity(proxy.proxy_url) proxies = {
if result["success"]: "http": proxy_url,
return result "https": proxy_url
return _auto_disable_proxy_on_failure(db, proxy, f"代理返回错误状态码: {result['message'].removeprefix('状态码: ')}") }
response = cffi_requests.get(
test_url,
proxies=proxies,
timeout=3,
impersonate="chrome110"
)
elapsed_time = time.time() - start_time
if response.status_code == 200:
ip_info = response.json()
return {
"success": True,
"ip": ip_info.get("ip", ""),
"response_time": round(elapsed_time * 1000),
"message": f"代理连接成功,出口 IP: {ip_info.get('ip', 'unknown')}"
}
else:
return {
"success": False,
"message": f"代理返回错误状态码: {response.status_code}"
}
except Exception as e: except Exception as e:
return _auto_disable_proxy_on_failure(db, proxy, f"代理连接失败: {str(e)}") return {
"success": False,
"message": f"代理连接失败: {str(e)}"
}
@router.post("/proxies/test-all") @router.post("/proxies/test-all")
async def test_all_proxies(): async def test_all_proxies():
"""测试所有启用的代理""" """测试所有启用的代理"""
import time
from curl_cffi import requests as cffi_requests
with get_db() as db: with get_db() as db:
proxies = crud.get_enabled_proxies(db) proxies = crud.get_enabled_proxies(db)
results = [] results = []
auto_disabled_count = 0
for proxy in proxies: for proxy in proxies:
proxy_url = proxy.proxy_url
test_url = "https://api.ipify.org?format=json"
start_time = time.time()
try: try:
result = _test_proxy_connectivity(proxy.proxy_url) proxies_dict = {
if result["success"]: "http": proxy_url,
"https": proxy_url
}
response = cffi_requests.get(
test_url,
proxies=proxies_dict,
timeout=3,
impersonate="chrome110"
)
elapsed_time = time.time() - start_time
if response.status_code == 200:
ip_info = response.json()
results.append({ results.append({
"id": proxy.id, "id": proxy.id,
"name": proxy.name, "name": proxy.name,
"success": True, "success": True,
"ip": result.get("ip", ""), "ip": ip_info.get("ip", ""),
"response_time": result.get("response_time"), "response_time": round(elapsed_time * 1000)
"auto_disabled": False,
}) })
else: else:
failure_result = _auto_disable_proxy_on_failure(
db,
proxy,
f"代理返回错误状态码: {result['message'].removeprefix('状态码: ')}"
)
auto_disabled_count += 1 if failure_result["auto_disabled"] else 0
results.append({ results.append({
"id": proxy.id, "id": proxy.id,
"name": proxy.name, "name": proxy.name,
**failure_result, "success": False,
"message": f"状态码: {response.status_code}"
}) })
except Exception as e: except Exception as e:
failure_result = _auto_disable_proxy_on_failure(db, proxy, f"代理连接失败: {str(e)}")
auto_disabled_count += 1 if failure_result["auto_disabled"] else 0
results.append({ results.append({
"id": proxy.id, "id": proxy.id,
"name": proxy.name, "name": proxy.name,
**failure_result, "success": False,
"message": str(e)
}) })
success_count = sum(1 for r in results if r["success"]) success_count = sum(1 for r in results if r["success"])
@@ -673,7 +667,6 @@ async def test_all_proxies():
"total": len(proxies), "total": len(proxies),
"success": success_count, "success": success_count,
"failed": len(proxies) - success_count, "failed": len(proxies) - success_count,
"auto_disabled": auto_disabled_count,
"results": results "results": results
} }
@@ -698,14 +691,6 @@ async def disable_proxy(proxy_id: int):
return {"success": True, "message": "代理已禁用"} return {"success": True, "message": "代理已禁用"}
@router.delete("/proxies/disabled/batch-delete")
async def delete_disabled_proxies():
"""批量删除所有已禁用代理"""
with get_db() as db:
deleted = crud.delete_disabled_proxies(db)
return {"success": True, "deleted": deleted, "message": f"已删除 {deleted} 个禁用代理"}
# ============== Outlook 设置 ============== # ============== Outlook 设置 ==============
class OutlookSettings(BaseModel): class OutlookSettings(BaseModel):
@@ -720,6 +705,7 @@ async def get_outlook_settings():
return { return {
"default_client_id": settings.outlook_default_client_id, "default_client_id": settings.outlook_default_client_id,
"provider_priority": settings.outlook_provider_priority,
"health_failure_threshold": settings.outlook_health_failure_threshold, "health_failure_threshold": settings.outlook_health_failure_threshold,
"health_disable_duration": settings.outlook_health_disable_duration, "health_disable_duration": settings.outlook_health_disable_duration,
} }

View File

@@ -7,12 +7,33 @@ import asyncio
import logging import logging
from fastapi import APIRouter, WebSocket, WebSocketDisconnect from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from ...database import crud
from ...database.session import get_db
from ..task_manager import task_manager from ..task_manager import task_manager
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
router = APIRouter() router = APIRouter()
def _restore_task_snapshot(task_uuid: str) -> tuple[dict, list[str]]:
"""从数据库恢复任务状态和历史日志,解决服务重启后的监控空白。"""
with get_db() as db:
task = crud.get_registration_task(db, task_uuid)
if not task:
return {}, []
status = {"status": task.status}
if task.result and task.result.get("email"):
status["email"] = task.result["email"]
if task.error_message:
status["error"] = task.error_message
logs = task.logs.splitlines() if task.logs else []
task_manager.sync_task_state(task_uuid, status=status, logs=logs)
return status, logs
@router.websocket("/ws/task/{task_uuid}") @router.websocket("/ws/task/{task_uuid}")
async def task_websocket(websocket: WebSocket, task_uuid: str): async def task_websocket(websocket: WebSocket, task_uuid: str):
""" """
@@ -25,14 +46,15 @@ async def task_websocket(websocket: WebSocket, task_uuid: str):
- 客户端发送: {"type": "cancel"} - 取消任务 - 客户端发送: {"type": "cancel"} - 取消任务
""" """
await websocket.accept() await websocket.accept()
restored_status, restored_logs = _restore_task_snapshot(task_uuid)
# 注册连接(会记录当前日志数量,避免重复发送历史日志) # 注册连接,并取得注册时刻的历史日志快照,避免与后续实时推送串扰
task_manager.register_websocket(task_uuid, websocket) history_logs = task_manager.register_websocket(task_uuid, websocket)
logger.info(f"WebSocket 连接已建立: {task_uuid}") logger.info(f"WebSocket 连接已建立: {task_uuid}")
try: try:
# 发送当前状态 # 发送当前状态
status = task_manager.get_status(task_uuid) status = task_manager.get_status(task_uuid) or restored_status
if status: if status:
await websocket.send_json({ await websocket.send_json({
"type": "status", "type": "status",
@@ -40,9 +62,8 @@ async def task_websocket(websocket: WebSocket, task_uuid: str):
**status **status
}) })
# 发送历史日志(只发送注册时已存在的日志,避免与实时推送重复) # 发送历史日志。服务重启后 _restore_task_snapshot 会先把数据库快照回填到内存。
history_logs = task_manager.get_unsent_logs(task_uuid, websocket) for log in history_logs or restored_logs:
for log in history_logs:
await websocket.send_json({ await websocket.send_json({
"type": "log", "type": "log",
"task_uuid": task_uuid, "task_uuid": task_uuid,
@@ -107,8 +128,8 @@ async def batch_websocket(websocket: WebSocket, batch_id: str):
""" """
await websocket.accept() await websocket.accept()
# 注册连接(会记录当前日志数量,避免重复发送历史日志) # 注册连接,并取得注册时刻的历史日志快照,避免漏发/重复发送
task_manager.register_batch_websocket(batch_id, websocket) history_logs = task_manager.register_batch_websocket(batch_id, websocket)
logger.info(f"批量任务 WebSocket 连接已建立: {batch_id}") logger.info(f"批量任务 WebSocket 连接已建立: {batch_id}")
try: try:
@@ -121,8 +142,6 @@ async def batch_websocket(websocket: WebSocket, batch_id: str):
**status **status
}) })
# 发送历史日志(只发送注册时已存在的日志,避免与实时推送重复)
history_logs = task_manager.get_unsent_batch_logs(batch_id, websocket)
for log in history_logs: for log in history_logs:
await websocket.send_json({ await websocket.send_json({
"type": "log", "type": "log",

View File

@@ -144,20 +144,22 @@ class TaskManager:
except Exception as e: except Exception as e:
logger.warning(f"WebSocket 发送状态失败: {e}") logger.warning(f"WebSocket 发送状态失败: {e}")
def register_websocket(self, task_uuid: str, websocket): def register_websocket(self, task_uuid: str, websocket) -> List[str]:
"""注册 WebSocket 连接""" """注册 WebSocket 连接,并返回注册时刻的历史日志快照"""
history_logs: List[str] = []
with _ws_lock: with _ws_lock:
if task_uuid not in _ws_connections: if task_uuid not in _ws_connections:
_ws_connections[task_uuid] = [] _ws_connections[task_uuid] = []
# 避免重复注册同一个连接 # 避免重复注册同一个连接
if websocket not in _ws_connections[task_uuid]: if websocket not in _ws_connections[task_uuid]:
_ws_connections[task_uuid].append(websocket)
# 记录已发送的日志数量,用于发送历史日志时避免重复
with _get_log_lock(task_uuid): with _get_log_lock(task_uuid):
_ws_sent_index[task_uuid][id(websocket)] = len(_log_queues.get(task_uuid, [])) history_logs = _log_queues.get(task_uuid, []).copy()
_ws_sent_index[task_uuid][id(websocket)] = len(history_logs)
_ws_connections[task_uuid].append(websocket)
logger.info(f"WebSocket 连接已注册: {task_uuid}") logger.info(f"WebSocket 连接已注册: {task_uuid}")
else: else:
logger.warning(f"WebSocket 连接已存在,跳过重复注册: {task_uuid}") logger.warning(f"WebSocket 连接已存在,跳过重复注册: {task_uuid}")
return history_logs
def get_unsent_logs(self, task_uuid: str, websocket) -> List[str]: def get_unsent_logs(self, task_uuid: str, websocket) -> List[str]:
"""获取未发送给该 WebSocket 的日志""" """获取未发送给该 WebSocket 的日志"""
@@ -190,15 +192,32 @@ class TaskManager:
with _get_log_lock(task_uuid): with _get_log_lock(task_uuid):
return _log_queues.get(task_uuid, []).copy() return _log_queues.get(task_uuid, []).copy()
def sync_task_state(
self,
task_uuid: str,
status: Optional[dict] = None,
logs: Optional[List[str]] = None
):
"""将数据库中的任务快照回填到内存态,便于重连恢复。"""
if status:
current_status = _task_status.get(task_uuid, {}).copy()
current_status.update(status)
_task_status[task_uuid] = current_status
if logs is not None:
with _get_log_lock(task_uuid):
cached_logs = _log_queues.get(task_uuid, [])
if len(logs) >= len(cached_logs):
_log_queues[task_uuid] = list(logs)
def update_status(self, task_uuid: str, status: str, **kwargs): def update_status(self, task_uuid: str, status: str, **kwargs):
"""更新任务状态并推送到 WebSocket""" """更新任务状态"""
if task_uuid not in _task_status: if task_uuid not in _task_status:
_task_status[task_uuid] = {} _task_status[task_uuid] = {}
_task_status[task_uuid]["status"] = status _task_status[task_uuid]["status"] = status
_task_status[task_uuid].update(kwargs) _task_status[task_uuid].update(kwargs)
# 推送状态变更到 WebSocket线程安全兼容同步线程调用
if self._loop and self._loop.is_running(): if self._loop and self._loop.is_running():
try: try:
asyncio.run_coroutine_threadsafe( asyncio.run_coroutine_threadsafe(
@@ -206,7 +225,7 @@ class TaskManager:
self._loop self._loop
) )
except Exception as e: except Exception as e:
logger.warning(f"推送状态到 WebSocket 失败: {e}") logger.warning(f"广播任务状态失败: {e}")
def get_status(self, task_uuid: str) -> Optional[dict]: def get_status(self, task_uuid: str) -> Optional[dict]:
"""获取任务状态""" """获取任务状态"""
@@ -223,18 +242,16 @@ class TaskManager:
def init_batch(self, batch_id: str, total: int): def init_batch(self, batch_id: str, total: int):
"""初始化批量任务""" """初始化批量任务"""
with _get_batch_lock(batch_id): _batch_status[batch_id] = {
_batch_status[batch_id] = { "status": "running",
"status": "running", "total": total,
"total": total, "completed": 0,
"completed": 0, "success": 0,
"success": 0, "failed": 0,
"failed": 0, "skipped": 0,
"skipped": 0, "current_index": 0,
"current_index": 0, "finished": False
"finished": False, }
"cancelled": False,
}
logger.info(f"批量任务 {batch_id} 已初始化,总数: {total}") logger.info(f"批量任务 {batch_id} 已初始化,总数: {total}")
def add_batch_log(self, batch_id: str, log_message: str): def add_batch_log(self, batch_id: str, log_message: str):
@@ -278,11 +295,11 @@ class TaskManager:
def update_batch_status(self, batch_id: str, **kwargs): def update_batch_status(self, batch_id: str, **kwargs):
"""更新批量任务状态""" """更新批量任务状态"""
with _get_batch_lock(batch_id): if batch_id not in _batch_status:
if batch_id not in _batch_status: logger.warning(f"批量任务 {batch_id} 不存在")
logger.warning(f"批量任务 {batch_id} 不存在") return
return
_batch_status[batch_id].update(kwargs) _batch_status[batch_id].update(kwargs)
# 异步广播状态更新 # 异步广播状态更新
if self._loop and self._loop.is_running(): if self._loop and self._loop.is_running():
@@ -294,35 +311,6 @@ class TaskManager:
except Exception as e: except Exception as e:
logger.warning(f"广播批量状态失败: {e}") logger.warning(f"广播批量状态失败: {e}")
def record_batch_task_result(self, batch_id: str, task_status: str) -> Optional[dict]:
"""原子记录单个子任务的终态并返回快照。"""
with _get_batch_lock(batch_id):
status = _batch_status.get(batch_id)
if status is None:
logger.warning(f"批量任务 {batch_id} 不存在")
return None
status["completed"] += 1
if task_status == "completed":
status["success"] += 1
elif task_status == "failed":
status["failed"] += 1
elif task_status == "cancelled":
status["skipped"] += 1
snapshot = status.copy()
if self._loop and self._loop.is_running():
try:
asyncio.run_coroutine_threadsafe(
self._broadcast_batch_status(batch_id),
self._loop
)
except Exception as e:
logger.warning(f"广播批量状态失败: {e}")
return snapshot
async def _broadcast_batch_status(self, batch_id: str): async def _broadcast_batch_status(self, batch_id: str):
"""广播批量任务状态""" """广播批量任务状态"""
with _ws_lock: with _ws_lock:
@@ -343,9 +331,7 @@ class TaskManager:
def get_batch_status(self, batch_id: str) -> Optional[dict]: def get_batch_status(self, batch_id: str) -> Optional[dict]:
"""获取批量任务状态""" """获取批量任务状态"""
with _get_batch_lock(batch_id): return _batch_status.get(batch_id)
status = _batch_status.get(batch_id)
return status.copy() if status else None
def get_batch_logs(self, batch_id: str) -> List[str]: def get_batch_logs(self, batch_id: str) -> List[str]:
"""获取批量任务日志""" """获取批量任务日志"""
@@ -354,44 +340,33 @@ class TaskManager:
def is_batch_cancelled(self, batch_id: str) -> bool: def is_batch_cancelled(self, batch_id: str) -> bool:
"""检查批量任务是否已取消""" """检查批量任务是否已取消"""
with _get_batch_lock(batch_id): status = _batch_status.get(batch_id, {})
status = _batch_status.get(batch_id, {}) return status.get("cancelled", False)
return status.get("cancelled", False)
def cancel_batch(self, batch_id: str): def cancel_batch(self, batch_id: str):
"""取消批量任务""" """取消批量任务"""
changed = False if batch_id in _batch_status:
with _get_batch_lock(batch_id): _batch_status[batch_id]["cancelled"] = True
if batch_id in _batch_status: _batch_status[batch_id]["status"] = "cancelling"
_batch_status[batch_id]["cancelled"] = True logger.info(f"批量任务 {batch_id} 已标记为取消")
_batch_status[batch_id]["status"] = "cancelling"
changed = True
logger.info(f"批量任务 {batch_id} 已标记为取消")
if changed and self._loop and self._loop.is_running(): def register_batch_websocket(self, batch_id: str, websocket) -> List[str]:
try: """注册批量任务 WebSocket 连接,并返回注册时刻的历史日志快照"""
asyncio.run_coroutine_threadsafe(
self._broadcast_batch_status(batch_id),
self._loop
)
except Exception as e:
logger.warning(f"广播批量状态失败: {e}")
def register_batch_websocket(self, batch_id: str, websocket):
"""注册批量任务 WebSocket 连接"""
key = f"batch_{batch_id}" key = f"batch_{batch_id}"
history_logs: List[str] = []
with _ws_lock: with _ws_lock:
if key not in _ws_connections: if key not in _ws_connections:
_ws_connections[key] = [] _ws_connections[key] = []
# 避免重复注册同一个连接 # 避免重复注册同一个连接
if websocket not in _ws_connections[key]: if websocket not in _ws_connections[key]:
_ws_connections[key].append(websocket)
# 记录已发送的日志数量,用于发送历史日志时避免重复
with _get_batch_lock(batch_id): with _get_batch_lock(batch_id):
_ws_sent_index[key][id(websocket)] = len(_batch_logs.get(batch_id, [])) history_logs = _batch_logs.get(batch_id, []).copy()
_ws_sent_index[key][id(websocket)] = len(history_logs)
_ws_connections[key].append(websocket)
logger.info(f"批量任务 WebSocket 连接已注册: {batch_id}") logger.info(f"批量任务 WebSocket 连接已注册: {batch_id}")
else: else:
logger.warning(f"批量任务 WebSocket 连接已存在,跳过重复注册: {batch_id}") logger.warning(f"批量任务 WebSocket 连接已存在,跳过重复注册: {batch_id}")
return history_logs
def get_unsent_batch_logs(self, batch_id: str, websocket) -> List[str]: def get_unsent_batch_logs(self, batch_id: str, websocket) -> List[str]:
"""获取未发送给该 WebSocket 的批量任务日志""" """获取未发送给该 WebSocket 的批量任务日志"""

5
static/favicon.svg Normal file
View File

@@ -0,0 +1,5 @@
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 64 64">
<rect width="64" height="64" rx="14" fill="#111827"/>
<path d="M18 20h28v6H18zm0 10h20v6H18zm0 10h28v6H18z" fill="#f9fafb"/>
<circle cx="46" cy="33" r="6" fill="#22c55e"/>
</svg>

After

Width:  |  Height:  |  Size: 246 B

View File

@@ -1055,6 +1055,7 @@ function resetButtons() {
elements.cancelBtn.disabled = true; elements.cancelBtn.disabled = true;
currentTask = null; currentTask = null;
currentBatch = null; currentBatch = null;
isBatchMode = false;
// 重置完成标志 // 重置完成标志
taskCompleted = false; taskCompleted = false;
batchCompleted = false; batchCompleted = false;
@@ -1280,13 +1281,12 @@ function connectBatchWebSocket(batchId) {
if (!toastShown) { if (!toastShown) {
toastShown = true; toastShown = true;
if (data.status === 'completed') { if (data.status === 'completed') {
const batchLabel = isOutlookBatchMode ? 'Outlook 批量' : '批量'; addLog('success', `[完成] Outlook 批量任务完成!成功: ${data.success}, 失败: ${data.failed}, 跳过: ${data.skipped || 0}`);
addLog('success', `[完成] ${batchLabel}任务完成!成功: ${data.success}, 失败: ${data.failed}, 跳过: ${data.skipped || 0}`);
if (data.success > 0) { if (data.success > 0) {
toast.success(`${batchLabel}注册完成,成功 ${data.success}`); toast.success(`Outlook 批量注册完成,成功 ${data.success}`);
loadRecentAccounts(); loadRecentAccounts();
} else { } else {
toast.warning(`${batchLabel}注册完成,但没有成功注册任何账号`); toast.warning('Outlook 批量注册完成,但没有成功注册任何账号');
} }
} else if (data.status === 'failed') { } else if (data.status === 'failed') {
addLog('error', '[错误] 批量任务执行失败'); addLog('error', '[错误] 批量任务执行失败');

View File

@@ -72,25 +72,6 @@ const elements = {
editOutlookForm: document.getElementById('edit-outlook-form'), editOutlookForm: document.getElementById('edit-outlook-form'),
closeEditOutlookModal: document.getElementById('close-edit-outlook-modal'), closeEditOutlookModal: document.getElementById('close-edit-outlook-modal'),
cancelEditOutlook: document.getElementById('cancel-edit-outlook'), cancelEditOutlook: document.getElementById('cancel-edit-outlook'),
// 收件箱模态框
inboxModal: document.getElementById('inbox-modal'),
closeInboxModal: document.getElementById('close-inbox-modal'),
inboxRefreshBtn: document.getElementById('inbox-refresh-btn'),
inboxOnlyUnseen: document.getElementById('inbox-only-unseen'),
inboxLoading: document.getElementById('inbox-loading'),
inboxTable: document.getElementById('inbox-table'),
inboxTbody: document.getElementById('inbox-tbody'),
inboxEmpty: document.getElementById('inbox-empty'),
inboxModalEmail: document.getElementById('inbox-modal-email'),
// 邮件正文模态框
emailDetailModal: document.getElementById('email-detail-modal'),
closeEmailDetailModal: document.getElementById('close-email-detail-modal'),
emailDetailSubject: document.getElementById('email-detail-subject'),
emailDetailSender: document.getElementById('email-detail-sender'),
emailDetailDate: document.getElementById('email-detail-date'),
emailDetailBody: document.getElementById('email-detail-body'),
}; };
const CUSTOM_SUBTYPE_LABELS = { const CUSTOM_SUBTYPE_LABELS = {
@@ -183,12 +164,6 @@ function initEventListeners() {
document.addEventListener('click', () => { document.addEventListener('click', () => {
document.querySelectorAll('.dropdown-menu.active').forEach(m => m.classList.remove('active')); document.querySelectorAll('.dropdown-menu.active').forEach(m => m.classList.remove('active'));
}); });
// 收件箱模态框事件
elements.closeInboxModal.addEventListener('click', () => elements.inboxModal.classList.remove('active'));
elements.closeEmailDetailModal.addEventListener('click', () => elements.emailDetailModal.classList.remove('active'));
elements.inboxRefreshBtn.addEventListener('click', () => loadInbox(currentInboxServiceId, true));
elements.inboxOnlyUnseen.addEventListener('change', () => loadInbox(currentInboxServiceId));
} }
function toggleEmailMoreMenu(btn) { function toggleEmailMoreMenu(btn) {
@@ -272,7 +247,6 @@ async function loadOutlookServices() {
<td>${format.date(service.last_used)}</td> <td>${format.date(service.last_used)}</td>
<td> <td>
<div style="display:flex;gap:4px;align-items:center;white-space:nowrap;"> <div style="display:flex;gap:4px;align-items:center;white-space:nowrap;">
<button class="btn btn-secondary btn-sm" onclick="openInboxModal(${service.id}, '${escapeHtml(service.config?.email || service.name)}')">收件箱</button>
<button class="btn btn-secondary btn-sm" onclick="editOutlookService(${service.id})">编辑</button> <button class="btn btn-secondary btn-sm" onclick="editOutlookService(${service.id})">编辑</button>
<div class="dropdown" style="position:relative;"> <div class="dropdown" style="position:relative;">
<button class="btn btn-secondary btn-sm" onclick="event.stopPropagation();toggleEmailMoreMenu(this)">更多</button> <button class="btn btn-secondary btn-sm" onclick="event.stopPropagation();toggleEmailMoreMenu(this)">更多</button>
@@ -813,57 +787,3 @@ async function handleEditOutlook(e) {
toast.error('更新失败: ' + error.message); toast.error('更新失败: ' + error.message);
} }
} }
// ============== 收件箱 ==============
let currentInboxServiceId = null;
async function openInboxModal(serviceId, email) {
currentInboxServiceId = serviceId;
elements.inboxModalEmail.textContent = email;
elements.inboxOnlyUnseen.checked = false;
elements.inboxModal.classList.add('active');
await loadInbox(serviceId);
}
async function loadInbox(serviceId) {
if (!serviceId) return;
const onlyUnseen = elements.inboxOnlyUnseen.checked;
elements.inboxLoading.style.display = 'block';
elements.inboxTable.style.display = 'none';
elements.inboxEmpty.style.display = 'none';
elements.inboxEmpty.textContent = '暂无邮件';
try {
const params = new URLSearchParams({ count: 50, only_unseen: onlyUnseen });
const data = await api.get(`/email-services/${serviceId}/inbox?${params}`);
const emails = data.emails || [];
elements.inboxLoading.style.display = 'none';
if (emails.length === 0) {
elements.inboxEmpty.style.display = 'block';
return;
}
elements.inboxTbody.innerHTML = emails.map(m => {
const dataAttr = escapeHtml(JSON.stringify(m));
return `<tr style="cursor:pointer;" onclick="showEmailDetail(JSON.parse(this.dataset.mail))" data-mail="${dataAttr}">
<td style="text-align:center;">${m.is_read ? '' : '<span style="color:var(--primary);font-size:10px;">●</span>'}</td>
<td style="max-width:300px;overflow:hidden;text-overflow:ellipsis;white-space:nowrap;" title="${escapeHtml(m.subject)}">${escapeHtml(m.subject) || '(无主题)'}</td>
<td style="max-width:180px;overflow:hidden;text-overflow:ellipsis;white-space:nowrap;" title="${escapeHtml(m.sender)}">${escapeHtml(m.sender)}</td>
<td>${format.date(m.received_at)}</td>
</tr>`;
}).join('');
elements.inboxTable.style.display = 'table';
} catch (e) {
elements.inboxLoading.style.display = 'none';
elements.inboxEmpty.style.display = 'block';
elements.inboxEmpty.textContent = '加载失败:' + (e.message || '未知错误');
console.error('加载收件箱失败:', e);
}
}
function showEmailDetail(mail) {
elements.emailDetailSubject.textContent = mail.subject || '(无主题)';
elements.emailDetailSender.textContent = mail.sender || '';
elements.emailDetailDate.textContent = format.date(mail.received_at);
elements.emailDetailBody.textContent = mail.body || mail.body_preview || '(无正文)';
elements.emailDetailModal.classList.add('active');
}

View File

@@ -31,7 +31,6 @@ const elements = {
proxiesTable: document.getElementById('proxies-table'), proxiesTable: document.getElementById('proxies-table'),
addProxyBtn: document.getElementById('add-proxy-btn'), addProxyBtn: document.getElementById('add-proxy-btn'),
testAllProxiesBtn: document.getElementById('test-all-proxies-btn'), testAllProxiesBtn: document.getElementById('test-all-proxies-btn'),
deleteDisabledProxiesBtn: document.getElementById('delete-disabled-proxies-btn'),
addProxyModal: document.getElementById('add-proxy-modal'), addProxyModal: document.getElementById('add-proxy-modal'),
proxyItemForm: document.getElementById('proxy-item-form'), proxyItemForm: document.getElementById('proxy-item-form'),
closeProxyModal: document.getElementById('close-proxy-modal'), closeProxyModal: document.getElementById('close-proxy-modal'),
@@ -207,10 +206,6 @@ function initEventListeners() {
elements.testAllProxiesBtn.addEventListener('click', handleTestAllProxies); elements.testAllProxiesBtn.addEventListener('click', handleTestAllProxies);
} }
if (elements.deleteDisabledProxiesBtn) {
elements.deleteDisabledProxiesBtn.addEventListener('click', handleDeleteDisabledProxies);
}
if (elements.closeProxyModal) { if (elements.closeProxyModal) {
elements.closeProxyModal.addEventListener('click', closeProxyModal); elements.closeProxyModal.addEventListener('click', closeProxyModal);
} }
@@ -678,16 +673,16 @@ async function handleOutlookBatchImport() {
lines.forEach((line, index) => { lines.forEach((line, index) => {
const parts = line.split('----').map(p => p.trim()); const parts = line.split('----').map(p => p.trim());
if (parts.length < 4) { if (parts.length < 2) {
errors.push(`${index + 1} 行格式错误,必须为 邮箱----密码----client_id----refresh_token`); errors.push(`${index + 1} 行格式错误`);
return; return;
} }
const account = { const account = {
email: parts[0], email: parts[0],
password: parts[1], password: parts[1],
client_id: parts[2], client_id: parts[2] || null,
refresh_token: parts[3], refresh_token: parts[3] || null,
enabled: enabled, enabled: enabled,
priority: priority priority: priority
}; };
@@ -697,11 +692,6 @@ async function handleOutlookBatchImport() {
return; return;
} }
if (!account.client_id || !account.refresh_token) {
errors.push(`${index + 1} 行 client_id 或 refresh_token 不能为空`);
return;
}
accounts.push(account); accounts.push(account);
}); });
@@ -777,13 +767,11 @@ async function loadProxies() {
try { try {
const data = await api.get('/settings/proxies'); const data = await api.get('/settings/proxies');
renderProxies(data.proxies); renderProxies(data.proxies);
updateProxyBulkActions(data.proxies || []);
} catch (error) { } catch (error) {
console.error('加载代理列表失败:', error); console.error('加载代理列表失败:', error);
updateProxyBulkActions([]);
elements.proxiesTable.innerHTML = ` elements.proxiesTable.innerHTML = `
<tr> <tr>
<td colspan="8"> <td colspan="7">
<div class="empty-state"> <div class="empty-state">
<div class="empty-state-icon">❌</div> <div class="empty-state-icon">❌</div>
<div class="empty-state-title">加载失败</div> <div class="empty-state-title">加载失败</div>
@@ -799,7 +787,7 @@ function renderProxies(proxies) {
if (!proxies || proxies.length === 0) { if (!proxies || proxies.length === 0) {
elements.proxiesTable.innerHTML = ` elements.proxiesTable.innerHTML = `
<tr> <tr>
<td colspan="8"> <td colspan="7">
<div class="empty-state"> <div class="empty-state">
<div class="empty-state-icon">🌐</div> <div class="empty-state-icon">🌐</div>
<div class="empty-state-title">暂无代理</div> <div class="empty-state-title">暂无代理</div>
@@ -843,17 +831,6 @@ function renderProxies(proxies) {
`).join(''); `).join('');
} }
function updateProxyBulkActions(proxies) {
if (!elements.deleteDisabledProxiesBtn) return;
const disabledCount = (proxies || []).filter(proxy => !proxy.enabled).length;
elements.deleteDisabledProxiesBtn.disabled = disabledCount === 0;
elements.deleteDisabledProxiesBtn.dataset.count = String(disabledCount);
elements.deleteDisabledProxiesBtn.textContent = disabledCount > 0
? `🧹 删除禁用项 (${disabledCount})`
: '🧹 删除禁用项';
}
function toggleSettingsMoreMenu(btn) { function toggleSettingsMoreMenu(btn) {
const menu = btn.nextElementSibling; const menu = btn.nextElementSibling;
const isActive = menu.classList.contains('active'); const isActive = menu.classList.contains('active');
@@ -949,12 +926,7 @@ async function testProxyItem(id) {
if (result.success) { if (result.success) {
toast.success(result.message); toast.success(result.message);
} else { } else {
if (result.auto_disabled) { toast.error(result.message);
toast.warning(result.message);
await loadProxies();
} else {
toast.error(result.message);
}
} }
} catch (error) { } catch (error) {
toast.error('测试失败: ' + error.message); toast.error('测试失败: ' + error.message);
@@ -987,22 +959,6 @@ async function deleteProxyItem(id) {
} }
} }
async function handleDeleteDisabledProxies() {
const count = Number(elements.deleteDisabledProxiesBtn?.dataset.count || 0);
if (!count) return;
const confirmed = await confirm(`确定要删除全部 ${count} 个已禁用代理吗?此操作不可恢复。`);
if (!confirmed) return;
try {
const result = await api.delete('/settings/proxies/disabled/batch-delete');
toast.success(result.message);
await loadProxies();
} catch (error) {
toast.error('批量删除失败: ' + error.message);
}
}
// 测试所有代理 // 测试所有代理
async function handleTestAllProxies() { async function handleTestAllProxies() {
elements.testAllProxiesBtn.disabled = true; elements.testAllProxiesBtn.disabled = true;
@@ -1010,13 +966,8 @@ async function handleTestAllProxies() {
try { try {
const result = await api.post('/settings/proxies/test-all'); const result = await api.post('/settings/proxies/test-all');
const summary = `测试完成: 成功 ${result.success}, 失败 ${result.failed}`; toast.info(`测试完成: 成功 ${result.success}, 失败 ${result.failed}`);
if (result.auto_disabled > 0) { loadProxies();
toast.warning(`${summary},已自动禁用 ${result.auto_disabled}`);
} else {
toast.info(summary);
}
await loadProxies();
} catch (error) { } catch (error) {
toast.error('测试失败: ' + error.message); toast.error('测试失败: ' + error.message);
} finally { } finally {

View File

@@ -5,7 +5,7 @@
<meta name="viewport" content="width=device-width, initial-scale=1.0"> <meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>账号管理 - OpenAI 注册系统</title> <title>账号管理 - OpenAI 注册系统</title>
<link rel="stylesheet" href="/static/css/style.css?v={{ static_version }}"> <link rel="stylesheet" href="/static/css/style.css?v={{ static_version }}">
<link rel="icon" href="data:image/svg+xml,<svg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 100 100'><text y='.9em' font-size='90'>📋</text></svg>"> <link rel="icon" type="image/svg+xml" href="/static/favicon.svg?v={{ static_version }}">
<style> <style>
.password-cell { .password-cell {
font-family: var(--font-mono); font-family: var(--font-mono);

View File

@@ -5,7 +5,7 @@
<meta name="viewport" content="width=device-width, initial-scale=1.0"> <meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>邮箱服务 - OpenAI 注册系统</title> <title>邮箱服务 - OpenAI 注册系统</title>
<link rel="stylesheet" href="/static/css/style.css?v={{ static_version }}"> <link rel="stylesheet" href="/static/css/style.css?v={{ static_version }}">
<link rel="icon" href="data:image/svg+xml,<svg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 100 100'><text y='.9em' font-size='90'>📧</text></svg>"> <link rel="icon" type="image/svg+xml" href="/static/favicon.svg?v={{ static_version }}">
</head> </head>
<body> <body>
<div class="container"> <div class="container">
@@ -62,13 +62,16 @@
</div> </div>
<div class="card-body" id="outlook-import-body" style="display: none;"> <div class="card-body" id="outlook-import-body" style="display: none;">
<div class="import-info"> <div class="import-info">
<p><strong>格式(每行一个账户)</strong></p> <p><strong>支持格式:</strong></p>
<p><code>邮箱----密码----client_id----refresh_token</code></p> <ul>
<p>使用四个连字符(----)分隔字段,以 # 开头的行将被忽略。</p> <li><code>邮箱----密码</code> (密码认证)</li>
<li><code>邮箱----密码----client_id----refresh_token</code> XOAUTH2 认证,推荐)</li>
</ul>
<p>每行一个账户,使用四个连字符(----)分隔字段。以 # 开头的行将被忽略。</p>
</div> </div>
<div class="form-group"> <div class="form-group">
<label for="outlook-import-data">批量导入数据</label> <label for="outlook-import-data">批量导入数据</label>
<textarea id="outlook-import-data" rows="8" placeholder="example@outlook.com----password123----client_id----refresh_token"></textarea> <textarea id="outlook-import-data" rows="8" placeholder="example@outlook.com----password123&#10;test@outlook.com----password456----client_id----refresh_token"></textarea>
</div> </div>
<div class="form-row"> <div class="form-row">
<div class="form-group"> <div class="form-group">
@@ -513,55 +516,6 @@
</div> </div>
<!-- 收件箱模态框 -->
<div class="modal" id="inbox-modal">
<div class="modal-content" style="max-width:800px;width:95%;">
<div class="modal-header">
<h3>📬 收件箱 — <span id="inbox-modal-email"></span></h3>
<div style="display:flex;gap:8px;align-items:center;">
<label style="display:flex;align-items:center;gap:4px;font-size:13px;">
<input type="checkbox" id="inbox-only-unseen"> 仅未读
</label>
<button class="btn btn-secondary btn-sm" id="inbox-refresh-btn">刷新</button>
<button class="modal-close" id="close-inbox-modal">&times;</button>
</div>
</div>
<div class="modal-body" style="padding:0;max-height:70vh;overflow-y:auto;">
<div id="inbox-loading" style="padding:32px;text-align:center;">加载中...</div>
<table class="data-table" id="inbox-table" style="display:none;">
<thead>
<tr>
<th style="width:36px;"></th>
<th>主题</th>
<th style="width:200px;">发件人</th>
<th style="width:150px;">时间</th>
</tr>
</thead>
<tbody id="inbox-tbody"></tbody>
</table>
<div id="inbox-empty" style="display:none;padding:32px;text-align:center;color:var(--text-muted);">暂无邮件</div>
</div>
</div>
</div>
<!-- 邮件正文模态框 -->
<div class="modal" id="email-detail-modal">
<div class="modal-content" style="max-width:700px;width:95%;">
<div class="modal-header">
<div>
<h3 id="email-detail-subject" style="margin:0;"></h3>
<div style="font-size:12px;color:var(--text-muted);margin-top:4px;">
<span id="email-detail-sender"></span> · <span id="email-detail-date"></span>
</div>
</div>
<button class="modal-close" id="close-email-detail-modal">&times;</button>
</div>
<div class="modal-body" style="max-height:65vh;overflow-y:auto;">
<div id="email-detail-body" style="white-space:pre-wrap;word-break:break-word;font-size:13px;"></div>
</div>
</div>
</div>
<script src="/static/js/utils.js?v={{ static_version }}"></script> <script src="/static/js/utils.js?v={{ static_version }}"></script>
<script src="/static/js/email_services.js?v={{ static_version }}"></script> <script src="/static/js/email_services.js?v={{ static_version }}"></script>
</body> </body>

View File

@@ -5,7 +5,7 @@
<meta name="viewport" content="width=device-width, initial-scale=1.0"> <meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>注册控制台 - OpenAI 注册系统</title> <title>注册控制台 - OpenAI 注册系统</title>
<link rel="stylesheet" href="/static/css/style.css?v={{ static_version }}"> <link rel="stylesheet" href="/static/css/style.css?v={{ static_version }}">
<link rel="icon" href="data:image/svg+xml,<svg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 100 100'><text y='.9em' font-size='90'>🚀</text></svg>"> <link rel="icon" type="image/svg+xml" href="/static/favicon.svg?v={{ static_version }}">
<style> <style>
/* 两栏布局 */ /* 两栏布局 */
.two-column-layout { .two-column-layout {

View File

@@ -5,7 +5,7 @@
<meta name="viewport" content="width=device-width, initial-scale=1.0"> <meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>访问验证 - OpenAI 注册系统</title> <title>访问验证 - OpenAI 注册系统</title>
<link rel="stylesheet" href="/static/css/style.css?v={{ static_version }}"> <link rel="stylesheet" href="/static/css/style.css?v={{ static_version }}">
<link rel="icon" href="data:image/svg+xml,<svg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 100 100'><text y='.9em' font-size='90'>🔒</text></svg>"> <link rel="icon" type="image/svg+xml" href="/static/favicon.svg?v={{ static_version }}">
<style> <style>
.login-wrap { .login-wrap {
max-width: 420px; max-width: 420px;

View File

@@ -5,7 +5,7 @@
<meta name="viewport" content="width=device-width, initial-scale=1.0"> <meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>支付升级 - OpenAI 注册系统</title> <title>支付升级 - OpenAI 注册系统</title>
<link rel="stylesheet" href="/static/css/style.css?v={{ static_version }}"> <link rel="stylesheet" href="/static/css/style.css?v={{ static_version }}">
<link rel="icon" href="data:image/svg+xml,<svg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 100 100'><text y='.9em' font-size='90'>💳</text></svg>"> <link rel="icon" type="image/svg+xml" href="/static/favicon.svg?v={{ static_version }}">
<style> <style>
.plan-cards { .plan-cards {
display: grid; display: grid;

View File

@@ -5,7 +5,7 @@
<meta name="viewport" content="width=device-width, initial-scale=1.0"> <meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>系统设置 - OpenAI 注册系统</title> <title>系统设置 - OpenAI 注册系统</title>
<link rel="stylesheet" href="/static/css/style.css?v={{ static_version }}"> <link rel="stylesheet" href="/static/css/style.css?v={{ static_version }}">
<link rel="icon" href="data:image/svg+xml,<svg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 100 100'><text y='.9em' font-size='90'>⚙️</text></svg>"> <link rel="icon" type="image/svg+xml" href="/static/favicon.svg?v={{ static_version }}">
</head> </head>
<body> <body>
<div class="container"> <div class="container">
@@ -95,7 +95,6 @@
<h3>代理列表</h3> <h3>代理列表</h3>
<div style="display: flex; gap: var(--spacing-sm);"> <div style="display: flex; gap: var(--spacing-sm);">
<button class="btn btn-secondary btn-sm" id="test-all-proxies-btn">🔌 测试全部</button> <button class="btn btn-secondary btn-sm" id="test-all-proxies-btn">🔌 测试全部</button>
<button class="btn btn-danger btn-sm" id="delete-disabled-proxies-btn" disabled>🧹 删除禁用项</button>
<button class="btn btn-primary btn-sm" id="add-proxy-btn"> 添加代理</button> <button class="btn btn-primary btn-sm" id="add-proxy-btn"> 添加代理</button>
</div> </div>
</div> </div>
@@ -116,7 +115,7 @@
</thead> </thead>
<tbody id="proxies-table"> <tbody id="proxies-table">
<tr> <tr>
<td colspan="8"> <td colspan="7">
<div class="empty-state"> <div class="empty-state">
<div class="empty-state-icon">🌐</div> <div class="empty-state-icon">🌐</div>
<div class="empty-state-title">暂无代理</div> <div class="empty-state-title">暂无代理</div>

View File

@@ -1,21 +1,90 @@
from concurrent.futures import ThreadPoolExecutor import asyncio
from contextlib import contextmanager
from types import SimpleNamespace
from src.web.routes import registration as registration_routes
from src.web.task_manager import task_manager from src.web.task_manager import task_manager
def test_record_batch_task_result_is_atomic_under_threads(): def test_init_batch_state_keeps_batch_tasks_and_task_manager_in_sync():
batch_id = "batch-atomic-test" batch_id = "batch-sync-init"
task_manager.init_batch(batch_id, 100) task_uuids = ["task-1", "task-2", "task-3"]
statuses = ["completed"] * 60 + ["failed"] * 40 registration_routes.batch_tasks.pop(batch_id, None)
registration_routes._init_batch_state(batch_id, task_uuids)
with ThreadPoolExecutor(max_workers=16) as executor: batch_snapshot = registration_routes.batch_tasks[batch_id]
list(executor.map(lambda status: task_manager.record_batch_task_result(batch_id, status), statuses)) manager_snapshot = task_manager.get_batch_status(batch_id)
snapshot = task_manager.get_batch_status(batch_id) assert manager_snapshot is not None
assert batch_snapshot["total"] == manager_snapshot["total"] == 3
assert batch_snapshot["completed"] == manager_snapshot["completed"] == 0
assert batch_snapshot["success"] == manager_snapshot["success"] == 0
assert batch_snapshot["failed"] == manager_snapshot["failed"] == 0
assert batch_snapshot["finished"] is False
assert manager_snapshot["finished"] is False
assert manager_snapshot["status"] == "running"
assert snapshot is not None
assert snapshot["completed"] == 100 def test_run_batch_parallel_keeps_counter_updates_in_sync(monkeypatch):
assert snapshot["success"] == 60 batch_id = "batch-sync-parallel"
assert snapshot["failed"] == 40 task_uuids = ["task-ok-1", "task-fail-1", "task-ok-2"]
assert snapshot["skipped"] == 0 task_statuses = {
"task-ok-1": "completed",
"task-fail-1": "failed",
"task-ok-2": "completed",
}
async def fake_run_registration_task(
task_uuid,
email_service_type,
proxy,
email_service_config,
email_service_id,
log_prefix="",
batch_id="",
auto_upload_cpa=False,
cpa_service_ids=None,
auto_upload_sub2api=False,
sub2api_service_ids=None,
auto_upload_tm=False,
tm_service_ids=None,
):
assert task_uuid in task_statuses
@contextmanager
def fake_get_db():
yield object()
def fake_get_registration_task(db, task_uuid):
status = task_statuses[task_uuid]
error_message = None if status == "completed" else f"{task_uuid}-error"
return SimpleNamespace(status=status, error_message=error_message)
registration_routes.batch_tasks.pop(batch_id, None)
monkeypatch.setattr(registration_routes, "run_registration_task", fake_run_registration_task)
monkeypatch.setattr(registration_routes, "get_db", fake_get_db)
monkeypatch.setattr(registration_routes.crud, "get_registration_task", fake_get_registration_task)
asyncio.run(
registration_routes.run_batch_parallel(
batch_id=batch_id,
task_uuids=task_uuids,
email_service_type="tempmail",
proxy=None,
email_service_config=None,
email_service_id=None,
concurrency=2,
)
)
batch_snapshot = registration_routes.batch_tasks[batch_id]
manager_snapshot = task_manager.get_batch_status(batch_id)
assert manager_snapshot is not None
assert batch_snapshot["completed"] == manager_snapshot["completed"] == 3
assert batch_snapshot["success"] == manager_snapshot["success"] == 2
assert batch_snapshot["failed"] == manager_snapshot["failed"] == 1
assert batch_snapshot["finished"] is True
assert manager_snapshot["finished"] is True
assert manager_snapshot["status"] == "completed"

View File

@@ -1,14 +0,0 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# @Time : 2026/3/21 14:48
from src.core.utils import base64_payload_decode, base64_decode
if __name__ == '__main__':
print(base64_payload_decode("eyJzZXNzaW9uX2lkIjoiYXV0aHNlc3NfcUE5eFByY3RaZmtHWXJnSlJGdUpxRXBPIiwiY291bnRyeV9jb2RlX2hpbnQiOiJVUyIsImF1dGhfc2Vzc2lvbl9sb2dnaW5nX2lkIjoiMTk0ZDg5OGQtM2Q0ZC00MzU5LWI1NTQtYmJjMjc1YTJlYjU1IiwicHJvbW8iOiIiLCJzaWdudXBfc291cmNlIjoiIiwib3BlbmFpX2NsaWVudF9pZCI6ImFwcF9FTW9hbUVFWjczZjBDa1hhWHA3aHJhbm4iLCJhcHBfbmFtZV9lbnVtIjoib2FpY2xpIiwiYWFzX2VuYWJsZWQiOmZhbHNlLCJvcmlnaW5hbF9zY3JlZW5faGludCI6ImxvZ2luIiwicGFzc3dvcmRsZXNzX2Rpc2FibGVkIjpmYWxzZSwicGFzc3dvcmRsZXNzX290cF9mcm9tX3Bhc3N3b3JkX3JlZGlyZWN0IjpmYWxzZSwiZW1haWxfdmVyaWZpY2F0aW9uX21vZGUiOiJwYXNzd29yZGxlc3NfbG9naW4iLCJlbWFpbCI6Imxob2xsYW5kNTcwQGdzb2xleWZveWxlLm9yZy51ayIsImVtYWlsX3ZlcmlmaWVkIjp0cnVlLCJuYW1lIjoiZXJyIiwid29ya3NwYWNlcyI6W3siaWQiOiI4NjhmZGNmYi1kNjI3LTRhZTItYTQ4Mi1jMTQxMjA1MGZhYTYiLCJuYW1lIjpudWxsLCJraW5kIjoicGVyc29uYWwiLCJwcm9maWxlX3BpY3R1cmVfYWx0X3RleHQiOiJlcnIifV19"))

View File

@@ -0,0 +1,61 @@
from src.services.base import (
EmailProviderBackoffState,
OTPTimeoutEmailServiceError,
RateLimitedEmailServiceError,
apply_adaptive_backoff,
calculate_adaptive_backoff_delay,
)
def test_calculate_adaptive_backoff_delay_uses_failure_count_progression():
assert calculate_adaptive_backoff_delay(0) == 30
assert calculate_adaptive_backoff_delay(1) == 30
assert calculate_adaptive_backoff_delay(2) == 60
assert calculate_adaptive_backoff_delay(3) == 120
def test_apply_adaptive_backoff_tracks_timeout_failures_to_one_hour():
state = EmailProviderBackoffState()
first = apply_adaptive_backoff(
state,
OTPTimeoutEmailServiceError("等待验证码超时", error_code="OTP_TIMEOUT_SECONDARY"),
now=1000.0,
)
second = apply_adaptive_backoff(
first,
OTPTimeoutEmailServiceError("等待验证码超时", error_code="OTP_TIMEOUT_SECONDARY"),
now=1031.0,
)
third = apply_adaptive_backoff(
second,
OTPTimeoutEmailServiceError("等待验证码超时", error_code="OTP_TIMEOUT_SECONDARY"),
now=1092.0,
)
assert first.failures == 1
assert first.delay_seconds == 30
assert first.opened_until == 1030.0
assert second.failures == 2
assert second.delay_seconds == 60
assert second.opened_until == 1091.0
assert third.failures == 3
assert third.delay_seconds == 3600
assert third.opened_until == 4692.0
def test_apply_adaptive_backoff_keeps_normal_rate_limit_on_exponential_curve():
state = EmailProviderBackoffState(failures=2, delay_seconds=60, opened_until=1060.0)
next_state = apply_adaptive_backoff(
state,
RateLimitedEmailServiceError("请求失败: 429", retry_after=7),
now=1100.0,
)
assert next_state.failures == 3
assert next_state.delay_seconds == 120
assert next_state.opened_until == 1220.0
assert next_state.retry_after == 7

View File

@@ -1,57 +0,0 @@
import base64
import json
from types import SimpleNamespace
from src.core.login import LoginEngine
def _build_auth_cookie(workspace_id: str) -> str:
payload = base64.urlsafe_b64encode(
json.dumps({"workspaces": [{"id": workspace_id}]}).encode("utf-8")
).decode("ascii").rstrip("=")
return f"{payload}.signature"
def test_get_workspace_id_retries_with_exponential_backoff(monkeypatch):
engine = LoginEngine.__new__(LoginEngine)
engine.logs = []
engine._log = lambda message, level="info": engine.logs.append((level, message))
auth_cookie = _build_auth_cookie("ws-123")
cookies = SimpleNamespace()
calls = {"count": 0}
def fake_get(name):
assert name == "oai-client-auth-session"
calls["count"] += 1
if calls["count"] < 4:
return None
return auth_cookie
cookies.get = fake_get
engine.session = SimpleNamespace(cookies=cookies)
sleeps = []
monkeypatch.setattr("src.core.login.time.sleep", lambda seconds: sleeps.append(seconds))
workspace_id = engine._get_workspace_id()
assert workspace_id == "ws-123"
assert calls["count"] == 4
assert sleeps == [1, 2, 4]
def test_run_always_closes_resources_on_early_return():
engine = LoginEngine.__new__(LoginEngine)
engine.logs = []
engine._log = lambda message, level="info": None
engine.close_called = False
engine.close = lambda: setattr(engine, "close_called", True)
engine._check_ip_location = lambda: (False, "blocked")
result = engine.run()
assert result.success is False
assert result.error_message == "IP 地理位置不支持: blocked"
assert engine.close_called is True

View File

@@ -1,439 +0,0 @@
from types import SimpleNamespace
import pytest
from src.config.constants import EmailServiceType, OPENAI_API_ENDPOINTS, OPENAI_PAGE_TYPES
from src.core import register
from src.services.base import BaseEmailService
class DummyEmailService(BaseEmailService):
def __init__(self):
super().__init__(EmailServiceType.TEMPMAIL, name="dummy")
def create_email(self, config=None):
return {"email": "tester@example.com", "service_id": "svc-1"}
def get_verification_code(self, email, email_id=None, timeout=120, pattern=None, otp_sent_at=None):
return "123456"
def list_emails(self, **kwargs):
return []
def delete_email(self, email_id):
return True
def check_health(self):
return True
def refresh_session(self):
return None
class FakeResponse:
def __init__(self, status_code=200, payload=None, text=""):
self.status_code = status_code
self._payload = payload if payload is not None else {}
self.text = text
def json(self):
return self._payload
class BrokenJSONResponse(FakeResponse):
def json(self):
raise ValueError("bad json")
class FakeSession:
def __init__(self, post_handler=None, get_handler=None, cookies=None):
self.post_handler = post_handler
self.get_handler = get_handler
self.cookies = cookies or {}
self.post_calls = []
self.get_calls = []
def post(self, url, **kwargs):
self.post_calls.append({"url": url, "kwargs": kwargs})
if self.post_handler is None:
raise AssertionError("unexpected post call")
return self.post_handler(url, **kwargs)
def get(self, url, **kwargs):
self.get_calls.append({"url": url, "kwargs": kwargs})
if self.get_handler is None:
raise AssertionError("unexpected get call")
return self.get_handler(url, **kwargs)
class DummyHTTPClient:
def __init__(self, proxy_url=None):
self.proxy_url = proxy_url
self.session = FakeSession()
self.closed = False
def close(self):
self.closed = True
def post(self, url, **kwargs):
raise AssertionError("unexpected http client post")
class DummyOAuthManager:
def __init__(self, **kwargs):
self.kwargs = kwargs
def start_oauth(self):
return SimpleNamespace(
auth_url="https://auth.example/start",
state="state-1",
code_verifier="verifier-1",
redirect_uri="http://localhost/callback",
)
def make_engine(monkeypatch, email_service=None):
monkeypatch.setattr(
register,
"get_settings",
lambda: SimpleNamespace(
openai_client_id="client-id",
openai_auth_url="https://auth.example/authorize",
openai_token_url="https://auth.example/token",
openai_redirect_uri="http://localhost/callback",
openai_scope="openid email profile offline_access",
),
)
monkeypatch.setattr(register, "OpenAIHTTPClient", DummyHTTPClient)
monkeypatch.setattr(register, "OAuthManager", DummyOAuthManager)
engine = register.RegistrationEngine(email_service or DummyEmailService())
engine.email = "tester@example.com"
engine.email_info = {"email": "tester@example.com", "service_id": "svc-1"}
return engine
@pytest.mark.parametrize(
"page_type",
[
"login_password",
OPENAI_PAGE_TYPES["EMAIL_OTP_VERIFICATION"],
"consent_required",
"some_other_page",
],
)
def test_submit_login_form_accepts_any_http_200_page(monkeypatch, page_type):
engine = make_engine(monkeypatch)
engine.session = FakeSession(
post_handler=lambda url, **kwargs: FakeResponse(
status_code=200,
payload={"page": {"type": page_type}},
)
)
result = engine._submit_login_form("did-1", "sen-1")
assert result.success is True
assert result.page_type == page_type
assert result.error_message == ""
def test_submit_login_form_accepts_http_200_even_when_json_is_invalid(monkeypatch):
engine = make_engine(monkeypatch)
engine.session = FakeSession(
post_handler=lambda url, **kwargs: BrokenJSONResponse(status_code=200)
)
result = engine._submit_login_form("did-1", "sen-1")
assert result.success is True
assert result.page_type == ""
assert result.response_data == {}
assert result.error_message == ""
def test_send_passwordless_otp_posts_empty_body(monkeypatch):
engine = make_engine(monkeypatch)
engine.session = FakeSession(
post_handler=lambda url, **kwargs: FakeResponse(status_code=200)
)
success = engine._send_passwordless_otp()
assert success is True
assert len(engine.session.post_calls) == 1
call = engine.session.post_calls[0]
assert call["url"] == OPENAI_API_ENDPOINTS["send_passwordless_otp"]
assert call["kwargs"]["data"] == ""
assert engine._otp_sent_at is not None
def test_send_passwordless_otp_does_not_update_timestamp_on_failure(monkeypatch):
engine = make_engine(monkeypatch)
engine._otp_sent_at = 1234.5
engine.session = FakeSession(
post_handler=lambda url, **kwargs: FakeResponse(status_code=500, text="server error")
)
success = engine._send_passwordless_otp()
assert success is False
assert engine._otp_sent_at == 1234.5
def test_get_verification_code_passes_explicit_otp_timestamp(monkeypatch):
captured = {}
class RecordingEmailService(DummyEmailService):
def get_verification_code(self, email, email_id=None, timeout=120, pattern=None, otp_sent_at=None):
captured["email"] = email
captured["email_id"] = email_id
captured["timeout"] = timeout
captured["pattern"] = pattern
captured["otp_sent_at"] = otp_sent_at
return "654321"
engine = make_engine(monkeypatch, email_service=RecordingEmailService())
code = engine._get_verification_code(otp_sent_at=1234.5)
assert code == "654321"
assert captured["email"] == "tester@example.com"
assert captured["email_id"] == "svc-1"
assert captured["timeout"] == 120
assert captured["otp_sent_at"] == 1234.5
def test_validate_verification_code_accepts_http_200_even_when_json_is_invalid(monkeypatch):
engine = make_engine(monkeypatch)
engine.session = FakeSession(
post_handler=lambda url, **kwargs: BrokenJSONResponse(status_code=200)
)
result = engine._validate_verification_code("123456")
assert result.success is True
assert result.continue_url == ""
assert result.response_data == {}
def test_run_closes_http_client_on_early_failure(monkeypatch):
engine = make_engine(monkeypatch)
tracking_client = DummyHTTPClient()
engine.http_client = tracking_client
monkeypatch.setattr(engine, "_check_ip_location", lambda: (False, None))
result = engine.run()
assert result.success is False
assert tracking_client.closed is True
assert engine.session is None
def test_fallback_to_login_flow_forces_otp_and_continue_url(monkeypatch):
engine = make_engine(monkeypatch)
steps = []
captured = {}
monkeypatch.setattr(engine, "_reset_oauth_session", lambda: steps.append("reset_session") or True)
monkeypatch.setattr(engine, "_get_device_id", lambda: steps.append("get_device_id") or "did-1")
monkeypatch.setattr(engine, "_check_sentinel", lambda did: steps.append("check_sentinel") or "sen-1")
monkeypatch.setattr(
engine,
"_submit_login_form",
lambda did, sen: steps.append("submit_login_form")
or register.SignupFormResult(success=True, page_type="login_password"),
)
def fake_send_passwordless_otp():
steps.append("send_passwordless_otp")
engine._otp_sent_at = 4567.89
return True
def fake_get_verification_code(otp_sent_at=None):
steps.append("get_verification_code")
captured["otp_sent_at"] = otp_sent_at
return "123456"
monkeypatch.setattr(engine, "_send_passwordless_otp", fake_send_passwordless_otp)
monkeypatch.setattr(engine, "_get_verification_code", fake_get_verification_code)
monkeypatch.setattr(
engine,
"_validate_verification_code",
lambda code: steps.append("validate_verification_code")
or register.OTPValidationResult(success=True, continue_url="https://auth.example/continue"),
)
def fake_try_upgrade(continue_url, stage):
steps.append("get_continue_url_and_parse_workspace")
captured["continue_url"] = continue_url
captured["stage"] = stage
return "ws-123"
monkeypatch.setattr(engine, "_try_upgrade_cookie_with_continue_url", fake_try_upgrade)
workspace_id = engine._fallback_to_login_flow()
assert workspace_id == "ws-123"
assert steps == [
"reset_session",
"get_device_id",
"check_sentinel",
"submit_login_form",
"send_passwordless_otp",
"get_verification_code",
"validate_verification_code",
"get_continue_url_and_parse_workspace",
]
assert captured["continue_url"] == "https://auth.example/continue"
assert captured["stage"] == "降级登录 Continue URL"
assert captured["otp_sent_at"] == 4567.89
def test_fallback_to_login_flow_requires_continue_url(monkeypatch):
engine = make_engine(monkeypatch)
monkeypatch.setattr(engine, "_reset_oauth_session", lambda: True)
monkeypatch.setattr(engine, "_get_device_id", lambda: "did-1")
monkeypatch.setattr(engine, "_check_sentinel", lambda did: "sen-1")
monkeypatch.setattr(
engine,
"_submit_login_form",
lambda did, sen: register.SignupFormResult(success=True, page_type="login_password"),
)
def fake_send_passwordless_otp():
engine._otp_sent_at = 9876.5
return True
monkeypatch.setattr(engine, "_send_passwordless_otp", fake_send_passwordless_otp)
monkeypatch.setattr(engine, "_get_verification_code", lambda otp_sent_at=None: "123456")
monkeypatch.setattr(
engine,
"_validate_verification_code",
lambda code: register.OTPValidationResult(success=True, continue_url=""),
)
def fail_if_called(*args, **kwargs):
raise AssertionError("continue_url 缺失时不应尝试升级 Cookie")
monkeypatch.setattr(engine, "_try_upgrade_cookie_with_continue_url", fail_if_called)
assert engine._fallback_to_login_flow() is None
def test_fallback_to_login_flow_accepts_workspace_without_continue_url(monkeypatch):
engine = make_engine(monkeypatch)
steps = []
monkeypatch.setattr(engine, "_reset_oauth_session", lambda: steps.append("reset_session") or True)
monkeypatch.setattr(engine, "_get_device_id", lambda: steps.append("get_device_id") or "did-1")
monkeypatch.setattr(engine, "_check_sentinel", lambda did: steps.append("check_sentinel") or "sen-1")
monkeypatch.setattr(
engine,
"_submit_login_form",
lambda did, sen: steps.append("submit_login_form")
or register.SignupFormResult(success=True, page_type="login_password"),
)
def fake_send_passwordless_otp():
steps.append("send_passwordless_otp")
engine._otp_sent_at = 4567.89
return True
monkeypatch.setattr(engine, "_send_passwordless_otp", fake_send_passwordless_otp)
monkeypatch.setattr(
engine,
"_get_verification_code",
lambda otp_sent_at=None: steps.append("get_verification_code") or "123456",
)
monkeypatch.setattr(
engine,
"_validate_verification_code",
lambda code: steps.append("validate_verification_code")
or register.OTPValidationResult(success=True, continue_url=""),
)
monkeypatch.setattr(
engine,
"_get_workspace_id",
lambda log_missing=True: steps.append("get_workspace_id") or "ws-cookie",
)
def fail_if_called(*args, **kwargs):
raise AssertionError("已有 workspace 时不应继续访问 continue_url")
monkeypatch.setattr(engine, "_try_upgrade_cookie_with_continue_url", fail_if_called)
workspace_id = engine._fallback_to_login_flow()
assert workspace_id == "ws-cookie"
assert steps == [
"reset_session",
"get_device_id",
"check_sentinel",
"submit_login_form",
"send_passwordless_otp",
"get_verification_code",
"validate_verification_code",
"get_workspace_id",
]
def test_get_verification_code_uses_provider_timeout_and_refreshes_once(monkeypatch):
captured = {"calls": [], "refresh_count": 0}
class RefreshableOutlookService(DummyEmailService):
def __init__(self):
super().__init__()
self.service_type = EmailServiceType.OUTLOOK
def get_verification_code(self, email, email_id=None, timeout=120, pattern=None, otp_sent_at=None):
captured["calls"].append(
{
"timeout": timeout,
"otp_sent_at": otp_sent_at,
"email": email,
"email_id": email_id,
}
)
if len(captured["calls"]) == 1:
return None
return "987654"
def refresh_session(self):
captured["refresh_count"] += 1
engine = make_engine(monkeypatch, email_service=RefreshableOutlookService())
code = engine._get_verification_code(otp_sent_at=2468.0)
assert code == "987654"
assert captured["refresh_count"] == 1
assert len(captured["calls"]) == 2
assert captured["calls"][0]["timeout"] == 180
assert captured["calls"][1]["timeout"] == 180
assert captured["calls"][0]["otp_sent_at"] == 2468.0
def test_try_upgrade_cookie_with_continue_url_retries_with_second_probe(monkeypatch):
engine = make_engine(monkeypatch)
sleep_calls = []
workspace_results = iter([None, None, None, None, None, "ws-delayed"])
engine.session = FakeSession(
get_handler=lambda url, **kwargs: FakeResponse(status_code=302),
cookies={},
)
monkeypatch.setattr(engine, "_log_cookie_state", lambda *args, **kwargs: None)
monkeypatch.setattr(engine, "_get_workspace_id", lambda log_missing=False: next(workspace_results))
monkeypatch.setattr(register.time, "sleep", lambda seconds: sleep_calls.append(seconds))
workspace_id = engine._try_upgrade_cookie_with_continue_url(
"https://auth.example/continue",
"降级登录 Continue URL",
)
assert workspace_id == "ws-delayed"
assert len(engine.session.get_calls) == 3
assert sleep_calls == [1.0, 2.0, 4.0]

View File

@@ -0,0 +1,82 @@
import json
from types import SimpleNamespace
import src.core.register as register_module
from src.config.constants import OPENAI_PAGE_TYPES
from src.core.register import RegistrationEngine
from src.services import EmailServiceType
class DummySettings:
openai_client_id = "client-id"
openai_auth_url = "https://auth.example.test"
openai_token_url = "https://token.example.test"
openai_redirect_uri = "https://callback.example.test"
openai_scope = "openid profile email"
class FakeResponse:
def __init__(self, status_code=200, payload=None, text=""):
self.status_code = status_code
self._payload = payload or {}
self.text = text
def json(self):
return self._payload
class FakeSession:
def __init__(self, response):
self.response = response
self.calls = []
def post(self, url, **kwargs):
self.calls.append({
"url": url,
**kwargs,
})
return self.response
def _build_engine(monkeypatch):
monkeypatch.setattr(register_module, "get_settings", lambda: DummySettings())
email_service = SimpleNamespace(service_type=EmailServiceType.DUCK_MAIL)
return RegistrationEngine(email_service=email_service)
def test_submit_signup_form_uses_stable_protocol_body(monkeypatch):
engine = _build_engine(monkeypatch)
session = FakeSession(FakeResponse(
status_code=200,
payload={"page": {"type": OPENAI_PAGE_TYPES["PASSWORD_REGISTRATION"]}},
))
engine.session = session
engine.email = "tester@example.com"
result = engine._submit_signup_form("did-1", None)
assert result.success is True
assert result.is_existing_account is False
assert (
session.calls[0]["data"]
== '{"username":{"value":"tester@example.com","kind":"email"},"screen_hint":"signup"}'
)
def test_register_password_uses_stable_protocol_body(monkeypatch):
engine = _build_engine(monkeypatch)
session = FakeSession(FakeResponse(status_code=200))
engine.session = session
engine.email = "tester@example.com"
monkeypatch.setattr(engine, "_generate_password", lambda length=0: "Pass12345")
success, password = engine._register_password()
assert success is True
assert password == "Pass12345"
assert session.calls[0]["data"] == json.dumps(
{
"password": "Pass12345",
"username": "tester@example.com",
}
)

View File

@@ -0,0 +1,414 @@
from contextlib import contextmanager
from pathlib import Path
from types import SimpleNamespace
import src.services.base as base_module
from src.core.register import (
ERROR_OTP_TIMEOUT_SECONDARY,
PhaseResult,
RegistrationResult,
)
from src.database.models import Base, EmailService, RegistrationTask
from src.database.session import DatabaseSessionManager
from src.services import EmailServiceType
from src.services.base import BaseEmailService, EmailProviderBackoffState
from src.web.routes import registration as registration_routes
class DummyTaskManager:
def __init__(self):
self.status_updates = []
self.logs = {}
def is_cancelled(self, task_uuid):
return False
def update_status(self, task_uuid, status, email=None, error=None, **kwargs):
self.status_updates.append((task_uuid, status, email, error, kwargs))
def create_log_callback(self, task_uuid, prefix="", batch_id=""):
def callback(message):
self.logs.setdefault(task_uuid, []).append(message)
return callback
class BackoffAwareEmailService(BaseEmailService):
def __init__(self, service_type, config=None, name=None):
super().__init__(service_type=service_type, name=name)
self.config = config or {}
def create_email(self, config=None):
return {"email": "tester@example.com", "service_id": "svc-1"}
def get_verification_code(self, **kwargs):
return None
def list_emails(self, **kwargs):
return []
def delete_email(self, email_id: str) -> bool:
return True
def check_health(self) -> bool:
return True
def test_registration_task_fails_over_after_rate_limit(monkeypatch):
runtime_dir = Path("tests_runtime")
runtime_dir.mkdir(exist_ok=True)
db_path = runtime_dir / "registration_failover.db"
if db_path.exists():
db_path.unlink()
manager = DatabaseSessionManager(f"sqlite:///{db_path}")
Base.metadata.create_all(bind=manager.engine)
task_uuid = "task-rate-limit-failover"
with manager.session_scope() as session:
session.add(RegistrationTask(task_uuid=task_uuid, status="pending"))
session.add_all([
EmailService(
service_type="duck_mail",
name="duck-primary",
config={
"base_url": "https://mail-1.example.test",
"default_domain": "mail.example.test",
},
enabled=True,
priority=0,
),
EmailService(
service_type="duck_mail",
name="duck-secondary",
config={
"base_url": "https://mail-2.example.test",
"default_domain": "mail.example.test",
},
enabled=True,
priority=1,
),
])
@contextmanager
def fake_get_db():
session = manager.SessionLocal()
try:
yield session
finally:
session.close()
class DummySettings:
pass
attempts = []
class FakeRegistrationEngine:
def __init__(self, email_service, proxy_url=None, callback_logger=None, task_uuid=None):
self.email_service = email_service
self.phase_history = []
def run(self):
attempts.append(self.email_service.name)
if self.email_service.name == "duck-primary":
self.phase_history = [
PhaseResult(
phase="email_prepare",
success=False,
error_message="创建邮箱失败",
error_code="EMAIL_PROVIDER_RATE_LIMITED",
retryable=True,
next_action="switch_provider",
provider_backoff=EmailProviderBackoffState(
failures=1,
delay_seconds=30,
opened_until=1030.0,
retry_after=7,
last_error="请求失败: 429",
),
)
]
return RegistrationResult(
success=False,
error_message="创建邮箱失败: 请求失败: 429",
logs=[],
)
self.phase_history = [
PhaseResult(
phase="email_prepare",
success=True,
provider_backoff=EmailProviderBackoffState(),
)
]
return RegistrationResult(
success=True,
email="tester@example.com",
password="Pass12345",
account_id="acct-1",
workspace_id="ws-1",
access_token="access-token",
refresh_token="refresh-token",
id_token="id-token",
logs=[],
)
def save_to_database(self, result):
return True
def close(self):
return None
monkeypatch.setattr(registration_routes, "get_db", fake_get_db)
monkeypatch.setattr(registration_routes, "get_settings", lambda: DummySettings())
monkeypatch.setattr(registration_routes, "task_manager", DummyTaskManager())
monkeypatch.setattr(registration_routes, "RegistrationEngine", FakeRegistrationEngine)
monkeypatch.setattr(
registration_routes.EmailServiceFactory,
"create",
lambda service_type, config, name=None: SimpleNamespace(
service_type=service_type,
name=name or service_type.value,
config=config,
),
)
monkeypatch.setattr(registration_routes, "update_proxy_usage", lambda db, proxy_id: None)
registration_routes.email_service_circuit_breakers.clear()
registration_routes._run_sync_registration_task(
task_uuid=task_uuid,
email_service_type=EmailServiceType.DUCK_MAIL.value,
proxy=None,
email_service_config=None,
)
with manager.session_scope() as session:
task = session.query(RegistrationTask).filter(RegistrationTask.task_uuid == task_uuid).first()
services = session.query(EmailService).order_by(EmailService.priority.asc()).all()
task_status = task.status
task_email_service_id = task.email_service_id
primary_service_id = services[0].id
secondary_service_id = services[1].id
assert attempts == ["duck-primary", "duck-secondary"]
assert task_status == "completed"
assert task_email_service_id == secondary_service_id
assert registration_routes.email_service_circuit_breakers[primary_service_id].failures == 1
assert registration_routes.email_service_circuit_breakers[primary_service_id].delay_seconds == 30
def test_registration_task_enters_deep_cooldown_after_three_otp_timeouts(monkeypatch):
runtime_dir = Path("tests_runtime")
runtime_dir.mkdir(exist_ok=True)
db_path = runtime_dir / "registration_otp_timeout_backoff.db"
if db_path.exists():
db_path.unlink()
manager = DatabaseSessionManager(f"sqlite:///{db_path}")
Base.metadata.create_all(bind=manager.engine)
task_uuids = [
"task-otp-timeout-1",
"task-otp-timeout-2",
"task-otp-timeout-3",
]
with manager.session_scope() as session:
session.add_all([RegistrationTask(task_uuid=task_uuid, status="pending") for task_uuid in task_uuids])
session.add(
EmailService(
service_type="duck_mail",
name="duck-primary",
config={
"base_url": "https://mail-1.example.test",
"default_domain": "mail.example.test",
},
enabled=True,
priority=0,
)
)
@contextmanager
def fake_get_db():
session = manager.SessionLocal()
try:
yield session
finally:
session.close()
class DummySettings:
pass
current_time = {"value": 1000.0}
class FakeRegistrationEngine:
def __init__(self, email_service, proxy_url=None, callback_logger=None, task_uuid=None):
self.email_service = email_service
self.phase_history = []
def run(self):
self.phase_history = [
PhaseResult(
phase="email_prepare",
success=True,
provider_backoff=EmailProviderBackoffState(),
)
]
return RegistrationResult(
success=False,
error_message="等待验证码超时",
error_code=ERROR_OTP_TIMEOUT_SECONDARY,
logs=[],
)
def save_to_database(self, result):
return True
def close(self):
return None
monkeypatch.setattr(registration_routes, "get_db", fake_get_db)
monkeypatch.setattr(registration_routes, "get_settings", lambda: DummySettings())
monkeypatch.setattr(registration_routes, "task_manager", DummyTaskManager())
monkeypatch.setattr(registration_routes, "RegistrationEngine", FakeRegistrationEngine)
monkeypatch.setattr(
registration_routes.EmailServiceFactory,
"create",
lambda service_type, config, name=None: BackoffAwareEmailService(
service_type=service_type,
config=config,
name=name,
),
)
monkeypatch.setattr(registration_routes, "update_proxy_usage", lambda db, proxy_id: None)
monkeypatch.setattr(base_module.time, "time", lambda: current_time["value"])
registration_routes.email_service_circuit_breakers.clear()
with manager.session_scope() as session:
service_id = session.query(EmailService.id).filter(EmailService.name == "duck-primary").scalar()
expected_delays = [30, 60, 3600]
for attempt_index, task_uuid in enumerate(task_uuids, start=1):
registration_routes._run_sync_registration_task(
task_uuid=task_uuid,
email_service_type=EmailServiceType.DUCK_MAIL.value,
proxy=None,
email_service_config=None,
)
with manager.session_scope() as session:
task = session.query(RegistrationTask).filter(RegistrationTask.task_uuid == task_uuid).first()
assert task.status == "failed"
assert task.error_message == "等待验证码超时"
state = registration_routes.email_service_circuit_breakers[service_id]
assert state.failures == attempt_index
assert state.delay_seconds == expected_delays[attempt_index - 1]
assert state.opened_until == current_time["value"] + expected_delays[attempt_index - 1]
if attempt_index < len(task_uuids):
current_time["value"] = state.opened_until + 1
final_state = registration_routes.email_service_circuit_breakers[service_id]
assert final_state.delay_seconds == 3600
assert final_state.failures == 3
def test_registration_task_success_clears_email_service_backoff(monkeypatch):
runtime_dir = Path("tests_runtime")
runtime_dir.mkdir(exist_ok=True)
db_path = runtime_dir / "registration_success_clears_backoff.db"
if db_path.exists():
db_path.unlink()
manager = DatabaseSessionManager(f"sqlite:///{db_path}")
Base.metadata.create_all(bind=manager.engine)
task_uuid = "task-success-clears-backoff"
with manager.session_scope() as session:
session.add(RegistrationTask(task_uuid=task_uuid, status="pending"))
session.add(
EmailService(
service_type="duck_mail",
name="duck-primary",
config={
"base_url": "https://mail-1.example.test",
"default_domain": "mail.example.test",
},
enabled=True,
priority=0,
)
)
@contextmanager
def fake_get_db():
session = manager.SessionLocal()
try:
yield session
finally:
session.close()
class DummySettings:
pass
class FakeRegistrationEngine:
def __init__(self, email_service, proxy_url=None, callback_logger=None, task_uuid=None):
self.email_service = email_service
self.phase_history = [
PhaseResult(
phase="email_prepare",
success=True,
provider_backoff=EmailProviderBackoffState(),
)
]
def run(self):
return RegistrationResult(
success=True,
email="tester@example.com",
password="Pass12345",
account_id="acct-1",
workspace_id="ws-1",
access_token="access-token",
refresh_token="refresh-token",
id_token="id-token",
logs=[],
)
def save_to_database(self, result):
return True
def close(self):
return None
monkeypatch.setattr(registration_routes, "get_db", fake_get_db)
monkeypatch.setattr(registration_routes, "get_settings", lambda: DummySettings())
monkeypatch.setattr(registration_routes, "task_manager", DummyTaskManager())
monkeypatch.setattr(registration_routes, "RegistrationEngine", FakeRegistrationEngine)
monkeypatch.setattr(
registration_routes.EmailServiceFactory,
"create",
lambda service_type, config, name=None: BackoffAwareEmailService(
service_type=service_type,
config=config,
name=name,
),
)
monkeypatch.setattr(registration_routes, "update_proxy_usage", lambda db, proxy_id: None)
registration_routes.email_service_circuit_breakers.clear()
with manager.session_scope() as session:
service_id = session.query(EmailService.id).filter(EmailService.name == "duck-primary").scalar()
registration_routes.email_service_circuit_breakers[service_id] = EmailProviderBackoffState(
failures=2,
delay_seconds=60,
opened_until=9999.0,
last_error="等待验证码超时",
)
registration_routes._run_sync_registration_task(
task_uuid=task_uuid,
email_service_type=EmailServiceType.DUCK_MAIL.value,
proxy=None,
email_service_config=None,
)
assert service_id not in registration_routes.email_service_circuit_breakers

View File

@@ -0,0 +1,71 @@
import src.core.register as register_module
from src.core.register import (
ERROR_OTP_TIMEOUT_SECONDARY,
PhaseContext,
RegistrationEngine,
)
from src.services import EmailServiceType
class DummySettings:
openai_client_id = "client-id"
openai_auth_url = "https://auth.example.test"
openai_token_url = "https://token.example.test"
openai_redirect_uri = "https://callback.example.test"
openai_scope = "openid profile email"
class FakeEmailService:
def __init__(self, code):
self.service_type = EmailServiceType.TEMPMAIL
self.code = code
self.calls = []
def get_verification_code(self, **kwargs):
self.calls.append(kwargs)
return self.code
def _build_engine(monkeypatch, email_service):
monkeypatch.setattr(register_module, "get_settings", lambda: DummySettings())
return RegistrationEngine(email_service=email_service)
def test_phase_otp_secondary_uses_remaining_budget_from_start_timestamp(monkeypatch):
email_service = FakeEmailService(code="654321")
engine = _build_engine(monkeypatch, email_service)
engine.email = "tester@example.com"
engine.email_info = {"service_id": "svc-1"}
monkeypatch.setattr(register_module.time, "time", lambda: 120.0)
code, phase_result = engine._phase_otp_secondary(
PhaseContext(otp_sent_at=77.0),
started_at=100.0,
)
assert code == "654321"
assert phase_result.success is True
assert email_service.calls[0]["timeout"] == 100
assert email_service.calls[0]["otp_sent_at"] == 77.0
assert email_service.calls[0]["email"] == "tester@example.com"
assert email_service.calls[0]["email_id"] == "svc-1"
def test_phase_otp_secondary_returns_dedicated_timeout_error_code(monkeypatch):
email_service = FakeEmailService(code=None)
engine = _build_engine(monkeypatch, email_service)
engine.email = "tester@example.com"
engine.email_info = {"service_id": "svc-1"}
monkeypatch.setattr(register_module.time, "time", lambda: 120.0)
code, phase_result = engine._phase_otp_secondary(
PhaseContext(otp_sent_at=80.0),
started_at=100.0,
)
assert code is None
assert phase_result.success is False
assert phase_result.error_code == ERROR_OTP_TIMEOUT_SECONDARY
assert engine.phase_history[0].error_code == ERROR_OTP_TIMEOUT_SECONDARY

View File

@@ -42,7 +42,13 @@ def test_run_sync_registration_task_disables_bad_proxy_and_retries(monkeypatch,
monkeypatch.setattr( monkeypatch.setattr(
registration, registration,
"EmailServiceFactory", "EmailServiceFactory",
SimpleNamespace(create=lambda service_type, config: SimpleNamespace(service_type=service_type, config=config)), SimpleNamespace(
create=lambda service_type, config, name=None: SimpleNamespace(
service_type=service_type,
config=config,
name=name or service_type.value,
)
),
) )
attempted_proxies = [] attempted_proxies = []
@@ -73,6 +79,7 @@ def test_run_sync_registration_task_disables_bad_proxy_and_retries(monkeypatch,
return True return True
monkeypatch.setattr(registration, "RegistrationEngine", FakeRegistrationEngine) monkeypatch.setattr(registration, "RegistrationEngine", FakeRegistrationEngine)
registration.email_service_circuit_breakers.clear()
registration._run_sync_registration_task( registration._run_sync_registration_task(
task_uuid="task-proxy-failover", task_uuid="task-proxy-failover",

View File

@@ -15,6 +15,7 @@ def test_static_asset_version_is_non_empty_string():
def test_email_services_template_uses_versioned_static_assets(): def test_email_services_template_uses_versioned_static_assets():
template = Path("templates/email_services.html").read_text(encoding="utf-8") template = Path("templates/email_services.html").read_text(encoding="utf-8")
assert '/static/favicon.svg?v={{ static_version }}' in template
assert '/static/css/style.css?v={{ static_version }}' in template assert '/static/css/style.css?v={{ static_version }}' in template
assert '/static/js/utils.js?v={{ static_version }}' in template assert '/static/js/utils.js?v={{ static_version }}' in template
assert '/static/js/email_services.js?v={{ static_version }}' in template assert '/static/js/email_services.js?v={{ static_version }}' in template
@@ -23,6 +24,7 @@ def test_email_services_template_uses_versioned_static_assets():
def test_index_template_uses_versioned_static_assets(): def test_index_template_uses_versioned_static_assets():
template = Path("templates/index.html").read_text(encoding="utf-8") template = Path("templates/index.html").read_text(encoding="utf-8")
assert '/static/favicon.svg?v={{ static_version }}' in template
assert '/static/css/style.css?v={{ static_version }}' in template assert '/static/css/style.css?v={{ static_version }}' in template
assert '/static/js/utils.js?v={{ static_version }}' in template assert '/static/js/utils.js?v={{ static_version }}' in template
assert '/static/js/app.js?v={{ static_version }}' in template assert '/static/js/app.js?v={{ static_version }}' in template

143
tests/test_task_recovery.py Normal file
View File

@@ -0,0 +1,143 @@
from contextlib import contextmanager
import asyncio
from fastapi import WebSocketDisconnect
from src.database import crud
from src.database.models import Base, RegistrationTask
from src.database.session import DatabaseSessionManager
from src.web.routes import websocket as websocket_routes
from src.web.task_manager import TaskManager
def test_fail_incomplete_registration_tasks_marks_pending_and_running_failed(tmp_path):
db_path = tmp_path / "recovery.db"
manager = DatabaseSessionManager(f"sqlite:///{db_path}")
Base.metadata.create_all(bind=manager.engine)
with manager.session_scope() as session:
session.add_all([
RegistrationTask(task_uuid="task-pending", status="pending"),
RegistrationTask(task_uuid="task-running", status="running", logs="[01:00:00] still running"),
RegistrationTask(task_uuid="task-done", status="completed"),
])
with manager.session_scope() as session:
cleaned = crud.fail_incomplete_registration_tasks(
session,
"服务启动时检测到未完成的历史任务,已标记失败,请重新发起。"
)
assert cleaned == ["task-pending", "task-running"]
with manager.session_scope() as session:
pending_task = crud.get_registration_task_by_uuid(session, "task-pending")
running_task = crud.get_registration_task_by_uuid(session, "task-running")
done_task = crud.get_registration_task_by_uuid(session, "task-done")
assert pending_task.status == "failed"
assert running_task.status == "failed"
assert pending_task.error_message == "服务启动时检测到未完成的历史任务,已标记失败,请重新发起。"
assert running_task.completed_at is not None
assert "[系统] 服务启动时检测到未完成的历史任务,已标记失败,请重新发起。" in running_task.logs
assert done_task.status == "completed"
def test_restore_task_snapshot_loads_status_and_logs_from_database(monkeypatch, tmp_path):
db_path = tmp_path / "websocket.db"
manager = DatabaseSessionManager(f"sqlite:///{db_path}")
Base.metadata.create_all(bind=manager.engine)
with manager.session_scope() as session:
session.add(
RegistrationTask(
task_uuid="task-websocket",
status="failed",
logs="[01:00:00] step 1\n[01:00:01] step 2",
result={"email": "tester@example.com"},
error_message="boom"
)
)
@contextmanager
def fake_get_db():
session = manager.SessionLocal()
try:
yield session
finally:
session.close()
monkeypatch.setattr(websocket_routes, "get_db", fake_get_db)
status, logs = websocket_routes._restore_task_snapshot("task-websocket")
assert status == {
"status": "failed",
"email": "tester@example.com",
"error": "boom",
}
assert logs == ["[01:00:00] step 1", "[01:00:01] step 2"]
def test_sync_task_state_prefers_longer_persisted_log_history():
manager = TaskManager()
task_uuid = "task-sync"
manager.sync_task_state(task_uuid, status={"status": "running"}, logs=["a", "b"])
manager.sync_task_state(task_uuid, logs=["a"])
assert manager.get_status(task_uuid) == {"status": "running"}
assert manager.get_logs(task_uuid) == ["a", "b"]
def test_register_websocket_returns_snapshot_and_keeps_live_cursor():
manager = TaskManager()
task_uuid = "task-live"
websocket = object()
manager.sync_task_state(task_uuid, status={"status": "running"}, logs=["log-1", "log-2"])
history_logs = manager.register_websocket(task_uuid, websocket)
assert history_logs == ["log-1", "log-2"]
assert manager.get_unsent_logs(task_uuid, websocket) == []
manager.add_log(task_uuid, "log-3")
assert manager.get_unsent_logs(task_uuid, websocket) == ["log-3"]
class _FakeWebSocket:
def __init__(self):
self.messages = []
self.accepted = False
async def accept(self):
self.accepted = True
async def send_json(self, payload):
self.messages.append(payload)
async def receive_json(self):
raise WebSocketDisconnect()
def test_batch_websocket_replays_history_logs_from_registration_snapshot(monkeypatch):
manager = TaskManager()
batch_id = "batch-history"
websocket = _FakeWebSocket()
manager.init_batch(batch_id, total=2)
manager.add_batch_log(batch_id, "[01:00:00] first")
manager.add_batch_log(batch_id, "[01:00:01] second")
monkeypatch.setattr(websocket_routes, "task_manager", manager)
asyncio.run(websocket_routes.batch_websocket(websocket, batch_id))
assert websocket.accepted is True
assert websocket.messages[0]["type"] == "status"
assert [msg["message"] for msg in websocket.messages[1:]] == [
"[01:00:00] first",
"[01:00:01] second",
]

View File

@@ -0,0 +1,98 @@
from datetime import datetime, timezone
from src.services.tempmail import TempmailService
class FakeResponse:
def __init__(self, payload, status_code=200):
self._payload = payload
self.status_code = status_code
def json(self):
return self._payload
class FakeHTTPClient:
def __init__(self, responses):
self.responses = list(responses)
self.calls = []
def get(self, url, **kwargs):
self.calls.append({"url": url, "kwargs": kwargs})
if not self.responses:
raise AssertionError(f"未准备响应: GET {url}")
return self.responses.pop(0)
def _to_timestamp(value: str) -> float:
return datetime.fromisoformat(value.replace("Z", "+00:00")).astimezone(timezone.utc).timestamp()
def test_get_verification_code_ignores_messages_received_before_otp_sent_at():
service = TempmailService({"base_url": "https://api.tempmail.test"})
service._email_cache["tester@example.com"] = {"token": "token-1"}
service.http_client = FakeHTTPClient([
FakeResponse(
{
"emails": [
{
"id": "old-mail",
"received_at": "2026-03-23T10:00:00Z",
"from": "noreply@openai.com",
"subject": "Old code",
"body": "111111",
},
{
"id": "new-mail",
"received_at": "2026-03-23T10:00:05Z",
"from": "noreply@openai.com",
"subject": "New code",
"body": "222222",
},
]
}
)
])
code = service.get_verification_code(
email="tester@example.com",
timeout=1,
otp_sent_at=_to_timestamp("2026-03-23T10:00:02Z"),
)
assert code == "222222"
def test_get_verification_code_requires_received_at_when_otp_sent_at_is_present():
service = TempmailService({"base_url": "https://api.tempmail.test"})
service._email_cache["tester@example.com"] = {"token": "token-1"}
service.http_client = FakeHTTPClient([
FakeResponse(
{
"emails": [
{
"id": "legacy-mail",
"date": "2026-03-23T10:00:06Z",
"from": "noreply@openai.com",
"subject": "Legacy code",
"body": "333333",
},
{
"id": "received-mail",
"received_at": "2026-03-23T10:00:07Z",
"from": "noreply@openai.com",
"subject": "Received code",
"body": "444444",
},
]
}
)
])
code = service.get_verification_code(
email="tester@example.com",
timeout=1,
otp_sent_at=_to_timestamp("2026-03-23T10:00:05Z"),
)
assert code == "444444"