feat(app): 重构outlook邮箱服务

This commit is contained in:
cnlimiter
2026-03-15 17:58:39 +08:00
parent 75f5bb439b
commit 1628552b92
21 changed files with 3245 additions and 4 deletions

View File

@@ -332,4 +332,38 @@ TIME_CONSTANTS = {
"HOUR": 3600,
"DAY": 86400,
"WEEK": 604800,
}
}
# ============================================================================
# Microsoft/Outlook 相关常量
# ============================================================================
# Microsoft OAuth2 Token 端点
MICROSOFT_TOKEN_ENDPOINTS = {
# 旧版 IMAP 使用的端点
"LIVE": "https://login.live.com/oauth20_token.srf",
# 新版 IMAP 使用的端点(需要特定 scope
"CONSUMERS": "https://login.microsoftonline.com/consumers/oauth2/v2.0/token",
# Graph API 使用的端点
"COMMON": "https://login.microsoftonline.com/common/oauth2/v2.0/token",
}
# IMAP 服务器配置
OUTLOOK_IMAP_SERVERS = {
"OLD": "outlook.office365.com", # 旧版 IMAP
"NEW": "outlook.live.com", # 新版 IMAP
}
# Microsoft OAuth2 Scopes
MICROSOFT_SCOPES = {
# 旧版 IMAP 不需要特定 scope
"IMAP_OLD": "",
# 新版 IMAP 需要的 scope
"IMAP_NEW": "https://outlook.office.com/IMAP.AccessAsUser.All offline_access",
# Graph API 需要的 scope
"GRAPH_API": "https://graph.microsoft.com/.default",
}
# Outlook 提供者默认优先级
OUTLOOK_PROVIDER_PRIORITY = ["imap_new", "imap_old", "graph_api"]

View File

@@ -4,7 +4,7 @@
"""
import os
from typing import Optional, Dict, Any, Type
from typing import Optional, Dict, Any, Type, List
from enum import Enum
from pydantic import BaseModel, field_validator
from pydantic.types import SecretStr
@@ -297,6 +297,32 @@ SETTING_DEFINITIONS: Dict[str, SettingDefinition] = {
category=SettingCategory.EMAIL,
description="验证码轮询间隔(秒)"
),
# Outlook 配置
"outlook_provider_priority": SettingDefinition(
db_key="outlook.provider_priority",
default_value=["graph_api", "imap_new", "imap_old"],
category=SettingCategory.EMAIL,
description="Outlook 提供者优先级"
),
"outlook_health_failure_threshold": SettingDefinition(
db_key="outlook.health_failure_threshold",
default_value=5,
category=SettingCategory.EMAIL,
description="Outlook 提供者连续失败次数阈值"
),
"outlook_health_disable_duration": SettingDefinition(
db_key="outlook.health_disable_duration",
default_value=60,
category=SettingCategory.EMAIL,
description="Outlook 提供者禁用时长(秒)"
),
"outlook_default_client_id": SettingDefinition(
db_key="outlook.default_client_id",
default_value="24d9a0ed-8787-4584-883c-2fd79308940a",
category=SettingCategory.EMAIL,
description="Outlook OAuth 默认 Client ID"
),
}
# 属性名到数据库键名的映射(用于向后兼容)
@@ -320,6 +346,9 @@ SETTING_TYPES: Dict[str, Type] = {
"cpa_enabled": bool,
"email_code_timeout": int,
"email_code_poll_interval": int,
"outlook_provider_priority": list,
"outlook_health_failure_threshold": int,
"outlook_health_disable_duration": int,
}
# 需要作为 SecretStr 处理的字段
@@ -346,6 +375,11 @@ def _convert_value(attr_name: str, value: str) -> Any:
return value
import json
return json.loads(value) if value else {}
elif target_type == list:
if isinstance(value, list):
return value
import json
return json.loads(value) if value else []
else:
return value
@@ -356,7 +390,7 @@ def _value_to_string(value: Any) -> str:
return value.get_secret_value()
elif isinstance(value, bool):
return "true" if value else "false"
elif isinstance(value, dict):
elif isinstance(value, (dict, list)):
import json
return json.dumps(value)
elif value is None:
@@ -533,6 +567,12 @@ class Settings(BaseModel):
email_code_timeout: int = 120
email_code_poll_interval: int = 3
# Outlook 配置
outlook_provider_priority: List[str] = ["graph_api", "imap_new", "imap_old"]
outlook_health_failure_threshold: int = 5
outlook_health_disable_duration: int = 60
outlook_default_client_id: str = "24d9a0ed-8787-4584-883c-2fd79308940a"
# 全局配置实例
_settings: Optional[Settings] = None

View File

@@ -19,14 +19,43 @@ EmailServiceFactory.register(EmailServiceType.TEMPMAIL, TempmailService)
EmailServiceFactory.register(EmailServiceType.OUTLOOK, OutlookService)
EmailServiceFactory.register(EmailServiceType.CUSTOM_DOMAIN, CustomDomainEmailService)
# 导出 Outlook 模块的额外内容
from .outlook.base import (
ProviderType,
EmailMessage,
TokenInfo,
ProviderHealth,
ProviderStatus,
)
from .outlook.account import OutlookAccount
from .outlook.providers import (
OutlookProvider,
IMAPOldProvider,
IMAPNewProvider,
GraphAPIProvider,
)
__all__ = [
# 基类
'BaseEmailService',
'EmailServiceError',
'EmailServiceStatus',
'EmailServiceFactory',
'create_email_service',
'EmailServiceType',
# 服务类
'TempmailService',
'OutlookService',
'CustomDomainEmailService',
# Outlook 模块
'ProviderType',
'EmailMessage',
'TokenInfo',
'ProviderHealth',
'ProviderStatus',
'OutlookAccount',
'OutlookProvider',
'IMAPOldProvider',
'IMAPNewProvider',
'GraphAPIProvider',
]

View File

@@ -0,0 +1,8 @@
"""
Outlook 邮箱服务模块
支持多种 IMAP/API 连接方式,自动故障切换
"""
from .service import OutlookService
__all__ = ['OutlookService']

View File

@@ -0,0 +1,51 @@
"""
Outlook 账户数据类
"""
from dataclasses import dataclass
from typing import Dict, Any, Optional
@dataclass
class OutlookAccount:
"""Outlook 账户信息"""
email: str
password: str = ""
client_id: str = ""
refresh_token: str = ""
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "OutlookAccount":
"""从配置创建账户"""
return cls(
email=config.get("email", ""),
password=config.get("password", ""),
client_id=config.get("client_id", ""),
refresh_token=config.get("refresh_token", "")
)
def has_oauth(self) -> bool:
"""是否支持 OAuth2"""
return bool(self.client_id and self.refresh_token)
def validate(self) -> bool:
"""验证账户信息是否有效"""
return bool(self.email and self.password) or self.has_oauth()
def to_dict(self, include_sensitive: bool = False) -> Dict[str, Any]:
"""转换为字典"""
result = {
"email": self.email,
"has_oauth": self.has_oauth(),
}
if include_sensitive:
result.update({
"password": self.password,
"client_id": self.client_id,
"refresh_token": self.refresh_token[:20] + "..." if self.refresh_token else "",
})
return result
def __str__(self) -> str:
"""字符串表示"""
return f"OutlookAccount({self.email})"

View File

@@ -0,0 +1,153 @@
"""
Outlook 服务基础定义
包含枚举类型和数据类
"""
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Optional, Dict, Any, List
class ProviderType(str, Enum):
"""Outlook 提供者类型"""
IMAP_OLD = "imap_old" # 旧版 IMAP (outlook.office365.com)
IMAP_NEW = "imap_new" # 新版 IMAP (outlook.live.com)
GRAPH_API = "graph_api" # Microsoft Graph API
class TokenEndpoint(str, Enum):
"""Token 端点"""
LIVE = "https://login.live.com/oauth20_token.srf"
CONSUMERS = "https://login.microsoftonline.com/consumers/oauth2/v2.0/token"
COMMON = "https://login.microsoftonline.com/common/oauth2/v2.0/token"
class IMAPServer(str, Enum):
"""IMAP 服务器"""
OLD = "outlook.office365.com"
NEW = "outlook.live.com"
class ProviderStatus(str, Enum):
"""提供者状态"""
HEALTHY = "healthy" # 健康
DEGRADED = "degraded" # 降级
DISABLED = "disabled" # 禁用
@dataclass
class EmailMessage:
"""邮件消息数据类"""
id: str # 消息 ID
subject: str # 主题
sender: str # 发件人
recipients: List[str] = field(default_factory=list) # 收件人列表
body: str = "" # 正文内容
body_preview: str = "" # 正文预览
received_at: Optional[datetime] = None # 接收时间
received_timestamp: int = 0 # 接收时间戳
is_read: bool = False # 是否已读
has_attachments: bool = False # 是否有附件
raw_data: Optional[bytes] = None # 原始数据(用于调试)
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
return {
"id": self.id,
"subject": self.subject,
"sender": self.sender,
"recipients": self.recipients,
"body": self.body,
"body_preview": self.body_preview,
"received_at": self.received_at.isoformat() if self.received_at else None,
"received_timestamp": self.received_timestamp,
"is_read": self.is_read,
"has_attachments": self.has_attachments,
}
@dataclass
class TokenInfo:
"""Token 信息数据类"""
access_token: str
expires_at: float # 过期时间戳
token_type: str = "Bearer"
scope: str = ""
refresh_token: Optional[str] = None
def is_expired(self, buffer_seconds: int = 120) -> bool:
"""检查 Token 是否已过期"""
import time
return time.time() >= (self.expires_at - buffer_seconds)
@classmethod
def from_response(cls, data: Dict[str, Any], scope: str = "") -> "TokenInfo":
"""从 API 响应创建"""
import time
return cls(
access_token=data.get("access_token", ""),
expires_at=time.time() + data.get("expires_in", 3600),
token_type=data.get("token_type", "Bearer"),
scope=scope or data.get("scope", ""),
refresh_token=data.get("refresh_token"),
)
@dataclass
class ProviderHealth:
"""提供者健康状态"""
provider_type: ProviderType
status: ProviderStatus = ProviderStatus.HEALTHY
failure_count: int = 0 # 连续失败次数
last_success: Optional[datetime] = None # 最后成功时间
last_failure: Optional[datetime] = None # 最后失败时间
last_error: str = "" # 最后错误信息
disabled_until: Optional[datetime] = None # 禁用截止时间
def record_success(self):
"""记录成功"""
self.status = ProviderStatus.HEALTHY
self.failure_count = 0
self.last_success = datetime.now()
self.disabled_until = None
def record_failure(self, error: str):
"""记录失败"""
self.failure_count += 1
self.last_failure = datetime.now()
self.last_error = error
def should_disable(self, threshold: int = 3) -> bool:
"""判断是否应该禁用"""
return self.failure_count >= threshold
def is_disabled(self) -> bool:
"""检查是否被禁用"""
if self.disabled_until and datetime.now() < self.disabled_until:
return True
return False
def disable(self, duration_seconds: int = 300):
"""禁用提供者"""
from datetime import timedelta
self.status = ProviderStatus.DISABLED
self.disabled_until = datetime.now() + timedelta(seconds=duration_seconds)
def enable(self):
"""启用提供者"""
self.status = ProviderStatus.HEALTHY
self.disabled_until = None
self.failure_count = 0
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
return {
"provider_type": self.provider_type.value,
"status": self.status.value,
"failure_count": self.failure_count,
"last_success": self.last_success.isoformat() if self.last_success else None,
"last_failure": self.last_failure.isoformat() if self.last_failure else None,
"last_error": self.last_error,
"disabled_until": self.disabled_until.isoformat() if self.disabled_until else None,
}

View File

@@ -0,0 +1,228 @@
"""
邮件解析和验证码提取
"""
import logging
import re
from typing import Optional, List, Dict, Any
from ...config.constants import (
OTP_CODE_SIMPLE_PATTERN,
OTP_CODE_SEMANTIC_PATTERN,
OPENAI_EMAIL_SENDERS,
OPENAI_VERIFICATION_KEYWORDS,
)
from .base import EmailMessage
logger = logging.getLogger(__name__)
class EmailParser:
"""
邮件解析器
用于识别 OpenAI 验证邮件并提取验证码
"""
def __init__(self):
# 编译正则表达式
self._simple_pattern = re.compile(OTP_CODE_SIMPLE_PATTERN)
self._semantic_pattern = re.compile(OTP_CODE_SEMANTIC_PATTERN, re.IGNORECASE)
def is_openai_verification_email(
self,
email: EmailMessage,
target_email: Optional[str] = None,
) -> bool:
"""
判断是否为 OpenAI 验证邮件
Args:
email: 邮件对象
target_email: 目标邮箱地址(用于验证收件人)
Returns:
是否为 OpenAI 验证邮件
"""
sender = email.sender.lower()
# 1. 发件人必须是 OpenAI
if not any(s in sender for s in OPENAI_EMAIL_SENDERS):
logger.debug(f"邮件发件人非 OpenAI: {sender}")
return False
# 2. 主题或正文包含验证关键词
subject = email.subject.lower()
body = email.body.lower()
combined = f"{subject} {body}"
if not any(kw in combined for kw in OPENAI_VERIFICATION_KEYWORDS):
logger.debug(f"邮件未包含验证关键词: {subject[:50]}")
return False
# 3. 收件人检查已移除:别名邮件的 IMAP 头中收件人可能不匹配,只靠发件人+关键词判断
logger.debug(f"识别为 OpenAI 验证邮件: {subject[:50]}")
return True
def extract_verification_code(
self,
email: EmailMessage,
) -> Optional[str]:
"""
从邮件中提取验证码
优先级:
1. 从主题提取6位数字
2. 从正文用语义正则提取(如 "code is 123456"
3. 兜底:任意 6 位数字
Args:
email: 邮件对象
Returns:
验证码字符串,如果未找到返回 None
"""
# 1. 主题优先
code = self._extract_from_subject(email.subject)
if code:
logger.debug(f"从主题提取验证码: {code}")
return code
# 2. 正文语义匹配
code = self._extract_semantic(email.body)
if code:
logger.debug(f"从正文语义提取验证码: {code}")
return code
# 3. 兜底:正文任意 6 位数字
code = self._extract_simple(email.body)
if code:
logger.debug(f"从正文兜底提取验证码: {code}")
return code
return None
def _extract_from_subject(self, subject: str) -> Optional[str]:
"""从主题提取验证码"""
match = self._simple_pattern.search(subject)
if match:
return match.group(1)
return None
def _extract_semantic(self, body: str) -> Optional[str]:
"""语义匹配提取验证码"""
match = self._semantic_pattern.search(body)
if match:
return match.group(1)
return None
def _extract_simple(self, body: str) -> Optional[str]:
"""简单匹配提取验证码"""
match = self._simple_pattern.search(body)
if match:
return match.group(1)
return None
def find_verification_code_in_emails(
self,
emails: List[EmailMessage],
target_email: Optional[str] = None,
min_timestamp: int = 0,
used_codes: Optional[set] = None,
) -> Optional[str]:
"""
从邮件列表中查找验证码
Args:
emails: 邮件列表
target_email: 目标邮箱地址
min_timestamp: 最小时间戳(用于过滤旧邮件)
used_codes: 已使用的验证码集合(用于去重)
Returns:
验证码字符串,如果未找到返回 None
"""
used_codes = used_codes or set()
for email in emails:
# 时间戳过滤
if min_timestamp > 0 and email.received_timestamp > 0:
if email.received_timestamp < min_timestamp:
logger.debug(f"跳过旧邮件: {email.subject[:50]}")
continue
# 检查是否是 OpenAI 验证邮件
if not self.is_openai_verification_email(email, target_email):
continue
# 提取验证码
code = self.extract_verification_code(email)
if code:
# 去重检查
if code in used_codes:
logger.debug(f"跳过已使用的验证码: {code}")
continue
logger.info(
f"[{target_email or 'unknown'}] 找到验证码: {code}, "
f"邮件主题: {email.subject[:30]}"
)
return code
return None
def filter_emails_by_sender(
self,
emails: List[EmailMessage],
sender_patterns: List[str],
) -> List[EmailMessage]:
"""
按发件人过滤邮件
Args:
emails: 邮件列表
sender_patterns: 发件人匹配模式列表
Returns:
过滤后的邮件列表
"""
filtered = []
for email in emails:
sender = email.sender.lower()
if any(pattern.lower() in sender for pattern in sender_patterns):
filtered.append(email)
return filtered
def filter_emails_by_subject(
self,
emails: List[EmailMessage],
keywords: List[str],
) -> List[EmailMessage]:
"""
按主题关键词过滤邮件
Args:
emails: 邮件列表
keywords: 关键词列表
Returns:
过滤后的邮件列表
"""
filtered = []
for email in emails:
subject = email.subject.lower()
if any(kw.lower() in subject for kw in keywords):
filtered.append(email)
return filtered
# 全局解析器实例
_parser: Optional[EmailParser] = None
def get_email_parser() -> EmailParser:
"""获取全局邮件解析器实例"""
global _parser
if _parser is None:
_parser = EmailParser()
return _parser

View File

@@ -0,0 +1,312 @@
"""
健康检查和故障切换管理
"""
import logging
import threading
import time
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any
from .base import ProviderType, ProviderHealth, ProviderStatus
from .providers.base import OutlookProvider
logger = logging.getLogger(__name__)
class HealthChecker:
"""
健康检查管理器
跟踪各提供者的健康状态,管理故障切换
"""
def __init__(
self,
failure_threshold: int = 3,
disable_duration: int = 300,
recovery_check_interval: int = 60,
):
"""
初始化健康检查器
Args:
failure_threshold: 连续失败次数阈值,超过后禁用
disable_duration: 禁用时长(秒)
recovery_check_interval: 恢复检查间隔(秒)
"""
self.failure_threshold = failure_threshold
self.disable_duration = disable_duration
self.recovery_check_interval = recovery_check_interval
# 提供者健康状态: ProviderType -> ProviderHealth
self._health_status: Dict[ProviderType, ProviderHealth] = {}
self._lock = threading.Lock()
# 初始化所有提供者的健康状态
for provider_type in ProviderType:
self._health_status[provider_type] = ProviderHealth(
provider_type=provider_type
)
def get_health(self, provider_type: ProviderType) -> ProviderHealth:
"""获取提供者的健康状态"""
with self._lock:
return self._health_status.get(provider_type, ProviderHealth(provider_type=provider_type))
def record_success(self, provider_type: ProviderType):
"""记录成功操作"""
with self._lock:
health = self._health_status.get(provider_type)
if health:
health.record_success()
logger.debug(f"{provider_type.value} 记录成功")
def record_failure(self, provider_type: ProviderType, error: str):
"""记录失败操作"""
with self._lock:
health = self._health_status.get(provider_type)
if health:
health.record_failure(error)
# 检查是否需要禁用
if health.should_disable(self.failure_threshold):
health.disable(self.disable_duration)
logger.warning(
f"{provider_type.value} 已禁用 {self.disable_duration} 秒,"
f"原因: {error}"
)
def is_available(self, provider_type: ProviderType) -> bool:
"""
检查提供者是否可用
Args:
provider_type: 提供者类型
Returns:
是否可用
"""
health = self.get_health(provider_type)
# 检查是否被禁用
if health.is_disabled():
remaining = (health.disabled_until - datetime.now()).total_seconds()
logger.debug(
f"{provider_type.value} 已被禁用,剩余 {int(remaining)}"
)
return False
return health.status != ProviderStatus.DISABLED
def get_available_providers(
self,
priority_order: Optional[List[ProviderType]] = None,
) -> List[ProviderType]:
"""
获取可用的提供者列表
Args:
priority_order: 优先级顺序,默认为 [IMAP_NEW, IMAP_OLD, GRAPH_API]
Returns:
可用的提供者列表
"""
if priority_order is None:
priority_order = [
ProviderType.IMAP_NEW,
ProviderType.IMAP_OLD,
ProviderType.GRAPH_API,
]
available = []
for provider_type in priority_order:
if self.is_available(provider_type):
available.append(provider_type)
return available
def get_next_available_provider(
self,
priority_order: Optional[List[ProviderType]] = None,
) -> Optional[ProviderType]:
"""
获取下一个可用的提供者
Args:
priority_order: 优先级顺序
Returns:
可用的提供者类型,如果没有返回 None
"""
available = self.get_available_providers(priority_order)
return available[0] if available else None
def force_disable(self, provider_type: ProviderType, duration: Optional[int] = None):
"""
强制禁用提供者
Args:
provider_type: 提供者类型
duration: 禁用时长(秒),默认使用配置值
"""
with self._lock:
health = self._health_status.get(provider_type)
if health:
health.disable(duration or self.disable_duration)
logger.warning(f"{provider_type.value} 已强制禁用")
def force_enable(self, provider_type: ProviderType):
"""
强制启用提供者
Args:
provider_type: 提供者类型
"""
with self._lock:
health = self._health_status.get(provider_type)
if health:
health.enable()
logger.info(f"{provider_type.value} 已启用")
def get_all_health_status(self) -> Dict[str, Any]:
"""
获取所有提供者的健康状态
Returns:
健康状态字典
"""
with self._lock:
return {
provider_type.value: health.to_dict()
for provider_type, health in self._health_status.items()
}
def check_and_recover(self):
"""
检查并恢复被禁用的提供者
如果禁用时间已过,自动恢复提供者
"""
with self._lock:
for provider_type, health in self._health_status.items():
if health.is_disabled():
# 检查是否可以恢复
if health.disabled_until and datetime.now() >= health.disabled_until:
health.enable()
logger.info(f"{provider_type.value} 已自动恢复")
def reset_all(self):
"""重置所有提供者的健康状态"""
with self._lock:
for provider_type in ProviderType:
self._health_status[provider_type] = ProviderHealth(
provider_type=provider_type
)
logger.info("已重置所有提供者的健康状态")
class FailoverManager:
"""
故障切换管理器
管理提供者之间的自动切换
"""
def __init__(
self,
health_checker: HealthChecker,
priority_order: Optional[List[ProviderType]] = None,
):
"""
初始化故障切换管理器
Args:
health_checker: 健康检查器
priority_order: 提供者优先级顺序
"""
self.health_checker = health_checker
self.priority_order = priority_order or [
ProviderType.IMAP_NEW,
ProviderType.IMAP_OLD,
ProviderType.GRAPH_API,
]
# 当前使用的提供者索引
self._current_index = 0
self._lock = threading.Lock()
def get_current_provider(self) -> Optional[ProviderType]:
"""
获取当前提供者
Returns:
当前提供者类型,如果没有可用的返回 None
"""
available = self.health_checker.get_available_providers(self.priority_order)
if not available:
return None
with self._lock:
# 尝试使用当前索引
if self._current_index < len(available):
return available[self._current_index]
return available[0]
def switch_to_next(self) -> Optional[ProviderType]:
"""
切换到下一个提供者
Returns:
下一个提供者类型,如果没有可用的返回 None
"""
available = self.health_checker.get_available_providers(self.priority_order)
if not available:
return None
with self._lock:
self._current_index = (self._current_index + 1) % len(available)
next_provider = available[self._current_index]
logger.info(f"切换到提供者: {next_provider.value}")
return next_provider
def on_provider_success(self, provider_type: ProviderType):
"""
提供者成功时调用
Args:
provider_type: 提供者类型
"""
self.health_checker.record_success(provider_type)
# 重置索引到成功的提供者
with self._lock:
available = self.health_checker.get_available_providers(self.priority_order)
if provider_type in available:
self._current_index = available.index(provider_type)
def on_provider_failure(self, provider_type: ProviderType, error: str):
"""
提供者失败时调用
Args:
provider_type: 提供者类型
error: 错误信息
"""
self.health_checker.record_failure(provider_type, error)
def get_status(self) -> Dict[str, Any]:
"""
获取故障切换状态
Returns:
状态字典
"""
current = self.get_current_provider()
return {
"current_provider": current.value if current else None,
"priority_order": [p.value for p in self.priority_order],
"available_providers": [
p.value for p in self.health_checker.get_available_providers(self.priority_order)
],
"health_status": self.health_checker.get_all_health_status(),
}

View File

@@ -0,0 +1,29 @@
"""
Outlook 提供者模块
"""
from .base import OutlookProvider, ProviderConfig
from .imap_old import IMAPOldProvider
from .imap_new import IMAPNewProvider
from .graph_api import GraphAPIProvider
__all__ = [
'OutlookProvider',
'ProviderConfig',
'IMAPOldProvider',
'IMAPNewProvider',
'GraphAPIProvider',
]
# 提供者注册表
PROVIDER_REGISTRY = {
'imap_old': IMAPOldProvider,
'imap_new': IMAPNewProvider,
'graph_api': GraphAPIProvider,
}
def get_provider_class(provider_type: str):
"""获取提供者类"""
return PROVIDER_REGISTRY.get(provider_type)

View File

@@ -0,0 +1,180 @@
"""
Outlook 提供者抽象基类
"""
import abc
import logging
from dataclasses import dataclass
from typing import Dict, Any, List, Optional
from ..base import ProviderType, EmailMessage, ProviderHealth, ProviderStatus
from ..account import OutlookAccount
logger = logging.getLogger(__name__)
@dataclass
class ProviderConfig:
"""提供者配置"""
timeout: int = 30
max_retries: int = 3
proxy_url: Optional[str] = None
# 健康检查配置
health_failure_threshold: int = 3
health_disable_duration: int = 300 # 秒
class OutlookProvider(abc.ABC):
"""
Outlook 提供者抽象基类
定义所有提供者必须实现的接口
"""
def __init__(
self,
account: OutlookAccount,
config: Optional[ProviderConfig] = None,
):
"""
初始化提供者
Args:
account: Outlook 账户
config: 提供者配置
"""
self.account = account
self.config = config or ProviderConfig()
# 健康状态
self._health = ProviderHealth(provider_type=self.provider_type)
# 连接状态
self._connected = False
self._last_error: Optional[str] = None
@property
@abc.abstractmethod
def provider_type(self) -> ProviderType:
"""获取提供者类型"""
pass
@property
def health(self) -> ProviderHealth:
"""获取健康状态"""
return self._health
@property
def is_healthy(self) -> bool:
"""检查是否健康"""
return (
self._health.status == ProviderStatus.HEALTHY
and not self._health.is_disabled()
)
@property
def is_connected(self) -> bool:
"""检查是否已连接"""
return self._connected
@abc.abstractmethod
def connect(self) -> bool:
"""
连接到服务
Returns:
是否连接成功
"""
pass
@abc.abstractmethod
def disconnect(self):
"""断开连接"""
pass
@abc.abstractmethod
def get_recent_emails(
self,
count: int = 20,
only_unseen: bool = True,
) -> List[EmailMessage]:
"""
获取最近的邮件
Args:
count: 获取数量
only_unseen: 是否只获取未读
Returns:
邮件列表
"""
pass
@abc.abstractmethod
def test_connection(self) -> bool:
"""
测试连接是否正常
Returns:
连接是否正常
"""
pass
def record_success(self):
"""记录成功操作"""
self._health.record_success()
self._last_error = None
logger.debug(f"[{self.account.email}] {self.provider_type.value} 操作成功")
def record_failure(self, error: str):
"""记录失败操作"""
self._health.record_failure(error)
self._last_error = error
# 检查是否需要禁用
if self._health.should_disable(self.config.health_failure_threshold):
self._health.disable(self.config.health_disable_duration)
logger.warning(
f"[{self.account.email}] {self.provider_type.value} 已禁用 "
f"{self.config.health_disable_duration} 秒,原因: {error}"
)
else:
logger.warning(
f"[{self.account.email}] {self.provider_type.value} 操作失败 "
f"({self._health.failure_count}/{self.config.health_failure_threshold}): {error}"
)
def check_health(self) -> bool:
"""
检查健康状态
Returns:
是否健康可用
"""
# 检查是否被禁用
if self._health.is_disabled():
logger.debug(
f"[{self.account.email}] {self.provider_type.value} 已被禁用,"
f"将在 {self._health.disabled_until} 后恢复"
)
return False
return self._health.status in (ProviderStatus.HEALTHY, ProviderStatus.DEGRADED)
def __enter__(self):
"""上下文管理器入口"""
self.connect()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""上下文管理器出口"""
self.disconnect()
return False
def __str__(self) -> str:
"""字符串表示"""
return f"{self.__class__.__name__}({self.account.email})"
def __repr__(self) -> str:
return self.__str__()

View File

@@ -0,0 +1,249 @@
"""
Graph API 提供者
使用 Microsoft Graph REST API
"""
import json
import logging
from typing import List, Optional
from datetime import datetime
from curl_cffi import requests as _requests
from ..base import ProviderType, EmailMessage
from ..account import OutlookAccount
from ..token_manager import TokenManager
from .base import OutlookProvider, ProviderConfig
logger = logging.getLogger(__name__)
class GraphAPIProvider(OutlookProvider):
"""
Graph API 提供者
使用 Microsoft Graph REST API 获取邮件
需要 graph.microsoft.com/.default scope
"""
# Graph API 端点
GRAPH_API_BASE = "https://graph.microsoft.com/v1.0"
MESSAGES_ENDPOINT = "/me/mailFolders/inbox/messages"
@property
def provider_type(self) -> ProviderType:
return ProviderType.GRAPH_API
def __init__(
self,
account: OutlookAccount,
config: Optional[ProviderConfig] = None,
):
super().__init__(account, config)
# Token 管理器
self._token_manager: Optional[TokenManager] = None
# 注意Graph API 必须使用 OAuth2
if not account.has_oauth():
logger.warning(
f"[{self.account.email}] Graph API 提供者需要 OAuth2 配置 "
f"(client_id + refresh_token)"
)
def connect(self) -> bool:
"""
验证连接(获取 Token
Returns:
是否连接成功
"""
if not self.account.has_oauth():
error = "Graph API 需要 OAuth2 配置"
self.record_failure(error)
logger.error(f"[{self.account.email}] {error}")
return False
if not self._token_manager:
self._token_manager = TokenManager(
self.account,
ProviderType.GRAPH_API,
self.config.proxy_url,
self.config.timeout,
)
# 尝试获取 Token
token = self._token_manager.get_access_token()
if token:
self._connected = True
self.record_success()
logger.info(f"[{self.account.email}] Graph API 连接成功")
return True
return False
def disconnect(self):
"""断开连接(清除状态)"""
self._connected = False
def get_recent_emails(
self,
count: int = 20,
only_unseen: bool = True,
) -> List[EmailMessage]:
"""
获取最近的邮件
Args:
count: 获取数量
only_unseen: 是否只获取未读
Returns:
邮件列表
"""
if not self._connected:
if not self.connect():
return []
try:
# 获取 Access Token
token = self._token_manager.get_access_token()
if not token:
self.record_failure("无法获取 Access Token")
return []
# 构建 API 请求
url = f"{self.GRAPH_API_BASE}{self.MESSAGES_ENDPOINT}"
params = {
"$top": count,
"$select": "id,subject,from,toRecipients,receivedDateTime,isRead,hasAttachments,bodyPreview,body",
"$orderby": "receivedDateTime desc",
}
# 只获取未读邮件
if only_unseen:
params["$filter"] = "isRead eq false"
# 构建代理配置
proxies = None
if self.config.proxy_url:
proxies = {"http": self.config.proxy_url, "https": self.config.proxy_url}
# 发送请求curl_cffi 自动对 params 进行 URL 编码)
resp = _requests.get(
url,
params=params,
headers={
"Authorization": f"Bearer {token}",
"Accept": "application/json",
"Prefer": "outlook.body-content-type='text'",
},
proxies=proxies,
timeout=self.config.timeout,
impersonate="chrome110",
)
if resp.status_code == 401:
# Token 失效,清除缓存
if self._token_manager:
self._token_manager.clear_cache()
self.record_failure(f"HTTP 401: Token 失效")
logger.error(f"[{self.account.email}] Graph API Token 失效")
return []
if resp.status_code != 200:
error_body = resp.text[:200]
self.record_failure(f"HTTP {resp.status_code}: {error_body}")
logger.error(f"[{self.account.email}] Graph API 请求失败: HTTP {resp.status_code}")
return []
data = resp.json()
# 解析邮件
messages = data.get("value", [])
emails = []
for msg in messages:
try:
email_msg = self._parse_graph_message(msg)
if email_msg:
emails.append(email_msg)
except Exception as e:
logger.warning(f"[{self.account.email}] 解析 Graph API 邮件失败: {e}")
self.record_success()
return emails
except Exception as e:
self.record_failure(str(e))
logger.error(f"[{self.account.email}] Graph API 获取邮件失败: {e}")
return []
def _parse_graph_message(self, msg: dict) -> Optional[EmailMessage]:
"""
解析 Graph API 消息
Args:
msg: Graph API 消息对象
Returns:
EmailMessage 对象
"""
# 解析发件人
from_info = msg.get("from", {})
sender_info = from_info.get("emailAddress", {})
sender = sender_info.get("address", "")
# 解析收件人
recipients = []
for recipient in msg.get("toRecipients", []):
addr_info = recipient.get("emailAddress", {})
addr = addr_info.get("address", "")
if addr:
recipients.append(addr)
# 解析日期
received_at = None
received_timestamp = 0
try:
date_str = msg.get("receivedDateTime", "")
if date_str:
# ISO 8601 格式
received_at = datetime.fromisoformat(date_str.replace("Z", "+00:00"))
received_timestamp = int(received_at.timestamp())
except Exception:
pass
# 获取正文
body_info = msg.get("body", {})
body = body_info.get("content", "")
body_preview = msg.get("bodyPreview", "")
return EmailMessage(
id=msg.get("id", ""),
subject=msg.get("subject", ""),
sender=sender,
recipients=recipients,
body=body,
body_preview=body_preview,
received_at=received_at,
received_timestamp=received_timestamp,
is_read=msg.get("isRead", False),
has_attachments=msg.get("hasAttachments", False),
)
def test_connection(self) -> bool:
"""
测试 Graph API 连接
Returns:
连接是否正常
"""
try:
# 尝试获取一封邮件来测试连接
emails = self.get_recent_emails(count=1, only_unseen=False)
return True
except Exception as e:
logger.warning(f"[{self.account.email}] Graph API 连接测试失败: {e}")
return False

View File

@@ -0,0 +1,231 @@
"""
新版 IMAP 提供者
使用 outlook.live.com 服务器和 login.microsoftonline.com/consumers Token 端点
"""
import email
import imaplib
import logging
from email.header import decode_header
from email.utils import parsedate_to_datetime
from typing import List, Optional
from ..base import ProviderType, EmailMessage
from ..account import OutlookAccount
from ..token_manager import TokenManager
from .base import OutlookProvider, ProviderConfig
from .imap_old import IMAPOldProvider
logger = logging.getLogger(__name__)
class IMAPNewProvider(OutlookProvider):
"""
新版 IMAP 提供者
使用 outlook.live.com:993 和 login.microsoftonline.com/consumers Token 端点
需要 IMAP.AccessAsUser.All scope
"""
# IMAP 服务器配置
IMAP_HOST = "outlook.live.com"
IMAP_PORT = 993
@property
def provider_type(self) -> ProviderType:
return ProviderType.IMAP_NEW
def __init__(
self,
account: OutlookAccount,
config: Optional[ProviderConfig] = None,
):
super().__init__(account, config)
# IMAP 连接
self._conn: Optional[imaplib.IMAP4_SSL] = None
# Token 管理器
self._token_manager: Optional[TokenManager] = None
# 注意:新版 IMAP 必须使用 OAuth2
if not account.has_oauth():
logger.warning(
f"[{self.account.email}] 新版 IMAP 提供者需要 OAuth2 配置 "
f"(client_id + refresh_token)"
)
def connect(self) -> bool:
"""
连接到 IMAP 服务器
Returns:
是否连接成功
"""
if self._connected and self._conn:
try:
self._conn.noop()
return True
except Exception:
self.disconnect()
# 新版 IMAP 必须使用 OAuth2无 OAuth 时静默跳过,不记录健康失败
if not self.account.has_oauth():
logger.debug(f"[{self.account.email}] 跳过 IMAP_NEW无 OAuth")
return False
try:
logger.debug(f"[{self.account.email}] 正在连接 IMAP ({self.IMAP_HOST})...")
# 创建连接
self._conn = imaplib.IMAP4_SSL(
self.IMAP_HOST,
self.IMAP_PORT,
timeout=self.config.timeout,
)
# XOAUTH2 认证
if self._authenticate_xoauth2():
self._connected = True
self.record_success()
logger.info(f"[{self.account.email}] 新版 IMAP 连接成功 (XOAUTH2)")
return True
return False
except Exception as e:
self.disconnect()
self.record_failure(str(e))
logger.error(f"[{self.account.email}] 新版 IMAP 连接失败: {e}")
return False
def _authenticate_xoauth2(self) -> bool:
"""
使用 XOAUTH2 认证
Returns:
是否认证成功
"""
if not self._token_manager:
self._token_manager = TokenManager(
self.account,
ProviderType.IMAP_NEW,
self.config.proxy_url,
self.config.timeout,
)
# 获取 Access Token
token = self._token_manager.get_access_token()
if not token:
logger.error(f"[{self.account.email}] 获取 IMAP Token 失败")
return False
try:
# 构建 XOAUTH2 认证字符串
auth_string = f"user={self.account.email}\x01auth=Bearer {token}\x01\x01"
self._conn.authenticate("XOAUTH2", lambda _: auth_string.encode("utf-8"))
return True
except Exception as e:
logger.error(f"[{self.account.email}] XOAUTH2 认证异常: {e}")
# 清除缓存的 Token
self._token_manager.clear_cache()
return False
def disconnect(self):
"""断开 IMAP 连接"""
if self._conn:
try:
self._conn.close()
except Exception:
pass
try:
self._conn.logout()
except Exception:
pass
self._conn = None
self._connected = False
def get_recent_emails(
self,
count: int = 20,
only_unseen: bool = True,
) -> List[EmailMessage]:
"""
获取最近的邮件
Args:
count: 获取数量
only_unseen: 是否只获取未读
Returns:
邮件列表
"""
if not self._connected:
if not self.connect():
return []
try:
# 选择收件箱
self._conn.select("INBOX", readonly=True)
# 搜索邮件
flag = "UNSEEN" if only_unseen else "ALL"
status, data = self._conn.search(None, flag)
if status != "OK" or not data or not data[0]:
return []
# 获取最新的邮件 ID
ids = data[0].split()
recent_ids = ids[-count:][::-1]
emails = []
for msg_id in recent_ids:
try:
email_msg = self._fetch_email(msg_id)
if email_msg:
emails.append(email_msg)
except Exception as e:
logger.warning(f"[{self.account.email}] 解析邮件失败 (ID: {msg_id}): {e}")
return emails
except Exception as e:
self.record_failure(str(e))
logger.error(f"[{self.account.email}] 获取邮件失败: {e}")
return []
def _fetch_email(self, msg_id: bytes) -> Optional[EmailMessage]:
"""获取并解析单封邮件"""
status, data = self._conn.fetch(msg_id, "(RFC822)")
if status != "OK" or not data or not data[0]:
return None
raw = b""
for part in data:
if isinstance(part, tuple) and len(part) > 1:
raw = part[1]
break
if not raw:
return None
return self._parse_email(raw)
@staticmethod
def _parse_email(raw: bytes) -> EmailMessage:
"""解析原始邮件"""
# 使用旧版提供者的解析方法
return IMAPOldProvider._parse_email(raw)
def test_connection(self) -> bool:
"""测试 IMAP 连接"""
try:
with self:
self._conn.select("INBOX", readonly=True)
self._conn.search(None, "ALL")
return True
except Exception as e:
logger.warning(f"[{self.account.email}] 新版 IMAP 连接测试失败: {e}")
return False

View File

@@ -0,0 +1,345 @@
"""
旧版 IMAP 提供者
使用 outlook.office365.com 服务器和 login.live.com Token 端点
"""
import email
import imaplib
import logging
from email.header import decode_header
from email.utils import parsedate_to_datetime
from typing import List, Optional
from ..base import ProviderType, EmailMessage
from ..account import OutlookAccount
from ..token_manager import TokenManager
from .base import OutlookProvider, ProviderConfig
logger = logging.getLogger(__name__)
class IMAPOldProvider(OutlookProvider):
"""
旧版 IMAP 提供者
使用 outlook.office365.com:993 和 login.live.com Token 端点
"""
# IMAP 服务器配置
IMAP_HOST = "outlook.office365.com"
IMAP_PORT = 993
@property
def provider_type(self) -> ProviderType:
return ProviderType.IMAP_OLD
def __init__(
self,
account: OutlookAccount,
config: Optional[ProviderConfig] = None,
):
super().__init__(account, config)
# IMAP 连接
self._conn: Optional[imaplib.IMAP4_SSL] = None
# Token 管理器
self._token_manager: Optional[TokenManager] = None
def connect(self) -> bool:
"""
连接到 IMAP 服务器
Returns:
是否连接成功
"""
if self._connected and self._conn:
# 检查现有连接
try:
self._conn.noop()
return True
except Exception:
self.disconnect()
try:
logger.debug(f"[{self.account.email}] 正在连接 IMAP ({self.IMAP_HOST})...")
# 创建连接
self._conn = imaplib.IMAP4_SSL(
self.IMAP_HOST,
self.IMAP_PORT,
timeout=self.config.timeout,
)
# 尝试 XOAUTH2 认证
if self.account.has_oauth():
if self._authenticate_xoauth2():
self._connected = True
self.record_success()
logger.info(f"[{self.account.email}] IMAP 连接成功 (XOAUTH2)")
return True
else:
logger.warning(f"[{self.account.email}] XOAUTH2 认证失败,尝试密码认证")
# 密码认证
if self.account.password:
self._conn.login(self.account.email, self.account.password)
self._connected = True
self.record_success()
logger.info(f"[{self.account.email}] IMAP 连接成功 (密码认证)")
return True
raise ValueError("没有可用的认证方式")
except Exception as e:
self.disconnect()
self.record_failure(str(e))
logger.error(f"[{self.account.email}] IMAP 连接失败: {e}")
return False
def _authenticate_xoauth2(self) -> bool:
"""
使用 XOAUTH2 认证
Returns:
是否认证成功
"""
if not self._token_manager:
self._token_manager = TokenManager(
self.account,
ProviderType.IMAP_OLD,
self.config.proxy_url,
self.config.timeout,
)
# 获取 Access Token
token = self._token_manager.get_access_token()
if not token:
return False
try:
# 构建 XOAUTH2 认证字符串
auth_string = f"user={self.account.email}\x01auth=Bearer {token}\x01\x01"
self._conn.authenticate("XOAUTH2", lambda _: auth_string.encode("utf-8"))
return True
except Exception as e:
logger.debug(f"[{self.account.email}] XOAUTH2 认证异常: {e}")
# 清除缓存的 Token
self._token_manager.clear_cache()
return False
def disconnect(self):
"""断开 IMAP 连接"""
if self._conn:
try:
self._conn.close()
except Exception:
pass
try:
self._conn.logout()
except Exception:
pass
self._conn = None
self._connected = False
def get_recent_emails(
self,
count: int = 20,
only_unseen: bool = True,
) -> List[EmailMessage]:
"""
获取最近的邮件
Args:
count: 获取数量
only_unseen: 是否只获取未读
Returns:
邮件列表
"""
if not self._connected:
if not self.connect():
return []
try:
# 选择收件箱
self._conn.select("INBOX", readonly=True)
# 搜索邮件
flag = "UNSEEN" if only_unseen else "ALL"
status, data = self._conn.search(None, flag)
if status != "OK" or not data or not data[0]:
return []
# 获取最新的邮件 ID
ids = data[0].split()
recent_ids = ids[-count:][::-1] # 倒序,最新的在前
emails = []
for msg_id in recent_ids:
try:
email_msg = self._fetch_email(msg_id)
if email_msg:
emails.append(email_msg)
except Exception as e:
logger.warning(f"[{self.account.email}] 解析邮件失败 (ID: {msg_id}): {e}")
return emails
except Exception as e:
self.record_failure(str(e))
logger.error(f"[{self.account.email}] 获取邮件失败: {e}")
return []
def _fetch_email(self, msg_id: bytes) -> Optional[EmailMessage]:
"""
获取并解析单封邮件
Args:
msg_id: 邮件 ID
Returns:
EmailMessage 对象,失败返回 None
"""
status, data = self._conn.fetch(msg_id, "(RFC822)")
if status != "OK" or not data or not data[0]:
return None
# 获取原始邮件内容
raw = b""
for part in data:
if isinstance(part, tuple) and len(part) > 1:
raw = part[1]
break
if not raw:
return None
return self._parse_email(raw)
@staticmethod
def _parse_email(raw: bytes) -> EmailMessage:
"""
解析原始邮件
Args:
raw: 原始邮件数据
Returns:
EmailMessage 对象
"""
# 移除 BOM
if raw.startswith(b"\xef\xbb\xbf"):
raw = raw[3:]
msg = email.message_from_bytes(raw)
# 解析邮件头
subject = IMAPOldProvider._decode_header(msg.get("Subject", ""))
sender = IMAPOldProvider._decode_header(msg.get("From", ""))
to = IMAPOldProvider._decode_header(msg.get("To", ""))
delivered_to = IMAPOldProvider._decode_header(msg.get("Delivered-To", ""))
x_original_to = IMAPOldProvider._decode_header(msg.get("X-Original-To", ""))
date_str = IMAPOldProvider._decode_header(msg.get("Date", ""))
# 提取正文
body = IMAPOldProvider._extract_body(msg)
# 解析日期
received_timestamp = 0
received_at = None
try:
if date_str:
received_at = parsedate_to_datetime(date_str)
received_timestamp = int(received_at.timestamp())
except Exception:
pass
# 构建收件人列表
recipients = [r for r in [to, delivered_to, x_original_to] if r]
return EmailMessage(
id=msg.get("Message-ID", ""),
subject=subject,
sender=sender,
recipients=recipients,
body=body,
received_at=received_at,
received_timestamp=received_timestamp,
is_read=False, # 搜索的是未读邮件
raw_data=raw[:500] if len(raw) > 500 else raw,
)
@staticmethod
def _decode_header(header: str) -> str:
"""解码邮件头"""
if not header:
return ""
parts = []
for chunk, encoding in decode_header(header):
if isinstance(chunk, bytes):
try:
decoded = chunk.decode(encoding or "utf-8", errors="replace")
parts.append(decoded)
except Exception:
parts.append(chunk.decode("utf-8", errors="replace"))
else:
parts.append(str(chunk))
return "".join(parts).strip()
@staticmethod
def _extract_body(msg) -> str:
"""提取邮件正文"""
import html as html_module
import re
texts = []
parts = msg.walk() if msg.is_multipart() else [msg]
for part in parts:
content_type = part.get_content_type()
if content_type not in ("text/plain", "text/html"):
continue
payload = part.get_payload(decode=True)
if not payload:
continue
charset = part.get_content_charset() or "utf-8"
try:
text = payload.decode(charset, errors="replace")
except LookupError:
text = payload.decode("utf-8", errors="replace")
# 如果是 HTML移除标签
if "<html" in text.lower():
text = re.sub(r"<[^>]+>", " ", text)
texts.append(text)
# 合并并清理文本
combined = " ".join(texts)
combined = html_module.unescape(combined)
combined = re.sub(r"\s+", " ", combined).strip()
return combined
def test_connection(self) -> bool:
"""
测试 IMAP 连接
Returns:
连接是否正常
"""
try:
with self:
self._conn.select("INBOX", readonly=True)
self._conn.search(None, "ALL")
return True
except Exception as e:
logger.warning(f"[{self.account.email}] IMAP 连接测试失败: {e}")
return False

View File

@@ -0,0 +1,485 @@
"""
Outlook 邮箱服务主类
支持多种 IMAP/API 连接方式,自动故障切换
"""
import logging
import threading
import time
from typing import Optional, Dict, Any, List
from ..base import BaseEmailService, EmailServiceError, EmailServiceStatus, EmailServiceType
from ...config.constants import EmailServiceType as ServiceType
from ...config.settings import get_settings
from .account import OutlookAccount
from .base import ProviderType, EmailMessage
from .email_parser import EmailParser, get_email_parser
from .health_checker import HealthChecker, FailoverManager
from .providers.base import OutlookProvider, ProviderConfig
from .providers.imap_old import IMAPOldProvider
from .providers.imap_new import IMAPNewProvider
from .providers.graph_api import GraphAPIProvider
logger = logging.getLogger(__name__)
# 默认提供者优先级
DEFAULT_PROVIDER_PRIORITY = [
ProviderType.GRAPH_API,
ProviderType.IMAP_NEW,
ProviderType.IMAP_OLD,
]
def get_email_code_settings() -> dict:
"""获取验证码等待配置"""
settings = get_settings()
return {
"timeout": settings.email_code_timeout,
"poll_interval": settings.email_code_poll_interval,
}
class OutlookService(BaseEmailService):
"""
Outlook 邮箱服务
支持多种 IMAP/API 连接方式,自动故障切换
"""
def __init__(self, config: Dict[str, Any] = None, name: str = None):
"""
初始化 Outlook 服务
Args:
config: 配置字典,支持以下键:
- accounts: Outlook 账户列表
- provider_priority: 提供者优先级列表
- health_failure_threshold: 连续失败次数阈值
- health_disable_duration: 禁用时长(秒)
- timeout: 请求超时时间
- proxy_url: 代理 URL
name: 服务名称
"""
super().__init__(ServiceType.OUTLOOK, name)
# 默认配置
default_config = {
"accounts": [],
"provider_priority": [p.value for p in DEFAULT_PROVIDER_PRIORITY],
"health_failure_threshold": 5,
"health_disable_duration": 60,
"timeout": 30,
"proxy_url": None,
}
self.config = {**default_config, **(config or {})}
# 解析提供者优先级
self.provider_priority = [
ProviderType(p) for p in self.config.get("provider_priority", [])
]
if not self.provider_priority:
self.provider_priority = DEFAULT_PROVIDER_PRIORITY
# 提供者配置
self.provider_config = ProviderConfig(
timeout=self.config.get("timeout", 30),
proxy_url=self.config.get("proxy_url"),
health_failure_threshold=self.config.get("health_failure_threshold", 3),
health_disable_duration=self.config.get("health_disable_duration", 300),
)
# 获取默认 client_id供无 client_id 的账户使用)
try:
_default_client_id = get_settings().outlook_default_client_id
except Exception:
_default_client_id = "24d9a0ed-8787-4584-883c-2fd79308940a"
# 解析账户
self.accounts: List[OutlookAccount] = []
self._current_account_index = 0
self._account_lock = threading.Lock()
# 支持两种配置格式
if "email" in self.config and "password" in self.config:
account = OutlookAccount.from_config(self.config)
if not account.client_id and _default_client_id:
account.client_id = _default_client_id
if account.validate():
self.accounts.append(account)
else:
for account_config in self.config.get("accounts", []):
account = OutlookAccount.from_config(account_config)
if not account.client_id and _default_client_id:
account.client_id = _default_client_id
if account.validate():
self.accounts.append(account)
if not self.accounts:
logger.warning("未配置有效的 Outlook 账户")
# 健康检查器和故障切换管理器
self.health_checker = HealthChecker(
failure_threshold=self.provider_config.health_failure_threshold,
disable_duration=self.provider_config.health_disable_duration,
)
self.failover_manager = FailoverManager(
health_checker=self.health_checker,
priority_order=self.provider_priority,
)
# 邮件解析器
self.email_parser = get_email_parser()
# 提供者实例缓存: (email, provider_type) -> OutlookProvider
self._providers: Dict[tuple, OutlookProvider] = {}
self._provider_lock = threading.Lock()
# IMAP 连接限制(防止限流)
self._imap_semaphore = threading.Semaphore(5)
# 验证码去重机制
self._used_codes: Dict[str, set] = {}
def _get_provider(
self,
account: OutlookAccount,
provider_type: ProviderType,
) -> OutlookProvider:
"""
获取或创建提供者实例
Args:
account: Outlook 账户
provider_type: 提供者类型
Returns:
提供者实例
"""
cache_key = (account.email.lower(), provider_type)
with self._provider_lock:
if cache_key not in self._providers:
provider = self._create_provider(account, provider_type)
self._providers[cache_key] = provider
return self._providers[cache_key]
def _create_provider(
self,
account: OutlookAccount,
provider_type: ProviderType,
) -> OutlookProvider:
"""
创建提供者实例
Args:
account: Outlook 账户
provider_type: 提供者类型
Returns:
提供者实例
"""
if provider_type == ProviderType.IMAP_OLD:
return IMAPOldProvider(account, self.provider_config)
elif provider_type == ProviderType.IMAP_NEW:
return IMAPNewProvider(account, self.provider_config)
elif provider_type == ProviderType.GRAPH_API:
return GraphAPIProvider(account, self.provider_config)
else:
raise ValueError(f"未知的提供者类型: {provider_type}")
def _get_provider_priority_for_account(self, account: OutlookAccount) -> List[ProviderType]:
"""根据账户是否有 OAuth返回适合的提供者优先级列表"""
if account.has_oauth():
return self.provider_priority
else:
# 无 OAuth直接走旧版 IMAP密码认证跳过需要 OAuth 的提供者
return [ProviderType.IMAP_OLD]
def _try_providers_for_emails(
self,
account: OutlookAccount,
count: int = 20,
only_unseen: bool = True,
) -> List[EmailMessage]:
"""
尝试多个提供者获取邮件
Args:
account: Outlook 账户
count: 获取数量
only_unseen: 是否只获取未读
Returns:
邮件列表
"""
errors = []
# 根据账户类型选择合适的提供者优先级
priority = self._get_provider_priority_for_account(account)
# 按优先级尝试各提供者
for provider_type in priority:
# 检查提供者是否可用
if not self.health_checker.is_available(provider_type):
logger.debug(
f"[{account.email}] {provider_type.value} 不可用,跳过"
)
continue
try:
provider = self._get_provider(account, provider_type)
with self._imap_semaphore:
with provider:
emails = provider.get_recent_emails(count, only_unseen)
if emails:
# 成功获取邮件
self.health_checker.record_success(provider_type)
logger.debug(
f"[{account.email}] {provider_type.value} 获取到 {len(emails)} 封邮件"
)
return emails
except Exception as e:
error_msg = str(e)
errors.append(f"{provider_type.value}: {error_msg}")
self.health_checker.record_failure(provider_type, error_msg)
logger.warning(
f"[{account.email}] {provider_type.value} 获取邮件失败: {e}"
)
logger.error(
f"[{account.email}] 所有提供者都失败: {'; '.join(errors)}"
)
return []
def create_email(self, config: Dict[str, Any] = None) -> Dict[str, Any]:
"""
选择可用的 Outlook 账户
Args:
config: 配置参数(未使用)
Returns:
包含邮箱信息的字典
"""
if not self.accounts:
self.update_status(False, EmailServiceError("没有可用的 Outlook 账户"))
raise EmailServiceError("没有可用的 Outlook 账户")
# 轮询选择账户
with self._account_lock:
account = self.accounts[self._current_account_index]
self._current_account_index = (self._current_account_index + 1) % len(self.accounts)
email_info = {
"email": account.email,
"service_id": account.email,
"account": {
"email": account.email,
"has_oauth": account.has_oauth()
}
}
logger.info(f"选择 Outlook 账户: {account.email}")
self.update_status(True)
return email_info
def get_verification_code(
self,
email: str,
email_id: str = None,
timeout: int = None,
pattern: str = None,
otp_sent_at: Optional[float] = None,
) -> Optional[str]:
"""
从 Outlook 邮箱获取验证码
Args:
email: 邮箱地址
email_id: 未使用
timeout: 超时时间(秒)
pattern: 验证码正则表达式(未使用)
otp_sent_at: OTP 发送时间戳
Returns:
验证码字符串
"""
# 查找对应的账户
account = None
for acc in self.accounts:
if acc.email.lower() == email.lower():
account = acc
break
if not account:
self.update_status(False, EmailServiceError(f"未找到邮箱对应的账户: {email}"))
return None
# 获取验证码等待配置
code_settings = get_email_code_settings()
actual_timeout = timeout or code_settings["timeout"]
poll_interval = code_settings["poll_interval"]
logger.info(
f"[{email}] 开始获取验证码,超时 {actual_timeout}s"
f"提供者优先级: {[p.value for p in self.provider_priority]}"
)
# 初始化验证码去重集合
if email not in self._used_codes:
self._used_codes[email] = set()
used_codes = self._used_codes[email]
# 计算最小时间戳(留出 60 秒时钟偏差)
min_timestamp = (otp_sent_at - 60) if otp_sent_at else 0
start_time = time.time()
poll_count = 0
while time.time() - start_time < actual_timeout:
poll_count += 1
# 渐进式邮件检查:前 3 次只检查未读
only_unseen = poll_count <= 3
try:
# 尝试多个提供者获取邮件
emails = self._try_providers_for_emails(
account,
count=15,
only_unseen=only_unseen,
)
if emails:
logger.debug(
f"[{email}] 第 {poll_count} 次轮询获取到 {len(emails)} 封邮件"
)
# 从邮件中查找验证码
code = self.email_parser.find_verification_code_in_emails(
emails,
target_email=email,
min_timestamp=min_timestamp,
used_codes=used_codes,
)
if code:
used_codes.add(code)
elapsed = int(time.time() - start_time)
logger.info(
f"[{email}] 找到验证码: {code}"
f"总耗时 {elapsed}s轮询 {poll_count}"
)
self.update_status(True)
return code
except Exception as e:
logger.warning(f"[{email}] 检查出错: {e}")
# 等待下次轮询
time.sleep(poll_interval)
elapsed = int(time.time() - start_time)
logger.warning(f"[{email}] 验证码超时 ({actual_timeout}s),共轮询 {poll_count}")
return None
def list_emails(self, **kwargs) -> List[Dict[str, Any]]:
"""列出所有可用的 Outlook 账户"""
return [
{
"email": account.email,
"id": account.email,
"has_oauth": account.has_oauth(),
"type": "outlook"
}
for account in self.accounts
]
def delete_email(self, email_id: str) -> bool:
"""删除邮箱Outlook 不支持删除账户)"""
logger.warning(f"Outlook 服务不支持删除账户: {email_id}")
return False
def check_health(self) -> bool:
"""检查 Outlook 服务是否可用"""
if not self.accounts:
self.update_status(False, EmailServiceError("没有配置的账户"))
return False
# 测试第一个账户的连接
test_account = self.accounts[0]
# 尝试任一提供者连接
for provider_type in self.provider_priority:
try:
provider = self._get_provider(test_account, provider_type)
if provider.test_connection():
self.update_status(True)
return True
except Exception as e:
logger.warning(
f"Outlook 健康检查失败 ({test_account.email}, {provider_type.value}): {e}"
)
self.update_status(False, EmailServiceError("健康检查失败"))
return False
def get_provider_status(self) -> Dict[str, Any]:
"""获取提供者状态"""
return self.failover_manager.get_status()
def get_account_stats(self) -> Dict[str, Any]:
"""获取账户统计信息"""
total = len(self.accounts)
oauth_count = sum(1 for acc in self.accounts if acc.has_oauth())
return {
"total_accounts": total,
"oauth_accounts": oauth_count,
"password_accounts": total - oauth_count,
"accounts": [acc.to_dict() for acc in self.accounts],
"provider_status": self.get_provider_status(),
}
def add_account(self, account_config: Dict[str, Any]) -> bool:
"""添加新的 Outlook 账户"""
try:
account = OutlookAccount.from_config(account_config)
if not account.validate():
return False
self.accounts.append(account)
logger.info(f"添加 Outlook 账户: {account.email}")
return True
except Exception as e:
logger.error(f"添加 Outlook 账户失败: {e}")
return False
def remove_account(self, email: str) -> bool:
"""移除 Outlook 账户"""
for i, acc in enumerate(self.accounts):
if acc.email.lower() == email.lower():
self.accounts.pop(i)
logger.info(f"移除 Outlook 账户: {email}")
return True
return False
def reset_provider_health(self):
"""重置所有提供者的健康状态"""
self.health_checker.reset_all()
logger.info("已重置所有提供者的健康状态")
def force_provider(self, provider_type: ProviderType):
"""强制使用指定的提供者"""
self.health_checker.force_enable(provider_type)
# 禁用其他提供者
for pt in ProviderType:
if pt != provider_type:
self.health_checker.force_disable(pt, 60)
logger.info(f"已强制使用提供者: {provider_type.value}")

View File

@@ -0,0 +1,239 @@
"""
Token 管理器
支持多个 Microsoft Token 端点,自动选择合适的端点
"""
import json
import logging
import threading
import time
from typing import Dict, Optional, Any
from curl_cffi import requests as _requests
from .base import ProviderType, TokenEndpoint, TokenInfo
from .account import OutlookAccount
logger = logging.getLogger(__name__)
# 各提供者的 Scope 配置
PROVIDER_SCOPES = {
ProviderType.IMAP_OLD: "", # 旧版 IMAP 不需要特定 scope
ProviderType.IMAP_NEW: "https://outlook.office.com/IMAP.AccessAsUser.All offline_access",
ProviderType.GRAPH_API: "https://graph.microsoft.com/.default",
}
# 各提供者的 Token 端点
PROVIDER_TOKEN_URLS = {
ProviderType.IMAP_OLD: TokenEndpoint.LIVE.value,
ProviderType.IMAP_NEW: TokenEndpoint.CONSUMERS.value,
ProviderType.GRAPH_API: TokenEndpoint.COMMON.value,
}
class TokenManager:
"""
Token 管理器
支持多端点 Token 获取和缓存
"""
# Token 缓存: key = (email, provider_type) -> TokenInfo
_token_cache: Dict[tuple, TokenInfo] = {}
_cache_lock = threading.Lock()
# 默认超时时间
DEFAULT_TIMEOUT = 30
# Token 刷新提前时间(秒)
REFRESH_BUFFER = 120
def __init__(
self,
account: OutlookAccount,
provider_type: ProviderType,
proxy_url: Optional[str] = None,
timeout: int = DEFAULT_TIMEOUT,
):
"""
初始化 Token 管理器
Args:
account: Outlook 账户
provider_type: 提供者类型
proxy_url: 代理 URL可选
timeout: 请求超时时间
"""
self.account = account
self.provider_type = provider_type
self.proxy_url = proxy_url
self.timeout = timeout
# 获取端点和 Scope
self.token_url = PROVIDER_TOKEN_URLS.get(provider_type, TokenEndpoint.LIVE.value)
self.scope = PROVIDER_SCOPES.get(provider_type, "")
def get_cached_token(self) -> Optional[TokenInfo]:
"""获取缓存的 Token"""
cache_key = (self.account.email.lower(), self.provider_type)
with self._cache_lock:
token = self._token_cache.get(cache_key)
if token and not token.is_expired(self.REFRESH_BUFFER):
return token
return None
def set_cached_token(self, token: TokenInfo):
"""缓存 Token"""
cache_key = (self.account.email.lower(), self.provider_type)
with self._cache_lock:
self._token_cache[cache_key] = token
def clear_cache(self):
"""清除缓存"""
cache_key = (self.account.email.lower(), self.provider_type)
with self._cache_lock:
self._token_cache.pop(cache_key, None)
def get_access_token(self, force_refresh: bool = False) -> Optional[str]:
"""
获取 Access Token
Args:
force_refresh: 是否强制刷新
Returns:
Access Token 字符串,失败返回 None
"""
# 检查缓存
if not force_refresh:
cached = self.get_cached_token()
if cached:
logger.debug(f"[{self.account.email}] 使用缓存的 Token ({self.provider_type.value})")
return cached.access_token
# 刷新 Token
try:
token = self._refresh_token()
if token:
self.set_cached_token(token)
return token.access_token
except Exception as e:
logger.error(f"[{self.account.email}] 获取 Token 失败 ({self.provider_type.value}): {e}")
return None
def _refresh_token(self) -> Optional[TokenInfo]:
"""
刷新 Token
Returns:
TokenInfo 对象,失败返回 None
"""
if not self.account.client_id or not self.account.refresh_token:
raise ValueError("缺少 client_id 或 refresh_token")
logger.debug(f"[{self.account.email}] 正在刷新 Token ({self.provider_type.value})...")
logger.debug(f"[{self.account.email}] Token URL: {self.token_url}")
# 构建请求体
data = {
"client_id": self.account.client_id,
"refresh_token": self.account.refresh_token,
"grant_type": "refresh_token",
}
# 添加 Scope如果需要
if self.scope:
data["scope"] = self.scope
headers = {
"Content-Type": "application/x-www-form-urlencoded",
"Accept": "application/json",
}
proxies = None
if self.proxy_url:
proxies = {"http": self.proxy_url, "https": self.proxy_url}
try:
resp = _requests.post(
self.token_url,
data=data,
headers=headers,
proxies=proxies,
timeout=self.timeout,
impersonate="chrome110",
)
if resp.status_code != 200:
error_body = resp.text
logger.error(f"[{self.account.email}] Token 刷新失败: HTTP {resp.status_code}")
logger.debug(f"[{self.account.email}] 错误响应: {error_body[:500]}")
if "service abuse" in error_body.lower():
logger.warning(f"[{self.account.email}] 账号可能被封禁")
elif "invalid_grant" in error_body.lower():
logger.warning(f"[{self.account.email}] Refresh Token 已失效")
return None
response_data = resp.json()
# 解析响应
token = TokenInfo.from_response(response_data, self.scope)
logger.info(
f"[{self.account.email}] Token 刷新成功 ({self.provider_type.value}), "
f"有效期 {int(token.expires_at - time.time())}"
)
return token
except json.JSONDecodeError as e:
logger.error(f"[{self.account.email}] JSON 解析错误: {e}")
return None
except Exception as e:
logger.error(f"[{self.account.email}] 未知错误: {e}")
return None
@classmethod
def clear_all_cache(cls):
"""清除所有 Token 缓存"""
with cls._cache_lock:
cls._token_cache.clear()
logger.info("已清除所有 Token 缓存")
@classmethod
def get_cache_stats(cls) -> Dict[str, Any]:
"""获取缓存统计"""
with cls._cache_lock:
return {
"cache_size": len(cls._token_cache),
"entries": [
{
"email": key[0],
"provider": key[1].value,
}
for key in cls._token_cache.keys()
],
}
def create_token_manager(
account: OutlookAccount,
provider_type: ProviderType,
proxy_url: Optional[str] = None,
timeout: int = TokenManager.DEFAULT_TIMEOUT,
) -> TokenManager:
"""
创建 Token 管理器的工厂函数
Args:
account: Outlook 账户
provider_type: 提供者类型
proxy_url: 代理 URL
timeout: 超时时间
Returns:
TokenManager 实例
"""
return TokenManager(account, provider_type, proxy_url, timeout)

View File

@@ -734,3 +734,37 @@ async def test_cpa_connection(request: CPATestRequest):
"success": success,
"message": message
}
# ============== Outlook 设置 ==============
class OutlookSettings(BaseModel):
"""Outlook 设置"""
default_client_id: Optional[str] = None
@router.get("/outlook")
async def get_outlook_settings():
"""获取 Outlook 设置"""
settings = get_settings()
return {
"default_client_id": settings.outlook_default_client_id,
"provider_priority": settings.outlook_provider_priority,
"health_failure_threshold": settings.outlook_health_failure_threshold,
"health_disable_duration": settings.outlook_health_disable_duration,
}
@router.post("/outlook")
async def update_outlook_settings(request: OutlookSettings):
"""更新 Outlook 设置"""
update_dict = {}
if request.default_client_id is not None:
update_dict["outlook_default_client_id"] = request.default_client_id
if update_dict:
update_settings(**update_dict)
return {"success": True, "message": "Outlook 设置已更新"}

170
src/web/routes/websocket.py Normal file
View File

@@ -0,0 +1,170 @@
"""
WebSocket 路由
提供实时日志推送和任务状态更新
"""
import asyncio
import logging
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from ..task_manager import task_manager
logger = logging.getLogger(__name__)
router = APIRouter()
@router.websocket("/ws/task/{task_uuid}")
async def task_websocket(websocket: WebSocket, task_uuid: str):
"""
任务日志 WebSocket
消息格式:
- 服务端发送: {"type": "log", "task_uuid": "xxx", "message": "...", "timestamp": "..."}
- 服务端发送: {"type": "status", "task_uuid": "xxx", "status": "running|completed|failed|cancelled", ...}
- 客户端发送: {"type": "ping"} - 心跳
- 客户端发送: {"type": "cancel"} - 取消任务
"""
await websocket.accept()
# 注册连接(会记录当前日志数量,避免重复发送历史日志)
task_manager.register_websocket(task_uuid, websocket)
logger.info(f"WebSocket 连接已建立: {task_uuid}")
try:
# 发送当前状态
status = task_manager.get_status(task_uuid)
if status:
await websocket.send_json({
"type": "status",
"task_uuid": task_uuid,
**status
})
# 发送历史日志(只发送注册时已存在的日志,避免与实时推送重复)
history_logs = task_manager.get_unsent_logs(task_uuid, websocket)
for log in history_logs:
await websocket.send_json({
"type": "log",
"task_uuid": task_uuid,
"message": log
})
# 保持连接,等待客户端消息
while True:
try:
# 使用 wait_for 实现超时,但不是断开连接
# 而是发送心跳检测
data = await asyncio.wait_for(
websocket.receive_json(),
timeout=30.0 # 30秒超时
)
# 处理心跳
if data.get("type") == "ping":
await websocket.send_json({"type": "pong"})
# 处理取消请求
elif data.get("type") == "cancel":
task_manager.cancel_task(task_uuid)
await websocket.send_json({
"type": "status",
"task_uuid": task_uuid,
"status": "cancelling",
"message": "取消请求已提交"
})
except asyncio.TimeoutError:
# 超时,发送心跳检测
try:
await websocket.send_json({"type": "ping"})
except Exception:
# 发送失败,可能是连接断开
logger.info(f"WebSocket 心跳检测失败: {task_uuid}")
break
except WebSocketDisconnect:
logger.info(f"WebSocket 断开: {task_uuid}")
except Exception as e:
logger.error(f"WebSocket 错误: {e}")
finally:
task_manager.unregister_websocket(task_uuid, websocket)
@router.websocket("/ws/batch/{batch_id}")
async def batch_websocket(websocket: WebSocket, batch_id: str):
"""
批量任务 WebSocket
用于批量注册任务的实时状态更新
消息格式:
- 服务端发送: {"type": "log", "batch_id": "xxx", "message": "...", "timestamp": "..."}
- 服务端发送: {"type": "status", "batch_id": "xxx", "status": "running|completed|cancelled", ...}
- 客户端发送: {"type": "ping"} - 心跳
- 客户端发送: {"type": "cancel"} - 取消批量任务
"""
await websocket.accept()
# 注册连接(会记录当前日志数量,避免重复发送历史日志)
task_manager.register_batch_websocket(batch_id, websocket)
logger.info(f"批量任务 WebSocket 连接已建立: {batch_id}")
try:
# 发送当前状态
status = task_manager.get_batch_status(batch_id)
if status:
await websocket.send_json({
"type": "status",
"batch_id": batch_id,
**status
})
# 发送历史日志(只发送注册时已存在的日志,避免与实时推送重复)
history_logs = task_manager.get_unsent_batch_logs(batch_id, websocket)
for log in history_logs:
await websocket.send_json({
"type": "log",
"batch_id": batch_id,
"message": log
})
# 保持连接,等待客户端消息
while True:
try:
data = await asyncio.wait_for(
websocket.receive_json(),
timeout=30.0
)
# 处理心跳
if data.get("type") == "ping":
await websocket.send_json({"type": "pong"})
# 处理取消请求
elif data.get("type") == "cancel":
task_manager.cancel_batch(batch_id)
await websocket.send_json({
"type": "status",
"batch_id": batch_id,
"status": "cancelling",
"message": "取消请求已提交"
})
except asyncio.TimeoutError:
# 超时,发送心跳检测
try:
await websocket.send_json({"type": "ping"})
except Exception:
logger.info(f"批量任务 WebSocket 心跳检测失败: {batch_id}")
break
except WebSocketDisconnect:
logger.info(f"批量任务 WebSocket 断开: {batch_id}")
except Exception as e:
logger.error(f"批量任务 WebSocket 错误: {e}")
finally:
task_manager.unregister_batch_websocket(batch_id, websocket)

361
src/web/task_manager.py Normal file
View File

@@ -0,0 +1,361 @@
"""
任务管理器
负责管理后台任务、日志队列和 WebSocket 推送
"""
import asyncio
import logging
import threading
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, Optional, List, Callable, Any
from collections import defaultdict
from datetime import datetime
logger = logging.getLogger(__name__)
# 全局线程池
_executor = ThreadPoolExecutor(max_workers=4, thread_name_prefix="reg_worker")
# 任务日志队列 (task_uuid -> list of logs)
_log_queues: Dict[str, List[str]] = defaultdict(list)
_log_locks: Dict[str, threading.Lock] = defaultdict(threading.Lock)
# WebSocket 连接管理 (task_uuid -> list of websockets)
_ws_connections: Dict[str, List] = defaultdict(list)
_ws_lock = threading.Lock()
# WebSocket 已发送日志索引 (task_uuid -> {websocket: sent_count})
_ws_sent_index: Dict[str, Dict] = defaultdict(dict)
# 任务状态
_task_status: Dict[str, dict] = {}
# 任务取消标志
_task_cancelled: Dict[str, bool] = {}
# 批量任务状态 (batch_id -> dict)
_batch_status: Dict[str, dict] = {}
_batch_logs: Dict[str, List[str]] = defaultdict(list)
_batch_locks: Dict[str, threading.Lock] = defaultdict(threading.Lock)
class TaskManager:
"""任务管理器"""
def __init__(self):
self.executor = _executor
self._loop: Optional[asyncio.AbstractEventLoop] = None
def set_loop(self, loop: asyncio.AbstractEventLoop):
"""设置事件循环(在 FastAPI 启动时调用)"""
self._loop = loop
def get_loop(self) -> Optional[asyncio.AbstractEventLoop]:
"""获取事件循环"""
return self._loop
def is_cancelled(self, task_uuid: str) -> bool:
"""检查任务是否已取消"""
return _task_cancelled.get(task_uuid, False)
def cancel_task(self, task_uuid: str):
"""取消任务"""
_task_cancelled[task_uuid] = True
logger.info(f"任务 {task_uuid} 已标记为取消")
def add_log(self, task_uuid: str, log_message: str):
"""添加日志并推送到 WebSocket线程安全"""
# 先广播到 WebSocket确保实时推送
# 然后再添加到队列,这样 get_unsent_logs 不会获取到这条日志
if self._loop and self._loop.is_running():
try:
asyncio.run_coroutine_threadsafe(
self._broadcast_log(task_uuid, log_message),
self._loop
)
except Exception as e:
logger.warning(f"推送日志到 WebSocket 失败: {e}")
# 广播后再添加到队列
with _log_locks[task_uuid]:
_log_queues[task_uuid].append(log_message)
async def _broadcast_log(self, task_uuid: str, log_message: str):
"""广播日志到所有 WebSocket 连接"""
with _ws_lock:
connections = _ws_connections.get(task_uuid, []).copy()
# 注意:不在这里更新 sent_index因为日志已经通过 add_log 添加到队列
# sent_index 应该只在 get_unsent_logs 或发送历史日志时更新
# 这样可以避免竞态条件
for ws in connections:
try:
await ws.send_json({
"type": "log",
"task_uuid": task_uuid,
"message": log_message,
"timestamp": datetime.utcnow().isoformat()
})
# 发送成功后更新 sent_index
with _ws_lock:
ws_id = id(ws)
if task_uuid in _ws_sent_index and ws_id in _ws_sent_index[task_uuid]:
_ws_sent_index[task_uuid][ws_id] += 1
except Exception as e:
logger.warning(f"WebSocket 发送失败: {e}")
async def broadcast_status(self, task_uuid: str, status: str, **kwargs):
"""广播任务状态更新"""
with _ws_lock:
connections = _ws_connections.get(task_uuid, []).copy()
message = {
"type": "status",
"task_uuid": task_uuid,
"status": status,
"timestamp": datetime.utcnow().isoformat(),
**kwargs
}
for ws in connections:
try:
await ws.send_json(message)
except Exception as e:
logger.warning(f"WebSocket 发送状态失败: {e}")
def register_websocket(self, task_uuid: str, websocket):
"""注册 WebSocket 连接"""
with _ws_lock:
if task_uuid not in _ws_connections:
_ws_connections[task_uuid] = []
# 避免重复注册同一个连接
if websocket not in _ws_connections[task_uuid]:
_ws_connections[task_uuid].append(websocket)
# 记录已发送的日志数量,用于发送历史日志时避免重复
with _log_locks[task_uuid]:
_ws_sent_index[task_uuid][id(websocket)] = len(_log_queues.get(task_uuid, []))
logger.info(f"WebSocket 连接已注册: {task_uuid}")
else:
logger.warning(f"WebSocket 连接已存在,跳过重复注册: {task_uuid}")
def get_unsent_logs(self, task_uuid: str, websocket) -> List[str]:
"""获取未发送给该 WebSocket 的日志"""
with _ws_lock:
ws_id = id(websocket)
sent_count = _ws_sent_index.get(task_uuid, {}).get(ws_id, 0)
with _log_locks[task_uuid]:
all_logs = _log_queues.get(task_uuid, [])
unsent_logs = all_logs[sent_count:]
# 更新已发送索引
_ws_sent_index[task_uuid][ws_id] = len(all_logs)
return unsent_logs
def unregister_websocket(self, task_uuid: str, websocket):
"""注销 WebSocket 连接"""
with _ws_lock:
if task_uuid in _ws_connections:
try:
_ws_connections[task_uuid].remove(websocket)
except ValueError:
pass
# 清理已发送索引
if task_uuid in _ws_sent_index:
_ws_sent_index[task_uuid].pop(id(websocket), None)
logger.info(f"WebSocket 连接已注销: {task_uuid}")
def get_logs(self, task_uuid: str) -> List[str]:
"""获取任务的所有日志"""
with _log_locks[task_uuid]:
return _log_queues.get(task_uuid, []).copy()
def update_status(self, task_uuid: str, status: str, **kwargs):
"""更新任务状态"""
if task_uuid not in _task_status:
_task_status[task_uuid] = {}
_task_status[task_uuid]["status"] = status
_task_status[task_uuid].update(kwargs)
def get_status(self, task_uuid: str) -> Optional[dict]:
"""获取任务状态"""
return _task_status.get(task_uuid)
def cleanup_task(self, task_uuid: str):
"""清理任务数据"""
# 保留日志队列一段时间,以便后续查询
# 只清理取消标志
if task_uuid in _task_cancelled:
del _task_cancelled[task_uuid]
# ============== 批量任务管理 ==============
def init_batch(self, batch_id: str, total: int):
"""初始化批量任务"""
_batch_status[batch_id] = {
"status": "running",
"total": total,
"completed": 0,
"success": 0,
"failed": 0,
"skipped": 0,
"current_index": 0,
"finished": False
}
logger.info(f"批量任务 {batch_id} 已初始化,总数: {total}")
def add_batch_log(self, batch_id: str, log_message: str):
"""添加批量任务日志并推送"""
# 先广播到 WebSocket确保实时推送
if self._loop and self._loop.is_running():
try:
asyncio.run_coroutine_threadsafe(
self._broadcast_batch_log(batch_id, log_message),
self._loop
)
except Exception as e:
logger.warning(f"推送批量日志到 WebSocket 失败: {e}")
# 广播后再添加到队列
with _batch_locks[batch_id]:
_batch_logs[batch_id].append(log_message)
async def _broadcast_batch_log(self, batch_id: str, log_message: str):
"""广播批量任务日志"""
key = f"batch_{batch_id}"
with _ws_lock:
connections = _ws_connections.get(key, []).copy()
# 注意:不在这里更新 sent_index避免竞态条件
for ws in connections:
try:
await ws.send_json({
"type": "log",
"batch_id": batch_id,
"message": log_message,
"timestamp": datetime.utcnow().isoformat()
})
# 发送成功后更新 sent_index
with _ws_lock:
ws_id = id(ws)
if key in _ws_sent_index and ws_id in _ws_sent_index[key]:
_ws_sent_index[key][ws_id] += 1
except Exception as e:
logger.warning(f"WebSocket 发送批量日志失败: {e}")
def update_batch_status(self, batch_id: str, **kwargs):
"""更新批量任务状态"""
if batch_id not in _batch_status:
logger.warning(f"批量任务 {batch_id} 不存在")
return
_batch_status[batch_id].update(kwargs)
# 异步广播状态更新
if self._loop and self._loop.is_running():
try:
asyncio.run_coroutine_threadsafe(
self._broadcast_batch_status(batch_id),
self._loop
)
except Exception as e:
logger.warning(f"广播批量状态失败: {e}")
async def _broadcast_batch_status(self, batch_id: str):
"""广播批量任务状态"""
with _ws_lock:
connections = _ws_connections.get(f"batch_{batch_id}", []).copy()
status = _batch_status.get(batch_id, {})
for ws in connections:
try:
await ws.send_json({
"type": "status",
"batch_id": batch_id,
"timestamp": datetime.utcnow().isoformat(),
**status
})
except Exception as e:
logger.warning(f"WebSocket 发送批量状态失败: {e}")
def get_batch_status(self, batch_id: str) -> Optional[dict]:
"""获取批量任务状态"""
return _batch_status.get(batch_id)
def get_batch_logs(self, batch_id: str) -> List[str]:
"""获取批量任务日志"""
with _batch_locks[batch_id]:
return _batch_logs.get(batch_id, []).copy()
def is_batch_cancelled(self, batch_id: str) -> bool:
"""检查批量任务是否已取消"""
status = _batch_status.get(batch_id, {})
return status.get("cancelled", False)
def cancel_batch(self, batch_id: str):
"""取消批量任务"""
if batch_id in _batch_status:
_batch_status[batch_id]["cancelled"] = True
_batch_status[batch_id]["status"] = "cancelling"
logger.info(f"批量任务 {batch_id} 已标记为取消")
def register_batch_websocket(self, batch_id: str, websocket):
"""注册批量任务 WebSocket 连接"""
key = f"batch_{batch_id}"
with _ws_lock:
if key not in _ws_connections:
_ws_connections[key] = []
# 避免重复注册同一个连接
if websocket not in _ws_connections[key]:
_ws_connections[key].append(websocket)
# 记录已发送的日志数量,用于发送历史日志时避免重复
with _batch_locks[batch_id]:
_ws_sent_index[key][id(websocket)] = len(_batch_logs.get(batch_id, []))
logger.info(f"批量任务 WebSocket 连接已注册: {batch_id}")
else:
logger.warning(f"批量任务 WebSocket 连接已存在,跳过重复注册: {batch_id}")
def get_unsent_batch_logs(self, batch_id: str, websocket) -> List[str]:
"""获取未发送给该 WebSocket 的批量任务日志"""
key = f"batch_{batch_id}"
with _ws_lock:
ws_id = id(websocket)
sent_count = _ws_sent_index.get(key, {}).get(ws_id, 0)
with _batch_locks[batch_id]:
all_logs = _batch_logs.get(batch_id, [])
unsent_logs = all_logs[sent_count:]
# 更新已发送索引
_ws_sent_index[key][ws_id] = len(all_logs)
return unsent_logs
def unregister_batch_websocket(self, batch_id: str, websocket):
"""注销批量任务 WebSocket 连接"""
key = f"batch_{batch_id}"
with _ws_lock:
if key in _ws_connections:
try:
_ws_connections[key].remove(websocket)
except ValueError:
pass
# 清理已发送索引
if key in _ws_sent_index:
_ws_sent_index[key].pop(id(websocket), None)
logger.info(f"批量任务 WebSocket 连接已注销: {batch_id}")
def create_log_callback(self, task_uuid: str) -> Callable[[str], None]:
"""创建日志回调函数"""
def callback(msg: str):
self.add_log(task_uuid, msg)
return callback
def create_check_cancelled_callback(self, task_uuid: str) -> Callable[[], bool]:
"""创建检查取消的回调函数"""
def callback() -> bool:
return self.is_cancelled(task_uuid)
return callback
# 全局实例
task_manager = TaskManager()