fix: restore protocol baseline, resolve 403/400 registration errors, and fully remove deprecated playwright dependency

This commit is contained in:
Mison
2026-03-23 18:49:51 +08:00
parent a7a6391f0d
commit de2c4aa7ab
60 changed files with 4126 additions and 2834 deletions

View File

@@ -21,13 +21,13 @@ jobs:
include:
- os: windows-latest
artifact_name: codex-register.exe
asset_name: codex-register-v2-windows-x64.exe
asset_name: codex-register-windows-x64.exe
- os: ubuntu-latest
artifact_name: codex-register
asset_name: codex-register-v2-linux-x64
asset_name: codex-register-linux-x64
- os: macos-latest
artifact_name: codex-register
asset_name: codex-register-v2-macos-arm64
asset_name: codex-register-macos-arm64
steps:
- name: 检出代码
@@ -74,19 +74,19 @@ jobs:
- name: 整理文件并打包 zip
run: |
mkdir -p release
# download-artifact@v4 将每个 artifact 放在 dist/<asset_name>/ 子目录下
# 遍历子目录,用目录名作为平台标识(即 matrix.asset_name
for artifact_dir in dist/*/; do
platform=$(basename "$artifact_dir")
# 找到该目录下的二进制文件(只有一个)
binary=$(find "$artifact_dir" -maxdepth 1 -type f | head -n1)
if [ -z "$binary" ]; then
echo "警告:$artifact_dir 下没有找到文件,跳过"
continue
fi
# 为每个平台二进制文件打包成 zip
find dist/ -type f | while read f; do
name=$(basename "$f")
# 根据文件名确定平台标识
case "$name" in
*windows*) platform=$(echo "$name" | sed 's/\.[^.]*$//') ;;
*linux*) platform="$name" ;;
*macos*) platform="$name" ;;
*) platform="$name" ;;
esac
tmpdir="tmp_${platform}"
mkdir -p "$tmpdir"
cp "$binary" "$tmpdir/"
cp "$f" "$tmpdir/"
cp README.md "$tmpdir/README.md"
cp .env.example "$tmpdir/.env.example"
[ -f LICENSE ] && cp LICENSE "$tmpdir/LICENSE" || true
@@ -103,14 +103,14 @@ jobs:
files: release/*
generate_release_notes: true
body: |
## OpenAI 账号管理系统 v2
## OpenAI 自动注册系统 v2
### 下载说明
| 平台 | 文件 |
|------|------|
| Windows x64 | `codex-register-v2-windows-x64.exe` |
| Linux x64 | `codex-register-v2-linux-x64` |
| macOS ARM64 | `codex-register-v2-macos-arm64` |
| Windows x64 | `codex-register-windows-x64.exe` |
| Linux x64 | `codex-register-linux-x64` |
| macOS ARM64 | `codex-register-macos-arm64` |
### 使用方法
```bash

View File

@@ -64,6 +64,5 @@ jobs:
push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
platforms: linux/amd64,linux/arm64
cache-from: type=gha
cache-to: type=gha,mode=max

View File

@@ -2,6 +2,8 @@
管理 OpenAI 账号的 Web UI 系统,支持多种邮箱服务、并发批量注册、代理管理和账号管理。
# 官方拉闸了,改变了授权流程,各位自行研究吧
> ⚠️ **免责声明**:本工具仅供学习和研究使用,使用本工具产生的一切后果由使用者自行承担。请遵守相关服务的使用条款,不要用于任何违法或不当用途。 如有侵权,请及时联系,会及时删除。
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](LICENSE)
@@ -144,9 +146,9 @@ docker-compose up -d
```bash
docker run -d \
-p 15555:15555 \
-p 1455:1455 \
-e WEBUI_HOST=0.0.0.0 \
-e WEBUI_PORT=15555 \
-e WEBUI_PORT=1455 \
-e WEBUI_ACCESS_PASSWORD=your_secure_password \
-v $(pwd)/data:/app/data \
--name codex-register \
@@ -155,7 +157,7 @@ docker run -d \
环境变量说明:
- `WEBUI_HOST`: 监听的主机地址 (默认 `0.0.0.0`)
- `WEBUI_PORT`: 监听的端口 (默认 `15555`)
- `WEBUI_PORT`: 监听的端口 (默认 `1455`)
- `WEBUI_ACCESS_PASSWORD`: 设置 Web UI 的访问密码
- `DEBUG`: 设为 `1``true` 开启调试模式
- `LOG_LEVEL`: 日志级别,如 `info`, `debug`
@@ -371,8 +373,7 @@ docker-compose build --no-cache
- CPA / Sub2API / Team Manager 上传始终直连,不走代理;其中 CPA 可选把账号记录的代理写入 auth file 的 `proxy_url`
- 注册时自动随机生成用户名和生日(年龄范围 18-45 岁)
- 支付链接生成使用账号 access_token 鉴权,走全局代理配置
- 无痕浏览器优先使用 playwright注入 cookie 直达支付页);未安装时降级为系统 Chrome/Edge 无痕模式
- 安装完整支付功能:`pip install ".[payment]" && playwright install chromium`(可选)
- 无痕打开支付页默认调用系统 Chrome/Edge 的隐私模式
- 订阅状态自动检测调用 `chatgpt.com/backend-api/me`,走全局代理
- 批量注册并发数上限为 50线程池大小已相应调整

View File

@@ -1,5 +1,3 @@
version: '3.8'
services:
webui:
build: .

View File

@@ -23,9 +23,6 @@ dev = [
"pytest>=7.0.0",
"httpx>=0.24.0",
]
payment = [
"playwright>=1.40.0",
]
[project.scripts]
codex-webui = "webui:main"

View File

@@ -11,5 +11,3 @@ python-multipart>=0.0.6
sqlalchemy>=2.0.0
aiosqlite>=0.19.0
psycopg[binary]>=3.1.18
# 可选:无痕打开支付页需要 playwrightpip install playwright && playwright install chromium
# playwright>=1.40.0

View File

@@ -56,7 +56,7 @@ APP_DESCRIPTION = "自动注册 OpenAI/Codex CLI 账号的系统"
OAUTH_CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann"
OAUTH_AUTH_URL = "https://auth.openai.com/oauth/authorize"
OAUTH_TOKEN_URL = "https://auth.openai.com/oauth/token"
OAUTH_REDIRECT_URI = "http://localhost:15555/auth/callback"
OAUTH_REDIRECT_URI = "http://localhost:1455/auth/callback"
OAUTH_SCOPE = "openid email profile offline_access"
# OpenAI API 端点
@@ -65,20 +65,15 @@ OPENAI_API_ENDPOINTS = {
"signup": "https://auth.openai.com/api/accounts/authorize/continue",
"register": "https://auth.openai.com/api/accounts/user/register",
"send_otp": "https://auth.openai.com/api/accounts/email-otp/send",
"passwordless_send_otp": "https://auth.openai.com/api/accounts/passwordless/send-otp",
"validate_otp": "https://auth.openai.com/api/accounts/email-otp/validate",
"create_account": "https://auth.openai.com/api/accounts/create_account",
"add_phone": "https://auth.openai.com/add-phone",
"select_workspace": "https://auth.openai.com/api/accounts/workspace/select",
"send_passwordless_otp": "https://auth.openai.com/api/accounts/passwordless/send-otp",
"password_verify": "https://auth.openai.com/api/accounts/password/verify",
}
# OpenAI 页面类型(用于判断账号状态)
OPENAI_PAGE_TYPES = {
"LOGIN_PASSWORD": "login_password",
"EMAIL_OTP_VERIFICATION": "email_otp_verification", # 已注册账号,需要 OTP 验证
"PASSWORD_REGISTRATION": "create_account_password", # 新账号,需要设置密码
"PASSWORD_REGISTRATION": "password", # 新账号,需要设置密码
}
# ============================================================================
@@ -272,7 +267,7 @@ DEFAULT_SETTINGS = [
("registration.timeout", "120", "超时时间(秒)", "registration"),
("registration.default_password_length", "12", "默认密码长度", "registration"),
("webui.host", "0.0.0.0", "Web UI 监听主机", "webui"),
("webui.port", "8000", "Web UI 监听端口", "webui"),
("webui.port", "15555", "Web UI 监听端口", "webui"),
("webui.debug", "true", "调试模式", "webui"),
]
@@ -383,8 +378,20 @@ MICROSOFT_TOKEN_ENDPOINTS = {
}
# IMAP 服务器配置
OUTLOOK_IMAP_SERVER = "outlook.live.com"
OUTLOOK_IMAP_PORT = 993
OUTLOOK_IMAP_SERVERS = {
"OLD": "outlook.office365.com", # 旧版 IMAP
"NEW": "outlook.live.com", # 新版 IMAP
}
# Microsoft OAuth2 ScopeIMAP_NEW
OUTLOOK_IMAP_SCOPE = "https://outlook.office.com/IMAP.AccessAsUser.All offline_access"
# Microsoft OAuth2 Scopes
MICROSOFT_SCOPES = {
# 旧版 IMAP 不需要特定 scope
"IMAP_OLD": "",
# 新版 IMAP 需要的 scope
"IMAP_NEW": "https://outlook.office.com/IMAP.AccessAsUser.All offline_access",
# Graph API 需要的 scope
"GRAPH_API": "https://graph.microsoft.com/.default",
}
# Outlook 提供者默认优先级
OUTLOOK_PROVIDER_PRIORITY = ["imap_new", "imap_old", "graph_api"]

View File

@@ -76,7 +76,7 @@ SETTING_DEFINITIONS: Dict[str, SettingDefinition] = {
),
"webui_port": SettingDefinition(
db_key="webui.port",
default_value=8000,
default_value=15555,
category=SettingCategory.WEBUI,
description="Web UI 监听端口"
),
@@ -136,7 +136,7 @@ SETTING_DEFINITIONS: Dict[str, SettingDefinition] = {
),
"openai_redirect_uri": SettingDefinition(
db_key="openai.redirect_uri",
default_value="http://localhost:15555/auth/callback",
default_value="http://localhost:1455/auth/callback",
category=SettingCategory.OPENAI,
description="OpenAI OAuth 回调 URI"
),
@@ -358,6 +358,12 @@ SETTING_DEFINITIONS: Dict[str, SettingDefinition] = {
),
# Outlook 配置
"outlook_provider_priority": SettingDefinition(
db_key="outlook.provider_priority",
default_value=["imap_old", "imap_new", "graph_api"],
category=SettingCategory.EMAIL,
description="Outlook 提供者优先级"
),
"outlook_health_failure_threshold": SettingDefinition(
db_key="outlook.health_failure_threshold",
default_value=5,
@@ -376,12 +382,6 @@ SETTING_DEFINITIONS: Dict[str, SettingDefinition] = {
category=SettingCategory.EMAIL,
description="Outlook OAuth 默认 Client ID"
),
"outlook_use_idle": SettingDefinition(
db_key="outlook.use_idle",
default_value=True,
category=SettingCategory.EMAIL,
description="使用 IMAP IDLE 替代轮询获取验证码(降低延迟,默认开启)"
),
}
# 属性名到数据库键名的映射(用于向后兼容)
@@ -407,9 +407,9 @@ SETTING_TYPES: Dict[str, Type] = {
"cpa_enabled": bool,
"email_code_timeout": int,
"email_code_poll_interval": int,
"outlook_provider_priority": list,
"outlook_health_failure_threshold": int,
"outlook_health_disable_duration": int,
"outlook_use_idle": bool,
}
# 需要作为 SecretStr 处理的字段
@@ -609,7 +609,7 @@ class Settings(BaseModel):
# Web UI 配置
webui_host: str = "0.0.0.0"
webui_port: int = 8000
webui_port: int = 15555
webui_secret_key: SecretStr = SecretStr("your-secret-key-change-in-production")
webui_access_password: SecretStr = SecretStr("admin123")
@@ -622,7 +622,7 @@ class Settings(BaseModel):
openai_client_id: str = "app_EMoamEEZ73f0CkXaXp7hrann"
openai_auth_url: str = "https://auth.openai.com/oauth/authorize"
openai_token_url: str = "https://auth.openai.com/oauth/token"
openai_redirect_uri: str = "http://localhost:15555/auth/callback"
openai_redirect_uri: str = "http://localhost:1455/auth/callback"
openai_scope: str = "openid email profile offline_access"
# 代理配置
@@ -694,10 +694,10 @@ class Settings(BaseModel):
email_code_poll_interval: int = 3
# Outlook 配置
outlook_provider_priority: List[str] = ["imap_old", "imap_new", "graph_api"]
outlook_health_failure_threshold: int = 5
outlook_health_disable_duration: int = 60
outlook_default_client_id: str = "24d9a0ed-8787-4584-883c-2fd79308940a"
outlook_use_idle: bool = True
# 全局配置实例

View File

@@ -12,7 +12,6 @@ from .http_client import (
create_openai_client,
)
from .register import RegistrationEngine, RegistrationResult
from .login import LoginEngine
from .utils import setup_logging, get_data_dir
__all__ = [
@@ -28,7 +27,6 @@ __all__ = [
'create_openai_client',
'RegistrationEngine',
'RegistrationResult',
'LoginEngine',
'setup_logging',
'get_data_dir',
]

View File

@@ -282,13 +282,13 @@ class OpenAIHTTPClient(HTTPClient):
loc = loc_match.group(1) if loc_match else None
# 检查是否支持
if loc in ["CN", "HK", "MO"]:
if loc in ["CN", "HK", "MO", "TW"]:
return False, loc
return True, loc
except Exception as e:
logger.error(f"检查 IP 地理位置失败: {e}")
return False, str(e)
return False, None
def send_openai_request(
self,

View File

@@ -81,7 +81,10 @@ class LoginEngine(RegistrationEngine):
}
if sen_token:
sentinel = f'{{"p": "", "t": "", "c": "{sen_token}", "id": "{did}", "flow": "authorize_continue"}}'
sentinel = (
f'{{"p": "", "t": "", "c": "{sen_token}", '
f'"id": "{did}", "flow": "authorize_continue"}}'
)
headers["openai-sentinel-token"] = sentinel
response = self.session.post(
@@ -101,7 +104,6 @@ class LoginEngine(RegistrationEngine):
def _send_verification_code_passwordless(self) -> bool:
"""发送验证码"""
try:
# 记录发送时间戳
self._otp_sent_at = time.time()
response = self.session.post(
OPENAI_API_ENDPOINTS["passwordless_send_otp"],
@@ -281,7 +283,6 @@ class LoginEngine(RegistrationEngine):
self._log("开始注册流程")
self._log("=" * 60)
# 1. 检查 IP 地理位置
self._log("1. 检查 IP 地理位置...")
ip_ok, location = self._check_ip_location()
if not ip_ok:
@@ -291,7 +292,6 @@ class LoginEngine(RegistrationEngine):
self._log(f"IP 位置: {location}")
# 2. 创建邮箱
self._log("2. 创建邮箱...")
if not self._create_email():
result.error_message = "创建邮箱失败"
@@ -299,26 +299,22 @@ class LoginEngine(RegistrationEngine):
result.email = self.email
# 3. 初始化会话
self._log("3. 初始化会话...")
if not self._init_session():
result.error_message = "初始化会话失败"
return result
# 4. 开始 OAuth 流程
self._log("4. 开始 OAuth 授权流程...")
if not self._start_oauth():
result.error_message = "开始 OAuth 流程失败"
return result
# 5. 获取 Device ID
self._log("5. 获取 Device ID...")
did = self._get_device_id()
if not did:
result.error_message = "获取 Device ID 失败"
return result
# 6. 检查 Sentinel 拦截
self._log("6. 检查 Sentinel 拦截...")
sen_token = self._check_sentinel(did)
if sen_token:
@@ -326,32 +322,28 @@ class LoginEngine(RegistrationEngine):
else:
self._log("Sentinel 检查失败或未启用", "warning")
# 7. 提交注册表单 + 解析响应判断账号状态
self._log("7. 提交注册表单...")
signup_result = self._submit_signup_form(did, sen_token)
if not signup_result.success:
result.error_message = f"提交注册表单失败: {signup_result.error_message}"
return result
# 8. 检测到已注册账号 → 直接终止任务
if self._is_existing_account:
self._log(f"8. 邮箱 {self.email} 在 OpenAI 已注册,跳过注册流程", "warning")
result.error_message = f"邮箱 {self.email} 已在 OpenAI 注册"
return result
else:
self._log("8. 注册密码...")
password_ok, password = self._register_password()
if not password_ok:
result.error_message = "注册密码失败"
return result
# 9. 发送验证码
self._log("9. 发送验证码...")
if not self._send_verification_code():
result.error_message = "发送验证码失败"
return result
# 10. 获取验证码(超时后重发一次)
self._log("10. 等待验证码...")
code = self._get_verification_code()
if not code:
@@ -362,13 +354,11 @@ class LoginEngine(RegistrationEngine):
result.error_message = "获取验证码失败"
return result
# 11. 验证验证码
self._log("11. 验证验证码...")
if not self._validate_verification_code(code):
result.error_message = "验证验证码失败"
return result
# 12. 创建用户账户
self._log("12. 创建用户账户...")
if not self._create_user_account():
result.error_message = "创建用户账户失败"
@@ -404,7 +394,6 @@ class LoginEngine(RegistrationEngine):
result.error_message = "验证验证码失败"
return result
# 13. 获取 Workspace ID
self._log("17. 获取 Workspace ID...")
workspace_id = self._get_workspace_id()
if not workspace_id:
@@ -413,45 +402,37 @@ class LoginEngine(RegistrationEngine):
result.workspace_id = workspace_id
# 14. 选择 Workspace
self._log("18. 选择 Workspace...")
continue_url = self._select_workspace(workspace_id)
if not continue_url:
result.error_message = "选择 Workspace 失败"
return result
# 15. 跟随重定向链
self._log("19. 跟随重定向链...")
callback_url = self._follow_redirects(continue_url)
if not callback_url:
result.error_message = "跟随重定向链失败"
return result
# 16. 处理 OAuth 回调
self._log("20. 处理 OAuth 回调...")
token_info = self._handle_oauth_callback(callback_url)
if not token_info:
result.error_message = "处理 OAuth 回调失败"
return result
# 提取账户信息
result.account_id = token_info.get("account_id", "")
result.access_token = token_info.get("access_token", "")
result.refresh_token = token_info.get("refresh_token", "")
result.id_token = token_info.get("id_token", "")
result.password = self.password or "" # 保存密码(已注册账号为空)
# 设置来源标记
result.password = self.password or ""
result.source = "register"
# 尝试获取 session_token 从 cookie
session_cookie = self.session.cookies.get("__Secure-next-auth.session-token")
if session_cookie:
self.session_token = session_cookie
result.session_token = session_cookie
self._log(f"获取到 Session Token")
self._log("获取到 Session Token")
# 17. 完成
self._log("=" * 60)
self._log("注册成功!")
self._log(f"邮箱: {result.email}")

File diff suppressed because it is too large Load Diff

View File

@@ -15,7 +15,6 @@ import base64
import re
import uuid
from datetime import datetime, timedelta
from html.parser import HTMLParser
from typing import Any, Dict, List, Optional, Union, Callable
from pathlib import Path
@@ -569,49 +568,3 @@ class Timer:
if self.start_time is not None:
return time.time() - self.start_time
return 0.0
class BootstrapExtractor(HTMLParser):
"""内部解析器,专门提取 id="client-bootstrap" 的 script 内容"""
def __init__(self):
super().__init__()
self._in_target = False
self.json_text = None
def handle_starttag(self, tag, attrs):
if tag == 'script':
attrs_dict = dict(attrs)
if attrs_dict.get('id') == 'client-bootstrap':
self._in_target = True
def handle_endtag(self, tag):
if tag == 'script' and self._in_target:
self._in_target = False
def handle_data(self, data):
if self._in_target and self.json_text is None:
self.json_text = data.strip()
def extract_client_bootstrap_json(html: str):
"""
从 HTML 字符串中提取 id="client-bootstrap" 的 script 标签内容并解析为 JSON。
返回 dict 或 None未找到或解析失败
"""
parser = BootstrapExtractor()
parser.feed(html)
if parser.json_text:
try:
return json.loads(parser.json_text)
except json.JSONDecodeError:
return None
return None
def base64_payload_decode(payload_b64):
import base64
import json as json_module
padding = 4 - (len(payload_b64) % 4)
if padding != 4:
payload_b64 += '=' * padding
# 解码Base64URL 使用 - 和 _ 替代 + 和 /
payload_bytes = base64.urlsafe_b64decode(payload_b64)
return json_module.loads(payload_bytes)

View File

@@ -36,6 +36,7 @@ def create_account(
access_token: Optional[str] = None,
refresh_token: Optional[str] = None,
id_token: Optional[str] = None,
cookies: Optional[str] = None,
proxy_used: Optional[str] = None,
expires_at: Optional['datetime'] = None,
extra_data: Optional[Dict[str, Any]] = None,
@@ -62,6 +63,7 @@ def create_account(
access_token=access_token,
refresh_token=refresh_token,
id_token=id_token,
cookies=cookies,
proxy_used=proxy_used,
expires_at=expires_at,
extra_data=extra_data or {},
@@ -134,7 +136,6 @@ def update_account(
}
kwargs.setdefault("token_sync_status", _default_token_sync_status(persisted_token_values))
kwargs["token_sync_updated_at"] = datetime.utcnow()
for key, value in kwargs.items():
if hasattr(db_account, key) and value is not None:
setattr(db_account, key, value)
@@ -353,6 +354,34 @@ def delete_registration_task(db: Session, task_uuid: str) -> bool:
return True
def fail_incomplete_registration_tasks(db: Session, error_message: str) -> List[str]:
"""将服务重启后遗留的未完成任务标记为失败"""
tasks = db.query(RegistrationTask).filter(
RegistrationTask.status.in_(("pending", "running"))
).all()
if not tasks:
return []
now = datetime.utcnow()
cleaned_task_ids: List[str] = []
cleanup_log = f"[系统] {error_message}"
for task in tasks:
task.status = "failed"
task.error_message = error_message
task.completed_at = now
if task.logs:
if cleanup_log not in task.logs:
task.logs = f"{task.logs}\n{cleanup_log}"
else:
task.logs = cleanup_log
cleaned_task_ids.append(task.task_uuid)
db.commit()
return cleaned_task_ids
# 为 API 路由添加别名
get_account = get_account_by_id
get_registration_task = get_registration_task_by_uuid
@@ -503,13 +532,6 @@ def delete_proxy(db: Session, proxy_id: int) -> bool:
return True
def delete_disabled_proxies(db: Session) -> int:
"""删除所有已禁用代理"""
deleted = db.query(Proxy).filter(Proxy.enabled == False).delete(synchronize_session=False)
db.commit()
return deleted
def update_proxy_last_used(db: Session, proxy_id: int) -> bool:
"""更新代理最后使用时间"""
db_proxy = get_proxy_by_id(db, proxy_id)

View File

@@ -38,7 +38,9 @@ from .outlook.base import (
from .outlook.account import OutlookAccount
from .outlook.providers import (
OutlookProvider,
IMAPOldProvider,
IMAPNewProvider,
GraphAPIProvider,
)
__all__ = [
@@ -65,5 +67,7 @@ __all__ = [
'ProviderStatus',
'OutlookAccount',
'OutlookProvider',
'IMAPOldProvider',
'IMAPNewProvider',
'GraphAPIProvider',
]

View File

@@ -5,6 +5,8 @@
import abc
import logging
import time
from dataclasses import dataclass
from typing import Optional, Dict, Any, List
from enum import Enum
@@ -13,12 +15,109 @@ from ..config.constants import EmailServiceType
logger = logging.getLogger(__name__)
EMAIL_PROVIDER_BACKOFF_BASE_SECONDS = 30
EMAIL_PROVIDER_BACKOFF_MAX_SECONDS = 3600
OTP_TIMEOUT_ERROR_PREFIX = "OTP_TIMEOUT"
@dataclass(frozen=True)
class EmailProviderBackoffState:
"""邮箱供应商退避状态"""
failures: int = 0
delay_seconds: int = 0
opened_until: float = 0.0
retry_after: Optional[int] = None
last_error: Optional[str] = None
def is_open(self, now: Optional[float] = None) -> bool:
now_ts = now if now is not None else time.time()
return self.opened_until > now_ts
def to_dict(self) -> Dict[str, Any]:
return {
"failures": self.failures,
"delay_seconds": self.delay_seconds,
"opened_until": self.opened_until,
"retry_after": self.retry_after,
"last_error": self.last_error,
}
def calculate_adaptive_backoff_delay(
failures: int,
base_delay: int = EMAIL_PROVIDER_BACKOFF_BASE_SECONDS,
max_delay: int = EMAIL_PROVIDER_BACKOFF_MAX_SECONDS,
is_timeout: bool = False,
) -> int:
"""根据连续失败次数计算指数退避时长"""
normalized_failures = max(0, failures)
if is_timeout and normalized_failures >= 3:
return max_delay
exponent = max(0, normalized_failures - 1)
return min(base_delay * (2 ** exponent), max_delay)
def is_otp_timeout_error(error: object) -> bool:
"""识别 OTP 超时类错误码。"""
if error is None:
return False
if isinstance(error, OTPTimeoutEmailServiceError):
return True
error_code = getattr(error, "error_code", "")
if isinstance(error_code, str) and error_code.startswith(OTP_TIMEOUT_ERROR_PREFIX):
return True
return False
def apply_adaptive_backoff(
current_state: Optional[EmailProviderBackoffState],
error: "EmailServiceError",
now: Optional[float] = None,
) -> EmailProviderBackoffState:
"""在限流场景下推进邮箱供应商退避状态"""
state = current_state or EmailProviderBackoffState()
now_ts = now if now is not None else time.time()
next_failures = state.failures + 1
delay_seconds = calculate_adaptive_backoff_delay(
next_failures,
is_timeout=is_otp_timeout_error(error),
)
return EmailProviderBackoffState(
failures=next_failures,
delay_seconds=delay_seconds,
opened_until=now_ts + delay_seconds,
retry_after=getattr(error, "retry_after", None),
last_error=str(error),
)
def reset_adaptive_backoff() -> EmailProviderBackoffState:
"""重置邮箱供应商退避状态"""
return EmailProviderBackoffState()
class EmailServiceError(Exception):
"""邮箱服务异常"""
pass
class RateLimitedEmailServiceError(EmailServiceError):
"""邮箱服务被限流"""
def __init__(self, message: str, retry_after: Optional[int] = None):
super().__init__(message)
self.retry_after = retry_after
class OTPTimeoutEmailServiceError(EmailServiceError):
"""OTP 验证码等待超时。"""
def __init__(self, message: str, error_code: str = OTP_TIMEOUT_ERROR_PREFIX):
super().__init__(message)
self.error_code = error_code
class EmailServiceStatus(Enum):
"""邮箱服务状态"""
HEALTHY = "healthy"
@@ -45,6 +144,7 @@ class BaseEmailService(abc.ABC):
self.name = name or f"{service_type.value}_service"
self._status = EmailServiceStatus.HEALTHY
self._last_error = None
self._provider_backoff = reset_adaptive_backoff()
@property
def status(self) -> EmailServiceStatus:
@@ -56,6 +156,15 @@ class BaseEmailService(abc.ABC):
"""获取最后一次错误信息"""
return self._last_error
@property
def provider_backoff_state(self) -> EmailProviderBackoffState:
"""获取当前邮箱供应商退避状态"""
return self._provider_backoff
def apply_provider_backoff_state(self, state: Optional[EmailProviderBackoffState]) -> None:
"""注入外部持久化的邮箱供应商退避状态"""
self._provider_backoff = state or reset_adaptive_backoff()
@abc.abstractmethod
def create_email(self, config: Dict[str, Any] = None) -> Dict[str, Any]:
"""
@@ -92,7 +201,7 @@ class BaseEmailService(abc.ABC):
email_id: 邮箱服务中的 ID如果需要
timeout: 超时时间(秒)
pattern: 验证码正则表达式
otp_sent_at: OTP 发送时间戳,用于过滤旧邮件
otp_sent_at: OTP 发送时间戳,只允许使用严格晚于该锚点的邮件
Returns:
验证码字符串,如果超时或未找到返回 None
@@ -282,6 +391,14 @@ class BaseEmailService(abc.ABC):
if success:
self._status = EmailServiceStatus.HEALTHY
self._last_error = None
self._provider_backoff = reset_adaptive_backoff()
else:
if isinstance(error, RateLimitedEmailServiceError) or is_otp_timeout_error(error):
self._status = EmailServiceStatus.UNAVAILABLE
self._provider_backoff = apply_adaptive_backoff(
self._provider_backoff,
error,
)
else:
self._status = EmailServiceStatus.DEGRADED
if error:

View File

@@ -12,7 +12,7 @@ from datetime import datetime, timezone
from html import unescape
from typing import Any, Dict, List, Optional
from .base import BaseEmailService, EmailServiceError, EmailServiceType
from .base import BaseEmailService, EmailServiceError, EmailServiceType, RateLimitedEmailServiceError
from ..config.constants import OTP_CODE_PATTERN
from ..core.http_client import HTTPClient, RequestConfig
@@ -102,7 +102,19 @@ class DuckMailService(BaseEmailService):
error_message = f"{error_message} - {error_payload}"
except Exception:
error_message = f"{error_message} - {response.text[:200]}"
raise EmailServiceError(error_message)
retry_after = None
if response.status_code == 429:
retry_after_header = response.headers.get("Retry-After")
if retry_after_header:
try:
retry_after = max(1, int(retry_after_header))
except ValueError:
retry_after = None
error = RateLimitedEmailServiceError(error_message, retry_after=retry_after)
else:
error = EmailServiceError(error_message)
self.update_status(False, error)
raise error
try:
return response.json()

View File

@@ -10,7 +10,7 @@ import random
import string
from typing import Optional, Dict, Any, List
from .base import BaseEmailService, EmailServiceError, EmailServiceType
from .base import BaseEmailService, EmailServiceError, EmailServiceType, RateLimitedEmailServiceError
from ..core.http_client import HTTPClient, RequestConfig
from ..config.constants import OTP_CODE_PATTERN
@@ -96,8 +96,19 @@ class FreemailService(BaseEmailService):
error_msg = f"{error_msg} - {error_data}"
except Exception:
error_msg = f"{error_msg} - {response.text[:200]}"
self.update_status(False, EmailServiceError(error_msg))
raise EmailServiceError(error_msg)
retry_after = None
if response.status_code == 429:
retry_after_header = response.headers.get("Retry-After")
if retry_after_header:
try:
retry_after = max(1, int(retry_after_header))
except ValueError:
retry_after = None
error = RateLimitedEmailServiceError(error_msg, retry_after=retry_after)
else:
error = EmailServiceError(error_msg)
self.update_status(False, error)
raise error
try:
return response.json()

View File

@@ -10,7 +10,7 @@ import logging
from typing import Optional, Dict, Any, List
from urllib.parse import urljoin
from .base import BaseEmailService, EmailServiceError, EmailServiceType
from .base import BaseEmailService, EmailServiceError, EmailServiceType, RateLimitedEmailServiceError
from ..core.http_client import HTTPClient, RequestConfig
from ..config.constants import OTP_CODE_PATTERN
@@ -148,8 +148,20 @@ class MeoMailEmailService(BaseEmailService):
except:
error_msg = f"{error_msg} - {response.text[:200]}"
self.update_status(False, EmailServiceError(error_msg))
raise EmailServiceError(error_msg)
retry_after = None
if response.status_code == 429:
retry_after_header = response.headers.get("Retry-After")
if retry_after_header:
try:
retry_after = max(1, int(retry_after_header))
except ValueError:
retry_after = None
error = RateLimitedEmailServiceError(error_msg, retry_after=retry_after)
else:
error = EmailServiceError(error_msg)
self.update_status(False, error)
raise error
# 解析响应
try:

View File

@@ -3,7 +3,7 @@ Outlook 账户数据类
"""
from dataclasses import dataclass
from typing import Dict, Any
from typing import Dict, Any, Optional
@dataclass
@@ -16,6 +16,7 @@ class OutlookAccount:
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "OutlookAccount":
"""从配置创建账户"""
return cls(
email=config.get("email", ""),
password=config.get("password", ""),
@@ -24,12 +25,15 @@ class OutlookAccount:
)
def has_oauth(self) -> bool:
"""是否支持 OAuth2"""
return bool(self.client_id and self.refresh_token)
def validate(self) -> bool:
"""验证账户信息是否有效"""
return bool(self.email and self.password) or self.has_oauth()
def to_dict(self, include_sensitive: bool = False) -> Dict[str, Any]:
"""转换为字典"""
result = {
"email": self.email,
"has_oauth": self.has_oauth(),
@@ -43,4 +47,5 @@ class OutlookAccount:
return result
def __str__(self) -> str:
"""字符串表示"""
return f"OutlookAccount({self.email})"

View File

@@ -1,5 +1,6 @@
"""
Outlook 邮箱服务基础定义
Outlook 服务基础定义
包含枚举类型和数据类
"""
from dataclasses import dataclass, field
@@ -9,38 +10,49 @@ from typing import Optional, Dict, Any, List
class ProviderType(str, Enum):
"""Outlook 提供者类型(仅 IMAP_NEW"""
IMAP_NEW = "imap_new"
"""Outlook 提供者类型"""
IMAP_OLD = "imap_old" # 旧版 IMAP (outlook.office365.com)
IMAP_NEW = "imap_new" # 新版 IMAP (outlook.live.com)
GRAPH_API = "graph_api" # Microsoft Graph API
class TokenEndpoint(str, Enum):
"""Token 端点"""
LIVE = "https://login.live.com/oauth20_token.srf"
CONSUMERS = "https://login.microsoftonline.com/consumers/oauth2/v2.0/token"
COMMON = "https://login.microsoftonline.com/common/oauth2/v2.0/token"
class IMAPServer(str, Enum):
"""IMAP 服务器"""
OLD = "outlook.office365.com"
NEW = "outlook.live.com"
class ProviderStatus(str, Enum):
"""提供者状态"""
HEALTHY = "healthy"
DEGRADED = "degraded"
DISABLED = "disabled"
HEALTHY = "healthy" # 健康
DEGRADED = "degraded" # 降级
DISABLED = "disabled" # 禁用
@dataclass
class EmailMessage:
"""邮件消息数据类"""
id: str
subject: str
sender: str
recipients: List[str] = field(default_factory=list)
body: str = ""
body_preview: str = ""
received_at: Optional[datetime] = None
received_timestamp: int = 0
is_read: bool = False
has_attachments: bool = False
raw_data: Optional[bytes] = None
id: str # 消息 ID
subject: str # 主题
sender: str # 发件人
recipients: List[str] = field(default_factory=list) # 收件人列表
body: str = "" # 正文内容
body_preview: str = "" # 正文预览
received_at: Optional[datetime] = None # 接收时间
received_timestamp: int = 0 # 接收时间戳
is_read: bool = False # 是否已读
has_attachments: bool = False # 是否有附件
raw_data: Optional[bytes] = None # 原始数据(用于调试)
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
return {
"id": self.id,
"subject": self.subject,
@@ -59,17 +71,19 @@ class EmailMessage:
class TokenInfo:
"""Token 信息数据类"""
access_token: str
expires_at: float
expires_at: float # 过期时间戳
token_type: str = "Bearer"
scope: str = ""
refresh_token: Optional[str] = None
def is_expired(self, buffer_seconds: int = 120) -> bool:
"""检查 Token 是否已过期"""
import time
return time.time() >= (self.expires_at - buffer_seconds)
@classmethod
def from_response(cls, data: Dict[str, Any], scope: str = "") -> "TokenInfo":
"""从 API 响应创建"""
import time
return cls(
access_token=data.get("access_token", ""),
@@ -85,42 +99,49 @@ class ProviderHealth:
"""提供者健康状态"""
provider_type: ProviderType
status: ProviderStatus = ProviderStatus.HEALTHY
failure_count: int = 0
last_success: Optional[datetime] = None
last_failure: Optional[datetime] = None
last_error: str = ""
disabled_until: Optional[datetime] = None
failure_count: int = 0 # 连续失败次数
last_success: Optional[datetime] = None # 最后成功时间
last_failure: Optional[datetime] = None # 最后失败时间
last_error: str = "" # 最后错误信息
disabled_until: Optional[datetime] = None # 禁用截止时间
def record_success(self):
"""记录成功"""
self.status = ProviderStatus.HEALTHY
self.failure_count = 0
self.last_success = datetime.now()
self.disabled_until = None
def record_failure(self, error: str):
"""记录失败"""
self.failure_count += 1
self.last_failure = datetime.now()
self.last_error = error
def should_disable(self, threshold: int = 3) -> bool:
"""判断是否应该禁用"""
return self.failure_count >= threshold
def is_disabled(self) -> bool:
"""检查是否被禁用"""
if self.disabled_until and datetime.now() < self.disabled_until:
return True
return False
def disable(self, duration_seconds: int = 300):
"""禁用提供者"""
from datetime import timedelta
self.status = ProviderStatus.DISABLED
self.disabled_until = datetime.now() + timedelta(seconds=duration_seconds)
def enable(self):
"""启用提供者"""
self.status = ProviderStatus.HEALTHY
self.disabled_until = None
self.failure_count = 0
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
return {
"provider_type": self.provider_type.value,
"status": self.status.value,

View File

@@ -1,13 +1,15 @@
"""
健康检查管理(简化版,单 Provider
健康检查和故障切换管理
"""
import logging
import threading
import time
from datetime import datetime, timedelta
from typing import Any, Dict, Optional
from typing import Dict, List, Optional, Any
from .base import ProviderHealth, ProviderStatus, ProviderType
from .base import ProviderType, ProviderHealth, ProviderStatus
from .providers.base import OutlookProvider
logger = logging.getLogger(__name__)
@@ -15,49 +17,296 @@ logger = logging.getLogger(__name__)
class HealthChecker:
"""
单 Provider 健康检查器
跟踪 IMAP_NEW 的健康状态
健康检查管理
跟踪各提供者的健康状态,管理故障切换
"""
def __init__(
self,
failure_threshold: int = 3,
disable_duration: int = 300,
recovery_check_interval: int = 60,
):
"""
初始化健康检查器
Args:
failure_threshold: 连续失败次数阈值,超过后禁用
disable_duration: 禁用时长(秒)
recovery_check_interval: 恢复检查间隔(秒)
"""
self.failure_threshold = failure_threshold
self.disable_duration = disable_duration
self._health = ProviderHealth(provider_type=ProviderType.IMAP_NEW)
self.recovery_check_interval = recovery_check_interval
# 提供者健康状态: ProviderType -> ProviderHealth
self._health_status: Dict[ProviderType, ProviderHealth] = {}
self._lock = threading.Lock()
def record_success(self):
with self._lock:
self._health.record_success()
# 初始化所有提供者的健康状态
for provider_type in ProviderType:
self._health_status[provider_type] = ProviderHealth(
provider_type=provider_type
)
def record_failure(self, error: str):
def get_health(self, provider_type: ProviderType) -> ProviderHealth:
"""获取提供者的健康状态"""
with self._lock:
self._health.record_failure(error)
if self._health.should_disable(self.failure_threshold):
self._health.disable(self.disable_duration)
return self._health_status.get(provider_type, ProviderHealth(provider_type=provider_type))
def record_success(self, provider_type: ProviderType):
"""记录成功操作"""
with self._lock:
health = self._health_status.get(provider_type)
if health:
health.record_success()
logger.debug(f"{provider_type.value} 记录成功")
def record_failure(self, provider_type: ProviderType, error: str):
"""记录失败操作"""
with self._lock:
health = self._health_status.get(provider_type)
if health:
health.record_failure(error)
# 检查是否需要禁用
if health.should_disable(self.failure_threshold):
health.disable(self.disable_duration)
logger.warning(
f"IMAP_NEW 已禁用 {self.disable_duration}s原因: {error}"
f"{provider_type.value} 已禁用 {self.disable_duration} 秒,"
f"原因: {error}"
)
def is_available(self) -> bool:
with self._lock:
if self._health.is_disabled():
remaining = (
(self._health.disabled_until - datetime.now()).total_seconds()
if self._health.disabled_until
else 0
def is_available(self, provider_type: ProviderType) -> bool:
"""
检查提供者是否可用
Args:
provider_type: 提供者类型
Returns:
是否可用
"""
health = self.get_health(provider_type)
# 检查是否被禁用
if health.is_disabled():
remaining = (health.disabled_until - datetime.now()).total_seconds()
logger.debug(
f"{provider_type.value} 已被禁用,剩余 {int(remaining)}"
)
logger.debug(f"IMAP_NEW 已被禁用,剩余 {int(remaining)}s")
return False
return self._health.status != ProviderStatus.DISABLED
def reset(self):
return health.status != ProviderStatus.DISABLED
def get_available_providers(
self,
priority_order: Optional[List[ProviderType]] = None,
) -> List[ProviderType]:
"""
获取可用的提供者列表
Args:
priority_order: 优先级顺序,默认为 [IMAP_NEW, IMAP_OLD, GRAPH_API]
Returns:
可用的提供者列表
"""
if priority_order is None:
priority_order = [
ProviderType.IMAP_NEW,
ProviderType.IMAP_OLD,
ProviderType.GRAPH_API,
]
available = []
for provider_type in priority_order:
if self.is_available(provider_type):
available.append(provider_type)
return available
def get_next_available_provider(
self,
priority_order: Optional[List[ProviderType]] = None,
) -> Optional[ProviderType]:
"""
获取下一个可用的提供者
Args:
priority_order: 优先级顺序
Returns:
可用的提供者类型,如果没有返回 None
"""
available = self.get_available_providers(priority_order)
return available[0] if available else None
def force_disable(self, provider_type: ProviderType, duration: Optional[int] = None):
"""
强制禁用提供者
Args:
provider_type: 提供者类型
duration: 禁用时长(秒),默认使用配置值
"""
with self._lock:
self._health = ProviderHealth(provider_type=ProviderType.IMAP_NEW)
health = self._health_status.get(provider_type)
if health:
health.disable(duration or self.disable_duration)
logger.warning(f"{provider_type.value} 已强制禁用")
def force_enable(self, provider_type: ProviderType):
"""
强制启用提供者
Args:
provider_type: 提供者类型
"""
with self._lock:
health = self._health_status.get(provider_type)
if health:
health.enable()
logger.info(f"{provider_type.value} 已启用")
def get_all_health_status(self) -> Dict[str, Any]:
"""
获取所有提供者的健康状态
Returns:
健康状态字典
"""
with self._lock:
return {
provider_type.value: health.to_dict()
for provider_type, health in self._health_status.items()
}
def check_and_recover(self):
"""
检查并恢复被禁用的提供者
如果禁用时间已过,自动恢复提供者
"""
with self._lock:
for provider_type, health in self._health_status.items():
if health.is_disabled():
# 检查是否可以恢复
if health.disabled_until and datetime.now() >= health.disabled_until:
health.enable()
logger.info(f"{provider_type.value} 已自动恢复")
def reset_all(self):
"""重置所有提供者的健康状态"""
with self._lock:
for provider_type in ProviderType:
self._health_status[provider_type] = ProviderHealth(
provider_type=provider_type
)
logger.info("已重置所有提供者的健康状态")
class FailoverManager:
"""
故障切换管理器
管理提供者之间的自动切换
"""
def __init__(
self,
health_checker: HealthChecker,
priority_order: Optional[List[ProviderType]] = None,
):
"""
初始化故障切换管理器
Args:
health_checker: 健康检查器
priority_order: 提供者优先级顺序
"""
self.health_checker = health_checker
self.priority_order = priority_order or [
ProviderType.IMAP_NEW,
ProviderType.IMAP_OLD,
ProviderType.GRAPH_API,
]
# 当前使用的提供者索引
self._current_index = 0
self._lock = threading.Lock()
def get_current_provider(self) -> Optional[ProviderType]:
"""
获取当前提供者
Returns:
当前提供者类型,如果没有可用的返回 None
"""
available = self.health_checker.get_available_providers(self.priority_order)
if not available:
return None
with self._lock:
# 尝试使用当前索引
if self._current_index < len(available):
return available[self._current_index]
return available[0]
def switch_to_next(self) -> Optional[ProviderType]:
"""
切换到下一个提供者
Returns:
下一个提供者类型,如果没有可用的返回 None
"""
available = self.health_checker.get_available_providers(self.priority_order)
if not available:
return None
with self._lock:
self._current_index = (self._current_index + 1) % len(available)
next_provider = available[self._current_index]
logger.info(f"切换到提供者: {next_provider.value}")
return next_provider
def on_provider_success(self, provider_type: ProviderType):
"""
提供者成功时调用
Args:
provider_type: 提供者类型
"""
self.health_checker.record_success(provider_type)
# 重置索引到成功的提供者
with self._lock:
available = self.health_checker.get_available_providers(self.priority_order)
if provider_type in available:
self._current_index = available.index(provider_type)
def on_provider_failure(self, provider_type: ProviderType, error: str):
"""
提供者失败时调用
Args:
provider_type: 提供者类型
error: 错误信息
"""
self.health_checker.record_failure(provider_type, error)
def get_status(self) -> Dict[str, Any]:
with self._lock:
return self._health.to_dict()
"""
获取故障切换状态
Returns:
状态字典
"""
current = self.get_current_provider()
return {
"current_provider": current.value if current else None,
"priority_order": [p.value for p in self.priority_order],
"available_providers": [
p.value for p in self.health_checker.get_available_providers(self.priority_order)
],
"health_status": self.health_checker.get_all_health_status(),
}

View File

@@ -3,18 +3,27 @@ Outlook 提供者模块
"""
from .base import OutlookProvider, ProviderConfig
from .imap_old import IMAPOldProvider
from .imap_new import IMAPNewProvider
from .graph_api import GraphAPIProvider
__all__ = [
'OutlookProvider',
'ProviderConfig',
'IMAPOldProvider',
'IMAPNewProvider',
'GraphAPIProvider',
]
# 提供者注册表
PROVIDER_REGISTRY = {
'imap_old': IMAPOldProvider,
'imap_new': IMAPNewProvider,
'graph_api': GraphAPIProvider,
}
def get_provider_class(provider_type: str):
"""获取提供者类"""
return PROVIDER_REGISTRY.get(provider_type)

View File

@@ -5,7 +5,7 @@ Outlook 提供者抽象基类
import abc
import logging
from dataclasses import dataclass
from typing import List, Optional
from typing import Dict, Any, List, Optional
from ..base import ProviderType, EmailMessage, ProviderHealth, ProviderStatus
from ..account import OutlookAccount
@@ -18,36 +18,56 @@ logger = logging.getLogger(__name__)
class ProviderConfig:
"""提供者配置"""
timeout: int = 30
max_retries: int = 3
proxy_url: Optional[str] = None
service_id: Optional[int] = None
# 健康检查配置
health_failure_threshold: int = 3
health_disable_duration: int = 300
health_disable_duration: int = 300 # 秒
class OutlookProvider(abc.ABC):
"""Outlook 提供者抽象基类"""
"""
Outlook 提供者抽象基类
定义所有提供者必须实现的接口
"""
def __init__(
self,
account: OutlookAccount,
config: Optional[ProviderConfig] = None,
):
"""
初始化提供者
Args:
account: Outlook 账户
config: 提供者配置
"""
self.account = account
self.config = config or ProviderConfig()
self._health = ProviderHealth(provider_type=ProviderType.IMAP_NEW)
# 健康状态
self._health = ProviderHealth(provider_type=self.provider_type)
# 连接状态
self._connected = False
self._last_error: Optional[str] = None
@property
@abc.abstractmethod
def provider_type(self) -> ProviderType:
return ProviderType.IMAP_NEW
"""获取提供者类型"""
pass
@property
def health(self) -> ProviderHealth:
"""获取健康状态"""
return self._health
@property
def is_healthy(self) -> bool:
"""检查是否健康"""
return (
self._health.status == ProviderStatus.HEALTHY
and not self._health.is_disabled()
@@ -55,14 +75,22 @@ class OutlookProvider(abc.ABC):
@property
def is_connected(self) -> bool:
"""检查是否已连接"""
return self._connected
@abc.abstractmethod
def connect(self) -> bool:
"""
连接到服务
Returns:
是否连接成功
"""
pass
@abc.abstractmethod
def disconnect(self):
"""断开连接"""
pass
@abc.abstractmethod
@@ -71,44 +99,81 @@ class OutlookProvider(abc.ABC):
count: int = 20,
only_unseen: bool = True,
) -> List[EmailMessage]:
"""
获取最近的邮件
Args:
count: 获取数量
only_unseen: 是否只获取未读
Returns:
邮件列表
"""
pass
@abc.abstractmethod
def test_connection(self) -> bool:
"""
测试连接是否正常
Returns:
连接是否正常
"""
pass
def wait_for_new_email_idle(self, timeout: int = 25) -> bool:
"""IMAP IDLE默认不支持子类可覆盖"""
return False
def record_success(self):
"""记录成功操作"""
self._health.record_success()
self._last_error = None
logger.debug(f"[{self.account.email}] {self.provider_type.value} 操作成功")
def record_failure(self, error: str):
"""记录失败操作"""
self._health.record_failure(error)
self._last_error = error
# 检查是否需要禁用
if self._health.should_disable(self.config.health_failure_threshold):
self._health.disable(self.config.health_disable_duration)
logger.warning(
f"[{self.account.email}] IMAP_NEW 已禁用 "
f"{self.config.health_disable_duration}s,原因: {error}"
f"[{self.account.email}] {self.provider_type.value} 已禁用 "
f"{self.config.health_disable_duration},原因: {error}"
)
else:
logger.warning(
f"[{self.account.email}] {self.provider_type.value} 操作失败 "
f"({self._health.failure_count}/{self.config.health_failure_threshold}): {error}"
)
def check_health(self) -> bool:
"""
检查健康状态
Returns:
是否健康可用
"""
# 检查是否被禁用
if self._health.is_disabled():
logger.debug(
f"[{self.account.email}] {self.provider_type.value} 已被禁用,"
f"将在 {self._health.disabled_until} 后恢复"
)
return False
return self._health.status in (ProviderStatus.HEALTHY, ProviderStatus.DEGRADED)
def __enter__(self):
"""上下文管理器入口"""
self.connect()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""上下文管理器出口"""
self.disconnect()
return False
def __str__(self) -> str:
"""字符串表示"""
return f"{self.__class__.__name__}({self.account.email})"
def __repr__(self) -> str:

View File

@@ -0,0 +1,250 @@
"""
Graph API 提供者
使用 Microsoft Graph REST API
"""
import json
import logging
from typing import List, Optional
from datetime import datetime
from curl_cffi import requests as _requests
from ..base import ProviderType, EmailMessage
from ..account import OutlookAccount
from ..token_manager import TokenManager
from .base import OutlookProvider, ProviderConfig
logger = logging.getLogger(__name__)
class GraphAPIProvider(OutlookProvider):
"""
Graph API 提供者
使用 Microsoft Graph REST API 获取邮件
需要 graph.microsoft.com/.default scope
"""
# Graph API 端点
GRAPH_API_BASE = "https://graph.microsoft.com/v1.0"
MESSAGES_ENDPOINT = "/me/mailFolders/inbox/messages"
@property
def provider_type(self) -> ProviderType:
return ProviderType.GRAPH_API
def __init__(
self,
account: OutlookAccount,
config: Optional[ProviderConfig] = None,
):
super().__init__(account, config)
# Token 管理器
self._token_manager: Optional[TokenManager] = None
# 注意Graph API 必须使用 OAuth2
if not account.has_oauth():
logger.warning(
f"[{self.account.email}] Graph API 提供者需要 OAuth2 配置 "
f"(client_id + refresh_token)"
)
def connect(self) -> bool:
"""
验证连接(获取 Token
Returns:
是否连接成功
"""
if not self.account.has_oauth():
error = "Graph API 需要 OAuth2 配置"
self.record_failure(error)
logger.error(f"[{self.account.email}] {error}")
return False
if not self._token_manager:
self._token_manager = TokenManager(
self.account,
ProviderType.GRAPH_API,
self.config.proxy_url,
self.config.timeout,
)
# 尝试获取 Token
token = self._token_manager.get_access_token()
if token:
self._connected = True
self.record_success()
logger.info(f"[{self.account.email}] Graph API 连接成功")
return True
return False
def disconnect(self):
"""断开连接(清除状态)"""
self._connected = False
def get_recent_emails(
self,
count: int = 20,
only_unseen: bool = True,
) -> List[EmailMessage]:
"""
获取最近的邮件
Args:
count: 获取数量
only_unseen: 是否只获取未读
Returns:
邮件列表
"""
if not self._connected:
if not self.connect():
return []
try:
# 获取 Access Token
token = self._token_manager.get_access_token()
if not token:
self.record_failure("无法获取 Access Token")
return []
# 构建 API 请求
url = f"{self.GRAPH_API_BASE}{self.MESSAGES_ENDPOINT}"
params = {
"$top": count,
"$select": "id,subject,from,toRecipients,receivedDateTime,isRead,hasAttachments,bodyPreview,body",
"$orderby": "receivedDateTime desc",
}
# 只获取未读邮件
if only_unseen:
params["$filter"] = "isRead eq false"
# 构建代理配置
proxies = None
if self.config.proxy_url:
proxies = {"http": self.config.proxy_url, "https": self.config.proxy_url}
# 发送请求curl_cffi 自动对 params 进行 URL 编码)
resp = _requests.get(
url,
params=params,
headers={
"Authorization": f"Bearer {token}",
"Accept": "application/json",
"Prefer": "outlook.body-content-type='text'",
},
proxies=proxies,
timeout=self.config.timeout,
impersonate="chrome110",
)
if resp.status_code == 401:
# Token 无 Graph 权限client_id 未授权),清除缓存但不记录健康失败
# 避免因权限不足导致健康检查器禁用该提供者,影响其他账户
if self._token_manager:
self._token_manager.clear_cache()
self._connected = False
logger.warning(f"[{self.account.email}] Graph API 返回 401client_id 可能无 Graph 权限,跳过")
return []
if resp.status_code != 200:
error_body = resp.text[:200]
self.record_failure(f"HTTP {resp.status_code}: {error_body}")
logger.error(f"[{self.account.email}] Graph API 请求失败: HTTP {resp.status_code}")
return []
data = resp.json()
# 解析邮件
messages = data.get("value", [])
emails = []
for msg in messages:
try:
email_msg = self._parse_graph_message(msg)
if email_msg:
emails.append(email_msg)
except Exception as e:
logger.warning(f"[{self.account.email}] 解析 Graph API 邮件失败: {e}")
self.record_success()
return emails
except Exception as e:
self.record_failure(str(e))
logger.error(f"[{self.account.email}] Graph API 获取邮件失败: {e}")
return []
def _parse_graph_message(self, msg: dict) -> Optional[EmailMessage]:
"""
解析 Graph API 消息
Args:
msg: Graph API 消息对象
Returns:
EmailMessage 对象
"""
# 解析发件人
from_info = msg.get("from", {})
sender_info = from_info.get("emailAddress", {})
sender = sender_info.get("address", "")
# 解析收件人
recipients = []
for recipient in msg.get("toRecipients", []):
addr_info = recipient.get("emailAddress", {})
addr = addr_info.get("address", "")
if addr:
recipients.append(addr)
# 解析日期
received_at = None
received_timestamp = 0
try:
date_str = msg.get("receivedDateTime", "")
if date_str:
# ISO 8601 格式
received_at = datetime.fromisoformat(date_str.replace("Z", "+00:00"))
received_timestamp = int(received_at.timestamp())
except Exception:
pass
# 获取正文
body_info = msg.get("body", {})
body = body_info.get("content", "")
body_preview = msg.get("bodyPreview", "")
return EmailMessage(
id=msg.get("id", ""),
subject=msg.get("subject", ""),
sender=sender,
recipients=recipients,
body=body,
body_preview=body_preview,
received_at=received_at,
received_timestamp=received_timestamp,
is_read=msg.get("isRead", False),
has_attachments=msg.get("hasAttachments", False),
)
def test_connection(self) -> bool:
"""
测试 Graph API 连接
Returns:
连接是否正常
"""
try:
# 尝试获取一封邮件来测试连接
emails = self.get_recent_emails(count=1, only_unseen=False)
return True
except Exception as e:
logger.warning(f"[{self.account.email}] Graph API 连接测试失败: {e}")
return False

View File

@@ -1,261 +1,200 @@
"""
新版 IMAP 提供者
使用 outlook.live.com:993 + consumers Token 端点
引入进程级 IMAPConnectionPool 连接复用和 IMAP IDLE
使用 outlook.live.com 服务器和 login.microsoftonline.com/consumers Token 端点
"""
import email
import imaplib
import logging
import select
import time
import threading
from datetime import datetime, timedelta, timezone
from email.header import decode_header
from email.utils import parsedate_to_datetime
from typing import Dict, List, Optional
from typing import List, Optional
from ..base import EmailMessage
from ..base import ProviderType, EmailMessage
from ..account import OutlookAccount
from ..token_manager import TokenManager
from .base import OutlookProvider, ProviderConfig
from .imap_old import IMAPOldProvider
logger = logging.getLogger(__name__)
class IMAPConnectionPool:
"""进程级 IMAP 连接池,按 email 复用 IMAP4_SSL 连接"""
IMAP_HOST = "outlook.live.com"
IMAP_PORT = 993
def __init__(self):
self._connections: Dict[str, imaplib.IMAP4_SSL] = {}
self._lock = threading.Lock()
def get_connection(
self,
email_addr: str,
token: str,
timeout: int = 30,
) -> imaplib.IMAP4_SSL:
"""获取或新建 IMAP 连接"""
# 先在锁内检查现有连接
with self._lock:
conn = self._connections.get(email_addr)
if conn:
try:
conn.noop()
return conn
except Exception:
self._close_one(email_addr)
# 标记为「建连中」,防止重复建连
self._connections[email_addr] = None
# 锁外建立新连接(耗时操作不持锁)
try:
new_conn = imaplib.IMAP4_SSL(self.IMAP_HOST, self.IMAP_PORT, timeout=timeout)
auth_str = f"user={email_addr}\x01auth=Bearer {token}\x01\x01"
new_conn.authenticate("XOAUTH2", lambda _: auth_str.encode("utf-8"))
except Exception:
with self._lock:
# 建连失败,清除占位
if self._connections.get(email_addr) is None:
del self._connections[email_addr]
raise
with self._lock:
self._connections[email_addr] = new_conn
logger.debug(f"[{email_addr}] IMAP 新连接已建立")
return new_conn
def invalidate(self, email_addr: str):
"""废弃连接(认证失败或连接异常时调用)"""
with self._lock:
self._close_one(email_addr)
def _close_one(self, email_addr: str):
conn = self._connections.pop(email_addr, None)
if conn:
try:
conn.logout()
except Exception:
pass
# 模块级单例连接池
_imap_pool = IMAPConnectionPool()
class IMAPNewProvider(OutlookProvider):
"""
新版 IMAP 提供者
通过连接池复用连接,支持 IMAP IDLE
使用 outlook.live.com:993 和 login.microsoftonline.com/consumers Token 端点
需要 IMAP.AccessAsUser.All scope
"""
# IMAP 服务器配置
IMAP_HOST = "outlook.live.com"
IMAP_PORT = 993
@property
def provider_type(self) -> ProviderType:
return ProviderType.IMAP_NEW
def __init__(
self,
account: OutlookAccount,
config: Optional[ProviderConfig] = None,
):
super().__init__(account, config)
self._conn: Optional[imaplib.IMAP4_SSL] = None
self._token_manager: Optional[TokenManager] = None
self._idle_tag_counter = 0
self._idle_tag_lock = threading.Lock()
# IMAP 连接
self._conn: Optional[imaplib.IMAP4_SSL] = None
# Token 管理器
self._token_manager: Optional[TokenManager] = None
# 注意:新版 IMAP 必须使用 OAuth2
if not account.has_oauth():
logger.warning(
f"[{self.account.email}] IMAP_NEW 需要 OAuth2 配置 (client_id + refresh_token)"
f"[{self.account.email}] 新版 IMAP 提供者需要 OAuth2 配置 "
f"(client_id + refresh_token)"
)
def _get_token_manager(self) -> TokenManager:
if not self._token_manager:
self._token_manager = TokenManager(
self.account,
proxy_url=self.config.proxy_url,
timeout=self.config.timeout,
service_id=self.config.service_id,
)
return self._token_manager
def _next_idle_tag(self) -> str:
"""生成唯一 IDLE tag避免使用私有 _new_tag"""
with self._idle_tag_lock:
self._idle_tag_counter += 1
return f"IDLE{self._idle_tag_counter:04d}"
def connect(self) -> bool:
"""从连接池获取连接"""
"""
连接到 IMAP 服务器
Returns:
是否连接成功
"""
if self._connected and self._conn:
try:
self._conn.noop()
return True
except Exception:
self.disconnect()
# 新版 IMAP 必须使用 OAuth2无 OAuth 时静默跳过,不记录健康失败
if not self.account.has_oauth():
logger.debug(f"[{self.account.email}] 跳过 IMAP_NEW无 OAuth")
return False
try:
tm = self._get_token_manager()
token = tm.get_access_token()
logger.debug(f"[{self.account.email}] 正在连接 IMAP ({self.IMAP_HOST})...")
# 创建连接
self._conn = imaplib.IMAP4_SSL(
self.IMAP_HOST,
self.IMAP_PORT,
timeout=self.config.timeout,
)
# XOAUTH2 认证
if self._authenticate_xoauth2():
self._connected = True
self.record_success()
logger.info(f"[{self.account.email}] 新版 IMAP 连接成功 (XOAUTH2)")
return True
return False
except Exception as e:
self.disconnect()
self.record_failure(str(e))
logger.error(f"[{self.account.email}] 新版 IMAP 连接失败: {e}")
return False
def _authenticate_xoauth2(self) -> bool:
"""
使用 XOAUTH2 认证
Returns:
是否认证成功
"""
if not self._token_manager:
self._token_manager = TokenManager(
self.account,
ProviderType.IMAP_NEW,
self.config.proxy_url,
self.config.timeout,
)
# 获取 Access Token
token = self._token_manager.get_access_token()
if not token:
logger.error(f"[{self.account.email}] 获取 IMAP Token 失败")
return False
self._conn = _imap_pool.get_connection(
self.account.email, token, self.config.timeout
)
self._connected = True
self.record_success()
logger.debug(f"[{self.account.email}] IMAP 连接就绪(连接池)")
return True
except imaplib.IMAP4.error as e:
err = str(e)
# Token 失效时强制刷新并重试一次
if "AUTHENTICATE" in err or "invalid" in err.lower():
logger.warning(f"[{self.account.email}] XOAUTH2 认证失败,尝试刷新 Token")
_imap_pool.invalidate(self.account.email)
try:
tm = self._get_token_manager()
token = tm.get_access_token(force_refresh=True)
if token:
self._conn = _imap_pool.get_connection(
self.account.email, token, self.config.timeout
)
self._connected = True
self.record_success()
# 构建 XOAUTH2 认证字符串
auth_string = f"user={self.account.email}\x01auth=Bearer {token}\x01\x01"
self._conn.authenticate("XOAUTH2", lambda _: auth_string.encode("utf-8"))
return True
except Exception as retry_e:
self.record_failure(str(retry_e))
logger.error(f"[{self.account.email}] Token 刷新后重连失败: {retry_e}")
else:
self.record_failure(err)
logger.error(f"[{self.account.email}] IMAP 连接失败: {e}")
self._connected = False
self._conn = None
return False
except Exception as e:
self.record_failure(str(e))
logger.error(f"[{self.account.email}] IMAP 连接失败: {e}")
self._connected = False
self._conn = None
logger.error(f"[{self.account.email}] XOAUTH2 认证异常: {e}")
# 清除缓存的 Token
self._token_manager.clear_cache()
return False
def disconnect(self):
"""归还连接池(不 logout保持复用"""
self._connected = False
"""断开 IMAP 连接"""
if self._conn:
try:
self._conn.close()
except Exception:
pass
try:
self._conn.logout()
except Exception:
pass
self._conn = None
self._connected = False
def get_recent_emails(
self,
count: int = 20,
only_unseen: bool = True,
since_minutes: Optional[int] = None,
folders: Optional[List[str]] = None,
) -> List[EmailMessage]:
"""
获取最近的邮件,支持多文件夹搜索(合并去重)。
获取最近的邮件
搜索策略:
- since_minutes 指定时:用 SINCE 日期 + ALL 搜索最近N分钟内的邮件不受已读/未读限制)
- only_unseen=True 且未指定 since_minutes搜索 UNSEEN
- only_unseen=False 且未指定 since_minutes搜索全部取最近 count 封)
- folders 默认为 ["INBOX"],可传入多个文件夹(如 ["INBOX", "Junk Email"]
Args:
count: 获取数量
only_unseen: 是否只获取未读
Returns:
邮件列表
"""
if not self._connected:
if not self.connect():
return []
if folders is None:
folders = ["INBOX"]
all_emails: List[EmailMessage] = []
seen_ids: set = set()
for folder in folders:
try:
status, _ = self._conn.select(folder, readonly=True)
if status != "OK":
logger.debug(f"[{self.account.email}] 文件夹 {folder} 不存在或无法访问,跳过")
continue
# 选择收件箱
self._conn.select("INBOX", readonly=True)
if since_minutes is not None:
since_dt = datetime.now(timezone.utc) - timedelta(minutes=since_minutes)
since_str = since_dt.strftime("%d-%b-%Y")
status, data = self._conn.search(None, f"SINCE {since_str}")
elif only_unseen:
status, data = self._conn.search(None, "UNSEEN")
else:
status, data = self._conn.search(None, "ALL")
# 搜索邮件
flag = "UNSEEN" if only_unseen else "ALL"
status, data = self._conn.search(None, flag)
if status != "OK" or not data or not data[0]:
continue
return []
# 获取最新的邮件 ID
ids = data[0].split()
recent_ids = ids[-count:][::-1] # 取最新的 count 封,倒序(最新在前)
recent_ids = ids[-count:][::-1]
emails = []
for msg_id in recent_ids:
try:
msg = self._fetch_email(msg_id)
if msg and msg.id not in seen_ids:
seen_ids.add(msg.id)
all_emails.append(msg)
email_msg = self._fetch_email(msg_id)
if email_msg:
emails.append(email_msg)
except Exception as e:
logger.warning(f"[{self.account.email}] 解析邮件失败 (ID: {msg_id}, folder: {folder}): {e}")
logger.warning(f"[{self.account.email}] 解析邮件失败 (ID: {msg_id}): {e}")
return emails
except Exception as e:
self.record_failure(str(e))
logger.warning(f"[{self.account.email}] 搜索文件夹 {folder} 失败: {e}")
_imap_pool.invalidate(self.account.email)
self._connected = False
self._conn = None
break
# 按收信时间降序排列,截取 count 封
all_emails.sort(key=lambda m: m.received_timestamp, reverse=True)
return all_emails[:count]
logger.error(f"[{self.account.email}] 获取邮件失败: {e}")
return []
def _fetch_email(self, msg_id: bytes) -> Optional[EmailMessage]:
"""获取并解析单封邮件"""
@@ -272,193 +211,21 @@ class IMAPNewProvider(OutlookProvider):
if not raw:
return None
return _parse_email(raw)
return self._parse_email(raw)
def wait_for_new_email_idle(self, timeout: int = 25) -> bool:
"""
RFC 2177 IMAP IDLE 实现。
发送 IDLE 命令,等待服务器推送 EXISTS/RECENT然后发送 DONE。
Returns True 表示有新邮件推送False 表示超时或异常(调用方降级轮询)。
"""
if not self._connected:
if not self.connect():
return False
try:
self._conn.select("INBOX", readonly=True)
except Exception as e:
logger.warning(f"[{self.account.email}] IDLE 前 SELECT 失败: {e}")
return False
tag = self._next_idle_tag()
sock = self._conn.socket()
logger.info(f"[{self.account.email}] 进入 IMAP IDLE 等待模式(超时 {timeout}stag={tag}")
try:
# 发送 IDLE 命令
self._conn.send(f"{tag} IDLE\r\n".encode())
# 等待 "+" 延续响应(服务端确认进入 IDLE
deadline = time.time() + min(10.0, timeout)
buf = b""
got_continuation = False
while time.time() < deadline:
ready = select.select([sock], [], [], min(2.0, deadline - time.time()))
if ready[0]:
chunk = sock.recv(4096)
if not chunk:
break
buf += chunk
if b"+ " in buf or b"+\r\n" in buf:
got_continuation = True
break
if not got_continuation:
logger.warning(f"[{self.account.email}] 未收到 IDLE 延续响应,放弃")
return False
# 等待 EXISTS / RECENT 推送
got_new = False
buf = b""
deadline = time.time() + timeout
while time.time() < deadline:
remaining = deadline - time.time()
if remaining <= 0:
break
ready = select.select([sock], [], [], min(2.0, remaining))
if ready[0]:
chunk = sock.recv(4096)
if not chunk:
break
buf += chunk
if b"EXISTS" in buf or b"RECENT" in buf:
logger.debug(f"[{self.account.email}] IDLE 收到新邮件推送")
got_new = True
break
return got_new
except Exception as e:
logger.warning(f"[{self.account.email}] IMAP IDLE 异常: {e}")
return False
finally:
# 发送 DONE 结束 IDLE并排空服务端响应
try:
self._conn.send(b"DONE\r\n")
drain_deadline = time.time() + 5
drain_buf = b""
tag_end = f"{tag} OK".encode()
tag_no = f"{tag} NO".encode()
tag_bad = f"{tag} BAD".encode()
while time.time() < drain_deadline:
ready = select.select([sock], [], [], 1.0)
if not ready[0]:
break
chunk = sock.recv(4096)
if not chunk:
break
drain_buf += chunk
if any(t in drain_buf for t in (tag_end, tag_no, tag_bad)):
break
except Exception:
_imap_pool.invalidate(self.account.email)
self._connected = False
self._conn = None
@staticmethod
def _parse_email(raw: bytes) -> EmailMessage:
"""解析原始邮件"""
# 使用旧版提供者的解析方法
return IMAPOldProvider._parse_email(raw)
def test_connection(self) -> bool:
"""测试 IMAP 连接"""
try:
with self:
self._conn.select("INBOX", readonly=True)
self._conn.search(None, "ALL")
return True
except Exception as e:
logger.warning(f"[{self.account.email}] IMAP 连接测试失败: {e}")
logger.warning(f"[{self.account.email}] 新版 IMAP 连接测试失败: {e}")
return False
def _parse_email(raw: bytes) -> EmailMessage:
"""解析原始邮件为 EmailMessage优先 text/plain次选 text/html"""
msg = email.message_from_bytes(raw)
def _decode(val):
if not val:
return ""
parts = decode_header(str(val))
result = ""
for part, charset in parts:
if isinstance(part, bytes):
try:
result += part.decode(charset or "utf-8", errors="replace")
except (LookupError, UnicodeDecodeError):
result += part.decode("utf-8", errors="replace")
else:
result += str(part)
return result
subject = _decode(msg.get("Subject", ""))
sender = _decode(msg.get("From", ""))
recipients = [_decode(msg.get("To", ""))]
received_at = None
received_ts = 0
date_str = msg.get("Date", "")
if date_str:
try:
received_at = parsedate_to_datetime(date_str)
received_ts = int(received_at.timestamp())
except Exception:
pass
# 提取正文:优先 text/plain次选 text/html
plain_body = ""
html_body = ""
if msg.is_multipart():
for part in msg.walk():
ct = part.get_content_type()
cd = str(part.get("Content-Disposition", ""))
if "attachment" in cd.lower():
continue
try:
charset = part.get_content_charset() or "utf-8"
payload = part.get_payload(decode=True)
if not payload:
continue
decoded = payload.decode(charset, errors="replace")
if ct == "text/plain" and not plain_body:
plain_body = decoded
elif ct == "text/html" and not html_body:
html_body = decoded
except Exception:
pass
else:
try:
charset = msg.get_content_charset() or "utf-8"
payload = msg.get_payload(decode=True)
if payload:
ct = msg.get_content_type()
decoded = payload.decode(charset, errors="replace")
if ct == "text/plain":
plain_body = decoded
else:
html_body = decoded
except Exception:
pass
body = plain_body or html_body
body_preview = body[:200].strip()
msg_id = msg.get("Message-ID", "").strip("<>")
if not msg_id:
msg_id = f"{sender}_{received_ts}"
return EmailMessage(
id=msg_id,
subject=subject,
sender=sender,
recipients=recipients,
body=body,
body_preview=body_preview,
received_at=received_at,
received_timestamp=received_ts,
)

View File

@@ -0,0 +1,345 @@
"""
旧版 IMAP 提供者
使用 outlook.office365.com 服务器和 login.live.com Token 端点
"""
import email
import imaplib
import logging
from email.header import decode_header
from email.utils import parsedate_to_datetime
from typing import List, Optional
from ..base import ProviderType, EmailMessage
from ..account import OutlookAccount
from ..token_manager import TokenManager
from .base import OutlookProvider, ProviderConfig
logger = logging.getLogger(__name__)
class IMAPOldProvider(OutlookProvider):
"""
旧版 IMAP 提供者
使用 outlook.office365.com:993 和 login.live.com Token 端点
"""
# IMAP 服务器配置
IMAP_HOST = "outlook.office365.com"
IMAP_PORT = 993
@property
def provider_type(self) -> ProviderType:
return ProviderType.IMAP_OLD
def __init__(
self,
account: OutlookAccount,
config: Optional[ProviderConfig] = None,
):
super().__init__(account, config)
# IMAP 连接
self._conn: Optional[imaplib.IMAP4_SSL] = None
# Token 管理器
self._token_manager: Optional[TokenManager] = None
def connect(self) -> bool:
"""
连接到 IMAP 服务器
Returns:
是否连接成功
"""
if self._connected and self._conn:
# 检查现有连接
try:
self._conn.noop()
return True
except Exception:
self.disconnect()
try:
logger.debug(f"[{self.account.email}] 正在连接 IMAP ({self.IMAP_HOST})...")
# 创建连接
self._conn = imaplib.IMAP4_SSL(
self.IMAP_HOST,
self.IMAP_PORT,
timeout=self.config.timeout,
)
# 尝试 XOAUTH2 认证
if self.account.has_oauth():
if self._authenticate_xoauth2():
self._connected = True
self.record_success()
logger.info(f"[{self.account.email}] IMAP 连接成功 (XOAUTH2)")
return True
else:
logger.warning(f"[{self.account.email}] XOAUTH2 认证失败,尝试密码认证")
# 密码认证
if self.account.password:
self._conn.login(self.account.email, self.account.password)
self._connected = True
self.record_success()
logger.info(f"[{self.account.email}] IMAP 连接成功 (密码认证)")
return True
raise ValueError("没有可用的认证方式")
except Exception as e:
self.disconnect()
self.record_failure(str(e))
logger.error(f"[{self.account.email}] IMAP 连接失败: {e}")
return False
def _authenticate_xoauth2(self) -> bool:
"""
使用 XOAUTH2 认证
Returns:
是否认证成功
"""
if not self._token_manager:
self._token_manager = TokenManager(
self.account,
ProviderType.IMAP_OLD,
self.config.proxy_url,
self.config.timeout,
)
# 获取 Access Token
token = self._token_manager.get_access_token()
if not token:
return False
try:
# 构建 XOAUTH2 认证字符串
auth_string = f"user={self.account.email}\x01auth=Bearer {token}\x01\x01"
self._conn.authenticate("XOAUTH2", lambda _: auth_string.encode("utf-8"))
return True
except Exception as e:
logger.debug(f"[{self.account.email}] XOAUTH2 认证异常: {e}")
# 清除缓存的 Token
self._token_manager.clear_cache()
return False
def disconnect(self):
"""断开 IMAP 连接"""
if self._conn:
try:
self._conn.close()
except Exception:
pass
try:
self._conn.logout()
except Exception:
pass
self._conn = None
self._connected = False
def get_recent_emails(
self,
count: int = 20,
only_unseen: bool = True,
) -> List[EmailMessage]:
"""
获取最近的邮件
Args:
count: 获取数量
only_unseen: 是否只获取未读
Returns:
邮件列表
"""
if not self._connected:
if not self.connect():
return []
try:
# 选择收件箱
self._conn.select("INBOX", readonly=True)
# 搜索邮件
flag = "UNSEEN" if only_unseen else "ALL"
status, data = self._conn.search(None, flag)
if status != "OK" or not data or not data[0]:
return []
# 获取最新的邮件 ID
ids = data[0].split()
recent_ids = ids[-count:][::-1] # 倒序,最新的在前
emails = []
for msg_id in recent_ids:
try:
email_msg = self._fetch_email(msg_id)
if email_msg:
emails.append(email_msg)
except Exception as e:
logger.warning(f"[{self.account.email}] 解析邮件失败 (ID: {msg_id}): {e}")
return emails
except Exception as e:
self.record_failure(str(e))
logger.error(f"[{self.account.email}] 获取邮件失败: {e}")
return []
def _fetch_email(self, msg_id: bytes) -> Optional[EmailMessage]:
"""
获取并解析单封邮件
Args:
msg_id: 邮件 ID
Returns:
EmailMessage 对象,失败返回 None
"""
status, data = self._conn.fetch(msg_id, "(RFC822)")
if status != "OK" or not data or not data[0]:
return None
# 获取原始邮件内容
raw = b""
for part in data:
if isinstance(part, tuple) and len(part) > 1:
raw = part[1]
break
if not raw:
return None
return self._parse_email(raw)
@staticmethod
def _parse_email(raw: bytes) -> EmailMessage:
"""
解析原始邮件
Args:
raw: 原始邮件数据
Returns:
EmailMessage 对象
"""
# 移除 BOM
if raw.startswith(b"\xef\xbb\xbf"):
raw = raw[3:]
msg = email.message_from_bytes(raw)
# 解析邮件头
subject = IMAPOldProvider._decode_header(msg.get("Subject", ""))
sender = IMAPOldProvider._decode_header(msg.get("From", ""))
to = IMAPOldProvider._decode_header(msg.get("To", ""))
delivered_to = IMAPOldProvider._decode_header(msg.get("Delivered-To", ""))
x_original_to = IMAPOldProvider._decode_header(msg.get("X-Original-To", ""))
date_str = IMAPOldProvider._decode_header(msg.get("Date", ""))
# 提取正文
body = IMAPOldProvider._extract_body(msg)
# 解析日期
received_timestamp = 0
received_at = None
try:
if date_str:
received_at = parsedate_to_datetime(date_str)
received_timestamp = int(received_at.timestamp())
except Exception:
pass
# 构建收件人列表
recipients = [r for r in [to, delivered_to, x_original_to] if r]
return EmailMessage(
id=msg.get("Message-ID", ""),
subject=subject,
sender=sender,
recipients=recipients,
body=body,
received_at=received_at,
received_timestamp=received_timestamp,
is_read=False, # 搜索的是未读邮件
raw_data=raw[:500] if len(raw) > 500 else raw,
)
@staticmethod
def _decode_header(header: str) -> str:
"""解码邮件头"""
if not header:
return ""
parts = []
for chunk, encoding in decode_header(header):
if isinstance(chunk, bytes):
try:
decoded = chunk.decode(encoding or "utf-8", errors="replace")
parts.append(decoded)
except Exception:
parts.append(chunk.decode("utf-8", errors="replace"))
else:
parts.append(str(chunk))
return "".join(parts).strip()
@staticmethod
def _extract_body(msg) -> str:
"""提取邮件正文"""
import html as html_module
import re
texts = []
parts = msg.walk() if msg.is_multipart() else [msg]
for part in parts:
content_type = part.get_content_type()
if content_type not in ("text/plain", "text/html"):
continue
payload = part.get_payload(decode=True)
if not payload:
continue
charset = part.get_content_charset() or "utf-8"
try:
text = payload.decode(charset, errors="replace")
except LookupError:
text = payload.decode("utf-8", errors="replace")
# 如果是 HTML移除标签
if "<html" in text.lower():
text = re.sub(r"<[^>]+>", " ", text)
texts.append(text)
# 合并并清理文本
combined = " ".join(texts)
combined = html_module.unescape(combined)
combined = re.sub(r"\s+", " ", combined).strip()
return combined
def test_connection(self) -> bool:
"""
测试 IMAP 连接
Returns:
连接是否正常
"""
try:
with self:
self._conn.select("INBOX", readonly=True)
self._conn.search(None, "ALL")
return True
except Exception as e:
logger.warning(f"[{self.account.email}] IMAP 连接测试失败: {e}")
return False

View File

@@ -1,6 +1,6 @@
"""
Outlook 邮箱服务主类(简化版)
单一 IMAP_NEW Provider + 邮件缓存 + IMAP IDLE 支持
Outlook 邮箱服务主类
支持多种 IMAP/API 连接方式,自动故障切换
"""
import logging
@@ -8,24 +8,34 @@ import threading
import time
from typing import Optional, Dict, Any, List
from ..base import BaseEmailService, EmailServiceError, EmailServiceType
from ..base import BaseEmailService, EmailServiceError, EmailServiceStatus, EmailServiceType
from ...config.constants import EmailServiceType as ServiceType
from ...config.settings import get_settings
from .account import OutlookAccount
from .base import EmailMessage
from .email_parser import get_email_parser
from .health_checker import HealthChecker
from .providers.base import ProviderConfig
from .base import ProviderType, EmailMessage
from .email_parser import EmailParser, get_email_parser
from .health_checker import HealthChecker, FailoverManager
from .providers.base import OutlookProvider, ProviderConfig
from .providers.imap_old import IMAPOldProvider
from .providers.imap_new import IMAPNewProvider
from .providers.graph_api import GraphAPIProvider
logger = logging.getLogger(__name__)
# 验证码搜索的文件夹列表(同时搜索收件箱和垃圾箱)
_OUTLOOK_SEARCH_FOLDERS = ["INBOX", "Junk Email"]
# 默认提供者优先级
# IMAP_OLD 最兼容(只需 login.live.com tokenIMAP_NEW 次之Graph API 最后
# 原因:部分 client_id 没有 Graph API 权限,但有 IMAP 权限
DEFAULT_PROVIDER_PRIORITY = [
ProviderType.IMAP_OLD,
ProviderType.IMAP_NEW,
ProviderType.GRAPH_API,
]
def _get_code_settings() -> dict:
def get_email_code_settings() -> dict:
"""获取验证码等待配置"""
settings = get_settings()
return {
"timeout": settings.email_code_timeout,
@@ -33,58 +43,56 @@ def _get_code_settings() -> dict:
}
class _EmailCache:
"""轻量级邮件内存缓存TTL=60s减少重复 IMAP 请求)"""
TTL = 60
def __init__(self):
self._cache: Dict[str, tuple] = {} # email -> (timestamp, List[EmailMessage])
self._lock = threading.Lock()
def get(self, email: str) -> Optional[List[EmailMessage]]:
with self._lock:
entry = self._cache.get(email)
if entry and time.time() - entry[0] < self.TTL:
return entry[1]
return None
def set(self, email: str, messages: List[EmailMessage]):
with self._lock:
self._cache[email] = (time.time(), messages)
def invalidate(self, email: str):
with self._lock:
self._cache.pop(email, None)
class OutlookService(BaseEmailService):
"""
Outlook 邮箱服务
使用单一 IMAP_NEW Provider支持连接池复用和 IMAP IDLE
支持多种 IMAP/API 连接方式,自动故障切换
"""
def __init__(self, config: Dict[str, Any] = None, name: str = None):
"""
初始化 Outlook 服务
Args:
config: 配置字典,支持以下键:
- accounts: Outlook 账户列表
- provider_priority: 提供者优先级列表
- health_failure_threshold: 连续失败次数阈值
- health_disable_duration: 禁用时长(秒)
- timeout: 请求超时时间
- proxy_url: 代理 URL
name: 服务名称
"""
super().__init__(ServiceType.OUTLOOK, name)
# 默认配置
default_config = {
"accounts": [],
"provider_priority": [p.value for p in DEFAULT_PROVIDER_PRIORITY],
"health_failure_threshold": 5,
"health_disable_duration": 60,
"timeout": 30,
"proxy_url": None,
}
self.config = {**default_config, **(config or {})}
# 解析提供者优先级
self.provider_priority = [
ProviderType(p) for p in self.config.get("provider_priority", [])
]
if not self.provider_priority:
self.provider_priority = DEFAULT_PROVIDER_PRIORITY
# 提供者配置
self.provider_config = ProviderConfig(
timeout=self.config.get("timeout", 30),
proxy_url=self.config.get("proxy_url"),
service_id=self.config.get("service_id"),
health_failure_threshold=self.config.get("health_failure_threshold", 3),
health_disable_duration=self.config.get("health_disable_duration", 300),
)
# 获取默认 client_id
# 获取默认 client_id(供无 client_id 的账户使用)
try:
_default_client_id = get_settings().outlook_default_client_id
except Exception:
@@ -95,120 +103,193 @@ class OutlookService(BaseEmailService):
self._current_account_index = 0
self._account_lock = threading.Lock()
# 支持两种配置格式
if "email" in self.config and "password" in self.config:
account = OutlookAccount.from_config(self.config)
if not account.client_id and _default_client_id:
account.client_id = _default_client_id
if account.validate():
if not account.has_oauth():
logger.warning(
f"[{account.email}] 跳过IMAP_NEW 仅支持 OAuth2"
f"请配置 client_id 和 refresh_token"
)
else:
self.accounts.append(account)
else:
for ac in self.config.get("accounts", []):
account = OutlookAccount.from_config(ac)
for account_config in self.config.get("accounts", []):
account = OutlookAccount.from_config(account_config)
if not account.client_id and _default_client_id:
account.client_id = _default_client_id
if account.validate():
if not account.has_oauth():
logger.warning(
f"[{account.email}] 跳过IMAP_NEW 仅支持 OAuth2"
f"请配置 client_id 和 refresh_token"
)
else:
self.accounts.append(account)
if not self.accounts:
logger.warning("未配置有效的 Outlook 账户(需要 client_id + refresh_token")
logger.warning("未配置有效的 Outlook 账户")
# 健康检查器
# 健康检查器和故障切换管理器
self.health_checker = HealthChecker(
failure_threshold=self.provider_config.health_failure_threshold,
disable_duration=self.provider_config.health_disable_duration,
)
self.failover_manager = FailoverManager(
health_checker=self.health_checker,
priority_order=self.provider_priority,
)
# 邮件解析器
self.email_parser = get_email_parser()
# Provider 实例缓存: email -> IMAPNewProvider
self._providers: Dict[str, IMAPNewProvider] = {}
# 提供者实例缓存: (email, provider_type) -> OutlookProvider
self._providers: Dict[tuple, OutlookProvider] = {}
self._provider_lock = threading.Lock()
# IMAP 并发限制(最多 5 个并发
# IMAP 连接限制(防止限流
self._imap_semaphore = threading.Semaphore(5)
# 邮件缓存
self._email_cache = _EmailCache()
# 验证码去重
# 验证码去重机制
self._used_codes: Dict[str, set] = {}
def _get_provider(self, account: OutlookAccount) -> IMAPNewProvider:
key = account.email.lower()
with self._provider_lock:
if key not in self._providers:
self._providers[key] = IMAPNewProvider(account, self.provider_config)
return self._providers[key]
def _fetch_emails(
def _get_provider(
self,
account: OutlookAccount,
count: int = 15,
only_unseen: bool = True,
since_minutes: Optional[int] = None,
use_cache: bool = False,
folders: Optional[List[str]] = None,
) -> List[EmailMessage]:
"""通过 IMAP_NEW Provider 获取邮件,可选使用内存缓存"""
if use_cache:
cached = self._email_cache.get(account.email)
if cached is not None:
return cached
provider_type: ProviderType,
) -> OutlookProvider:
"""
获取或创建提供者实例
if not self.health_checker.is_available():
logger.debug(f"[{account.email}] IMAP_NEW 不可用,跳过")
return []
Args:
account: Outlook 账户
provider_type: 提供者类型
Returns:
提供者实例
"""
cache_key = (account.email.lower(), provider_type)
with self._provider_lock:
if cache_key not in self._providers:
provider = self._create_provider(account, provider_type)
self._providers[cache_key] = provider
return self._providers[cache_key]
def _create_provider(
self,
account: OutlookAccount,
provider_type: ProviderType,
) -> OutlookProvider:
"""
创建提供者实例
Args:
account: Outlook 账户
provider_type: 提供者类型
Returns:
提供者实例
"""
if provider_type == ProviderType.IMAP_OLD:
return IMAPOldProvider(account, self.provider_config)
elif provider_type == ProviderType.IMAP_NEW:
return IMAPNewProvider(account, self.provider_config)
elif provider_type == ProviderType.GRAPH_API:
return GraphAPIProvider(account, self.provider_config)
else:
raise ValueError(f"未知的提供者类型: {provider_type}")
def _get_provider_priority_for_account(self, account: OutlookAccount) -> List[ProviderType]:
"""根据账户是否有 OAuth返回适合的提供者优先级列表"""
if account.has_oauth():
return self.provider_priority
else:
# 无 OAuth直接走旧版 IMAP密码认证跳过需要 OAuth 的提供者
return [ProviderType.IMAP_OLD]
def _try_providers_for_emails(
self,
account: OutlookAccount,
count: int = 20,
only_unseen: bool = True,
) -> List[EmailMessage]:
"""
尝试多个提供者获取邮件
Args:
account: Outlook 账户
count: 获取数量
only_unseen: 是否只获取未读
Returns:
邮件列表
"""
errors = []
# 根据账户类型选择合适的提供者优先级
priority = self._get_provider_priority_for_account(account)
# 按优先级尝试各提供者
for provider_type in priority:
# 检查提供者是否可用
if not self.health_checker.is_available(provider_type):
logger.debug(
f"[{account.email}] {provider_type.value} 不可用,跳过"
)
continue
try:
provider = self._get_provider(account)
provider = self._get_provider(account, provider_type)
with self._imap_semaphore:
with provider:
emails = provider.get_recent_emails(
count, only_unseen, since_minutes=since_minutes, folders=folders
)
emails = provider.get_recent_emails(count, only_unseen)
if emails:
self.health_checker.record_success()
if use_cache:
self._email_cache.set(account.email, emails)
# 成功获取邮件
self.health_checker.record_success(provider_type)
logger.debug(
f"[{account.email}] {provider_type.value} 获取到 {len(emails)} 封邮件"
)
return emails
except Exception as e:
err = str(e)
self.health_checker.record_failure(err)
logger.warning(f"[{account.email}] 获取邮件失败: {e}")
error_msg = str(e)
errors.append(f"{provider_type.value}: {error_msg}")
self.health_checker.record_failure(provider_type, error_msg)
logger.warning(
f"[{account.email}] {provider_type.value} 获取邮件失败: {e}"
)
logger.error(
f"[{account.email}] 所有提供者都失败: {'; '.join(errors)}"
)
return []
def create_email(self, config: Dict[str, Any] = None) -> Dict[str, Any]:
"""轮询选择可用的 Outlook 账户"""
"""
选择可用的 Outlook 账户
Args:
config: 配置参数(未使用)
Returns:
包含邮箱信息的字典
"""
if not self.accounts:
self.update_status(False, EmailServiceError("没有可用的 Outlook 账户"))
raise EmailServiceError("没有可用的 Outlook 账户")
# 轮询选择账户
with self._account_lock:
account = self.accounts[self._current_account_index]
self._current_account_index = (self._current_account_index + 1) % len(self.accounts)
logger.info(f"选择 Outlook 账户: {account.email}")
self.update_status(True)
return {
email_info = {
"email": account.email,
"service_id": account.email,
"account": {"email": account.email, "has_oauth": account.has_oauth()},
"account": {
"email": account.email,
"has_oauth": account.has_oauth()
}
}
logger.info(f"选择 Outlook 账户: {account.email}")
self.update_status(True)
return email_info
def get_verification_code(
self,
@@ -218,185 +299,114 @@ class OutlookService(BaseEmailService):
pattern: str = None,
otp_sent_at: Optional[float] = None,
) -> Optional[str]:
"""从 Outlook 邮箱获取验证码"""
account = next(
(a for a in self.accounts if a.email.lower() == email.lower()), None
)
"""
从 Outlook 邮箱获取验证码
Args:
email: 邮箱地址
email_id: 未使用
timeout: 超时时间(秒)
pattern: 验证码正则表达式(未使用)
otp_sent_at: OTP 发送时间戳
Returns:
验证码字符串
"""
# 查找对应的账户
account = None
for acc in self.accounts:
if acc.email.lower() == email.lower():
account = acc
break
if not account:
self.update_status(False, EmailServiceError(f"未找到邮箱账户: {email}"))
self.update_status(False, EmailServiceError(f"未找到邮箱对应的账户: {email}"))
return None
code_settings = _get_code_settings()
# 获取验证码等待配置
code_settings = get_email_code_settings()
actual_timeout = timeout or code_settings["timeout"]
poll_interval = code_settings["poll_interval"]
logger.info(f"[{email}] 开始获取验证码,超时 {actual_timeout}s")
logger.info(
f"[{email}] 开始获取验证码,超时 {actual_timeout}s"
f"提供者优先级: {[p.value for p in self.provider_priority]}"
)
# 初始化验证码去重集合
if email not in self._used_codes:
self._used_codes[email] = set()
used_codes = self._used_codes[email]
# 计算最小时间戳(留出 60 秒时钟偏差)
min_timestamp = (otp_sent_at - 60) if otp_sent_at else 0
use_idle = True
try:
use_idle = get_settings().outlook_use_idle
except Exception:
pass
start_time = time.time()
poll_count = 0
if use_idle:
code = self._wait_with_idle(
account, email, actual_timeout, min_timestamp, used_codes, otp_sent_at
while time.time() - start_time < actual_timeout:
poll_count += 1
# 渐进式邮件检查:前 3 次只检查未读
only_unseen = poll_count <= 3
try:
# 尝试多个提供者获取邮件
emails = self._try_providers_for_emails(
account,
count=15,
only_unseen=only_unseen,
)
else:
code = self._wait_with_poll(
account, email, actual_timeout, poll_interval, min_timestamp, used_codes, otp_sent_at
if emails:
logger.debug(
f"[{email}] 第 {poll_count} 次轮询获取到 {len(emails)} 封邮件"
)
# 从邮件中查找验证码
code = self.email_parser.find_verification_code_in_emails(
emails,
target_email=email,
min_timestamp=min_timestamp,
used_codes=used_codes,
)
if code:
used_codes.add(code)
elapsed = int(time.time() - start_time)
logger.info(
f"[{email}] 找到验证码: {code}"
f"总耗时 {elapsed}s轮询 {poll_count}"
)
self.update_status(True)
return code
return None
def _wait_with_poll(
self,
account: OutlookAccount,
email: str,
timeout: int,
poll_interval: int,
min_timestamp: float,
used_codes: set,
otp_sent_at: Optional[float] = None,
) -> Optional[str]:
"""轮询方式等待验证码"""
start_time = time.time()
poll_count = 0
while time.time() - start_time < timeout:
poll_count += 1
# 每次动态计算 since_minutes确保时间窗口随轮询推进而更新
if otp_sent_at:
elapsed_since_send = int((time.time() - otp_sent_at) / 60) + 2
since_minutes: Optional[int] = min(elapsed_since_send, 180)
only_unseen = False
else:
since_minutes = None
only_unseen = poll_count <= 3
try:
emails = self._fetch_emails(
account, count=15, only_unseen=only_unseen,
since_minutes=since_minutes,
folders=_OUTLOOK_SEARCH_FOLDERS,
)
if emails:
code = self.email_parser.find_verification_code_in_emails(
emails,
target_email=email,
min_timestamp=min_timestamp,
used_codes=used_codes,
)
if code:
elapsed = int(time.time() - start_time)
logger.info(
f"[{email}] 找到验证码: {code},耗时 {elapsed}s轮询 {poll_count}"
)
return code
except Exception as e:
logger.warning(f"[{email}] 轮询出错: {e}")
logger.warning(f"[{email}] 检查出错: {e}")
# 等待下次轮询
time.sleep(poll_interval)
logger.warning(f"[{email}] 验证码超时 ({timeout}s),共轮询 {poll_count}")
elapsed = int(time.time() - start_time)
logger.warning(f"[{email}] 验证码超时 ({actual_timeout}s),共轮询 {poll_count}")
return None
def _wait_with_idle(
self,
account: OutlookAccount,
email: str,
timeout: int,
min_timestamp: float,
used_codes: set,
otp_sent_at: Optional[float] = None,
) -> Optional[str]:
"""IMAP IDLE 方式等待验证码,失败时自动降级为轮询"""
if not self.health_checker.is_available():
logger.warning(f"[{email}] IMAP_NEW 不可用,降级为轮询")
return self._wait_with_poll(
account, email, timeout, 3, min_timestamp, used_codes, otp_sent_at
)
def list_emails(self, **kwargs) -> List[Dict[str, Any]]:
"""列出所有可用的 Outlook 账户"""
return [
{
"email": account.email,
"id": account.email,
"has_oauth": account.has_oauth(),
"type": "outlook"
}
for account in self.accounts
]
# 计算 since_minutes从发送时间前2分钟开始最多180分钟
since_minutes: Optional[int] = None
if otp_sent_at:
elapsed_since_send = int((time.time() - otp_sent_at) / 60) + 2
since_minutes = min(elapsed_since_send, 180)
start_time = time.time()
try:
provider = self._get_provider(account)
with self._imap_semaphore:
with provider:
# 先做一次即时检查
emails = provider.get_recent_emails(
15, only_unseen=(since_minutes is None), since_minutes=since_minutes,
folders=_OUTLOOK_SEARCH_FOLDERS,
)
code = self.email_parser.find_verification_code_in_emails(
emails,
target_email=email,
min_timestamp=min_timestamp,
used_codes=used_codes,
)
if code:
elapsed = int(time.time() - start_time)
logger.info(f"[{email}] 找到验证码: {code},耗时 {elapsed}s即时检查")
return code
# IDLE 等待循环
while time.time() - start_time < timeout:
remaining = int(timeout - (time.time() - start_time))
if remaining <= 0:
break
arrived = provider.wait_for_new_email_idle(timeout=min(remaining, 25))
# 无效化缓存,强制重新拉取
self._email_cache.invalidate(email)
# IDLE 触发后用 since_minutes 搜索,覆盖已读邮件
fetch_since = since_minutes
if fetch_since is None:
# 没有 otp_sent_at 时用距当前时间2分钟内的邮件
fetch_since = 2
emails = provider.get_recent_emails(
15, only_unseen=False, since_minutes=fetch_since,
folders=_OUTLOOK_SEARCH_FOLDERS,
)
code = self.email_parser.find_verification_code_in_emails(
emails,
target_email=email,
min_timestamp=min_timestamp,
used_codes=used_codes,
)
if code:
elapsed = int(time.time() - start_time)
logger.info(
f"[{email}] 找到验证码: {code},耗时 {elapsed}s"
f"IDLE {'推送' if arrived else '超时检查'}"
)
return code
except Exception as e:
logger.warning(f"[{email}] IDLE 失败,降级为轮询: {e}")
elapsed = int(time.time() - start_time)
remaining = max(0, timeout - elapsed)
if remaining > 0:
code_settings = _get_code_settings()
return self._wait_with_poll(
account, email, remaining,
code_settings["poll_interval"], min_timestamp, used_codes, otp_sent_at
)
logger.warning(f"[{email}] IDLE 等待验证码超时 ({timeout}s)")
return None
def delete_email(self, email_id: str) -> bool:
"""删除邮箱Outlook 不支持删除账户)"""
logger.warning(f"Outlook 服务不支持删除账户: {email_id}")
return False
def check_health(self) -> bool:
"""检查 Outlook 服务是否可用"""
@@ -404,48 +414,48 @@ class OutlookService(BaseEmailService):
self.update_status(False, EmailServiceError("没有配置的账户"))
return False
# 测试第一个账户的连接
test_account = self.accounts[0]
# 尝试任一提供者连接
for provider_type in self.provider_priority:
try:
provider = self._get_provider(self.accounts[0])
provider = self._get_provider(test_account, provider_type)
if provider.test_connection():
self.update_status(True)
return True
except Exception as e:
logger.warning(f"Outlook 健康检查失败: {e}")
logger.warning(
f"Outlook 健康检查失败 ({test_account.email}, {provider_type.value}): {e}"
)
self.update_status(False, EmailServiceError("健康检查失败"))
return False
def list_emails(self, **kwargs) -> List[Dict[str, Any]]:
return [
{
"email": a.email,
"id": a.email,
"has_oauth": a.has_oauth(),
"type": "outlook",
}
for a in self.accounts
]
def delete_email(self, email_id: str) -> bool:
logger.warning(f"Outlook 服务不支持删除账户: {email_id}")
return False
def get_provider_status(self) -> Dict[str, Any]:
"""获取提供者状态"""
return self.failover_manager.get_status()
def get_account_stats(self) -> Dict[str, Any]:
"""获取账户统计信息"""
total = len(self.accounts)
oauth_count = sum(1 for a in self.accounts if a.has_oauth())
oauth_count = sum(1 for acc in self.accounts if acc.has_oauth())
return {
"total_accounts": total,
"oauth_accounts": oauth_count,
"password_accounts": total - oauth_count,
"accounts": [a.to_dict() for a in self.accounts],
"health_status": self.health_checker.get_status(),
"accounts": [acc.to_dict() for acc in self.accounts],
"provider_status": self.get_provider_status(),
}
def add_account(self, account_config: Dict[str, Any]) -> bool:
"""添加新的 Outlook 账户"""
try:
account = OutlookAccount.from_config(account_config)
if not account.validate():
return False
self.accounts.append(account)
logger.info(f"添加 Outlook 账户: {account.email}")
return True
@@ -454,13 +464,24 @@ class OutlookService(BaseEmailService):
return False
def remove_account(self, email: str) -> bool:
for i, a in enumerate(self.accounts):
if a.email.lower() == email.lower():
"""移除 Outlook 账户"""
for i, acc in enumerate(self.accounts):
if acc.email.lower() == email.lower():
self.accounts.pop(i)
logger.info(f"移除 Outlook 账户: {email}")
return True
return False
def reset_health(self):
self.health_checker.reset()
logger.info("已重置 IMAP_NEW 健康状态")
def reset_provider_health(self):
"""重置所有提供者的健康状态"""
self.health_checker.reset_all()
logger.info("已重置所有提供者的健康状态")
def force_provider(self, provider_type: ProviderType):
"""强制使用指定的提供者"""
self.health_checker.force_enable(provider_type)
# 禁用其他提供者
for pt in ProviderType:
if pt != provider_type:
self.health_checker.force_disable(pt, 60)
logger.info(f"已强制使用提供者: {provider_type.value}")

View File

@@ -1,6 +1,6 @@
"""
Token 管理器(简化版)
固定使用 consumers 端点 + IMAP scope
Token 管理器
支持多个 Microsoft Token 端点,自动选择合适的端点
"""
import json
@@ -11,98 +11,153 @@ from typing import Dict, Optional, Any
from curl_cffi import requests as _requests
from .base import TokenInfo
from .base import ProviderType, TokenEndpoint, TokenInfo
from .account import OutlookAccount
logger = logging.getLogger(__name__)
TOKEN_URL = "https://login.microsoftonline.com/consumers/oauth2/v2.0/token"
IMAP_SCOPE = "https://outlook.office.com/IMAP.AccessAsUser.All offline_access"
# 各提供者的 Scope 配置
PROVIDER_SCOPES = {
ProviderType.IMAP_OLD: "", # 旧版 IMAP 不需要特定 scope
ProviderType.IMAP_NEW: "https://outlook.office.com/IMAP.AccessAsUser.All offline_access",
ProviderType.GRAPH_API: "https://graph.microsoft.com/.default",
}
# 各提供者的 Token 端点
PROVIDER_TOKEN_URLS = {
ProviderType.IMAP_OLD: TokenEndpoint.LIVE.value,
ProviderType.IMAP_NEW: TokenEndpoint.CONSUMERS.value,
ProviderType.GRAPH_API: TokenEndpoint.COMMON.value,
}
class TokenManager:
"""
Token 管理器
固定 consumers 端点,缓存 key = email
支持多端点 Token 获取和缓存
"""
_token_cache: Dict[str, TokenInfo] = {}
# Token 缓存: key = (email, provider_type) -> TokenInfo
_token_cache: Dict[tuple, TokenInfo] = {}
_cache_lock = threading.Lock()
# 默认超时时间
DEFAULT_TIMEOUT = 30
# Token 刷新提前时间(秒)
REFRESH_BUFFER = 120
def __init__(
self,
account: OutlookAccount,
provider_type: ProviderType,
proxy_url: Optional[str] = None,
timeout: int = DEFAULT_TIMEOUT,
service_id: Optional[int] = None,
):
"""
初始化 Token 管理器
Args:
account: Outlook 账户
provider_type: 提供者类型
proxy_url: 代理 URL可选
timeout: 请求超时时间
"""
self.account = account
self.provider_type = provider_type
self.proxy_url = proxy_url
self.timeout = timeout
self.service_id = service_id
def _cache_key(self) -> str:
return self.account.email.lower()
# 获取端点和 Scope
self.token_url = PROVIDER_TOKEN_URLS.get(provider_type, TokenEndpoint.LIVE.value)
self.scope = PROVIDER_SCOPES.get(provider_type, "")
def get_cached_token(self) -> Optional[TokenInfo]:
"""获取缓存的 Token"""
cache_key = (self.account.email.lower(), self.provider_type)
with self._cache_lock:
token = self._token_cache.get(self._cache_key())
token = self._token_cache.get(cache_key)
if token and not token.is_expired(self.REFRESH_BUFFER):
return token
return None
def set_cached_token(self, token: TokenInfo):
"""缓存 Token"""
cache_key = (self.account.email.lower(), self.provider_type)
with self._cache_lock:
self._token_cache[self._cache_key()] = token
self._token_cache[cache_key] = token
def clear_cache(self):
"""清除缓存"""
cache_key = (self.account.email.lower(), self.provider_type)
with self._cache_lock:
self._token_cache.pop(self._cache_key(), None)
self._token_cache.pop(cache_key, None)
def get_access_token(self, force_refresh: bool = False) -> Optional[str]:
"""
获取 Access Token
Args:
force_refresh: 是否强制刷新
Returns:
Access Token 字符串,失败返回 None
"""
# 检查缓存
if not force_refresh:
cached = self.get_cached_token()
if cached:
logger.debug(f"[{self.account.email}] 使用缓存 Token")
logger.debug(f"[{self.account.email}] 使用缓存 Token ({self.provider_type.value})")
return cached.access_token
# 刷新 Token
try:
token = self._refresh_token()
if token:
self.set_cached_token(token)
return token.access_token
except Exception as e:
logger.error(f"[{self.account.email}] 获取 Token 失败: {e}")
logger.error(f"[{self.account.email}] 获取 Token 失败 ({self.provider_type.value}): {e}")
return None
def _refresh_token(self) -> Optional[TokenInfo]:
"""
刷新 Token
Returns:
TokenInfo 对象,失败返回 None
"""
if not self.account.client_id or not self.account.refresh_token:
raise ValueError("缺少 client_id 或 refresh_token")
logger.debug(f"[{self.account.email}] 正在刷新 Token...")
logger.debug(f"[{self.account.email}] 正在刷新 Token ({self.provider_type.value})...")
logger.debug(f"[{self.account.email}] Token URL: {self.token_url}")
# 构建请求体
data = {
"client_id": self.account.client_id,
"refresh_token": self.account.refresh_token,
"grant_type": "refresh_token",
"scope": IMAP_SCOPE,
}
# 添加 Scope如果需要
if self.scope:
data["scope"] = self.scope
headers = {
"Content-Type": "application/x-www-form-urlencoded",
"Accept": "application/json",
}
proxies = None
if self.proxy_url:
proxies = {"http": self.proxy_url, "https": self.proxy_url}
try:
resp = _requests.post(
TOKEN_URL,
self.token_url,
data=data,
headers=headers,
proxies=proxies,
@@ -111,56 +166,74 @@ class TokenManager:
)
if resp.status_code != 200:
body = resp.text
error_body = resp.text
logger.error(f"[{self.account.email}] Token 刷新失败: HTTP {resp.status_code}")
if "service abuse" in body.lower():
logger.debug(f"[{self.account.email}] 错误响应: {error_body[:500]}")
if "service abuse" in error_body.lower():
logger.warning(f"[{self.account.email}] 账号可能被封禁")
elif "invalid_grant" in body.lower():
elif "invalid_grant" in error_body.lower():
logger.warning(f"[{self.account.email}] Refresh Token 已失效")
return None
response_data = resp.json()
token = TokenInfo.from_response(response_data, IMAP_SCOPE)
# 解析响应
token = TokenInfo.from_response(response_data, self.scope)
logger.info(
f"[{self.account.email}] Token 刷新成功"
f"[{self.account.email}] Token 刷新成功 ({self.provider_type.value}), "
f"有效期 {int(token.expires_at - time.time())}"
)
# 若响应含新 refresh_token → 写回内存 + 持久化数据库
new_rt = response_data.get("refresh_token", "")
if new_rt and new_rt != self.account.refresh_token:
self.account.refresh_token = new_rt
if self.service_id:
try:
from ...database.session import get_session_manager
from ...database.crud import update_outlook_refresh_token
with get_session_manager().session_scope() as db:
update_outlook_refresh_token(
db, self.service_id, self.account.email, new_rt
)
logger.info(f"[{self.account.email}] refresh_token 已写回数据库")
except Exception as e:
logger.warning(f"[{self.account.email}] 写回 refresh_token 失败: {e}")
return token
except json.JSONDecodeError as e:
logger.error(f"[{self.account.email}] JSON 解析错误: {e}")
return None
except Exception as e:
logger.error(f"[{self.account.email}] 未知错误: {e}")
return None
@classmethod
def clear_all_cache(cls):
"""清除所有 Token 缓存"""
with cls._cache_lock:
cls._token_cache.clear()
logger.info("已清除所有 Token 缓存")
@classmethod
def get_cache_stats(cls) -> Dict[str, Any]:
"""获取缓存统计"""
with cls._cache_lock:
return {
"cache_size": len(cls._token_cache),
"entries": list(cls._token_cache.keys()),
"entries": [
{
"email": key[0],
"provider": key[1].value,
}
for key in cls._token_cache.keys()
],
}
def create_token_manager(
account: OutlookAccount,
provider_type: ProviderType,
proxy_url: Optional[str] = None,
timeout: int = TokenManager.DEFAULT_TIMEOUT,
) -> TokenManager:
"""
创建 Token 管理器的工厂函数
Args:
account: Outlook 账户
provider_type: 提供者类型
proxy_url: 代理 URL
timeout: 超时时间
Returns:
TokenManager 实例
"""
return TokenManager(account, provider_type, proxy_url, timeout)

View File

@@ -15,7 +15,7 @@ from email.policy import default as email_policy
from html import unescape
from typing import Optional, Dict, Any, List
from .base import BaseEmailService, EmailServiceError, EmailServiceType
from .base import BaseEmailService, EmailServiceError, EmailServiceType, RateLimitedEmailServiceError
from ..core.http_client import HTTPClient, RequestConfig
from ..config.constants import OTP_CODE_PATTERN
@@ -200,8 +200,19 @@ class TempMailService(BaseEmailService):
error_msg = f"{error_msg} - {error_data}"
except Exception:
error_msg = f"{error_msg} - {response.text[:200]}"
self.update_status(False, EmailServiceError(error_msg))
raise EmailServiceError(error_msg)
retry_after = None
if response.status_code == 429:
retry_after_header = response.headers.get("Retry-After")
if retry_after_header:
try:
retry_after = max(1, int(retry_after_header))
except ValueError:
retry_after = None
error = RateLimitedEmailServiceError(error_msg, retry_after=retry_after)
else:
error = EmailServiceError(error_msg)
self.update_status(False, error)
raise error
try:
return response.json()

View File

@@ -6,6 +6,7 @@ import re
import time
import logging
from typing import Optional, Dict, Any, List
from datetime import datetime, timezone
from .base import BaseEmailService, EmailServiceError, EmailServiceType
from ..core.http_client import HTTPClient, RequestConfig
@@ -59,6 +60,35 @@ class TempmailService(BaseEmailService):
self._email_cache: Dict[str, Dict[str, Any]] = {}
self._last_check_time: float = 0
def _parse_message_time(self, value: Any) -> Optional[float]:
"""解析 Tempmail 邮件时间,兼容 Unix 时间戳与 ISO 8601。"""
if value is None or value == "":
return None
if isinstance(value, (int, float)):
timestamp = float(value)
else:
text = str(value).strip()
if not text:
return None
try:
timestamp = float(text)
except ValueError:
try:
normalized = text.replace("Z", "+00:00")
timestamp = datetime.fromisoformat(normalized).astimezone(timezone.utc).timestamp()
except Exception:
return None
while timestamp > 1e11:
timestamp /= 1000.0
return timestamp if timestamp > 0 else None
def _get_received_timestamp(self, message: Dict[str, Any]) -> Optional[float]:
"""返回 Tempmail 邮件的接收时间戳。"""
return self._parse_message_time(message.get("received_at"))
def _save_token_to_db(self, email: str, token: str) -> None:
"""将邮箱 token 持久化到 Setting 表key=tempmail_token:{email}"""
try:
@@ -154,7 +184,7 @@ class TempmailService(BaseEmailService):
email_id: 邮箱 token如果不提供从缓存中查找
timeout: 超时时间(秒)
pattern: 验证码正则表达式
otp_sent_at: OTP 发送时间戳Tempmail 服务暂不使用此参数)
otp_sent_at: OTP 发送时间戳,只允许使用严格晚于该锚点的邮件
Returns:
验证码字符串,如果超时或未找到返回 None
@@ -209,11 +239,20 @@ class TempmailService(BaseEmailService):
if not isinstance(msg, dict):
continue
# 使用 date 作为唯一标识
msg_date = msg.get("date", 0)
if not msg_date or msg_date in seen_ids:
msg_timestamp = self._get_received_timestamp(msg)
if otp_sent_at is not None:
if msg_timestamp is None or msg_timestamp <= otp_sent_at:
continue
seen_ids.add(msg_date)
message_id = str(
msg.get("id")
or msg.get("date")
or msg.get("createdAt")
or f"{msg.get('from', '')}:{msg.get('subject', '')}:{msg_timestamp}"
).strip()
if not message_id or message_id in seen_ids:
continue
seen_ids.add(message_id)
sender = str(msg.get("from", "")).lower()
subject = str(msg.get("subject", ""))

View File

@@ -15,9 +15,11 @@ from fastapi import FastAPI, Request, Form
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse, RedirectResponse
from fastapi.responses import FileResponse, HTMLResponse, RedirectResponse
from ..config.settings import get_settings
from ..database import crud
from ..database.session import get_db
from .routes import api_router
from .routes.websocket import router as ws_router
from .task_manager import task_manager
@@ -108,9 +110,9 @@ def create_app() -> FastAPI:
async def login_page(request: Request, next: Optional[str] = "/"):
"""登录页面"""
return templates.TemplateResponse(
request,
"login.html",
{"error": "", "next": next or "/"}
request=request,
name="login.html",
context={"request": request, "error": "", "next": next or "/"}
)
@app.post("/login")
@@ -119,9 +121,9 @@ def create_app() -> FastAPI:
expected = get_settings().webui_access_password.get_secret_value()
if not secrets.compare_digest(password, expected):
return templates.TemplateResponse(
request,
"login.html",
{"error": "密码错误", "next": next or "/"},
request=request,
name="login.html",
context={"request": request, "error": "密码错误", "next": next or "/"},
status_code=401
)
@@ -136,38 +138,48 @@ def create_app() -> FastAPI:
response.delete_cookie("webui_auth")
return response
@app.get("/favicon.ico", include_in_schema=False)
async def favicon_ico():
"""兼容浏览器对根路径 favicon 的默认请求。"""
return FileResponse(STATIC_DIR / "favicon.svg", media_type="image/svg+xml")
@app.get("/favicon.svg", include_in_schema=False)
async def favicon_svg():
"""提供统一的站点图标资源。"""
return FileResponse(STATIC_DIR / "favicon.svg", media_type="image/svg+xml")
@app.get("/", response_class=HTMLResponse)
async def index(request: Request):
"""首页 - 注册页面"""
if not _is_authenticated(request):
return _redirect_to_login(request)
return templates.TemplateResponse(request, "index.html")
return templates.TemplateResponse(request=request, name="index.html", context={"request": request})
@app.get("/accounts", response_class=HTMLResponse)
async def accounts_page(request: Request):
"""账号管理页面"""
if not _is_authenticated(request):
return _redirect_to_login(request)
return templates.TemplateResponse(request, "accounts.html")
return templates.TemplateResponse(request=request, name="accounts.html", context={"request": request})
@app.get("/email-services", response_class=HTMLResponse)
async def email_services_page(request: Request):
"""邮箱服务管理页面"""
if not _is_authenticated(request):
return _redirect_to_login(request)
return templates.TemplateResponse(request, "email_services.html")
return templates.TemplateResponse(request=request, name="email_services.html", context={"request": request})
@app.get("/settings", response_class=HTMLResponse)
async def settings_page(request: Request):
"""设置页面"""
if not _is_authenticated(request):
return _redirect_to_login(request)
return templates.TemplateResponse(request, "settings.html")
return templates.TemplateResponse(request=request, name="settings.html", context={"request": request})
@app.get("/payment", response_class=HTMLResponse)
async def payment_page(request: Request):
"""支付页面"""
return templates.TemplateResponse(request, "payment.html")
return templates.TemplateResponse(request=request, name="payment.html", context={"request": request})
@app.on_event("startup")
async def startup_event():
@@ -185,6 +197,12 @@ def create_app() -> FastAPI:
loop = asyncio.get_event_loop()
task_manager.set_loop(loop)
stale_error = "服务启动时检测到未完成的历史任务,已标记失败,请重新发起。"
with get_db() as db:
stale_tasks = crud.fail_incomplete_registration_tasks(db, stale_error)
if stale_tasks:
logger.warning("已收敛 %s 个僵尸任务: %s", len(stale_tasks), ", ".join(task[:8] for task in stale_tasks))
logger.info("=" * 50)
logger.info(f"{settings.app_name} v{settings.app_version} 启动中...")
logger.info(f"调试模式: {settings.debug}")

View File

@@ -1028,11 +1028,9 @@ def _build_inbox_config(db, service_type, email: str) -> dict:
EmailServiceModel.enabled == True
)
if service_type == EST.OUTLOOK:
# 按 config.email 精确匹配,不受 enabled 限制(收件箱是账号自己的邮箱)
all_outlook = db.query(EmailServiceModel).filter(
EmailServiceModel.service_type == db_type
).all()
svc = next((s for s in all_outlook if (s.config or {}).get("email", "").lower() == email.lower()), None)
# 按 config.email 匹配账号 email
services = query.all()
svc = next((s for s in services if (s.config or {}).get("email") == email), None)
else:
svc = query.order_by(EmailServiceModel.priority.asc()).first()

View File

@@ -67,7 +67,7 @@ class ServiceTestResult(BaseModel):
class OutlookBatchImportRequest(BaseModel):
"""Outlook 批量导入请求"""
data: str # 多行数据,每行格式: 邮箱----密码----client_id----refresh_token
data: str # 多行数据,每行格式: 邮箱----密码 或 邮箱----密码----client_id----refresh_token
enabled: bool = True
priority: int = 0
@@ -461,8 +461,11 @@ async def batch_import_outlook(request: OutlookBatchImportRequest):
"""
批量导入 Outlook 邮箱账户
格式(每行):邮箱----密码----client_id----refresh_token
使用四个连字符(----)分隔字段
支持两种格式:
- 格式一(密码认证):邮箱----密码
- 格式二XOAUTH2 认证):邮箱----密码----client_id----refresh_token
每行一个账户,使用四个连字符(----)分隔字段
"""
lines = request.data.strip().split("\n")
total = len(lines)
@@ -481,18 +484,14 @@ async def batch_import_outlook(request: OutlookBatchImportRequest):
parts = line.split("----")
# 必须是四字段格式
if len(parts) < 4:
# 验证格式
if len(parts) < 2:
failed += 1
errors.append(
f"{i+1}: 格式错误,必须为 邮箱----密码----client_id----refresh_token"
)
errors.append(f"{i+1}: 格式错误,至少需要邮箱和密码")
continue
email = parts[0].strip()
password = parts[1].strip()
client_id = parts[2].strip()
refresh_token = parts[3].strip()
# 验证邮箱格式
if "@" not in email:
@@ -500,12 +499,6 @@ async def batch_import_outlook(request: OutlookBatchImportRequest):
errors.append(f"{i+1}: 无效的邮箱地址: {email}")
continue
# 验证 OAuth 字段非空
if not client_id or not refresh_token:
failed += 1
errors.append(f"{i+1}: [{email}] client_id 或 refresh_token 不能为空")
continue
# 检查是否已存在
existing = db.query(EmailServiceModel).filter(
EmailServiceModel.service_type == "outlook",
@@ -520,11 +513,17 @@ async def batch_import_outlook(request: OutlookBatchImportRequest):
# 构建配置
config = {
"email": email,
"password": password,
"client_id": client_id,
"refresh_token": refresh_token,
"password": password
}
# 检查是否有 OAuth 信息(格式二)
if len(parts) >= 4:
client_id = parts[2].strip()
refresh_token = parts[3].strip()
if client_id and refresh_token:
config["client_id"] = client_id
config["refresh_token"] = refresh_token
# 创建服务记录
try:
service = EmailServiceModel(
@@ -609,86 +608,3 @@ async def test_tempmail_service(request: TempmailTestRequest):
except Exception as e:
logger.error(f"测试临时邮箱失败: {e}")
return {"success": False, "message": f"测试失败: {str(e)}"}
# ============== 收件箱 ==============
@router.get("/{service_id}/inbox")
async def get_outlook_inbox(
service_id: int,
count: int = Query(30, ge=1, le=100),
only_unseen: bool = Query(False),
):
"""获取 Outlook 收件箱邮件列表"""
with get_db() as db:
service = db.query(EmailServiceModel).filter(EmailServiceModel.id == service_id).first()
if not service:
raise HTTPException(status_code=404, detail="服务不存在")
if service.service_type != "outlook":
raise HTTPException(status_code=400, detail="仅支持 Outlook 类型服务")
config = service.config or {}
email_addr = config.get("email", "")
client_id = config.get("client_id", "")
refresh_token = config.get("refresh_token", "")
# client_id 为空时尝试使用全局默认值
if not client_id:
from ...config.settings import get_settings
client_id = get_settings().outlook_default_client_id or ""
if not client_id or not refresh_token:
raise HTTPException(status_code=400, detail="该账户缺少 OAuth 配置client_id / refresh_token无法读取收件箱")
try:
from ...services.outlook.account import OutlookAccount
from ...services.outlook.token_manager import TokenManager
from ...services.outlook.providers.imap_new import IMAPNewProvider
from ...services.outlook.providers.base import ProviderConfig
account = OutlookAccount(
email=email_addr,
password=config.get("password", ""),
client_id=client_id,
refresh_token=refresh_token,
)
provider_config = ProviderConfig(
proxy_url=None,
timeout=30,
service_id=service_id,
)
provider = IMAPNewProvider(account, provider_config)
connected = provider.connect()
if not connected:
raise HTTPException(status_code=502, detail="IMAP 连接失败,请检查 OAuth 配置")
try:
messages = provider.get_recent_emails(count=count, only_unseen=only_unseen)
finally:
provider.disconnect()
emails = []
for m in messages:
received_str = m.received_at.isoformat() if m.received_at else None
emails.append({
"id": m.id or "",
"subject": m.subject or "",
"sender": m.sender or "",
"received_at": received_str,
"body_preview": m.body_preview or (m.body or "")[:200],
"body": m.body or "",
"is_read": m.is_read,
})
return {
"email": email_addr,
"total": len(emails),
"emails": emails,
}
except HTTPException:
raise
except Exception as e:
logger.error(f"获取收件箱失败 service_id={service_id}: {e}")
raise HTTPException(status_code=500, detail=f"获取收件箱失败: {str(e)}")

View File

@@ -7,8 +7,9 @@ import logging
import uuid
import random
import re
import time
from datetime import datetime
from typing import List, Optional, Dict, Tuple, Any
from typing import List, Optional, Dict, Tuple
from fastapi import APIRouter, HTTPException, Query, BackgroundTasks
from pydantic import BaseModel, Field
@@ -16,9 +17,13 @@ from pydantic import BaseModel, Field
from ...database import crud
from ...database.session import get_db
from ...database.models import RegistrationTask, Proxy
from ...core.login import LoginEngine
from ...core.register import RegistrationResult
from ...core.register import (
ERROR_OTP_TIMEOUT_SECONDARY,
RegistrationEngine,
RegistrationResult,
)
from ...services import EmailServiceFactory, EmailServiceType
from ...services.base import EmailProviderBackoffState, OTPTimeoutEmailServiceError
from ...config.settings import get_settings
from ..task_manager import task_manager
@@ -29,6 +34,7 @@ router = APIRouter()
running_tasks: dict = {}
# 批量任务存储
batch_tasks: Dict[str, dict] = {}
email_service_circuit_breakers: Dict[int, EmailProviderBackoffState] = {}
# ============== Proxy Helper Functions ==============
@@ -253,6 +259,176 @@ def _normalize_email_service_config(
return normalized
def _get_email_service_backoff_state(service_id: Optional[int]) -> EmailProviderBackoffState:
if service_id is None:
return EmailProviderBackoffState()
return email_service_circuit_breakers.get(service_id, EmailProviderBackoffState())
def _store_email_service_backoff_state(
service_id: Optional[int],
backoff_state: Optional[EmailProviderBackoffState],
) -> Optional[EmailProviderBackoffState]:
if service_id is None or backoff_state is None:
return None
if backoff_state.failures == 0 and backoff_state.delay_seconds == 0:
email_service_circuit_breakers.pop(service_id, None)
return backoff_state
email_service_circuit_breakers[service_id] = backoff_state
return backoff_state
def _get_phase_result(phase_history, phase_name: str):
for phase_result in phase_history or []:
if getattr(phase_result, "phase", None) == phase_name:
return phase_result
return None
def _is_email_service_circuit_open(service_id: Optional[int], now: Optional[float] = None) -> bool:
if service_id is None:
return False
return _get_email_service_backoff_state(service_id).is_open(now)
def _trip_email_service_circuit(
service_id: Optional[int],
backoff_state: Optional[EmailProviderBackoffState],
) -> int:
if service_id is None or backoff_state is None:
return 0
_store_email_service_backoff_state(service_id, backoff_state)
return backoff_state.delay_seconds
def _record_email_service_timeout_backoff(
service_id: Optional[int],
email_service,
previous_backoff_state: EmailProviderBackoffState,
error_code: str,
error_message: str,
) -> Optional[EmailProviderBackoffState]:
if service_id is None:
return None
timeout_error = OTPTimeoutEmailServiceError(
error_message or "等待验证码超时",
error_code=error_code,
)
if hasattr(email_service, "apply_provider_backoff_state"):
email_service.apply_provider_backoff_state(previous_backoff_state)
if hasattr(email_service, "update_status"):
email_service.update_status(False, timeout_error)
backoff_state = getattr(email_service, "provider_backoff_state", None)
return _store_email_service_backoff_state(service_id, backoff_state)
def _build_email_service_candidates(
db,
service_type: EmailServiceType,
actual_proxy_url: Optional[str],
email_service_id: Optional[int],
email_service_config: Optional[dict],
) -> List[Dict[str, object]]:
from ...database.models import EmailService as EmailServiceModel, Account
settings = get_settings()
candidates: List[Dict[str, object]] = []
def append_candidate(candidate_type: EmailServiceType, config: dict, db_service=None) -> None:
candidates.append({
"service_type": candidate_type,
"config": config,
"db_service": db_service,
})
def append_database_candidates(db_service_type: str) -> None:
services = db.query(EmailServiceModel).filter(
EmailServiceModel.service_type == db_service_type,
EmailServiceModel.enabled == True
).order_by(EmailServiceModel.priority.asc(), EmailServiceModel.id.asc()).all()
for db_service in services:
if _is_email_service_circuit_open(db_service.id):
continue
candidate_type = EmailServiceType(db_service.service_type)
config = _normalize_email_service_config(candidate_type, db_service.config, actual_proxy_url)
append_candidate(candidate_type, config, db_service=db_service)
if email_service_id:
db_service = db.query(EmailServiceModel).filter(
EmailServiceModel.id == email_service_id,
EmailServiceModel.enabled == True
).first()
if not db_service:
raise ValueError(f"邮箱服务不存在或已禁用: {email_service_id}")
if _is_email_service_circuit_open(db_service.id):
raise ValueError(f"邮箱服务处于熔断状态: {db_service.name}")
candidate_type = EmailServiceType(db_service.service_type)
config = _normalize_email_service_config(candidate_type, db_service.config, actual_proxy_url)
append_candidate(candidate_type, config, db_service=db_service)
return candidates
if service_type == EmailServiceType.TEMPMAIL:
append_candidate(service_type, {
"base_url": settings.tempmail_base_url,
"timeout": settings.tempmail_timeout,
"max_retries": settings.tempmail_max_retries,
"proxy_url": actual_proxy_url,
})
elif service_type == EmailServiceType.MOE_MAIL:
append_database_candidates("moe_mail")
if not candidates:
if settings.custom_domain_base_url and settings.custom_domain_api_key:
append_candidate(service_type, {
"base_url": settings.custom_domain_base_url,
"api_key": settings.custom_domain_api_key.get_secret_value() if settings.custom_domain_api_key else "",
"proxy_url": actual_proxy_url,
})
else:
raise ValueError("没有可用的自定义域名邮箱服务,请先在设置中配置")
elif service_type == EmailServiceType.OUTLOOK:
services = db.query(EmailServiceModel).filter(
EmailServiceModel.service_type == "outlook",
EmailServiceModel.enabled == True
).order_by(EmailServiceModel.priority.asc(), EmailServiceModel.id.asc()).all()
if not services:
raise ValueError("没有可用的 Outlook 账户,请先在设置中导入账户")
for db_service in services:
if _is_email_service_circuit_open(db_service.id):
continue
email = db_service.config.get("email") if db_service.config else None
if not email:
continue
existing = db.query(Account).filter(Account.email == email).first()
if existing:
logger.info(f"跳过已注册的 Outlook 账户: {email}")
continue
config = _normalize_email_service_config(service_type, db_service.config, actual_proxy_url)
append_candidate(service_type, config, db_service=db_service)
if not candidates:
raise ValueError("所有 Outlook 账户都已注册过,或当前均处于熔断状态")
elif service_type == EmailServiceType.DUCK_MAIL:
append_database_candidates("duck_mail")
if not candidates:
raise ValueError("没有可用的 DuckMail 邮箱服务,请先在邮箱服务页面添加服务")
elif service_type == EmailServiceType.FREEMAIL:
append_database_candidates("freemail")
if not candidates:
raise ValueError("没有可用的 Freemail 邮箱服务,请先在邮箱服务页面添加服务")
elif service_type == EmailServiceType.IMAP_MAIL:
append_database_candidates("imap_mail")
if not candidates:
raise ValueError("没有可用的 IMAP 邮箱服务,请先在邮箱服务中添加")
else:
append_candidate(service_type, email_service_config or {})
return candidates
def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy: Optional[str], email_service_config: Optional[dict], email_service_id: Optional[int] = None, log_prefix: str = "", batch_id: str = "", auto_upload_cpa: bool = False, cpa_service_ids: List[int] = None, auto_upload_sub2api: bool = False, sub2api_service_ids: List[int] = None, auto_upload_tm: bool = False, tm_service_ids: List[int] = None):
"""
在线程池中执行的同步注册任务
@@ -261,165 +437,31 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy:
"""
with get_db() as db:
try:
# 检查是否已取消
if task_manager.is_cancelled(task_uuid):
logger.info(f"任务 {task_uuid} 已取消,跳过执行")
return
# 更新任务状态为运行中
task = crud.update_registration_task(
db, task_uuid,
status="running",
started_at=datetime.utcnow()
)
if not task:
logger.error(f"任务不存在: {task_uuid}")
return
resolved_email_service_id = email_service_id or task.email_service_id
# 更新 TaskManager 状态
task_manager.update_status(task_uuid, "running")
settings = get_settings()
log_callback = task_manager.create_log_callback(task_uuid, prefix=log_prefix, batch_id=batch_id)
def build_email_service(active_proxy_url: Optional[str]):
requested_service_type = EmailServiceType(email_service_type)
if resolved_email_service_id:
from ...database.models import EmailService as EmailServiceModel
db_service = db.query(EmailServiceModel).filter(
EmailServiceModel.id == resolved_email_service_id,
EmailServiceModel.enabled == True
).first()
if db_service:
selected_service_type = EmailServiceType(db_service.service_type)
config = _normalize_email_service_config(selected_service_type, db_service.config, active_proxy_url)
crud.update_registration_task(db, task_uuid, email_service_id=db_service.id)
logger.info(
f"使用数据库邮箱服务: {db_service.name} "
f"(ID: {db_service.id}, 类型: {selected_service_type.value})"
)
email_service = EmailServiceFactory.create(selected_service_type, config)
return email_service, selected_service_type
raise ValueError(f"邮箱服务不存在或已禁用: {resolved_email_service_id}")
service_type = requested_service_type
if service_type == EmailServiceType.TEMPMAIL:
config = {
"base_url": settings.tempmail_base_url,
"timeout": settings.tempmail_timeout,
"max_retries": settings.tempmail_max_retries,
"proxy_url": active_proxy_url,
}
elif service_type == EmailServiceType.MOE_MAIL:
from ...database.models import EmailService as EmailServiceModel
db_service = db.query(EmailServiceModel).filter(
EmailServiceModel.service_type == "moe_mail",
EmailServiceModel.enabled == True
).order_by(EmailServiceModel.priority.asc()).first()
if db_service and db_service.config:
config = _normalize_email_service_config(service_type, db_service.config, active_proxy_url)
crud.update_registration_task(db, task_uuid, email_service_id=db_service.id)
logger.info(f"使用数据库自定义域名服务: {db_service.name}")
elif settings.custom_domain_base_url and settings.custom_domain_api_key:
config = {
"base_url": settings.custom_domain_base_url,
"api_key": settings.custom_domain_api_key.get_secret_value() if settings.custom_domain_api_key else "",
"proxy_url": active_proxy_url,
}
else:
raise ValueError("没有可用的自定义域名邮箱服务,请先在设置中配置")
elif service_type == EmailServiceType.OUTLOOK:
from ...database.models import EmailService as EmailServiceModel, Account
outlook_services = db.query(EmailServiceModel).filter(
EmailServiceModel.service_type == "outlook",
EmailServiceModel.enabled == True
).order_by(EmailServiceModel.priority.asc()).all()
if not outlook_services:
raise ValueError("没有可用的 Outlook 账户,请先在设置中导入账户")
# 找到一个未注册的 Outlook 账户
selected_service = None
for svc in outlook_services:
email = svc.config.get("email") if svc.config else None
if not email:
continue
existing = db.query(Account).filter(Account.email == email).first()
if not existing:
selected_service = svc
logger.info(f"选择未注册的 Outlook 账户: {email}")
break
logger.info(f"跳过已注册的 Outlook 账户: {email}")
if selected_service and selected_service.config:
config = selected_service.config.copy()
config['service_id'] = selected_service.id
crud.update_registration_task(db, task_uuid, email_service_id=selected_service.id)
logger.info(f"使用数据库 Outlook 账户: {selected_service.name}")
else:
raise ValueError("所有 Outlook 账户都已注册过 OpenAI 账号,请添加新的 Outlook 账户")
elif service_type == EmailServiceType.DUCK_MAIL:
from ...database.models import EmailService as EmailServiceModel
db_service = db.query(EmailServiceModel).filter(
EmailServiceModel.service_type == "duck_mail",
EmailServiceModel.enabled == True
).order_by(EmailServiceModel.priority.asc()).first()
if db_service and db_service.config:
config = _normalize_email_service_config(service_type, db_service.config, active_proxy_url)
crud.update_registration_task(db, task_uuid, email_service_id=db_service.id)
logger.info(f"使用数据库 DuckMail 服务: {db_service.name}")
else:
raise ValueError("没有可用的 DuckMail 邮箱服务,请先在邮箱服务页面添加服务")
elif service_type == EmailServiceType.FREEMAIL:
from ...database.models import EmailService as EmailServiceModel
db_service = db.query(EmailServiceModel).filter(
EmailServiceModel.service_type == "freemail",
EmailServiceModel.enabled == True
).order_by(EmailServiceModel.priority.asc()).first()
if db_service and db_service.config:
config = _normalize_email_service_config(service_type, db_service.config, active_proxy_url)
crud.update_registration_task(db, task_uuid, email_service_id=db_service.id)
logger.info(f"使用数据库 Freemail 服务: {db_service.name}")
else:
raise ValueError("没有可用的 Freemail 邮箱服务,请先在邮箱服务页面添加服务")
elif service_type == EmailServiceType.IMAP_MAIL:
from ...database.models import EmailService as EmailServiceModel
db_service = db.query(EmailServiceModel).filter(
EmailServiceModel.service_type == "imap_mail",
EmailServiceModel.enabled == True
).order_by(EmailServiceModel.priority.asc()).first()
if db_service and db_service.config:
config = _normalize_email_service_config(service_type, db_service.config, active_proxy_url)
crud.update_registration_task(db, task_uuid, email_service_id=db_service.id)
logger.info(f"使用数据库 IMAP 邮箱服务: {db_service.name}")
else:
raise ValueError("没有可用的 IMAP 邮箱服务,请先在邮箱服务中添加")
else:
config = email_service_config or {}
email_service = EmailServiceFactory.create(service_type, config)
return email_service, service_type
requested_proxy = proxy
exhausted_proxy_ids = set()
result = None
active_service_type = EmailServiceType(email_service_type)
result = RegistrationResult(success=False, logs=[])
active_service_type = requested_service_type
proxy_id = None
while True:
actual_proxy_url = requested_proxy
proxy_id = None
if not actual_proxy_url:
actual_proxy_url, proxy_id = get_proxy_for_registration(
db,
@@ -429,20 +471,126 @@ def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy:
logger.info(f"任务 {task_uuid} 使用代理: {actual_proxy_url[:50]}...")
crud.update_registration_task(db, task_uuid, proxy=actual_proxy_url)
email_service, active_service_type = build_email_service(actual_proxy_url)
service_candidates = _build_email_service_candidates(
db,
requested_service_type,
actual_proxy_url,
email_service_id,
email_service_config,
)
should_retry_with_new_proxy = False
for attempt_index, candidate in enumerate(service_candidates, start=1):
selected_service_type = candidate["service_type"]
candidate_config = candidate["config"]
db_service = candidate.get("db_service")
active_service_type = selected_service_type
if db_service is not None:
crud.update_registration_task(db, task_uuid, email_service_id=db_service.id)
logger.info(
f"任务 {task_uuid} 使用数据库邮箱服务: {db_service.name} "
f"(ID: {db_service.id}, 类型: {selected_service_type.value}, 尝试: {attempt_index}/{len(service_candidates)})"
)
log_callback(
f"[系统] 使用邮箱服务: {db_service.name} "
f"({selected_service_type.value}, 尝试 {attempt_index}/{len(service_candidates)})"
)
else:
crud.update_registration_task(db, task_uuid, email_service_id=None)
task_manager.update_status(task_uuid, "running", email_service=active_service_type.value)
engine = LoginEngine(
email_service = EmailServiceFactory.create(
selected_service_type,
candidate_config,
name=db_service.name if db_service is not None else None,
)
provider_backoff_before_run = EmailProviderBackoffState()
if db_service is not None:
provider_backoff_before_run = _get_email_service_backoff_state(db_service.id)
if db_service is not None and hasattr(email_service, "apply_provider_backoff_state"):
email_service.apply_provider_backoff_state(provider_backoff_before_run)
engine = RegistrationEngine(
email_service=email_service,
proxy_url=actual_proxy_url,
callback_logger=log_callback,
task_uuid=task_uuid
)
try:
result = engine.run()
finally:
close_engine = getattr(engine, "close", None)
if callable(close_engine):
close_engine()
email_prepare_phase = _get_phase_result(
getattr(engine, "phase_history", []),
"email_prepare",
)
if db_service is not None and email_prepare_phase is not None:
_store_email_service_backoff_state(
db_service.id,
getattr(email_prepare_phase, "provider_backoff", None),
)
if result.success:
break
if is_retryable_proxy_error(result.error_message):
should_retry_with_new_proxy = True
break
can_failover = (
db_service is not None
and attempt_index < len(service_candidates)
and email_prepare_phase is not None
and not email_prepare_phase.success
and email_prepare_phase.error_code == "EMAIL_PROVIDER_RATE_LIMITED"
and email_prepare_phase.provider_backoff is not None
)
if not can_failover:
if (
db_service is not None
and result.error_code == ERROR_OTP_TIMEOUT_SECONDARY
):
timeout_backoff = _record_email_service_timeout_backoff(
db_service.id,
email_service,
provider_backoff_before_run,
result.error_code,
result.error_message,
)
if timeout_backoff is not None:
logger.warning(
f"邮箱服务 OTP 超时,已退避 {db_service.name} "
f"{timeout_backoff.delay_seconds} 秒,连续失败 "
f"{timeout_backoff.failures}"
)
log_callback(
f"[系统] 邮箱服务 OTP 超时,退避 "
f"{timeout_backoff.delay_seconds} 秒: {db_service.name} "
f"(连续失败 {timeout_backoff.failures} 次)"
)
break
backoff_state = email_prepare_phase.provider_backoff
cooldown = _trip_email_service_circuit(db_service.id, backoff_state)
logger.warning(
f"邮箱服务限流,已退避 {db_service.name} {cooldown} 秒,"
f"连续失败 {backoff_state.failures} 次,"
f"任务 {task_uuid} 将切换到下一个服务"
)
log_callback(
f"[系统] 邮箱服务限流,退避 {cooldown} 秒并切换: "
f"{db_service.name} (连续失败 {backoff_state.failures} 次)"
)
if result.success:
break
if should_retry_with_new_proxy:
log_callback(f"[代理] 检测到可重试网络错误: {result.error_message}")
if proxy_id and disable_proxy_for_network_error(db, proxy_id, result.error_message):
exhausted_proxy_ids.add(proxy_id)
@@ -652,53 +800,34 @@ async def run_registration_task(task_uuid: str, email_service_type: str, proxy:
def _init_batch_state(batch_id: str, task_uuids: List[str]):
"""初始化批量任务内存状态"""
task_manager.init_batch(batch_id, len(task_uuids))
metadata = batch_tasks.get(batch_id, {}).copy()
metadata["task_uuids"] = task_uuids
batch_tasks[batch_id] = metadata
batch_tasks[batch_id] = {
"total": len(task_uuids),
"completed": 0,
"success": 0,
"failed": 0,
"cancelled": False,
"task_uuids": task_uuids,
"current_index": 0,
"logs": [],
"finished": False
}
def _make_batch_helpers(batch_id: str):
"""返回 add_batch_log 和 update_batch_status 辅助函数"""
def add_batch_log(msg: str):
batch_tasks[batch_id]["logs"].append(msg)
task_manager.add_batch_log(batch_id, msg)
def update_batch_status(**kwargs):
for key, value in kwargs.items():
if key in batch_tasks[batch_id]:
batch_tasks[batch_id][key] = value
task_manager.update_batch_status(batch_id, **kwargs)
return add_batch_log, update_batch_status
def _get_batch_snapshot(batch_id: str) -> Optional[Dict[str, Any]]:
"""聚合批量任务元数据与实时状态。"""
metadata = batch_tasks.get(batch_id)
if metadata is None:
return None
status = task_manager.get_batch_status(batch_id) or {}
initial_skipped = int(metadata.get("initial_skipped", 0) or 0)
total = status.get("total", metadata.get("total", 0))
completed = status.get("completed", 0)
success = status.get("success", 0)
failed = status.get("failed", 0)
skipped = status.get("skipped", 0) + initial_skipped
return {
"batch_id": batch_id,
"total": total,
"completed": completed,
"success": success,
"failed": failed,
"skipped": skipped,
"current_index": status.get("current_index", 0),
"cancelled": status.get("cancelled", False),
"finished": status.get("finished", False),
"status": status.get("status", metadata.get("status", "pending")),
"logs": task_manager.get_batch_logs(batch_id),
"task_uuids": metadata.get("task_uuids", []),
"service_ids": metadata.get("service_ids", []),
}
async def run_batch_parallel(
batch_id: str,
task_uuids: List[str],
@@ -737,19 +866,21 @@ async def run_batch_parallel(
t = crud.get_registration_task(db, uuid)
if t:
async with counter_lock:
new_completed = batch_tasks[batch_id]["completed"] + 1
new_success = batch_tasks[batch_id]["success"]
new_failed = batch_tasks[batch_id]["failed"]
if t.status == "completed":
new_success += 1
add_batch_log(f"{prefix} [成功] 注册成功")
elif t.status == "failed":
new_failed += 1
add_batch_log(f"{prefix} [失败] 注册失败: {t.error_message}")
task_manager.record_batch_task_result(batch_id, t.status)
update_batch_status(completed=new_completed, success=new_success, failed=new_failed)
try:
await asyncio.gather(*[_run_one(i, u) for i, u in enumerate(task_uuids)], return_exceptions=True)
if not task_manager.is_batch_cancelled(batch_id):
snapshot = task_manager.get_batch_status(batch_id) or {}
add_batch_log(
f"[完成] 批量任务完成!成功: {snapshot.get('success', 0)}, 失败: {snapshot.get('failed', 0)}"
)
add_batch_log(f"[完成] 批量任务完成!成功: {batch_tasks[batch_id]['success']}, 失败: {batch_tasks[batch_id]['failed']}")
update_batch_status(finished=True, status="completed")
else:
update_batch_status(finished=True, status="cancelled")
@@ -757,6 +888,8 @@ async def run_batch_parallel(
logger.error(f"批量任务 {batch_id} 异常: {e}")
add_batch_log(f"[错误] 批量任务异常: {str(e)}")
update_batch_status(finished=True, status="failed")
finally:
batch_tasks[batch_id]["finished"] = True
async def run_batch_pipeline(
@@ -799,17 +932,22 @@ async def run_batch_pipeline(
t = crud.get_registration_task(db, uuid)
if t:
async with counter_lock:
new_completed = batch_tasks[batch_id]["completed"] + 1
new_success = batch_tasks[batch_id]["success"]
new_failed = batch_tasks[batch_id]["failed"]
if t.status == "completed":
new_success += 1
add_batch_log(f"{pfx} [成功] 注册成功")
elif t.status == "failed":
new_failed += 1
add_batch_log(f"{pfx} [失败] 注册失败: {t.error_message}")
task_manager.record_batch_task_result(batch_id, t.status)
update_batch_status(completed=new_completed, success=new_success, failed=new_failed)
finally:
semaphore.release()
try:
for i, task_uuid in enumerate(task_uuids):
if task_manager.is_batch_cancelled(batch_id):
if task_manager.is_batch_cancelled(batch_id) or batch_tasks[batch_id]["cancelled"]:
with get_db() as db:
for remaining_uuid in task_uuids[i:]:
crud.update_registration_task(db, remaining_uuid, status="cancelled")
@@ -833,15 +971,14 @@ async def run_batch_pipeline(
await asyncio.gather(*running_tasks_list, return_exceptions=True)
if not task_manager.is_batch_cancelled(batch_id):
snapshot = task_manager.get_batch_status(batch_id) or {}
add_batch_log(
f"[完成] 批量任务完成!成功: {snapshot.get('success', 0)}, 失败: {snapshot.get('failed', 0)}"
)
add_batch_log(f"[完成] 批量任务完成!成功: {batch_tasks[batch_id]['success']}, 失败: {batch_tasks[batch_id]['failed']}")
update_batch_status(finished=True, status="completed")
except Exception as e:
logger.error(f"批量任务 {batch_id} 异常: {e}")
add_batch_log(f"[错误] 批量任务异常: {str(e)}")
update_batch_status(finished=True, status="failed")
finally:
batch_tasks[batch_id]["finished"] = True
async def run_batch_registration(
@@ -974,7 +1111,6 @@ async def start_batch_registration(
# 创建批量任务
batch_id = str(uuid.uuid4())
task_uuids = []
batch_tasks[batch_id] = {"total": request.count}
with get_db() as db:
for _ in range(request.count):
@@ -1021,33 +1157,34 @@ async def start_batch_registration(
@router.get("/batch/{batch_id}")
async def get_batch_status(batch_id: str):
"""获取批量任务状态"""
snapshot = _get_batch_snapshot(batch_id)
if snapshot is None:
if batch_id not in batch_tasks:
raise HTTPException(status_code=404, detail="批量任务不存在")
batch = batch_tasks[batch_id]
return {
"batch_id": batch_id,
"total": snapshot["total"],
"completed": snapshot["completed"],
"success": snapshot["success"],
"failed": snapshot["failed"],
"current_index": snapshot["current_index"],
"cancelled": snapshot["cancelled"],
"finished": snapshot["finished"],
"progress": f"{snapshot['completed']}/{snapshot['total']}"
"total": batch["total"],
"completed": batch["completed"],
"success": batch["success"],
"failed": batch["failed"],
"current_index": batch["current_index"],
"cancelled": batch["cancelled"],
"finished": batch.get("finished", False),
"progress": f"{batch['completed']}/{batch['total']}"
}
@router.post("/batch/{batch_id}/cancel")
async def cancel_batch(batch_id: str):
"""取消批量任务"""
snapshot = _get_batch_snapshot(batch_id)
if snapshot is None:
if batch_id not in batch_tasks:
raise HTTPException(status_code=404, detail="批量任务不存在")
if snapshot.get("finished"):
batch = batch_tasks[batch_id]
if batch.get("finished"):
raise HTTPException(status_code=400, detail="批量任务已完成")
batch["cancelled"] = True
task_manager.cancel_batch(batch_id)
return {"success": True, "message": "批量任务取消请求已提交"}
@@ -1528,11 +1665,18 @@ async def start_outlook_batch_registration(
# 创建批量任务
batch_id = str(uuid.uuid4())
# 记录额外元数据,由 task_manager 维护实时状态
# 初始化批量任务状态
batch_tasks[batch_id] = {
"total": len(actual_service_ids),
"initial_skipped": skipped_count,
"completed": 0,
"success": 0,
"failed": 0,
"skipped": 0,
"cancelled": False,
"service_ids": actual_service_ids,
"current_index": 0,
"logs": [],
"finished": False
}
# 在后台运行批量注册
@@ -1566,35 +1710,37 @@ async def start_outlook_batch_registration(
@router.get("/outlook-batch/{batch_id}")
async def get_outlook_batch_status(batch_id: str):
"""获取 Outlook 批量任务状态"""
snapshot = _get_batch_snapshot(batch_id)
if snapshot is None:
if batch_id not in batch_tasks:
raise HTTPException(status_code=404, detail="批量任务不存在")
batch = batch_tasks[batch_id]
return {
"batch_id": batch_id,
"total": snapshot["total"],
"completed": snapshot["completed"],
"success": snapshot["success"],
"failed": snapshot["failed"],
"skipped": snapshot["skipped"],
"current_index": snapshot["current_index"],
"cancelled": snapshot["cancelled"],
"finished": snapshot["finished"],
"logs": snapshot["logs"],
"progress": f"{snapshot['completed']}/{snapshot['total']}"
"total": batch["total"],
"completed": batch["completed"],
"success": batch["success"],
"failed": batch["failed"],
"skipped": batch.get("skipped", 0),
"current_index": batch["current_index"],
"cancelled": batch["cancelled"],
"finished": batch.get("finished", False),
"logs": batch.get("logs", []),
"progress": f"{batch['completed']}/{batch['total']}"
}
@router.post("/outlook-batch/{batch_id}/cancel")
async def cancel_outlook_batch(batch_id: str):
"""取消 Outlook 批量任务"""
snapshot = _get_batch_snapshot(batch_id)
if snapshot is None:
if batch_id not in batch_tasks:
raise HTTPException(status_code=404, detail="批量任务不存在")
if snapshot.get("finished"):
batch = batch_tasks[batch_id]
if batch.get("finished"):
raise HTTPException(status_code=400, detail="批量任务已完成")
# 同时更新两个系统的取消状态
batch["cancelled"] = True
task_manager.cancel_batch(batch_id)
return {"success": True, "message": "批量任务取消请求已提交"}

View File

@@ -469,60 +469,6 @@ class ProxyUpdateRequest(BaseModel):
priority: Optional[int] = None
def _test_proxy_connectivity(proxy_url: str) -> dict:
"""测试代理连通性并返回统一结果。"""
import time
from curl_cffi import requests as cffi_requests
test_url = "https://api.ipify.org?format=json"
start_time = time.time()
proxies_dict = {
"http": proxy_url,
"https": proxy_url
}
response = cffi_requests.get(
test_url,
proxies=proxies_dict,
timeout=3,
impersonate="chrome110"
)
elapsed_time = time.time() - start_time
if response.status_code == 200:
ip_info = response.json()
return {
"success": True,
"ip": ip_info.get("ip", ""),
"response_time": round(elapsed_time * 1000),
"message": f"代理连接成功,出口 IP: {ip_info.get('ip', 'unknown')}"
}
return {
"success": False,
"message": f"状态码: {response.status_code}"
}
def _auto_disable_proxy_on_failure(db, proxy, message: str) -> dict:
"""代理测试失败时自动禁用,并返回统一提示。"""
auto_disabled = False
if proxy.enabled:
crud.update_proxy(db, proxy.id, enabled=False)
auto_disabled = True
final_message = message
if auto_disabled:
final_message = f"{message},已自动禁用"
return {
"success": False,
"auto_disabled": auto_disabled,
"message": final_message,
}
@router.get("/proxies")
async def get_proxies_list(enabled: Optional[bool] = None):
"""获取代理列表"""
@@ -613,59 +559,107 @@ async def set_proxy_default(proxy_id: int):
@router.post("/proxies/{proxy_id}/test")
async def test_proxy_item(proxy_id: int):
"""测试单个代理"""
import time
from curl_cffi import requests as cffi_requests
with get_db() as db:
proxy = crud.get_proxy_by_id(db, proxy_id)
if not proxy:
raise HTTPException(status_code=404, detail="代理不存在")
proxy_url = proxy.proxy_url
test_url = "https://api.ipify.org?format=json"
start_time = time.time()
try:
result = _test_proxy_connectivity(proxy.proxy_url)
if result["success"]:
return result
return _auto_disable_proxy_on_failure(db, proxy, f"代理返回错误状态码: {result['message'].removeprefix('状态码: ')}")
proxies = {
"http": proxy_url,
"https": proxy_url
}
response = cffi_requests.get(
test_url,
proxies=proxies,
timeout=3,
impersonate="chrome110"
)
elapsed_time = time.time() - start_time
if response.status_code == 200:
ip_info = response.json()
return {
"success": True,
"ip": ip_info.get("ip", ""),
"response_time": round(elapsed_time * 1000),
"message": f"代理连接成功,出口 IP: {ip_info.get('ip', 'unknown')}"
}
else:
return {
"success": False,
"message": f"代理返回错误状态码: {response.status_code}"
}
except Exception as e:
return _auto_disable_proxy_on_failure(db, proxy, f"代理连接失败: {str(e)}")
return {
"success": False,
"message": f"代理连接失败: {str(e)}"
}
@router.post("/proxies/test-all")
async def test_all_proxies():
"""测试所有启用的代理"""
import time
from curl_cffi import requests as cffi_requests
with get_db() as db:
proxies = crud.get_enabled_proxies(db)
results = []
auto_disabled_count = 0
for proxy in proxies:
proxy_url = proxy.proxy_url
test_url = "https://api.ipify.org?format=json"
start_time = time.time()
try:
result = _test_proxy_connectivity(proxy.proxy_url)
if result["success"]:
proxies_dict = {
"http": proxy_url,
"https": proxy_url
}
response = cffi_requests.get(
test_url,
proxies=proxies_dict,
timeout=3,
impersonate="chrome110"
)
elapsed_time = time.time() - start_time
if response.status_code == 200:
ip_info = response.json()
results.append({
"id": proxy.id,
"name": proxy.name,
"success": True,
"ip": result.get("ip", ""),
"response_time": result.get("response_time"),
"auto_disabled": False,
"ip": ip_info.get("ip", ""),
"response_time": round(elapsed_time * 1000)
})
else:
failure_result = _auto_disable_proxy_on_failure(
db,
proxy,
f"代理返回错误状态码: {result['message'].removeprefix('状态码: ')}"
)
auto_disabled_count += 1 if failure_result["auto_disabled"] else 0
results.append({
"id": proxy.id,
"name": proxy.name,
**failure_result,
"success": False,
"message": f"状态码: {response.status_code}"
})
except Exception as e:
failure_result = _auto_disable_proxy_on_failure(db, proxy, f"代理连接失败: {str(e)}")
auto_disabled_count += 1 if failure_result["auto_disabled"] else 0
results.append({
"id": proxy.id,
"name": proxy.name,
**failure_result,
"success": False,
"message": str(e)
})
success_count = sum(1 for r in results if r["success"])
@@ -673,7 +667,6 @@ async def test_all_proxies():
"total": len(proxies),
"success": success_count,
"failed": len(proxies) - success_count,
"auto_disabled": auto_disabled_count,
"results": results
}
@@ -698,14 +691,6 @@ async def disable_proxy(proxy_id: int):
return {"success": True, "message": "代理已禁用"}
@router.delete("/proxies/disabled/batch-delete")
async def delete_disabled_proxies():
"""批量删除所有已禁用代理"""
with get_db() as db:
deleted = crud.delete_disabled_proxies(db)
return {"success": True, "deleted": deleted, "message": f"已删除 {deleted} 个禁用代理"}
# ============== Outlook 设置 ==============
class OutlookSettings(BaseModel):
@@ -720,6 +705,7 @@ async def get_outlook_settings():
return {
"default_client_id": settings.outlook_default_client_id,
"provider_priority": settings.outlook_provider_priority,
"health_failure_threshold": settings.outlook_health_failure_threshold,
"health_disable_duration": settings.outlook_health_disable_duration,
}

View File

@@ -7,12 +7,33 @@ import asyncio
import logging
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from ...database import crud
from ...database.session import get_db
from ..task_manager import task_manager
logger = logging.getLogger(__name__)
router = APIRouter()
def _restore_task_snapshot(task_uuid: str) -> tuple[dict, list[str]]:
"""从数据库恢复任务状态和历史日志,解决服务重启后的监控空白。"""
with get_db() as db:
task = crud.get_registration_task(db, task_uuid)
if not task:
return {}, []
status = {"status": task.status}
if task.result and task.result.get("email"):
status["email"] = task.result["email"]
if task.error_message:
status["error"] = task.error_message
logs = task.logs.splitlines() if task.logs else []
task_manager.sync_task_state(task_uuid, status=status, logs=logs)
return status, logs
@router.websocket("/ws/task/{task_uuid}")
async def task_websocket(websocket: WebSocket, task_uuid: str):
"""
@@ -25,14 +46,15 @@ async def task_websocket(websocket: WebSocket, task_uuid: str):
- 客户端发送: {"type": "cancel"} - 取消任务
"""
await websocket.accept()
restored_status, restored_logs = _restore_task_snapshot(task_uuid)
# 注册连接(会记录当前日志数量,避免重复发送历史日志)
task_manager.register_websocket(task_uuid, websocket)
# 注册连接,并取得注册时刻的历史日志快照,避免与后续实时推送串扰
history_logs = task_manager.register_websocket(task_uuid, websocket)
logger.info(f"WebSocket 连接已建立: {task_uuid}")
try:
# 发送当前状态
status = task_manager.get_status(task_uuid)
status = task_manager.get_status(task_uuid) or restored_status
if status:
await websocket.send_json({
"type": "status",
@@ -40,9 +62,8 @@ async def task_websocket(websocket: WebSocket, task_uuid: str):
**status
})
# 发送历史日志(只发送注册时已存在的日志,避免与实时推送重复)
history_logs = task_manager.get_unsent_logs(task_uuid, websocket)
for log in history_logs:
# 发送历史日志。服务重启后 _restore_task_snapshot 会先把数据库快照回填到内存。
for log in history_logs or restored_logs:
await websocket.send_json({
"type": "log",
"task_uuid": task_uuid,
@@ -107,8 +128,8 @@ async def batch_websocket(websocket: WebSocket, batch_id: str):
"""
await websocket.accept()
# 注册连接(会记录当前日志数量,避免重复发送历史日志)
task_manager.register_batch_websocket(batch_id, websocket)
# 注册连接,并取得注册时刻的历史日志快照,避免漏发/重复发送
history_logs = task_manager.register_batch_websocket(batch_id, websocket)
logger.info(f"批量任务 WebSocket 连接已建立: {batch_id}")
try:
@@ -121,8 +142,6 @@ async def batch_websocket(websocket: WebSocket, batch_id: str):
**status
})
# 发送历史日志(只发送注册时已存在的日志,避免与实时推送重复)
history_logs = task_manager.get_unsent_batch_logs(batch_id, websocket)
for log in history_logs:
await websocket.send_json({
"type": "log",

View File

@@ -144,20 +144,22 @@ class TaskManager:
except Exception as e:
logger.warning(f"WebSocket 发送状态失败: {e}")
def register_websocket(self, task_uuid: str, websocket):
"""注册 WebSocket 连接"""
def register_websocket(self, task_uuid: str, websocket) -> List[str]:
"""注册 WebSocket 连接,并返回注册时刻的历史日志快照"""
history_logs: List[str] = []
with _ws_lock:
if task_uuid not in _ws_connections:
_ws_connections[task_uuid] = []
# 避免重复注册同一个连接
if websocket not in _ws_connections[task_uuid]:
_ws_connections[task_uuid].append(websocket)
# 记录已发送的日志数量,用于发送历史日志时避免重复
with _get_log_lock(task_uuid):
_ws_sent_index[task_uuid][id(websocket)] = len(_log_queues.get(task_uuid, []))
history_logs = _log_queues.get(task_uuid, []).copy()
_ws_sent_index[task_uuid][id(websocket)] = len(history_logs)
_ws_connections[task_uuid].append(websocket)
logger.info(f"WebSocket 连接已注册: {task_uuid}")
else:
logger.warning(f"WebSocket 连接已存在,跳过重复注册: {task_uuid}")
return history_logs
def get_unsent_logs(self, task_uuid: str, websocket) -> List[str]:
"""获取未发送给该 WebSocket 的日志"""
@@ -190,15 +192,32 @@ class TaskManager:
with _get_log_lock(task_uuid):
return _log_queues.get(task_uuid, []).copy()
def sync_task_state(
self,
task_uuid: str,
status: Optional[dict] = None,
logs: Optional[List[str]] = None
):
"""将数据库中的任务快照回填到内存态,便于重连恢复。"""
if status:
current_status = _task_status.get(task_uuid, {}).copy()
current_status.update(status)
_task_status[task_uuid] = current_status
if logs is not None:
with _get_log_lock(task_uuid):
cached_logs = _log_queues.get(task_uuid, [])
if len(logs) >= len(cached_logs):
_log_queues[task_uuid] = list(logs)
def update_status(self, task_uuid: str, status: str, **kwargs):
"""更新任务状态并推送到 WebSocket"""
"""更新任务状态"""
if task_uuid not in _task_status:
_task_status[task_uuid] = {}
_task_status[task_uuid]["status"] = status
_task_status[task_uuid].update(kwargs)
# 推送状态变更到 WebSocket线程安全兼容同步线程调用
if self._loop and self._loop.is_running():
try:
asyncio.run_coroutine_threadsafe(
@@ -206,7 +225,7 @@ class TaskManager:
self._loop
)
except Exception as e:
logger.warning(f"推送状态到 WebSocket 失败: {e}")
logger.warning(f"广播任务状态失败: {e}")
def get_status(self, task_uuid: str) -> Optional[dict]:
"""获取任务状态"""
@@ -223,7 +242,6 @@ class TaskManager:
def init_batch(self, batch_id: str, total: int):
"""初始化批量任务"""
with _get_batch_lock(batch_id):
_batch_status[batch_id] = {
"status": "running",
"total": total,
@@ -232,8 +250,7 @@ class TaskManager:
"failed": 0,
"skipped": 0,
"current_index": 0,
"finished": False,
"cancelled": False,
"finished": False
}
logger.info(f"批量任务 {batch_id} 已初始化,总数: {total}")
@@ -278,10 +295,10 @@ class TaskManager:
def update_batch_status(self, batch_id: str, **kwargs):
"""更新批量任务状态"""
with _get_batch_lock(batch_id):
if batch_id not in _batch_status:
logger.warning(f"批量任务 {batch_id} 不存在")
return
_batch_status[batch_id].update(kwargs)
# 异步广播状态更新
@@ -294,35 +311,6 @@ class TaskManager:
except Exception as e:
logger.warning(f"广播批量状态失败: {e}")
def record_batch_task_result(self, batch_id: str, task_status: str) -> Optional[dict]:
"""原子记录单个子任务的终态并返回快照。"""
with _get_batch_lock(batch_id):
status = _batch_status.get(batch_id)
if status is None:
logger.warning(f"批量任务 {batch_id} 不存在")
return None
status["completed"] += 1
if task_status == "completed":
status["success"] += 1
elif task_status == "failed":
status["failed"] += 1
elif task_status == "cancelled":
status["skipped"] += 1
snapshot = status.copy()
if self._loop and self._loop.is_running():
try:
asyncio.run_coroutine_threadsafe(
self._broadcast_batch_status(batch_id),
self._loop
)
except Exception as e:
logger.warning(f"广播批量状态失败: {e}")
return snapshot
async def _broadcast_batch_status(self, batch_id: str):
"""广播批量任务状态"""
with _ws_lock:
@@ -343,9 +331,7 @@ class TaskManager:
def get_batch_status(self, batch_id: str) -> Optional[dict]:
"""获取批量任务状态"""
with _get_batch_lock(batch_id):
status = _batch_status.get(batch_id)
return status.copy() if status else None
return _batch_status.get(batch_id)
def get_batch_logs(self, batch_id: str) -> List[str]:
"""获取批量任务日志"""
@@ -354,44 +340,33 @@ class TaskManager:
def is_batch_cancelled(self, batch_id: str) -> bool:
"""检查批量任务是否已取消"""
with _get_batch_lock(batch_id):
status = _batch_status.get(batch_id, {})
return status.get("cancelled", False)
def cancel_batch(self, batch_id: str):
"""取消批量任务"""
changed = False
with _get_batch_lock(batch_id):
if batch_id in _batch_status:
_batch_status[batch_id]["cancelled"] = True
_batch_status[batch_id]["status"] = "cancelling"
changed = True
logger.info(f"批量任务 {batch_id} 已标记为取消")
if changed and self._loop and self._loop.is_running():
try:
asyncio.run_coroutine_threadsafe(
self._broadcast_batch_status(batch_id),
self._loop
)
except Exception as e:
logger.warning(f"广播批量状态失败: {e}")
def register_batch_websocket(self, batch_id: str, websocket):
"""注册批量任务 WebSocket 连接"""
def register_batch_websocket(self, batch_id: str, websocket) -> List[str]:
"""注册批量任务 WebSocket 连接,并返回注册时刻的历史日志快照"""
key = f"batch_{batch_id}"
history_logs: List[str] = []
with _ws_lock:
if key not in _ws_connections:
_ws_connections[key] = []
# 避免重复注册同一个连接
if websocket not in _ws_connections[key]:
_ws_connections[key].append(websocket)
# 记录已发送的日志数量,用于发送历史日志时避免重复
with _get_batch_lock(batch_id):
_ws_sent_index[key][id(websocket)] = len(_batch_logs.get(batch_id, []))
history_logs = _batch_logs.get(batch_id, []).copy()
_ws_sent_index[key][id(websocket)] = len(history_logs)
_ws_connections[key].append(websocket)
logger.info(f"批量任务 WebSocket 连接已注册: {batch_id}")
else:
logger.warning(f"批量任务 WebSocket 连接已存在,跳过重复注册: {batch_id}")
return history_logs
def get_unsent_batch_logs(self, batch_id: str, websocket) -> List[str]:
"""获取未发送给该 WebSocket 的批量任务日志"""

5
static/favicon.svg Normal file
View File

@@ -0,0 +1,5 @@
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 64 64">
<rect width="64" height="64" rx="14" fill="#111827"/>
<path d="M18 20h28v6H18zm0 10h20v6H18zm0 10h28v6H18z" fill="#f9fafb"/>
<circle cx="46" cy="33" r="6" fill="#22c55e"/>
</svg>

After

Width:  |  Height:  |  Size: 246 B

View File

@@ -1055,6 +1055,7 @@ function resetButtons() {
elements.cancelBtn.disabled = true;
currentTask = null;
currentBatch = null;
isBatchMode = false;
// 重置完成标志
taskCompleted = false;
batchCompleted = false;
@@ -1280,13 +1281,12 @@ function connectBatchWebSocket(batchId) {
if (!toastShown) {
toastShown = true;
if (data.status === 'completed') {
const batchLabel = isOutlookBatchMode ? 'Outlook 批量' : '批量';
addLog('success', `[完成] ${batchLabel}任务完成!成功: ${data.success}, 失败: ${data.failed}, 跳过: ${data.skipped || 0}`);
addLog('success', `[完成] Outlook 批量任务完成!成功: ${data.success}, 失败: ${data.failed}, 跳过: ${data.skipped || 0}`);
if (data.success > 0) {
toast.success(`${batchLabel}注册完成,成功 ${data.success}`);
toast.success(`Outlook 批量注册完成,成功 ${data.success}`);
loadRecentAccounts();
} else {
toast.warning(`${batchLabel}注册完成,但没有成功注册任何账号`);
toast.warning('Outlook 批量注册完成,但没有成功注册任何账号');
}
} else if (data.status === 'failed') {
addLog('error', '[错误] 批量任务执行失败');

View File

@@ -72,25 +72,6 @@ const elements = {
editOutlookForm: document.getElementById('edit-outlook-form'),
closeEditOutlookModal: document.getElementById('close-edit-outlook-modal'),
cancelEditOutlook: document.getElementById('cancel-edit-outlook'),
// 收件箱模态框
inboxModal: document.getElementById('inbox-modal'),
closeInboxModal: document.getElementById('close-inbox-modal'),
inboxRefreshBtn: document.getElementById('inbox-refresh-btn'),
inboxOnlyUnseen: document.getElementById('inbox-only-unseen'),
inboxLoading: document.getElementById('inbox-loading'),
inboxTable: document.getElementById('inbox-table'),
inboxTbody: document.getElementById('inbox-tbody'),
inboxEmpty: document.getElementById('inbox-empty'),
inboxModalEmail: document.getElementById('inbox-modal-email'),
// 邮件正文模态框
emailDetailModal: document.getElementById('email-detail-modal'),
closeEmailDetailModal: document.getElementById('close-email-detail-modal'),
emailDetailSubject: document.getElementById('email-detail-subject'),
emailDetailSender: document.getElementById('email-detail-sender'),
emailDetailDate: document.getElementById('email-detail-date'),
emailDetailBody: document.getElementById('email-detail-body'),
};
const CUSTOM_SUBTYPE_LABELS = {
@@ -183,12 +164,6 @@ function initEventListeners() {
document.addEventListener('click', () => {
document.querySelectorAll('.dropdown-menu.active').forEach(m => m.classList.remove('active'));
});
// 收件箱模态框事件
elements.closeInboxModal.addEventListener('click', () => elements.inboxModal.classList.remove('active'));
elements.closeEmailDetailModal.addEventListener('click', () => elements.emailDetailModal.classList.remove('active'));
elements.inboxRefreshBtn.addEventListener('click', () => loadInbox(currentInboxServiceId, true));
elements.inboxOnlyUnseen.addEventListener('change', () => loadInbox(currentInboxServiceId));
}
function toggleEmailMoreMenu(btn) {
@@ -272,7 +247,6 @@ async function loadOutlookServices() {
<td>${format.date(service.last_used)}</td>
<td>
<div style="display:flex;gap:4px;align-items:center;white-space:nowrap;">
<button class="btn btn-secondary btn-sm" onclick="openInboxModal(${service.id}, '${escapeHtml(service.config?.email || service.name)}')">收件箱</button>
<button class="btn btn-secondary btn-sm" onclick="editOutlookService(${service.id})">编辑</button>
<div class="dropdown" style="position:relative;">
<button class="btn btn-secondary btn-sm" onclick="event.stopPropagation();toggleEmailMoreMenu(this)">更多</button>
@@ -813,57 +787,3 @@ async function handleEditOutlook(e) {
toast.error('更新失败: ' + error.message);
}
}
// ============== 收件箱 ==============
let currentInboxServiceId = null;
async function openInboxModal(serviceId, email) {
currentInboxServiceId = serviceId;
elements.inboxModalEmail.textContent = email;
elements.inboxOnlyUnseen.checked = false;
elements.inboxModal.classList.add('active');
await loadInbox(serviceId);
}
async function loadInbox(serviceId) {
if (!serviceId) return;
const onlyUnseen = elements.inboxOnlyUnseen.checked;
elements.inboxLoading.style.display = 'block';
elements.inboxTable.style.display = 'none';
elements.inboxEmpty.style.display = 'none';
elements.inboxEmpty.textContent = '暂无邮件';
try {
const params = new URLSearchParams({ count: 50, only_unseen: onlyUnseen });
const data = await api.get(`/email-services/${serviceId}/inbox?${params}`);
const emails = data.emails || [];
elements.inboxLoading.style.display = 'none';
if (emails.length === 0) {
elements.inboxEmpty.style.display = 'block';
return;
}
elements.inboxTbody.innerHTML = emails.map(m => {
const dataAttr = escapeHtml(JSON.stringify(m));
return `<tr style="cursor:pointer;" onclick="showEmailDetail(JSON.parse(this.dataset.mail))" data-mail="${dataAttr}">
<td style="text-align:center;">${m.is_read ? '' : '<span style="color:var(--primary);font-size:10px;">●</span>'}</td>
<td style="max-width:300px;overflow:hidden;text-overflow:ellipsis;white-space:nowrap;" title="${escapeHtml(m.subject)}">${escapeHtml(m.subject) || '(无主题)'}</td>
<td style="max-width:180px;overflow:hidden;text-overflow:ellipsis;white-space:nowrap;" title="${escapeHtml(m.sender)}">${escapeHtml(m.sender)}</td>
<td>${format.date(m.received_at)}</td>
</tr>`;
}).join('');
elements.inboxTable.style.display = 'table';
} catch (e) {
elements.inboxLoading.style.display = 'none';
elements.inboxEmpty.style.display = 'block';
elements.inboxEmpty.textContent = '加载失败:' + (e.message || '未知错误');
console.error('加载收件箱失败:', e);
}
}
function showEmailDetail(mail) {
elements.emailDetailSubject.textContent = mail.subject || '(无主题)';
elements.emailDetailSender.textContent = mail.sender || '';
elements.emailDetailDate.textContent = format.date(mail.received_at);
elements.emailDetailBody.textContent = mail.body || mail.body_preview || '(无正文)';
elements.emailDetailModal.classList.add('active');
}

View File

@@ -31,7 +31,6 @@ const elements = {
proxiesTable: document.getElementById('proxies-table'),
addProxyBtn: document.getElementById('add-proxy-btn'),
testAllProxiesBtn: document.getElementById('test-all-proxies-btn'),
deleteDisabledProxiesBtn: document.getElementById('delete-disabled-proxies-btn'),
addProxyModal: document.getElementById('add-proxy-modal'),
proxyItemForm: document.getElementById('proxy-item-form'),
closeProxyModal: document.getElementById('close-proxy-modal'),
@@ -207,10 +206,6 @@ function initEventListeners() {
elements.testAllProxiesBtn.addEventListener('click', handleTestAllProxies);
}
if (elements.deleteDisabledProxiesBtn) {
elements.deleteDisabledProxiesBtn.addEventListener('click', handleDeleteDisabledProxies);
}
if (elements.closeProxyModal) {
elements.closeProxyModal.addEventListener('click', closeProxyModal);
}
@@ -678,16 +673,16 @@ async function handleOutlookBatchImport() {
lines.forEach((line, index) => {
const parts = line.split('----').map(p => p.trim());
if (parts.length < 4) {
errors.push(`${index + 1} 行格式错误,必须为 邮箱----密码----client_id----refresh_token`);
if (parts.length < 2) {
errors.push(`${index + 1} 行格式错误`);
return;
}
const account = {
email: parts[0],
password: parts[1],
client_id: parts[2],
refresh_token: parts[3],
client_id: parts[2] || null,
refresh_token: parts[3] || null,
enabled: enabled,
priority: priority
};
@@ -697,11 +692,6 @@ async function handleOutlookBatchImport() {
return;
}
if (!account.client_id || !account.refresh_token) {
errors.push(`${index + 1} 行 client_id 或 refresh_token 不能为空`);
return;
}
accounts.push(account);
});
@@ -777,13 +767,11 @@ async function loadProxies() {
try {
const data = await api.get('/settings/proxies');
renderProxies(data.proxies);
updateProxyBulkActions(data.proxies || []);
} catch (error) {
console.error('加载代理列表失败:', error);
updateProxyBulkActions([]);
elements.proxiesTable.innerHTML = `
<tr>
<td colspan="8">
<td colspan="7">
<div class="empty-state">
<div class="empty-state-icon">❌</div>
<div class="empty-state-title">加载失败</div>
@@ -799,7 +787,7 @@ function renderProxies(proxies) {
if (!proxies || proxies.length === 0) {
elements.proxiesTable.innerHTML = `
<tr>
<td colspan="8">
<td colspan="7">
<div class="empty-state">
<div class="empty-state-icon">🌐</div>
<div class="empty-state-title">暂无代理</div>
@@ -843,17 +831,6 @@ function renderProxies(proxies) {
`).join('');
}
function updateProxyBulkActions(proxies) {
if (!elements.deleteDisabledProxiesBtn) return;
const disabledCount = (proxies || []).filter(proxy => !proxy.enabled).length;
elements.deleteDisabledProxiesBtn.disabled = disabledCount === 0;
elements.deleteDisabledProxiesBtn.dataset.count = String(disabledCount);
elements.deleteDisabledProxiesBtn.textContent = disabledCount > 0
? `🧹 删除禁用项 (${disabledCount})`
: '🧹 删除禁用项';
}
function toggleSettingsMoreMenu(btn) {
const menu = btn.nextElementSibling;
const isActive = menu.classList.contains('active');
@@ -948,14 +925,9 @@ async function testProxyItem(id) {
const result = await api.post(`/settings/proxies/${id}/test`);
if (result.success) {
toast.success(result.message);
} else {
if (result.auto_disabled) {
toast.warning(result.message);
await loadProxies();
} else {
toast.error(result.message);
}
}
} catch (error) {
toast.error('测试失败: ' + error.message);
}
@@ -987,22 +959,6 @@ async function deleteProxyItem(id) {
}
}
async function handleDeleteDisabledProxies() {
const count = Number(elements.deleteDisabledProxiesBtn?.dataset.count || 0);
if (!count) return;
const confirmed = await confirm(`确定要删除全部 ${count} 个已禁用代理吗?此操作不可恢复。`);
if (!confirmed) return;
try {
const result = await api.delete('/settings/proxies/disabled/batch-delete');
toast.success(result.message);
await loadProxies();
} catch (error) {
toast.error('批量删除失败: ' + error.message);
}
}
// 测试所有代理
async function handleTestAllProxies() {
elements.testAllProxiesBtn.disabled = true;
@@ -1010,13 +966,8 @@ async function handleTestAllProxies() {
try {
const result = await api.post('/settings/proxies/test-all');
const summary = `测试完成: 成功 ${result.success}, 失败 ${result.failed}`;
if (result.auto_disabled > 0) {
toast.warning(`${summary},已自动禁用 ${result.auto_disabled}`);
} else {
toast.info(summary);
}
await loadProxies();
toast.info(`测试完成: 成功 ${result.success}, 失败 ${result.failed}`);
loadProxies();
} catch (error) {
toast.error('测试失败: ' + error.message);
} finally {

View File

@@ -5,7 +5,7 @@
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>账号管理 - OpenAI 注册系统</title>
<link rel="stylesheet" href="/static/css/style.css?v={{ static_version }}">
<link rel="icon" href="data:image/svg+xml,<svg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 100 100'><text y='.9em' font-size='90'>📋</text></svg>">
<link rel="icon" type="image/svg+xml" href="/static/favicon.svg?v={{ static_version }}">
<style>
.password-cell {
font-family: var(--font-mono);

View File

@@ -5,7 +5,7 @@
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>邮箱服务 - OpenAI 注册系统</title>
<link rel="stylesheet" href="/static/css/style.css?v={{ static_version }}">
<link rel="icon" href="data:image/svg+xml,<svg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 100 100'><text y='.9em' font-size='90'>📧</text></svg>">
<link rel="icon" type="image/svg+xml" href="/static/favicon.svg?v={{ static_version }}">
</head>
<body>
<div class="container">
@@ -62,13 +62,16 @@
</div>
<div class="card-body" id="outlook-import-body" style="display: none;">
<div class="import-info">
<p><strong>格式(每行一个账户)</strong></p>
<p><code>邮箱----密码----client_id----refresh_token</code></p>
<p>使用四个连字符(----)分隔字段,以 # 开头的行将被忽略。</p>
<p><strong>支持格式:</strong></p>
<ul>
<li><code>邮箱----密码</code> (密码认证)</li>
<li><code>邮箱----密码----client_id----refresh_token</code> XOAUTH2 认证,推荐)</li>
</ul>
<p>每行一个账户,使用四个连字符(----)分隔字段。以 # 开头的行将被忽略。</p>
</div>
<div class="form-group">
<label for="outlook-import-data">批量导入数据</label>
<textarea id="outlook-import-data" rows="8" placeholder="example@outlook.com----password123----client_id----refresh_token"></textarea>
<textarea id="outlook-import-data" rows="8" placeholder="example@outlook.com----password123&#10;test@outlook.com----password456----client_id----refresh_token"></textarea>
</div>
<div class="form-row">
<div class="form-group">
@@ -513,55 +516,6 @@
</div>
<!-- 收件箱模态框 -->
<div class="modal" id="inbox-modal">
<div class="modal-content" style="max-width:800px;width:95%;">
<div class="modal-header">
<h3>📬 收件箱 — <span id="inbox-modal-email"></span></h3>
<div style="display:flex;gap:8px;align-items:center;">
<label style="display:flex;align-items:center;gap:4px;font-size:13px;">
<input type="checkbox" id="inbox-only-unseen"> 仅未读
</label>
<button class="btn btn-secondary btn-sm" id="inbox-refresh-btn">刷新</button>
<button class="modal-close" id="close-inbox-modal">&times;</button>
</div>
</div>
<div class="modal-body" style="padding:0;max-height:70vh;overflow-y:auto;">
<div id="inbox-loading" style="padding:32px;text-align:center;">加载中...</div>
<table class="data-table" id="inbox-table" style="display:none;">
<thead>
<tr>
<th style="width:36px;"></th>
<th>主题</th>
<th style="width:200px;">发件人</th>
<th style="width:150px;">时间</th>
</tr>
</thead>
<tbody id="inbox-tbody"></tbody>
</table>
<div id="inbox-empty" style="display:none;padding:32px;text-align:center;color:var(--text-muted);">暂无邮件</div>
</div>
</div>
</div>
<!-- 邮件正文模态框 -->
<div class="modal" id="email-detail-modal">
<div class="modal-content" style="max-width:700px;width:95%;">
<div class="modal-header">
<div>
<h3 id="email-detail-subject" style="margin:0;"></h3>
<div style="font-size:12px;color:var(--text-muted);margin-top:4px;">
<span id="email-detail-sender"></span> · <span id="email-detail-date"></span>
</div>
</div>
<button class="modal-close" id="close-email-detail-modal">&times;</button>
</div>
<div class="modal-body" style="max-height:65vh;overflow-y:auto;">
<div id="email-detail-body" style="white-space:pre-wrap;word-break:break-word;font-size:13px;"></div>
</div>
</div>
</div>
<script src="/static/js/utils.js?v={{ static_version }}"></script>
<script src="/static/js/email_services.js?v={{ static_version }}"></script>
</body>

View File

@@ -5,7 +5,7 @@
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>注册控制台 - OpenAI 注册系统</title>
<link rel="stylesheet" href="/static/css/style.css?v={{ static_version }}">
<link rel="icon" href="data:image/svg+xml,<svg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 100 100'><text y='.9em' font-size='90'>🚀</text></svg>">
<link rel="icon" type="image/svg+xml" href="/static/favicon.svg?v={{ static_version }}">
<style>
/* 两栏布局 */
.two-column-layout {

View File

@@ -5,7 +5,7 @@
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>访问验证 - OpenAI 注册系统</title>
<link rel="stylesheet" href="/static/css/style.css?v={{ static_version }}">
<link rel="icon" href="data:image/svg+xml,<svg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 100 100'><text y='.9em' font-size='90'>🔒</text></svg>">
<link rel="icon" type="image/svg+xml" href="/static/favicon.svg?v={{ static_version }}">
<style>
.login-wrap {
max-width: 420px;

View File

@@ -5,7 +5,7 @@
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>支付升级 - OpenAI 注册系统</title>
<link rel="stylesheet" href="/static/css/style.css?v={{ static_version }}">
<link rel="icon" href="data:image/svg+xml,<svg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 100 100'><text y='.9em' font-size='90'>💳</text></svg>">
<link rel="icon" type="image/svg+xml" href="/static/favicon.svg?v={{ static_version }}">
<style>
.plan-cards {
display: grid;

View File

@@ -5,7 +5,7 @@
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>系统设置 - OpenAI 注册系统</title>
<link rel="stylesheet" href="/static/css/style.css?v={{ static_version }}">
<link rel="icon" href="data:image/svg+xml,<svg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 100 100'><text y='.9em' font-size='90'>⚙️</text></svg>">
<link rel="icon" type="image/svg+xml" href="/static/favicon.svg?v={{ static_version }}">
</head>
<body>
<div class="container">
@@ -95,7 +95,6 @@
<h3>代理列表</h3>
<div style="display: flex; gap: var(--spacing-sm);">
<button class="btn btn-secondary btn-sm" id="test-all-proxies-btn">🔌 测试全部</button>
<button class="btn btn-danger btn-sm" id="delete-disabled-proxies-btn" disabled>🧹 删除禁用项</button>
<button class="btn btn-primary btn-sm" id="add-proxy-btn"> 添加代理</button>
</div>
</div>
@@ -116,7 +115,7 @@
</thead>
<tbody id="proxies-table">
<tr>
<td colspan="8">
<td colspan="7">
<div class="empty-state">
<div class="empty-state-icon">🌐</div>
<div class="empty-state-title">暂无代理</div>

View File

@@ -1,21 +1,90 @@
from concurrent.futures import ThreadPoolExecutor
import asyncio
from contextlib import contextmanager
from types import SimpleNamespace
from src.web.routes import registration as registration_routes
from src.web.task_manager import task_manager
def test_record_batch_task_result_is_atomic_under_threads():
batch_id = "batch-atomic-test"
task_manager.init_batch(batch_id, 100)
def test_init_batch_state_keeps_batch_tasks_and_task_manager_in_sync():
batch_id = "batch-sync-init"
task_uuids = ["task-1", "task-2", "task-3"]
statuses = ["completed"] * 60 + ["failed"] * 40
registration_routes.batch_tasks.pop(batch_id, None)
registration_routes._init_batch_state(batch_id, task_uuids)
with ThreadPoolExecutor(max_workers=16) as executor:
list(executor.map(lambda status: task_manager.record_batch_task_result(batch_id, status), statuses))
batch_snapshot = registration_routes.batch_tasks[batch_id]
manager_snapshot = task_manager.get_batch_status(batch_id)
snapshot = task_manager.get_batch_status(batch_id)
assert manager_snapshot is not None
assert batch_snapshot["total"] == manager_snapshot["total"] == 3
assert batch_snapshot["completed"] == manager_snapshot["completed"] == 0
assert batch_snapshot["success"] == manager_snapshot["success"] == 0
assert batch_snapshot["failed"] == manager_snapshot["failed"] == 0
assert batch_snapshot["finished"] is False
assert manager_snapshot["finished"] is False
assert manager_snapshot["status"] == "running"
assert snapshot is not None
assert snapshot["completed"] == 100
assert snapshot["success"] == 60
assert snapshot["failed"] == 40
assert snapshot["skipped"] == 0
def test_run_batch_parallel_keeps_counter_updates_in_sync(monkeypatch):
batch_id = "batch-sync-parallel"
task_uuids = ["task-ok-1", "task-fail-1", "task-ok-2"]
task_statuses = {
"task-ok-1": "completed",
"task-fail-1": "failed",
"task-ok-2": "completed",
}
async def fake_run_registration_task(
task_uuid,
email_service_type,
proxy,
email_service_config,
email_service_id,
log_prefix="",
batch_id="",
auto_upload_cpa=False,
cpa_service_ids=None,
auto_upload_sub2api=False,
sub2api_service_ids=None,
auto_upload_tm=False,
tm_service_ids=None,
):
assert task_uuid in task_statuses
@contextmanager
def fake_get_db():
yield object()
def fake_get_registration_task(db, task_uuid):
status = task_statuses[task_uuid]
error_message = None if status == "completed" else f"{task_uuid}-error"
return SimpleNamespace(status=status, error_message=error_message)
registration_routes.batch_tasks.pop(batch_id, None)
monkeypatch.setattr(registration_routes, "run_registration_task", fake_run_registration_task)
monkeypatch.setattr(registration_routes, "get_db", fake_get_db)
monkeypatch.setattr(registration_routes.crud, "get_registration_task", fake_get_registration_task)
asyncio.run(
registration_routes.run_batch_parallel(
batch_id=batch_id,
task_uuids=task_uuids,
email_service_type="tempmail",
proxy=None,
email_service_config=None,
email_service_id=None,
concurrency=2,
)
)
batch_snapshot = registration_routes.batch_tasks[batch_id]
manager_snapshot = task_manager.get_batch_status(batch_id)
assert manager_snapshot is not None
assert batch_snapshot["completed"] == manager_snapshot["completed"] == 3
assert batch_snapshot["success"] == manager_snapshot["success"] == 2
assert batch_snapshot["failed"] == manager_snapshot["failed"] == 1
assert batch_snapshot["finished"] is True
assert manager_snapshot["finished"] is True
assert manager_snapshot["status"] == "completed"

View File

@@ -1,14 +0,0 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# @Time : 2026/3/21 14:48
from src.core.utils import base64_payload_decode, base64_decode
if __name__ == '__main__':
print(base64_payload_decode("eyJzZXNzaW9uX2lkIjoiYXV0aHNlc3NfcUE5eFByY3RaZmtHWXJnSlJGdUpxRXBPIiwiY291bnRyeV9jb2RlX2hpbnQiOiJVUyIsImF1dGhfc2Vzc2lvbl9sb2dnaW5nX2lkIjoiMTk0ZDg5OGQtM2Q0ZC00MzU5LWI1NTQtYmJjMjc1YTJlYjU1IiwicHJvbW8iOiIiLCJzaWdudXBfc291cmNlIjoiIiwib3BlbmFpX2NsaWVudF9pZCI6ImFwcF9FTW9hbUVFWjczZjBDa1hhWHA3aHJhbm4iLCJhcHBfbmFtZV9lbnVtIjoib2FpY2xpIiwiYWFzX2VuYWJsZWQiOmZhbHNlLCJvcmlnaW5hbF9zY3JlZW5faGludCI6ImxvZ2luIiwicGFzc3dvcmRsZXNzX2Rpc2FibGVkIjpmYWxzZSwicGFzc3dvcmRsZXNzX290cF9mcm9tX3Bhc3N3b3JkX3JlZGlyZWN0IjpmYWxzZSwiZW1haWxfdmVyaWZpY2F0aW9uX21vZGUiOiJwYXNzd29yZGxlc3NfbG9naW4iLCJlbWFpbCI6Imxob2xsYW5kNTcwQGdzb2xleWZveWxlLm9yZy51ayIsImVtYWlsX3ZlcmlmaWVkIjp0cnVlLCJuYW1lIjoiZXJyIiwid29ya3NwYWNlcyI6W3siaWQiOiI4NjhmZGNmYi1kNjI3LTRhZTItYTQ4Mi1jMTQxMjA1MGZhYTYiLCJuYW1lIjpudWxsLCJraW5kIjoicGVyc29uYWwiLCJwcm9maWxlX3BpY3R1cmVfYWx0X3RleHQiOiJlcnIifV19"))

View File

@@ -0,0 +1,61 @@
from src.services.base import (
EmailProviderBackoffState,
OTPTimeoutEmailServiceError,
RateLimitedEmailServiceError,
apply_adaptive_backoff,
calculate_adaptive_backoff_delay,
)
def test_calculate_adaptive_backoff_delay_uses_failure_count_progression():
assert calculate_adaptive_backoff_delay(0) == 30
assert calculate_adaptive_backoff_delay(1) == 30
assert calculate_adaptive_backoff_delay(2) == 60
assert calculate_adaptive_backoff_delay(3) == 120
def test_apply_adaptive_backoff_tracks_timeout_failures_to_one_hour():
state = EmailProviderBackoffState()
first = apply_adaptive_backoff(
state,
OTPTimeoutEmailServiceError("等待验证码超时", error_code="OTP_TIMEOUT_SECONDARY"),
now=1000.0,
)
second = apply_adaptive_backoff(
first,
OTPTimeoutEmailServiceError("等待验证码超时", error_code="OTP_TIMEOUT_SECONDARY"),
now=1031.0,
)
third = apply_adaptive_backoff(
second,
OTPTimeoutEmailServiceError("等待验证码超时", error_code="OTP_TIMEOUT_SECONDARY"),
now=1092.0,
)
assert first.failures == 1
assert first.delay_seconds == 30
assert first.opened_until == 1030.0
assert second.failures == 2
assert second.delay_seconds == 60
assert second.opened_until == 1091.0
assert third.failures == 3
assert third.delay_seconds == 3600
assert third.opened_until == 4692.0
def test_apply_adaptive_backoff_keeps_normal_rate_limit_on_exponential_curve():
state = EmailProviderBackoffState(failures=2, delay_seconds=60, opened_until=1060.0)
next_state = apply_adaptive_backoff(
state,
RateLimitedEmailServiceError("请求失败: 429", retry_after=7),
now=1100.0,
)
assert next_state.failures == 3
assert next_state.delay_seconds == 120
assert next_state.opened_until == 1220.0
assert next_state.retry_after == 7

View File

@@ -1,57 +0,0 @@
import base64
import json
from types import SimpleNamespace
from src.core.login import LoginEngine
def _build_auth_cookie(workspace_id: str) -> str:
payload = base64.urlsafe_b64encode(
json.dumps({"workspaces": [{"id": workspace_id}]}).encode("utf-8")
).decode("ascii").rstrip("=")
return f"{payload}.signature"
def test_get_workspace_id_retries_with_exponential_backoff(monkeypatch):
engine = LoginEngine.__new__(LoginEngine)
engine.logs = []
engine._log = lambda message, level="info": engine.logs.append((level, message))
auth_cookie = _build_auth_cookie("ws-123")
cookies = SimpleNamespace()
calls = {"count": 0}
def fake_get(name):
assert name == "oai-client-auth-session"
calls["count"] += 1
if calls["count"] < 4:
return None
return auth_cookie
cookies.get = fake_get
engine.session = SimpleNamespace(cookies=cookies)
sleeps = []
monkeypatch.setattr("src.core.login.time.sleep", lambda seconds: sleeps.append(seconds))
workspace_id = engine._get_workspace_id()
assert workspace_id == "ws-123"
assert calls["count"] == 4
assert sleeps == [1, 2, 4]
def test_run_always_closes_resources_on_early_return():
engine = LoginEngine.__new__(LoginEngine)
engine.logs = []
engine._log = lambda message, level="info": None
engine.close_called = False
engine.close = lambda: setattr(engine, "close_called", True)
engine._check_ip_location = lambda: (False, "blocked")
result = engine.run()
assert result.success is False
assert result.error_message == "IP 地理位置不支持: blocked"
assert engine.close_called is True

View File

@@ -1,439 +0,0 @@
from types import SimpleNamespace
import pytest
from src.config.constants import EmailServiceType, OPENAI_API_ENDPOINTS, OPENAI_PAGE_TYPES
from src.core import register
from src.services.base import BaseEmailService
class DummyEmailService(BaseEmailService):
def __init__(self):
super().__init__(EmailServiceType.TEMPMAIL, name="dummy")
def create_email(self, config=None):
return {"email": "tester@example.com", "service_id": "svc-1"}
def get_verification_code(self, email, email_id=None, timeout=120, pattern=None, otp_sent_at=None):
return "123456"
def list_emails(self, **kwargs):
return []
def delete_email(self, email_id):
return True
def check_health(self):
return True
def refresh_session(self):
return None
class FakeResponse:
def __init__(self, status_code=200, payload=None, text=""):
self.status_code = status_code
self._payload = payload if payload is not None else {}
self.text = text
def json(self):
return self._payload
class BrokenJSONResponse(FakeResponse):
def json(self):
raise ValueError("bad json")
class FakeSession:
def __init__(self, post_handler=None, get_handler=None, cookies=None):
self.post_handler = post_handler
self.get_handler = get_handler
self.cookies = cookies or {}
self.post_calls = []
self.get_calls = []
def post(self, url, **kwargs):
self.post_calls.append({"url": url, "kwargs": kwargs})
if self.post_handler is None:
raise AssertionError("unexpected post call")
return self.post_handler(url, **kwargs)
def get(self, url, **kwargs):
self.get_calls.append({"url": url, "kwargs": kwargs})
if self.get_handler is None:
raise AssertionError("unexpected get call")
return self.get_handler(url, **kwargs)
class DummyHTTPClient:
def __init__(self, proxy_url=None):
self.proxy_url = proxy_url
self.session = FakeSession()
self.closed = False
def close(self):
self.closed = True
def post(self, url, **kwargs):
raise AssertionError("unexpected http client post")
class DummyOAuthManager:
def __init__(self, **kwargs):
self.kwargs = kwargs
def start_oauth(self):
return SimpleNamespace(
auth_url="https://auth.example/start",
state="state-1",
code_verifier="verifier-1",
redirect_uri="http://localhost/callback",
)
def make_engine(monkeypatch, email_service=None):
monkeypatch.setattr(
register,
"get_settings",
lambda: SimpleNamespace(
openai_client_id="client-id",
openai_auth_url="https://auth.example/authorize",
openai_token_url="https://auth.example/token",
openai_redirect_uri="http://localhost/callback",
openai_scope="openid email profile offline_access",
),
)
monkeypatch.setattr(register, "OpenAIHTTPClient", DummyHTTPClient)
monkeypatch.setattr(register, "OAuthManager", DummyOAuthManager)
engine = register.RegistrationEngine(email_service or DummyEmailService())
engine.email = "tester@example.com"
engine.email_info = {"email": "tester@example.com", "service_id": "svc-1"}
return engine
@pytest.mark.parametrize(
"page_type",
[
"login_password",
OPENAI_PAGE_TYPES["EMAIL_OTP_VERIFICATION"],
"consent_required",
"some_other_page",
],
)
def test_submit_login_form_accepts_any_http_200_page(monkeypatch, page_type):
engine = make_engine(monkeypatch)
engine.session = FakeSession(
post_handler=lambda url, **kwargs: FakeResponse(
status_code=200,
payload={"page": {"type": page_type}},
)
)
result = engine._submit_login_form("did-1", "sen-1")
assert result.success is True
assert result.page_type == page_type
assert result.error_message == ""
def test_submit_login_form_accepts_http_200_even_when_json_is_invalid(monkeypatch):
engine = make_engine(monkeypatch)
engine.session = FakeSession(
post_handler=lambda url, **kwargs: BrokenJSONResponse(status_code=200)
)
result = engine._submit_login_form("did-1", "sen-1")
assert result.success is True
assert result.page_type == ""
assert result.response_data == {}
assert result.error_message == ""
def test_send_passwordless_otp_posts_empty_body(monkeypatch):
engine = make_engine(monkeypatch)
engine.session = FakeSession(
post_handler=lambda url, **kwargs: FakeResponse(status_code=200)
)
success = engine._send_passwordless_otp()
assert success is True
assert len(engine.session.post_calls) == 1
call = engine.session.post_calls[0]
assert call["url"] == OPENAI_API_ENDPOINTS["send_passwordless_otp"]
assert call["kwargs"]["data"] == ""
assert engine._otp_sent_at is not None
def test_send_passwordless_otp_does_not_update_timestamp_on_failure(monkeypatch):
engine = make_engine(monkeypatch)
engine._otp_sent_at = 1234.5
engine.session = FakeSession(
post_handler=lambda url, **kwargs: FakeResponse(status_code=500, text="server error")
)
success = engine._send_passwordless_otp()
assert success is False
assert engine._otp_sent_at == 1234.5
def test_get_verification_code_passes_explicit_otp_timestamp(monkeypatch):
captured = {}
class RecordingEmailService(DummyEmailService):
def get_verification_code(self, email, email_id=None, timeout=120, pattern=None, otp_sent_at=None):
captured["email"] = email
captured["email_id"] = email_id
captured["timeout"] = timeout
captured["pattern"] = pattern
captured["otp_sent_at"] = otp_sent_at
return "654321"
engine = make_engine(monkeypatch, email_service=RecordingEmailService())
code = engine._get_verification_code(otp_sent_at=1234.5)
assert code == "654321"
assert captured["email"] == "tester@example.com"
assert captured["email_id"] == "svc-1"
assert captured["timeout"] == 120
assert captured["otp_sent_at"] == 1234.5
def test_validate_verification_code_accepts_http_200_even_when_json_is_invalid(monkeypatch):
engine = make_engine(monkeypatch)
engine.session = FakeSession(
post_handler=lambda url, **kwargs: BrokenJSONResponse(status_code=200)
)
result = engine._validate_verification_code("123456")
assert result.success is True
assert result.continue_url == ""
assert result.response_data == {}
def test_run_closes_http_client_on_early_failure(monkeypatch):
engine = make_engine(monkeypatch)
tracking_client = DummyHTTPClient()
engine.http_client = tracking_client
monkeypatch.setattr(engine, "_check_ip_location", lambda: (False, None))
result = engine.run()
assert result.success is False
assert tracking_client.closed is True
assert engine.session is None
def test_fallback_to_login_flow_forces_otp_and_continue_url(monkeypatch):
engine = make_engine(monkeypatch)
steps = []
captured = {}
monkeypatch.setattr(engine, "_reset_oauth_session", lambda: steps.append("reset_session") or True)
monkeypatch.setattr(engine, "_get_device_id", lambda: steps.append("get_device_id") or "did-1")
monkeypatch.setattr(engine, "_check_sentinel", lambda did: steps.append("check_sentinel") or "sen-1")
monkeypatch.setattr(
engine,
"_submit_login_form",
lambda did, sen: steps.append("submit_login_form")
or register.SignupFormResult(success=True, page_type="login_password"),
)
def fake_send_passwordless_otp():
steps.append("send_passwordless_otp")
engine._otp_sent_at = 4567.89
return True
def fake_get_verification_code(otp_sent_at=None):
steps.append("get_verification_code")
captured["otp_sent_at"] = otp_sent_at
return "123456"
monkeypatch.setattr(engine, "_send_passwordless_otp", fake_send_passwordless_otp)
monkeypatch.setattr(engine, "_get_verification_code", fake_get_verification_code)
monkeypatch.setattr(
engine,
"_validate_verification_code",
lambda code: steps.append("validate_verification_code")
or register.OTPValidationResult(success=True, continue_url="https://auth.example/continue"),
)
def fake_try_upgrade(continue_url, stage):
steps.append("get_continue_url_and_parse_workspace")
captured["continue_url"] = continue_url
captured["stage"] = stage
return "ws-123"
monkeypatch.setattr(engine, "_try_upgrade_cookie_with_continue_url", fake_try_upgrade)
workspace_id = engine._fallback_to_login_flow()
assert workspace_id == "ws-123"
assert steps == [
"reset_session",
"get_device_id",
"check_sentinel",
"submit_login_form",
"send_passwordless_otp",
"get_verification_code",
"validate_verification_code",
"get_continue_url_and_parse_workspace",
]
assert captured["continue_url"] == "https://auth.example/continue"
assert captured["stage"] == "降级登录 Continue URL"
assert captured["otp_sent_at"] == 4567.89
def test_fallback_to_login_flow_requires_continue_url(monkeypatch):
engine = make_engine(monkeypatch)
monkeypatch.setattr(engine, "_reset_oauth_session", lambda: True)
monkeypatch.setattr(engine, "_get_device_id", lambda: "did-1")
monkeypatch.setattr(engine, "_check_sentinel", lambda did: "sen-1")
monkeypatch.setattr(
engine,
"_submit_login_form",
lambda did, sen: register.SignupFormResult(success=True, page_type="login_password"),
)
def fake_send_passwordless_otp():
engine._otp_sent_at = 9876.5
return True
monkeypatch.setattr(engine, "_send_passwordless_otp", fake_send_passwordless_otp)
monkeypatch.setattr(engine, "_get_verification_code", lambda otp_sent_at=None: "123456")
monkeypatch.setattr(
engine,
"_validate_verification_code",
lambda code: register.OTPValidationResult(success=True, continue_url=""),
)
def fail_if_called(*args, **kwargs):
raise AssertionError("continue_url 缺失时不应尝试升级 Cookie")
monkeypatch.setattr(engine, "_try_upgrade_cookie_with_continue_url", fail_if_called)
assert engine._fallback_to_login_flow() is None
def test_fallback_to_login_flow_accepts_workspace_without_continue_url(monkeypatch):
engine = make_engine(monkeypatch)
steps = []
monkeypatch.setattr(engine, "_reset_oauth_session", lambda: steps.append("reset_session") or True)
monkeypatch.setattr(engine, "_get_device_id", lambda: steps.append("get_device_id") or "did-1")
monkeypatch.setattr(engine, "_check_sentinel", lambda did: steps.append("check_sentinel") or "sen-1")
monkeypatch.setattr(
engine,
"_submit_login_form",
lambda did, sen: steps.append("submit_login_form")
or register.SignupFormResult(success=True, page_type="login_password"),
)
def fake_send_passwordless_otp():
steps.append("send_passwordless_otp")
engine._otp_sent_at = 4567.89
return True
monkeypatch.setattr(engine, "_send_passwordless_otp", fake_send_passwordless_otp)
monkeypatch.setattr(
engine,
"_get_verification_code",
lambda otp_sent_at=None: steps.append("get_verification_code") or "123456",
)
monkeypatch.setattr(
engine,
"_validate_verification_code",
lambda code: steps.append("validate_verification_code")
or register.OTPValidationResult(success=True, continue_url=""),
)
monkeypatch.setattr(
engine,
"_get_workspace_id",
lambda log_missing=True: steps.append("get_workspace_id") or "ws-cookie",
)
def fail_if_called(*args, **kwargs):
raise AssertionError("已有 workspace 时不应继续访问 continue_url")
monkeypatch.setattr(engine, "_try_upgrade_cookie_with_continue_url", fail_if_called)
workspace_id = engine._fallback_to_login_flow()
assert workspace_id == "ws-cookie"
assert steps == [
"reset_session",
"get_device_id",
"check_sentinel",
"submit_login_form",
"send_passwordless_otp",
"get_verification_code",
"validate_verification_code",
"get_workspace_id",
]
def test_get_verification_code_uses_provider_timeout_and_refreshes_once(monkeypatch):
captured = {"calls": [], "refresh_count": 0}
class RefreshableOutlookService(DummyEmailService):
def __init__(self):
super().__init__()
self.service_type = EmailServiceType.OUTLOOK
def get_verification_code(self, email, email_id=None, timeout=120, pattern=None, otp_sent_at=None):
captured["calls"].append(
{
"timeout": timeout,
"otp_sent_at": otp_sent_at,
"email": email,
"email_id": email_id,
}
)
if len(captured["calls"]) == 1:
return None
return "987654"
def refresh_session(self):
captured["refresh_count"] += 1
engine = make_engine(monkeypatch, email_service=RefreshableOutlookService())
code = engine._get_verification_code(otp_sent_at=2468.0)
assert code == "987654"
assert captured["refresh_count"] == 1
assert len(captured["calls"]) == 2
assert captured["calls"][0]["timeout"] == 180
assert captured["calls"][1]["timeout"] == 180
assert captured["calls"][0]["otp_sent_at"] == 2468.0
def test_try_upgrade_cookie_with_continue_url_retries_with_second_probe(monkeypatch):
engine = make_engine(monkeypatch)
sleep_calls = []
workspace_results = iter([None, None, None, None, None, "ws-delayed"])
engine.session = FakeSession(
get_handler=lambda url, **kwargs: FakeResponse(status_code=302),
cookies={},
)
monkeypatch.setattr(engine, "_log_cookie_state", lambda *args, **kwargs: None)
monkeypatch.setattr(engine, "_get_workspace_id", lambda log_missing=False: next(workspace_results))
monkeypatch.setattr(register.time, "sleep", lambda seconds: sleep_calls.append(seconds))
workspace_id = engine._try_upgrade_cookie_with_continue_url(
"https://auth.example/continue",
"降级登录 Continue URL",
)
assert workspace_id == "ws-delayed"
assert len(engine.session.get_calls) == 3
assert sleep_calls == [1.0, 2.0, 4.0]

View File

@@ -0,0 +1,82 @@
import json
from types import SimpleNamespace
import src.core.register as register_module
from src.config.constants import OPENAI_PAGE_TYPES
from src.core.register import RegistrationEngine
from src.services import EmailServiceType
class DummySettings:
openai_client_id = "client-id"
openai_auth_url = "https://auth.example.test"
openai_token_url = "https://token.example.test"
openai_redirect_uri = "https://callback.example.test"
openai_scope = "openid profile email"
class FakeResponse:
def __init__(self, status_code=200, payload=None, text=""):
self.status_code = status_code
self._payload = payload or {}
self.text = text
def json(self):
return self._payload
class FakeSession:
def __init__(self, response):
self.response = response
self.calls = []
def post(self, url, **kwargs):
self.calls.append({
"url": url,
**kwargs,
})
return self.response
def _build_engine(monkeypatch):
monkeypatch.setattr(register_module, "get_settings", lambda: DummySettings())
email_service = SimpleNamespace(service_type=EmailServiceType.DUCK_MAIL)
return RegistrationEngine(email_service=email_service)
def test_submit_signup_form_uses_stable_protocol_body(monkeypatch):
engine = _build_engine(monkeypatch)
session = FakeSession(FakeResponse(
status_code=200,
payload={"page": {"type": OPENAI_PAGE_TYPES["PASSWORD_REGISTRATION"]}},
))
engine.session = session
engine.email = "tester@example.com"
result = engine._submit_signup_form("did-1", None)
assert result.success is True
assert result.is_existing_account is False
assert (
session.calls[0]["data"]
== '{"username":{"value":"tester@example.com","kind":"email"},"screen_hint":"signup"}'
)
def test_register_password_uses_stable_protocol_body(monkeypatch):
engine = _build_engine(monkeypatch)
session = FakeSession(FakeResponse(status_code=200))
engine.session = session
engine.email = "tester@example.com"
monkeypatch.setattr(engine, "_generate_password", lambda length=0: "Pass12345")
success, password = engine._register_password()
assert success is True
assert password == "Pass12345"
assert session.calls[0]["data"] == json.dumps(
{
"password": "Pass12345",
"username": "tester@example.com",
}
)

View File

@@ -0,0 +1,414 @@
from contextlib import contextmanager
from pathlib import Path
from types import SimpleNamespace
import src.services.base as base_module
from src.core.register import (
ERROR_OTP_TIMEOUT_SECONDARY,
PhaseResult,
RegistrationResult,
)
from src.database.models import Base, EmailService, RegistrationTask
from src.database.session import DatabaseSessionManager
from src.services import EmailServiceType
from src.services.base import BaseEmailService, EmailProviderBackoffState
from src.web.routes import registration as registration_routes
class DummyTaskManager:
def __init__(self):
self.status_updates = []
self.logs = {}
def is_cancelled(self, task_uuid):
return False
def update_status(self, task_uuid, status, email=None, error=None, **kwargs):
self.status_updates.append((task_uuid, status, email, error, kwargs))
def create_log_callback(self, task_uuid, prefix="", batch_id=""):
def callback(message):
self.logs.setdefault(task_uuid, []).append(message)
return callback
class BackoffAwareEmailService(BaseEmailService):
def __init__(self, service_type, config=None, name=None):
super().__init__(service_type=service_type, name=name)
self.config = config or {}
def create_email(self, config=None):
return {"email": "tester@example.com", "service_id": "svc-1"}
def get_verification_code(self, **kwargs):
return None
def list_emails(self, **kwargs):
return []
def delete_email(self, email_id: str) -> bool:
return True
def check_health(self) -> bool:
return True
def test_registration_task_fails_over_after_rate_limit(monkeypatch):
runtime_dir = Path("tests_runtime")
runtime_dir.mkdir(exist_ok=True)
db_path = runtime_dir / "registration_failover.db"
if db_path.exists():
db_path.unlink()
manager = DatabaseSessionManager(f"sqlite:///{db_path}")
Base.metadata.create_all(bind=manager.engine)
task_uuid = "task-rate-limit-failover"
with manager.session_scope() as session:
session.add(RegistrationTask(task_uuid=task_uuid, status="pending"))
session.add_all([
EmailService(
service_type="duck_mail",
name="duck-primary",
config={
"base_url": "https://mail-1.example.test",
"default_domain": "mail.example.test",
},
enabled=True,
priority=0,
),
EmailService(
service_type="duck_mail",
name="duck-secondary",
config={
"base_url": "https://mail-2.example.test",
"default_domain": "mail.example.test",
},
enabled=True,
priority=1,
),
])
@contextmanager
def fake_get_db():
session = manager.SessionLocal()
try:
yield session
finally:
session.close()
class DummySettings:
pass
attempts = []
class FakeRegistrationEngine:
def __init__(self, email_service, proxy_url=None, callback_logger=None, task_uuid=None):
self.email_service = email_service
self.phase_history = []
def run(self):
attempts.append(self.email_service.name)
if self.email_service.name == "duck-primary":
self.phase_history = [
PhaseResult(
phase="email_prepare",
success=False,
error_message="创建邮箱失败",
error_code="EMAIL_PROVIDER_RATE_LIMITED",
retryable=True,
next_action="switch_provider",
provider_backoff=EmailProviderBackoffState(
failures=1,
delay_seconds=30,
opened_until=1030.0,
retry_after=7,
last_error="请求失败: 429",
),
)
]
return RegistrationResult(
success=False,
error_message="创建邮箱失败: 请求失败: 429",
logs=[],
)
self.phase_history = [
PhaseResult(
phase="email_prepare",
success=True,
provider_backoff=EmailProviderBackoffState(),
)
]
return RegistrationResult(
success=True,
email="tester@example.com",
password="Pass12345",
account_id="acct-1",
workspace_id="ws-1",
access_token="access-token",
refresh_token="refresh-token",
id_token="id-token",
logs=[],
)
def save_to_database(self, result):
return True
def close(self):
return None
monkeypatch.setattr(registration_routes, "get_db", fake_get_db)
monkeypatch.setattr(registration_routes, "get_settings", lambda: DummySettings())
monkeypatch.setattr(registration_routes, "task_manager", DummyTaskManager())
monkeypatch.setattr(registration_routes, "RegistrationEngine", FakeRegistrationEngine)
monkeypatch.setattr(
registration_routes.EmailServiceFactory,
"create",
lambda service_type, config, name=None: SimpleNamespace(
service_type=service_type,
name=name or service_type.value,
config=config,
),
)
monkeypatch.setattr(registration_routes, "update_proxy_usage", lambda db, proxy_id: None)
registration_routes.email_service_circuit_breakers.clear()
registration_routes._run_sync_registration_task(
task_uuid=task_uuid,
email_service_type=EmailServiceType.DUCK_MAIL.value,
proxy=None,
email_service_config=None,
)
with manager.session_scope() as session:
task = session.query(RegistrationTask).filter(RegistrationTask.task_uuid == task_uuid).first()
services = session.query(EmailService).order_by(EmailService.priority.asc()).all()
task_status = task.status
task_email_service_id = task.email_service_id
primary_service_id = services[0].id
secondary_service_id = services[1].id
assert attempts == ["duck-primary", "duck-secondary"]
assert task_status == "completed"
assert task_email_service_id == secondary_service_id
assert registration_routes.email_service_circuit_breakers[primary_service_id].failures == 1
assert registration_routes.email_service_circuit_breakers[primary_service_id].delay_seconds == 30
def test_registration_task_enters_deep_cooldown_after_three_otp_timeouts(monkeypatch):
runtime_dir = Path("tests_runtime")
runtime_dir.mkdir(exist_ok=True)
db_path = runtime_dir / "registration_otp_timeout_backoff.db"
if db_path.exists():
db_path.unlink()
manager = DatabaseSessionManager(f"sqlite:///{db_path}")
Base.metadata.create_all(bind=manager.engine)
task_uuids = [
"task-otp-timeout-1",
"task-otp-timeout-2",
"task-otp-timeout-3",
]
with manager.session_scope() as session:
session.add_all([RegistrationTask(task_uuid=task_uuid, status="pending") for task_uuid in task_uuids])
session.add(
EmailService(
service_type="duck_mail",
name="duck-primary",
config={
"base_url": "https://mail-1.example.test",
"default_domain": "mail.example.test",
},
enabled=True,
priority=0,
)
)
@contextmanager
def fake_get_db():
session = manager.SessionLocal()
try:
yield session
finally:
session.close()
class DummySettings:
pass
current_time = {"value": 1000.0}
class FakeRegistrationEngine:
def __init__(self, email_service, proxy_url=None, callback_logger=None, task_uuid=None):
self.email_service = email_service
self.phase_history = []
def run(self):
self.phase_history = [
PhaseResult(
phase="email_prepare",
success=True,
provider_backoff=EmailProviderBackoffState(),
)
]
return RegistrationResult(
success=False,
error_message="等待验证码超时",
error_code=ERROR_OTP_TIMEOUT_SECONDARY,
logs=[],
)
def save_to_database(self, result):
return True
def close(self):
return None
monkeypatch.setattr(registration_routes, "get_db", fake_get_db)
monkeypatch.setattr(registration_routes, "get_settings", lambda: DummySettings())
monkeypatch.setattr(registration_routes, "task_manager", DummyTaskManager())
monkeypatch.setattr(registration_routes, "RegistrationEngine", FakeRegistrationEngine)
monkeypatch.setattr(
registration_routes.EmailServiceFactory,
"create",
lambda service_type, config, name=None: BackoffAwareEmailService(
service_type=service_type,
config=config,
name=name,
),
)
monkeypatch.setattr(registration_routes, "update_proxy_usage", lambda db, proxy_id: None)
monkeypatch.setattr(base_module.time, "time", lambda: current_time["value"])
registration_routes.email_service_circuit_breakers.clear()
with manager.session_scope() as session:
service_id = session.query(EmailService.id).filter(EmailService.name == "duck-primary").scalar()
expected_delays = [30, 60, 3600]
for attempt_index, task_uuid in enumerate(task_uuids, start=1):
registration_routes._run_sync_registration_task(
task_uuid=task_uuid,
email_service_type=EmailServiceType.DUCK_MAIL.value,
proxy=None,
email_service_config=None,
)
with manager.session_scope() as session:
task = session.query(RegistrationTask).filter(RegistrationTask.task_uuid == task_uuid).first()
assert task.status == "failed"
assert task.error_message == "等待验证码超时"
state = registration_routes.email_service_circuit_breakers[service_id]
assert state.failures == attempt_index
assert state.delay_seconds == expected_delays[attempt_index - 1]
assert state.opened_until == current_time["value"] + expected_delays[attempt_index - 1]
if attempt_index < len(task_uuids):
current_time["value"] = state.opened_until + 1
final_state = registration_routes.email_service_circuit_breakers[service_id]
assert final_state.delay_seconds == 3600
assert final_state.failures == 3
def test_registration_task_success_clears_email_service_backoff(monkeypatch):
runtime_dir = Path("tests_runtime")
runtime_dir.mkdir(exist_ok=True)
db_path = runtime_dir / "registration_success_clears_backoff.db"
if db_path.exists():
db_path.unlink()
manager = DatabaseSessionManager(f"sqlite:///{db_path}")
Base.metadata.create_all(bind=manager.engine)
task_uuid = "task-success-clears-backoff"
with manager.session_scope() as session:
session.add(RegistrationTask(task_uuid=task_uuid, status="pending"))
session.add(
EmailService(
service_type="duck_mail",
name="duck-primary",
config={
"base_url": "https://mail-1.example.test",
"default_domain": "mail.example.test",
},
enabled=True,
priority=0,
)
)
@contextmanager
def fake_get_db():
session = manager.SessionLocal()
try:
yield session
finally:
session.close()
class DummySettings:
pass
class FakeRegistrationEngine:
def __init__(self, email_service, proxy_url=None, callback_logger=None, task_uuid=None):
self.email_service = email_service
self.phase_history = [
PhaseResult(
phase="email_prepare",
success=True,
provider_backoff=EmailProviderBackoffState(),
)
]
def run(self):
return RegistrationResult(
success=True,
email="tester@example.com",
password="Pass12345",
account_id="acct-1",
workspace_id="ws-1",
access_token="access-token",
refresh_token="refresh-token",
id_token="id-token",
logs=[],
)
def save_to_database(self, result):
return True
def close(self):
return None
monkeypatch.setattr(registration_routes, "get_db", fake_get_db)
monkeypatch.setattr(registration_routes, "get_settings", lambda: DummySettings())
monkeypatch.setattr(registration_routes, "task_manager", DummyTaskManager())
monkeypatch.setattr(registration_routes, "RegistrationEngine", FakeRegistrationEngine)
monkeypatch.setattr(
registration_routes.EmailServiceFactory,
"create",
lambda service_type, config, name=None: BackoffAwareEmailService(
service_type=service_type,
config=config,
name=name,
),
)
monkeypatch.setattr(registration_routes, "update_proxy_usage", lambda db, proxy_id: None)
registration_routes.email_service_circuit_breakers.clear()
with manager.session_scope() as session:
service_id = session.query(EmailService.id).filter(EmailService.name == "duck-primary").scalar()
registration_routes.email_service_circuit_breakers[service_id] = EmailProviderBackoffState(
failures=2,
delay_seconds=60,
opened_until=9999.0,
last_error="等待验证码超时",
)
registration_routes._run_sync_registration_task(
task_uuid=task_uuid,
email_service_type=EmailServiceType.DUCK_MAIL.value,
proxy=None,
email_service_config=None,
)
assert service_id not in registration_routes.email_service_circuit_breakers

View File

@@ -0,0 +1,71 @@
import src.core.register as register_module
from src.core.register import (
ERROR_OTP_TIMEOUT_SECONDARY,
PhaseContext,
RegistrationEngine,
)
from src.services import EmailServiceType
class DummySettings:
openai_client_id = "client-id"
openai_auth_url = "https://auth.example.test"
openai_token_url = "https://token.example.test"
openai_redirect_uri = "https://callback.example.test"
openai_scope = "openid profile email"
class FakeEmailService:
def __init__(self, code):
self.service_type = EmailServiceType.TEMPMAIL
self.code = code
self.calls = []
def get_verification_code(self, **kwargs):
self.calls.append(kwargs)
return self.code
def _build_engine(monkeypatch, email_service):
monkeypatch.setattr(register_module, "get_settings", lambda: DummySettings())
return RegistrationEngine(email_service=email_service)
def test_phase_otp_secondary_uses_remaining_budget_from_start_timestamp(monkeypatch):
email_service = FakeEmailService(code="654321")
engine = _build_engine(monkeypatch, email_service)
engine.email = "tester@example.com"
engine.email_info = {"service_id": "svc-1"}
monkeypatch.setattr(register_module.time, "time", lambda: 120.0)
code, phase_result = engine._phase_otp_secondary(
PhaseContext(otp_sent_at=77.0),
started_at=100.0,
)
assert code == "654321"
assert phase_result.success is True
assert email_service.calls[0]["timeout"] == 100
assert email_service.calls[0]["otp_sent_at"] == 77.0
assert email_service.calls[0]["email"] == "tester@example.com"
assert email_service.calls[0]["email_id"] == "svc-1"
def test_phase_otp_secondary_returns_dedicated_timeout_error_code(monkeypatch):
email_service = FakeEmailService(code=None)
engine = _build_engine(monkeypatch, email_service)
engine.email = "tester@example.com"
engine.email_info = {"service_id": "svc-1"}
monkeypatch.setattr(register_module.time, "time", lambda: 120.0)
code, phase_result = engine._phase_otp_secondary(
PhaseContext(otp_sent_at=80.0),
started_at=100.0,
)
assert code is None
assert phase_result.success is False
assert phase_result.error_code == ERROR_OTP_TIMEOUT_SECONDARY
assert engine.phase_history[0].error_code == ERROR_OTP_TIMEOUT_SECONDARY

View File

@@ -42,7 +42,13 @@ def test_run_sync_registration_task_disables_bad_proxy_and_retries(monkeypatch,
monkeypatch.setattr(
registration,
"EmailServiceFactory",
SimpleNamespace(create=lambda service_type, config: SimpleNamespace(service_type=service_type, config=config)),
SimpleNamespace(
create=lambda service_type, config, name=None: SimpleNamespace(
service_type=service_type,
config=config,
name=name or service_type.value,
)
),
)
attempted_proxies = []
@@ -73,6 +79,7 @@ def test_run_sync_registration_task_disables_bad_proxy_and_retries(monkeypatch,
return True
monkeypatch.setattr(registration, "RegistrationEngine", FakeRegistrationEngine)
registration.email_service_circuit_breakers.clear()
registration._run_sync_registration_task(
task_uuid="task-proxy-failover",

View File

@@ -15,6 +15,7 @@ def test_static_asset_version_is_non_empty_string():
def test_email_services_template_uses_versioned_static_assets():
template = Path("templates/email_services.html").read_text(encoding="utf-8")
assert '/static/favicon.svg?v={{ static_version }}' in template
assert '/static/css/style.css?v={{ static_version }}' in template
assert '/static/js/utils.js?v={{ static_version }}' in template
assert '/static/js/email_services.js?v={{ static_version }}' in template
@@ -23,6 +24,7 @@ def test_email_services_template_uses_versioned_static_assets():
def test_index_template_uses_versioned_static_assets():
template = Path("templates/index.html").read_text(encoding="utf-8")
assert '/static/favicon.svg?v={{ static_version }}' in template
assert '/static/css/style.css?v={{ static_version }}' in template
assert '/static/js/utils.js?v={{ static_version }}' in template
assert '/static/js/app.js?v={{ static_version }}' in template

143
tests/test_task_recovery.py Normal file
View File

@@ -0,0 +1,143 @@
from contextlib import contextmanager
import asyncio
from fastapi import WebSocketDisconnect
from src.database import crud
from src.database.models import Base, RegistrationTask
from src.database.session import DatabaseSessionManager
from src.web.routes import websocket as websocket_routes
from src.web.task_manager import TaskManager
def test_fail_incomplete_registration_tasks_marks_pending_and_running_failed(tmp_path):
db_path = tmp_path / "recovery.db"
manager = DatabaseSessionManager(f"sqlite:///{db_path}")
Base.metadata.create_all(bind=manager.engine)
with manager.session_scope() as session:
session.add_all([
RegistrationTask(task_uuid="task-pending", status="pending"),
RegistrationTask(task_uuid="task-running", status="running", logs="[01:00:00] still running"),
RegistrationTask(task_uuid="task-done", status="completed"),
])
with manager.session_scope() as session:
cleaned = crud.fail_incomplete_registration_tasks(
session,
"服务启动时检测到未完成的历史任务,已标记失败,请重新发起。"
)
assert cleaned == ["task-pending", "task-running"]
with manager.session_scope() as session:
pending_task = crud.get_registration_task_by_uuid(session, "task-pending")
running_task = crud.get_registration_task_by_uuid(session, "task-running")
done_task = crud.get_registration_task_by_uuid(session, "task-done")
assert pending_task.status == "failed"
assert running_task.status == "failed"
assert pending_task.error_message == "服务启动时检测到未完成的历史任务,已标记失败,请重新发起。"
assert running_task.completed_at is not None
assert "[系统] 服务启动时检测到未完成的历史任务,已标记失败,请重新发起。" in running_task.logs
assert done_task.status == "completed"
def test_restore_task_snapshot_loads_status_and_logs_from_database(monkeypatch, tmp_path):
db_path = tmp_path / "websocket.db"
manager = DatabaseSessionManager(f"sqlite:///{db_path}")
Base.metadata.create_all(bind=manager.engine)
with manager.session_scope() as session:
session.add(
RegistrationTask(
task_uuid="task-websocket",
status="failed",
logs="[01:00:00] step 1\n[01:00:01] step 2",
result={"email": "tester@example.com"},
error_message="boom"
)
)
@contextmanager
def fake_get_db():
session = manager.SessionLocal()
try:
yield session
finally:
session.close()
monkeypatch.setattr(websocket_routes, "get_db", fake_get_db)
status, logs = websocket_routes._restore_task_snapshot("task-websocket")
assert status == {
"status": "failed",
"email": "tester@example.com",
"error": "boom",
}
assert logs == ["[01:00:00] step 1", "[01:00:01] step 2"]
def test_sync_task_state_prefers_longer_persisted_log_history():
manager = TaskManager()
task_uuid = "task-sync"
manager.sync_task_state(task_uuid, status={"status": "running"}, logs=["a", "b"])
manager.sync_task_state(task_uuid, logs=["a"])
assert manager.get_status(task_uuid) == {"status": "running"}
assert manager.get_logs(task_uuid) == ["a", "b"]
def test_register_websocket_returns_snapshot_and_keeps_live_cursor():
manager = TaskManager()
task_uuid = "task-live"
websocket = object()
manager.sync_task_state(task_uuid, status={"status": "running"}, logs=["log-1", "log-2"])
history_logs = manager.register_websocket(task_uuid, websocket)
assert history_logs == ["log-1", "log-2"]
assert manager.get_unsent_logs(task_uuid, websocket) == []
manager.add_log(task_uuid, "log-3")
assert manager.get_unsent_logs(task_uuid, websocket) == ["log-3"]
class _FakeWebSocket:
def __init__(self):
self.messages = []
self.accepted = False
async def accept(self):
self.accepted = True
async def send_json(self, payload):
self.messages.append(payload)
async def receive_json(self):
raise WebSocketDisconnect()
def test_batch_websocket_replays_history_logs_from_registration_snapshot(monkeypatch):
manager = TaskManager()
batch_id = "batch-history"
websocket = _FakeWebSocket()
manager.init_batch(batch_id, total=2)
manager.add_batch_log(batch_id, "[01:00:00] first")
manager.add_batch_log(batch_id, "[01:00:01] second")
monkeypatch.setattr(websocket_routes, "task_manager", manager)
asyncio.run(websocket_routes.batch_websocket(websocket, batch_id))
assert websocket.accepted is True
assert websocket.messages[0]["type"] == "status"
assert [msg["message"] for msg in websocket.messages[1:]] == [
"[01:00:00] first",
"[01:00:01] second",
]

View File

@@ -0,0 +1,98 @@
from datetime import datetime, timezone
from src.services.tempmail import TempmailService
class FakeResponse:
def __init__(self, payload, status_code=200):
self._payload = payload
self.status_code = status_code
def json(self):
return self._payload
class FakeHTTPClient:
def __init__(self, responses):
self.responses = list(responses)
self.calls = []
def get(self, url, **kwargs):
self.calls.append({"url": url, "kwargs": kwargs})
if not self.responses:
raise AssertionError(f"未准备响应: GET {url}")
return self.responses.pop(0)
def _to_timestamp(value: str) -> float:
return datetime.fromisoformat(value.replace("Z", "+00:00")).astimezone(timezone.utc).timestamp()
def test_get_verification_code_ignores_messages_received_before_otp_sent_at():
service = TempmailService({"base_url": "https://api.tempmail.test"})
service._email_cache["tester@example.com"] = {"token": "token-1"}
service.http_client = FakeHTTPClient([
FakeResponse(
{
"emails": [
{
"id": "old-mail",
"received_at": "2026-03-23T10:00:00Z",
"from": "noreply@openai.com",
"subject": "Old code",
"body": "111111",
},
{
"id": "new-mail",
"received_at": "2026-03-23T10:00:05Z",
"from": "noreply@openai.com",
"subject": "New code",
"body": "222222",
},
]
}
)
])
code = service.get_verification_code(
email="tester@example.com",
timeout=1,
otp_sent_at=_to_timestamp("2026-03-23T10:00:02Z"),
)
assert code == "222222"
def test_get_verification_code_requires_received_at_when_otp_sent_at_is_present():
service = TempmailService({"base_url": "https://api.tempmail.test"})
service._email_cache["tester@example.com"] = {"token": "token-1"}
service.http_client = FakeHTTPClient([
FakeResponse(
{
"emails": [
{
"id": "legacy-mail",
"date": "2026-03-23T10:00:06Z",
"from": "noreply@openai.com",
"subject": "Legacy code",
"body": "333333",
},
{
"id": "received-mail",
"received_at": "2026-03-23T10:00:07Z",
"from": "noreply@openai.com",
"subject": "Received code",
"body": "444444",
},
]
}
)
])
code = service.get_verification_code(
email="tester@example.com",
timeout=1,
otp_sent_at=_to_timestamp("2026-03-23T10:00:05Z"),
)
assert code == "444444"