diff --git a/src/config/constants.py b/src/config/constants.py index d3553d0..79ead1c 100644 --- a/src/config/constants.py +++ b/src/config/constants.py @@ -332,4 +332,38 @@ TIME_CONSTANTS = { "HOUR": 3600, "DAY": 86400, "WEEK": 604800, -} \ No newline at end of file +} + + +# ============================================================================ +# Microsoft/Outlook 相关常量 +# ============================================================================ + +# Microsoft OAuth2 Token 端点 +MICROSOFT_TOKEN_ENDPOINTS = { + # 旧版 IMAP 使用的端点 + "LIVE": "https://login.live.com/oauth20_token.srf", + # 新版 IMAP 使用的端点(需要特定 scope) + "CONSUMERS": "https://login.microsoftonline.com/consumers/oauth2/v2.0/token", + # Graph API 使用的端点 + "COMMON": "https://login.microsoftonline.com/common/oauth2/v2.0/token", +} + +# IMAP 服务器配置 +OUTLOOK_IMAP_SERVERS = { + "OLD": "outlook.office365.com", # 旧版 IMAP + "NEW": "outlook.live.com", # 新版 IMAP +} + +# Microsoft OAuth2 Scopes +MICROSOFT_SCOPES = { + # 旧版 IMAP 不需要特定 scope + "IMAP_OLD": "", + # 新版 IMAP 需要的 scope + "IMAP_NEW": "https://outlook.office.com/IMAP.AccessAsUser.All offline_access", + # Graph API 需要的 scope + "GRAPH_API": "https://graph.microsoft.com/.default", +} + +# Outlook 提供者默认优先级 +OUTLOOK_PROVIDER_PRIORITY = ["imap_new", "imap_old", "graph_api"] \ No newline at end of file diff --git a/src/config/settings.py b/src/config/settings.py index 7055c0a..45801ad 100644 --- a/src/config/settings.py +++ b/src/config/settings.py @@ -4,7 +4,7 @@ """ import os -from typing import Optional, Dict, Any, Type +from typing import Optional, Dict, Any, Type, List from enum import Enum from pydantic import BaseModel, field_validator from pydantic.types import SecretStr @@ -297,6 +297,32 @@ SETTING_DEFINITIONS: Dict[str, SettingDefinition] = { category=SettingCategory.EMAIL, description="验证码轮询间隔(秒)" ), + + # Outlook 配置 + "outlook_provider_priority": SettingDefinition( + db_key="outlook.provider_priority", + default_value=["graph_api", "imap_new", "imap_old"], + category=SettingCategory.EMAIL, + description="Outlook 提供者优先级" + ), + "outlook_health_failure_threshold": SettingDefinition( + db_key="outlook.health_failure_threshold", + default_value=5, + category=SettingCategory.EMAIL, + description="Outlook 提供者连续失败次数阈值" + ), + "outlook_health_disable_duration": SettingDefinition( + db_key="outlook.health_disable_duration", + default_value=60, + category=SettingCategory.EMAIL, + description="Outlook 提供者禁用时长(秒)" + ), + "outlook_default_client_id": SettingDefinition( + db_key="outlook.default_client_id", + default_value="24d9a0ed-8787-4584-883c-2fd79308940a", + category=SettingCategory.EMAIL, + description="Outlook OAuth 默认 Client ID" + ), } # 属性名到数据库键名的映射(用于向后兼容) @@ -320,6 +346,9 @@ SETTING_TYPES: Dict[str, Type] = { "cpa_enabled": bool, "email_code_timeout": int, "email_code_poll_interval": int, + "outlook_provider_priority": list, + "outlook_health_failure_threshold": int, + "outlook_health_disable_duration": int, } # 需要作为 SecretStr 处理的字段 @@ -346,6 +375,11 @@ def _convert_value(attr_name: str, value: str) -> Any: return value import json return json.loads(value) if value else {} + elif target_type == list: + if isinstance(value, list): + return value + import json + return json.loads(value) if value else [] else: return value @@ -356,7 +390,7 @@ def _value_to_string(value: Any) -> str: return value.get_secret_value() elif isinstance(value, bool): return "true" if value else "false" - elif isinstance(value, dict): + elif isinstance(value, (dict, list)): import json return json.dumps(value) elif value is None: @@ -533,6 +567,12 @@ class Settings(BaseModel): email_code_timeout: int = 120 email_code_poll_interval: int = 3 + # Outlook 配置 + outlook_provider_priority: List[str] = ["graph_api", "imap_new", "imap_old"] + outlook_health_failure_threshold: int = 5 + outlook_health_disable_duration: int = 60 + outlook_default_client_id: str = "24d9a0ed-8787-4584-883c-2fd79308940a" + # 全局配置实例 _settings: Optional[Settings] = None diff --git a/src/services/__init__.py b/src/services/__init__.py index 4d1cb9c..144805c 100644 --- a/src/services/__init__.py +++ b/src/services/__init__.py @@ -19,14 +19,43 @@ EmailServiceFactory.register(EmailServiceType.TEMPMAIL, TempmailService) EmailServiceFactory.register(EmailServiceType.OUTLOOK, OutlookService) EmailServiceFactory.register(EmailServiceType.CUSTOM_DOMAIN, CustomDomainEmailService) +# 导出 Outlook 模块的额外内容 +from .outlook.base import ( + ProviderType, + EmailMessage, + TokenInfo, + ProviderHealth, + ProviderStatus, +) +from .outlook.account import OutlookAccount +from .outlook.providers import ( + OutlookProvider, + IMAPOldProvider, + IMAPNewProvider, + GraphAPIProvider, +) + __all__ = [ + # 基类 'BaseEmailService', 'EmailServiceError', 'EmailServiceStatus', 'EmailServiceFactory', 'create_email_service', 'EmailServiceType', + # 服务类 'TempmailService', 'OutlookService', 'CustomDomainEmailService', + # Outlook 模块 + 'ProviderType', + 'EmailMessage', + 'TokenInfo', + 'ProviderHealth', + 'ProviderStatus', + 'OutlookAccount', + 'OutlookProvider', + 'IMAPOldProvider', + 'IMAPNewProvider', + 'GraphAPIProvider', ] \ No newline at end of file diff --git a/src/services/outlook/__init__.py b/src/services/outlook/__init__.py new file mode 100644 index 0000000..fbdd660 --- /dev/null +++ b/src/services/outlook/__init__.py @@ -0,0 +1,8 @@ +""" +Outlook 邮箱服务模块 +支持多种 IMAP/API 连接方式,自动故障切换 +""" + +from .service import OutlookService + +__all__ = ['OutlookService'] diff --git a/src/services/outlook/account.py b/src/services/outlook/account.py new file mode 100644 index 0000000..6f427d5 --- /dev/null +++ b/src/services/outlook/account.py @@ -0,0 +1,51 @@ +""" +Outlook 账户数据类 +""" + +from dataclasses import dataclass +from typing import Dict, Any, Optional + + +@dataclass +class OutlookAccount: + """Outlook 账户信息""" + email: str + password: str = "" + client_id: str = "" + refresh_token: str = "" + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "OutlookAccount": + """从配置创建账户""" + return cls( + email=config.get("email", ""), + password=config.get("password", ""), + client_id=config.get("client_id", ""), + refresh_token=config.get("refresh_token", "") + ) + + def has_oauth(self) -> bool: + """是否支持 OAuth2""" + return bool(self.client_id and self.refresh_token) + + def validate(self) -> bool: + """验证账户信息是否有效""" + return bool(self.email and self.password) or self.has_oauth() + + def to_dict(self, include_sensitive: bool = False) -> Dict[str, Any]: + """转换为字典""" + result = { + "email": self.email, + "has_oauth": self.has_oauth(), + } + if include_sensitive: + result.update({ + "password": self.password, + "client_id": self.client_id, + "refresh_token": self.refresh_token[:20] + "..." if self.refresh_token else "", + }) + return result + + def __str__(self) -> str: + """字符串表示""" + return f"OutlookAccount({self.email})" diff --git a/src/services/outlook/base.py b/src/services/outlook/base.py new file mode 100644 index 0000000..335b11e --- /dev/null +++ b/src/services/outlook/base.py @@ -0,0 +1,153 @@ +""" +Outlook 服务基础定义 +包含枚举类型和数据类 +""" + +from dataclasses import dataclass, field +from datetime import datetime +from enum import Enum +from typing import Optional, Dict, Any, List + + +class ProviderType(str, Enum): + """Outlook 提供者类型""" + IMAP_OLD = "imap_old" # 旧版 IMAP (outlook.office365.com) + IMAP_NEW = "imap_new" # 新版 IMAP (outlook.live.com) + GRAPH_API = "graph_api" # Microsoft Graph API + + +class TokenEndpoint(str, Enum): + """Token 端点""" + LIVE = "https://login.live.com/oauth20_token.srf" + CONSUMERS = "https://login.microsoftonline.com/consumers/oauth2/v2.0/token" + COMMON = "https://login.microsoftonline.com/common/oauth2/v2.0/token" + + +class IMAPServer(str, Enum): + """IMAP 服务器""" + OLD = "outlook.office365.com" + NEW = "outlook.live.com" + + +class ProviderStatus(str, Enum): + """提供者状态""" + HEALTHY = "healthy" # 健康 + DEGRADED = "degraded" # 降级 + DISABLED = "disabled" # 禁用 + + +@dataclass +class EmailMessage: + """邮件消息数据类""" + id: str # 消息 ID + subject: str # 主题 + sender: str # 发件人 + recipients: List[str] = field(default_factory=list) # 收件人列表 + body: str = "" # 正文内容 + body_preview: str = "" # 正文预览 + received_at: Optional[datetime] = None # 接收时间 + received_timestamp: int = 0 # 接收时间戳 + is_read: bool = False # 是否已读 + has_attachments: bool = False # 是否有附件 + raw_data: Optional[bytes] = None # 原始数据(用于调试) + + def to_dict(self) -> Dict[str, Any]: + """转换为字典""" + return { + "id": self.id, + "subject": self.subject, + "sender": self.sender, + "recipients": self.recipients, + "body": self.body, + "body_preview": self.body_preview, + "received_at": self.received_at.isoformat() if self.received_at else None, + "received_timestamp": self.received_timestamp, + "is_read": self.is_read, + "has_attachments": self.has_attachments, + } + + +@dataclass +class TokenInfo: + """Token 信息数据类""" + access_token: str + expires_at: float # 过期时间戳 + token_type: str = "Bearer" + scope: str = "" + refresh_token: Optional[str] = None + + def is_expired(self, buffer_seconds: int = 120) -> bool: + """检查 Token 是否已过期""" + import time + return time.time() >= (self.expires_at - buffer_seconds) + + @classmethod + def from_response(cls, data: Dict[str, Any], scope: str = "") -> "TokenInfo": + """从 API 响应创建""" + import time + return cls( + access_token=data.get("access_token", ""), + expires_at=time.time() + data.get("expires_in", 3600), + token_type=data.get("token_type", "Bearer"), + scope=scope or data.get("scope", ""), + refresh_token=data.get("refresh_token"), + ) + + +@dataclass +class ProviderHealth: + """提供者健康状态""" + provider_type: ProviderType + status: ProviderStatus = ProviderStatus.HEALTHY + failure_count: int = 0 # 连续失败次数 + last_success: Optional[datetime] = None # 最后成功时间 + last_failure: Optional[datetime] = None # 最后失败时间 + last_error: str = "" # 最后错误信息 + disabled_until: Optional[datetime] = None # 禁用截止时间 + + def record_success(self): + """记录成功""" + self.status = ProviderStatus.HEALTHY + self.failure_count = 0 + self.last_success = datetime.now() + self.disabled_until = None + + def record_failure(self, error: str): + """记录失败""" + self.failure_count += 1 + self.last_failure = datetime.now() + self.last_error = error + + def should_disable(self, threshold: int = 3) -> bool: + """判断是否应该禁用""" + return self.failure_count >= threshold + + def is_disabled(self) -> bool: + """检查是否被禁用""" + if self.disabled_until and datetime.now() < self.disabled_until: + return True + return False + + def disable(self, duration_seconds: int = 300): + """禁用提供者""" + from datetime import timedelta + self.status = ProviderStatus.DISABLED + self.disabled_until = datetime.now() + timedelta(seconds=duration_seconds) + + def enable(self): + """启用提供者""" + self.status = ProviderStatus.HEALTHY + self.disabled_until = None + self.failure_count = 0 + + def to_dict(self) -> Dict[str, Any]: + """转换为字典""" + return { + "provider_type": self.provider_type.value, + "status": self.status.value, + "failure_count": self.failure_count, + "last_success": self.last_success.isoformat() if self.last_success else None, + "last_failure": self.last_failure.isoformat() if self.last_failure else None, + "last_error": self.last_error, + "disabled_until": self.disabled_until.isoformat() if self.disabled_until else None, + } diff --git a/src/services/outlook/email_parser.py b/src/services/outlook/email_parser.py new file mode 100644 index 0000000..84d5228 --- /dev/null +++ b/src/services/outlook/email_parser.py @@ -0,0 +1,228 @@ +""" +邮件解析和验证码提取 +""" + +import logging +import re +from typing import Optional, List, Dict, Any + +from ...config.constants import ( + OTP_CODE_SIMPLE_PATTERN, + OTP_CODE_SEMANTIC_PATTERN, + OPENAI_EMAIL_SENDERS, + OPENAI_VERIFICATION_KEYWORDS, +) +from .base import EmailMessage + + +logger = logging.getLogger(__name__) + + +class EmailParser: + """ + 邮件解析器 + 用于识别 OpenAI 验证邮件并提取验证码 + """ + + def __init__(self): + # 编译正则表达式 + self._simple_pattern = re.compile(OTP_CODE_SIMPLE_PATTERN) + self._semantic_pattern = re.compile(OTP_CODE_SEMANTIC_PATTERN, re.IGNORECASE) + + def is_openai_verification_email( + self, + email: EmailMessage, + target_email: Optional[str] = None, + ) -> bool: + """ + 判断是否为 OpenAI 验证邮件 + + Args: + email: 邮件对象 + target_email: 目标邮箱地址(用于验证收件人) + + Returns: + 是否为 OpenAI 验证邮件 + """ + sender = email.sender.lower() + + # 1. 发件人必须是 OpenAI + if not any(s in sender for s in OPENAI_EMAIL_SENDERS): + logger.debug(f"邮件发件人非 OpenAI: {sender}") + return False + + # 2. 主题或正文包含验证关键词 + subject = email.subject.lower() + body = email.body.lower() + combined = f"{subject} {body}" + + if not any(kw in combined for kw in OPENAI_VERIFICATION_KEYWORDS): + logger.debug(f"邮件未包含验证关键词: {subject[:50]}") + return False + + # 3. 收件人检查已移除:别名邮件的 IMAP 头中收件人可能不匹配,只靠发件人+关键词判断 + logger.debug(f"识别为 OpenAI 验证邮件: {subject[:50]}") + return True + + def extract_verification_code( + self, + email: EmailMessage, + ) -> Optional[str]: + """ + 从邮件中提取验证码 + + 优先级: + 1. 从主题提取(6位数字) + 2. 从正文用语义正则提取(如 "code is 123456") + 3. 兜底:任意 6 位数字 + + Args: + email: 邮件对象 + + Returns: + 验证码字符串,如果未找到返回 None + """ + # 1. 主题优先 + code = self._extract_from_subject(email.subject) + if code: + logger.debug(f"从主题提取验证码: {code}") + return code + + # 2. 正文语义匹配 + code = self._extract_semantic(email.body) + if code: + logger.debug(f"从正文语义提取验证码: {code}") + return code + + # 3. 兜底:正文任意 6 位数字 + code = self._extract_simple(email.body) + if code: + logger.debug(f"从正文兜底提取验证码: {code}") + return code + + return None + + def _extract_from_subject(self, subject: str) -> Optional[str]: + """从主题提取验证码""" + match = self._simple_pattern.search(subject) + if match: + return match.group(1) + return None + + def _extract_semantic(self, body: str) -> Optional[str]: + """语义匹配提取验证码""" + match = self._semantic_pattern.search(body) + if match: + return match.group(1) + return None + + def _extract_simple(self, body: str) -> Optional[str]: + """简单匹配提取验证码""" + match = self._simple_pattern.search(body) + if match: + return match.group(1) + return None + + def find_verification_code_in_emails( + self, + emails: List[EmailMessage], + target_email: Optional[str] = None, + min_timestamp: int = 0, + used_codes: Optional[set] = None, + ) -> Optional[str]: + """ + 从邮件列表中查找验证码 + + Args: + emails: 邮件列表 + target_email: 目标邮箱地址 + min_timestamp: 最小时间戳(用于过滤旧邮件) + used_codes: 已使用的验证码集合(用于去重) + + Returns: + 验证码字符串,如果未找到返回 None + """ + used_codes = used_codes or set() + + for email in emails: + # 时间戳过滤 + if min_timestamp > 0 and email.received_timestamp > 0: + if email.received_timestamp < min_timestamp: + logger.debug(f"跳过旧邮件: {email.subject[:50]}") + continue + + # 检查是否是 OpenAI 验证邮件 + if not self.is_openai_verification_email(email, target_email): + continue + + # 提取验证码 + code = self.extract_verification_code(email) + if code: + # 去重检查 + if code in used_codes: + logger.debug(f"跳过已使用的验证码: {code}") + continue + + logger.info( + f"[{target_email or 'unknown'}] 找到验证码: {code}, " + f"邮件主题: {email.subject[:30]}" + ) + return code + + return None + + def filter_emails_by_sender( + self, + emails: List[EmailMessage], + sender_patterns: List[str], + ) -> List[EmailMessage]: + """ + 按发件人过滤邮件 + + Args: + emails: 邮件列表 + sender_patterns: 发件人匹配模式列表 + + Returns: + 过滤后的邮件列表 + """ + filtered = [] + for email in emails: + sender = email.sender.lower() + if any(pattern.lower() in sender for pattern in sender_patterns): + filtered.append(email) + return filtered + + def filter_emails_by_subject( + self, + emails: List[EmailMessage], + keywords: List[str], + ) -> List[EmailMessage]: + """ + 按主题关键词过滤邮件 + + Args: + emails: 邮件列表 + keywords: 关键词列表 + + Returns: + 过滤后的邮件列表 + """ + filtered = [] + for email in emails: + subject = email.subject.lower() + if any(kw.lower() in subject for kw in keywords): + filtered.append(email) + return filtered + + +# 全局解析器实例 +_parser: Optional[EmailParser] = None + + +def get_email_parser() -> EmailParser: + """获取全局邮件解析器实例""" + global _parser + if _parser is None: + _parser = EmailParser() + return _parser diff --git a/src/services/outlook/health_checker.py b/src/services/outlook/health_checker.py new file mode 100644 index 0000000..c68ed4e --- /dev/null +++ b/src/services/outlook/health_checker.py @@ -0,0 +1,312 @@ +""" +健康检查和故障切换管理 +""" + +import logging +import threading +import time +from datetime import datetime, timedelta +from typing import Dict, List, Optional, Any + +from .base import ProviderType, ProviderHealth, ProviderStatus +from .providers.base import OutlookProvider + + +logger = logging.getLogger(__name__) + + +class HealthChecker: + """ + 健康检查管理器 + 跟踪各提供者的健康状态,管理故障切换 + """ + + def __init__( + self, + failure_threshold: int = 3, + disable_duration: int = 300, + recovery_check_interval: int = 60, + ): + """ + 初始化健康检查器 + + Args: + failure_threshold: 连续失败次数阈值,超过后禁用 + disable_duration: 禁用时长(秒) + recovery_check_interval: 恢复检查间隔(秒) + """ + self.failure_threshold = failure_threshold + self.disable_duration = disable_duration + self.recovery_check_interval = recovery_check_interval + + # 提供者健康状态: ProviderType -> ProviderHealth + self._health_status: Dict[ProviderType, ProviderHealth] = {} + self._lock = threading.Lock() + + # 初始化所有提供者的健康状态 + for provider_type in ProviderType: + self._health_status[provider_type] = ProviderHealth( + provider_type=provider_type + ) + + def get_health(self, provider_type: ProviderType) -> ProviderHealth: + """获取提供者的健康状态""" + with self._lock: + return self._health_status.get(provider_type, ProviderHealth(provider_type=provider_type)) + + def record_success(self, provider_type: ProviderType): + """记录成功操作""" + with self._lock: + health = self._health_status.get(provider_type) + if health: + health.record_success() + logger.debug(f"{provider_type.value} 记录成功") + + def record_failure(self, provider_type: ProviderType, error: str): + """记录失败操作""" + with self._lock: + health = self._health_status.get(provider_type) + if health: + health.record_failure(error) + + # 检查是否需要禁用 + if health.should_disable(self.failure_threshold): + health.disable(self.disable_duration) + logger.warning( + f"{provider_type.value} 已禁用 {self.disable_duration} 秒," + f"原因: {error}" + ) + + def is_available(self, provider_type: ProviderType) -> bool: + """ + 检查提供者是否可用 + + Args: + provider_type: 提供者类型 + + Returns: + 是否可用 + """ + health = self.get_health(provider_type) + + # 检查是否被禁用 + if health.is_disabled(): + remaining = (health.disabled_until - datetime.now()).total_seconds() + logger.debug( + f"{provider_type.value} 已被禁用,剩余 {int(remaining)} 秒" + ) + return False + + return health.status != ProviderStatus.DISABLED + + def get_available_providers( + self, + priority_order: Optional[List[ProviderType]] = None, + ) -> List[ProviderType]: + """ + 获取可用的提供者列表 + + Args: + priority_order: 优先级顺序,默认为 [IMAP_NEW, IMAP_OLD, GRAPH_API] + + Returns: + 可用的提供者列表 + """ + if priority_order is None: + priority_order = [ + ProviderType.IMAP_NEW, + ProviderType.IMAP_OLD, + ProviderType.GRAPH_API, + ] + + available = [] + for provider_type in priority_order: + if self.is_available(provider_type): + available.append(provider_type) + + return available + + def get_next_available_provider( + self, + priority_order: Optional[List[ProviderType]] = None, + ) -> Optional[ProviderType]: + """ + 获取下一个可用的提供者 + + Args: + priority_order: 优先级顺序 + + Returns: + 可用的提供者类型,如果没有返回 None + """ + available = self.get_available_providers(priority_order) + return available[0] if available else None + + def force_disable(self, provider_type: ProviderType, duration: Optional[int] = None): + """ + 强制禁用提供者 + + Args: + provider_type: 提供者类型 + duration: 禁用时长(秒),默认使用配置值 + """ + with self._lock: + health = self._health_status.get(provider_type) + if health: + health.disable(duration or self.disable_duration) + logger.warning(f"{provider_type.value} 已强制禁用") + + def force_enable(self, provider_type: ProviderType): + """ + 强制启用提供者 + + Args: + provider_type: 提供者类型 + """ + with self._lock: + health = self._health_status.get(provider_type) + if health: + health.enable() + logger.info(f"{provider_type.value} 已启用") + + def get_all_health_status(self) -> Dict[str, Any]: + """ + 获取所有提供者的健康状态 + + Returns: + 健康状态字典 + """ + with self._lock: + return { + provider_type.value: health.to_dict() + for provider_type, health in self._health_status.items() + } + + def check_and_recover(self): + """ + 检查并恢复被禁用的提供者 + + 如果禁用时间已过,自动恢复提供者 + """ + with self._lock: + for provider_type, health in self._health_status.items(): + if health.is_disabled(): + # 检查是否可以恢复 + if health.disabled_until and datetime.now() >= health.disabled_until: + health.enable() + logger.info(f"{provider_type.value} 已自动恢复") + + def reset_all(self): + """重置所有提供者的健康状态""" + with self._lock: + for provider_type in ProviderType: + self._health_status[provider_type] = ProviderHealth( + provider_type=provider_type + ) + logger.info("已重置所有提供者的健康状态") + + +class FailoverManager: + """ + 故障切换管理器 + 管理提供者之间的自动切换 + """ + + def __init__( + self, + health_checker: HealthChecker, + priority_order: Optional[List[ProviderType]] = None, + ): + """ + 初始化故障切换管理器 + + Args: + health_checker: 健康检查器 + priority_order: 提供者优先级顺序 + """ + self.health_checker = health_checker + self.priority_order = priority_order or [ + ProviderType.IMAP_NEW, + ProviderType.IMAP_OLD, + ProviderType.GRAPH_API, + ] + + # 当前使用的提供者索引 + self._current_index = 0 + self._lock = threading.Lock() + + def get_current_provider(self) -> Optional[ProviderType]: + """ + 获取当前提供者 + + Returns: + 当前提供者类型,如果没有可用的返回 None + """ + available = self.health_checker.get_available_providers(self.priority_order) + if not available: + return None + + with self._lock: + # 尝试使用当前索引 + if self._current_index < len(available): + return available[self._current_index] + return available[0] + + def switch_to_next(self) -> Optional[ProviderType]: + """ + 切换到下一个提供者 + + Returns: + 下一个提供者类型,如果没有可用的返回 None + """ + available = self.health_checker.get_available_providers(self.priority_order) + if not available: + return None + + with self._lock: + self._current_index = (self._current_index + 1) % len(available) + next_provider = available[self._current_index] + logger.info(f"切换到提供者: {next_provider.value}") + return next_provider + + def on_provider_success(self, provider_type: ProviderType): + """ + 提供者成功时调用 + + Args: + provider_type: 提供者类型 + """ + self.health_checker.record_success(provider_type) + + # 重置索引到成功的提供者 + with self._lock: + available = self.health_checker.get_available_providers(self.priority_order) + if provider_type in available: + self._current_index = available.index(provider_type) + + def on_provider_failure(self, provider_type: ProviderType, error: str): + """ + 提供者失败时调用 + + Args: + provider_type: 提供者类型 + error: 错误信息 + """ + self.health_checker.record_failure(provider_type, error) + + def get_status(self) -> Dict[str, Any]: + """ + 获取故障切换状态 + + Returns: + 状态字典 + """ + current = self.get_current_provider() + return { + "current_provider": current.value if current else None, + "priority_order": [p.value for p in self.priority_order], + "available_providers": [ + p.value for p in self.health_checker.get_available_providers(self.priority_order) + ], + "health_status": self.health_checker.get_all_health_status(), + } diff --git a/src/services/outlook/providers/__init__.py b/src/services/outlook/providers/__init__.py new file mode 100644 index 0000000..d6fe6a1 --- /dev/null +++ b/src/services/outlook/providers/__init__.py @@ -0,0 +1,29 @@ +""" +Outlook 提供者模块 +""" + +from .base import OutlookProvider, ProviderConfig +from .imap_old import IMAPOldProvider +from .imap_new import IMAPNewProvider +from .graph_api import GraphAPIProvider + +__all__ = [ + 'OutlookProvider', + 'ProviderConfig', + 'IMAPOldProvider', + 'IMAPNewProvider', + 'GraphAPIProvider', +] + + +# 提供者注册表 +PROVIDER_REGISTRY = { + 'imap_old': IMAPOldProvider, + 'imap_new': IMAPNewProvider, + 'graph_api': GraphAPIProvider, +} + + +def get_provider_class(provider_type: str): + """获取提供者类""" + return PROVIDER_REGISTRY.get(provider_type) diff --git a/src/services/outlook/providers/base.py b/src/services/outlook/providers/base.py new file mode 100644 index 0000000..0d6c072 --- /dev/null +++ b/src/services/outlook/providers/base.py @@ -0,0 +1,180 @@ +""" +Outlook 提供者抽象基类 +""" + +import abc +import logging +from dataclasses import dataclass +from typing import Dict, Any, List, Optional + +from ..base import ProviderType, EmailMessage, ProviderHealth, ProviderStatus +from ..account import OutlookAccount + + +logger = logging.getLogger(__name__) + + +@dataclass +class ProviderConfig: + """提供者配置""" + timeout: int = 30 + max_retries: int = 3 + proxy_url: Optional[str] = None + + # 健康检查配置 + health_failure_threshold: int = 3 + health_disable_duration: int = 300 # 秒 + + +class OutlookProvider(abc.ABC): + """ + Outlook 提供者抽象基类 + 定义所有提供者必须实现的接口 + """ + + def __init__( + self, + account: OutlookAccount, + config: Optional[ProviderConfig] = None, + ): + """ + 初始化提供者 + + Args: + account: Outlook 账户 + config: 提供者配置 + """ + self.account = account + self.config = config or ProviderConfig() + + # 健康状态 + self._health = ProviderHealth(provider_type=self.provider_type) + + # 连接状态 + self._connected = False + self._last_error: Optional[str] = None + + @property + @abc.abstractmethod + def provider_type(self) -> ProviderType: + """获取提供者类型""" + pass + + @property + def health(self) -> ProviderHealth: + """获取健康状态""" + return self._health + + @property + def is_healthy(self) -> bool: + """检查是否健康""" + return ( + self._health.status == ProviderStatus.HEALTHY + and not self._health.is_disabled() + ) + + @property + def is_connected(self) -> bool: + """检查是否已连接""" + return self._connected + + @abc.abstractmethod + def connect(self) -> bool: + """ + 连接到服务 + + Returns: + 是否连接成功 + """ + pass + + @abc.abstractmethod + def disconnect(self): + """断开连接""" + pass + + @abc.abstractmethod + def get_recent_emails( + self, + count: int = 20, + only_unseen: bool = True, + ) -> List[EmailMessage]: + """ + 获取最近的邮件 + + Args: + count: 获取数量 + only_unseen: 是否只获取未读 + + Returns: + 邮件列表 + """ + pass + + @abc.abstractmethod + def test_connection(self) -> bool: + """ + 测试连接是否正常 + + Returns: + 连接是否正常 + """ + pass + + def record_success(self): + """记录成功操作""" + self._health.record_success() + self._last_error = None + logger.debug(f"[{self.account.email}] {self.provider_type.value} 操作成功") + + def record_failure(self, error: str): + """记录失败操作""" + self._health.record_failure(error) + self._last_error = error + + # 检查是否需要禁用 + if self._health.should_disable(self.config.health_failure_threshold): + self._health.disable(self.config.health_disable_duration) + logger.warning( + f"[{self.account.email}] {self.provider_type.value} 已禁用 " + f"{self.config.health_disable_duration} 秒,原因: {error}" + ) + else: + logger.warning( + f"[{self.account.email}] {self.provider_type.value} 操作失败 " + f"({self._health.failure_count}/{self.config.health_failure_threshold}): {error}" + ) + + def check_health(self) -> bool: + """ + 检查健康状态 + + Returns: + 是否健康可用 + """ + # 检查是否被禁用 + if self._health.is_disabled(): + logger.debug( + f"[{self.account.email}] {self.provider_type.value} 已被禁用," + f"将在 {self._health.disabled_until} 后恢复" + ) + return False + + return self._health.status in (ProviderStatus.HEALTHY, ProviderStatus.DEGRADED) + + def __enter__(self): + """上下文管理器入口""" + self.connect() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """上下文管理器出口""" + self.disconnect() + return False + + def __str__(self) -> str: + """字符串表示""" + return f"{self.__class__.__name__}({self.account.email})" + + def __repr__(self) -> str: + return self.__str__() diff --git a/src/services/outlook/providers/graph_api.py b/src/services/outlook/providers/graph_api.py new file mode 100644 index 0000000..13af5d8 --- /dev/null +++ b/src/services/outlook/providers/graph_api.py @@ -0,0 +1,249 @@ +""" +Graph API 提供者 +使用 Microsoft Graph REST API +""" + +import json +import logging +from typing import List, Optional +from datetime import datetime + +from curl_cffi import requests as _requests + +from ..base import ProviderType, EmailMessage +from ..account import OutlookAccount +from ..token_manager import TokenManager +from .base import OutlookProvider, ProviderConfig + + +logger = logging.getLogger(__name__) + + +class GraphAPIProvider(OutlookProvider): + """ + Graph API 提供者 + 使用 Microsoft Graph REST API 获取邮件 + 需要 graph.microsoft.com/.default scope + """ + + # Graph API 端点 + GRAPH_API_BASE = "https://graph.microsoft.com/v1.0" + MESSAGES_ENDPOINT = "/me/mailFolders/inbox/messages" + + @property + def provider_type(self) -> ProviderType: + return ProviderType.GRAPH_API + + def __init__( + self, + account: OutlookAccount, + config: Optional[ProviderConfig] = None, + ): + super().__init__(account, config) + + # Token 管理器 + self._token_manager: Optional[TokenManager] = None + + # 注意:Graph API 必须使用 OAuth2 + if not account.has_oauth(): + logger.warning( + f"[{self.account.email}] Graph API 提供者需要 OAuth2 配置 " + f"(client_id + refresh_token)" + ) + + def connect(self) -> bool: + """ + 验证连接(获取 Token) + + Returns: + 是否连接成功 + """ + if not self.account.has_oauth(): + error = "Graph API 需要 OAuth2 配置" + self.record_failure(error) + logger.error(f"[{self.account.email}] {error}") + return False + + if not self._token_manager: + self._token_manager = TokenManager( + self.account, + ProviderType.GRAPH_API, + self.config.proxy_url, + self.config.timeout, + ) + + # 尝试获取 Token + token = self._token_manager.get_access_token() + if token: + self._connected = True + self.record_success() + logger.info(f"[{self.account.email}] Graph API 连接成功") + return True + + return False + + def disconnect(self): + """断开连接(清除状态)""" + self._connected = False + + def get_recent_emails( + self, + count: int = 20, + only_unseen: bool = True, + ) -> List[EmailMessage]: + """ + 获取最近的邮件 + + Args: + count: 获取数量 + only_unseen: 是否只获取未读 + + Returns: + 邮件列表 + """ + if not self._connected: + if not self.connect(): + return [] + + try: + # 获取 Access Token + token = self._token_manager.get_access_token() + if not token: + self.record_failure("无法获取 Access Token") + return [] + + # 构建 API 请求 + url = f"{self.GRAPH_API_BASE}{self.MESSAGES_ENDPOINT}" + + params = { + "$top": count, + "$select": "id,subject,from,toRecipients,receivedDateTime,isRead,hasAttachments,bodyPreview,body", + "$orderby": "receivedDateTime desc", + } + + # 只获取未读邮件 + if only_unseen: + params["$filter"] = "isRead eq false" + + # 构建代理配置 + proxies = None + if self.config.proxy_url: + proxies = {"http": self.config.proxy_url, "https": self.config.proxy_url} + + # 发送请求(curl_cffi 自动对 params 进行 URL 编码) + resp = _requests.get( + url, + params=params, + headers={ + "Authorization": f"Bearer {token}", + "Accept": "application/json", + "Prefer": "outlook.body-content-type='text'", + }, + proxies=proxies, + timeout=self.config.timeout, + impersonate="chrome110", + ) + + if resp.status_code == 401: + # Token 失效,清除缓存 + if self._token_manager: + self._token_manager.clear_cache() + self.record_failure(f"HTTP 401: Token 失效") + logger.error(f"[{self.account.email}] Graph API Token 失效") + return [] + + if resp.status_code != 200: + error_body = resp.text[:200] + self.record_failure(f"HTTP {resp.status_code}: {error_body}") + logger.error(f"[{self.account.email}] Graph API 请求失败: HTTP {resp.status_code}") + return [] + + data = resp.json() + + # 解析邮件 + messages = data.get("value", []) + emails = [] + + for msg in messages: + try: + email_msg = self._parse_graph_message(msg) + if email_msg: + emails.append(email_msg) + except Exception as e: + logger.warning(f"[{self.account.email}] 解析 Graph API 邮件失败: {e}") + + self.record_success() + return emails + + except Exception as e: + self.record_failure(str(e)) + logger.error(f"[{self.account.email}] Graph API 获取邮件失败: {e}") + return [] + + def _parse_graph_message(self, msg: dict) -> Optional[EmailMessage]: + """ + 解析 Graph API 消息 + + Args: + msg: Graph API 消息对象 + + Returns: + EmailMessage 对象 + """ + # 解析发件人 + from_info = msg.get("from", {}) + sender_info = from_info.get("emailAddress", {}) + sender = sender_info.get("address", "") + + # 解析收件人 + recipients = [] + for recipient in msg.get("toRecipients", []): + addr_info = recipient.get("emailAddress", {}) + addr = addr_info.get("address", "") + if addr: + recipients.append(addr) + + # 解析日期 + received_at = None + received_timestamp = 0 + try: + date_str = msg.get("receivedDateTime", "") + if date_str: + # ISO 8601 格式 + received_at = datetime.fromisoformat(date_str.replace("Z", "+00:00")) + received_timestamp = int(received_at.timestamp()) + except Exception: + pass + + # 获取正文 + body_info = msg.get("body", {}) + body = body_info.get("content", "") + body_preview = msg.get("bodyPreview", "") + + return EmailMessage( + id=msg.get("id", ""), + subject=msg.get("subject", ""), + sender=sender, + recipients=recipients, + body=body, + body_preview=body_preview, + received_at=received_at, + received_timestamp=received_timestamp, + is_read=msg.get("isRead", False), + has_attachments=msg.get("hasAttachments", False), + ) + + def test_connection(self) -> bool: + """ + 测试 Graph API 连接 + + Returns: + 连接是否正常 + """ + try: + # 尝试获取一封邮件来测试连接 + emails = self.get_recent_emails(count=1, only_unseen=False) + return True + except Exception as e: + logger.warning(f"[{self.account.email}] Graph API 连接测试失败: {e}") + return False diff --git a/src/services/outlook/providers/imap_new.py b/src/services/outlook/providers/imap_new.py new file mode 100644 index 0000000..5daa2f3 --- /dev/null +++ b/src/services/outlook/providers/imap_new.py @@ -0,0 +1,231 @@ +""" +新版 IMAP 提供者 +使用 outlook.live.com 服务器和 login.microsoftonline.com/consumers Token 端点 +""" + +import email +import imaplib +import logging +from email.header import decode_header +from email.utils import parsedate_to_datetime +from typing import List, Optional + +from ..base import ProviderType, EmailMessage +from ..account import OutlookAccount +from ..token_manager import TokenManager +from .base import OutlookProvider, ProviderConfig +from .imap_old import IMAPOldProvider + + +logger = logging.getLogger(__name__) + + +class IMAPNewProvider(OutlookProvider): + """ + 新版 IMAP 提供者 + 使用 outlook.live.com:993 和 login.microsoftonline.com/consumers Token 端点 + 需要 IMAP.AccessAsUser.All scope + """ + + # IMAP 服务器配置 + IMAP_HOST = "outlook.live.com" + IMAP_PORT = 993 + + @property + def provider_type(self) -> ProviderType: + return ProviderType.IMAP_NEW + + def __init__( + self, + account: OutlookAccount, + config: Optional[ProviderConfig] = None, + ): + super().__init__(account, config) + + # IMAP 连接 + self._conn: Optional[imaplib.IMAP4_SSL] = None + + # Token 管理器 + self._token_manager: Optional[TokenManager] = None + + # 注意:新版 IMAP 必须使用 OAuth2 + if not account.has_oauth(): + logger.warning( + f"[{self.account.email}] 新版 IMAP 提供者需要 OAuth2 配置 " + f"(client_id + refresh_token)" + ) + + def connect(self) -> bool: + """ + 连接到 IMAP 服务器 + + Returns: + 是否连接成功 + """ + if self._connected and self._conn: + try: + self._conn.noop() + return True + except Exception: + self.disconnect() + + # 新版 IMAP 必须使用 OAuth2,无 OAuth 时静默跳过,不记录健康失败 + if not self.account.has_oauth(): + logger.debug(f"[{self.account.email}] 跳过 IMAP_NEW(无 OAuth)") + return False + + try: + logger.debug(f"[{self.account.email}] 正在连接 IMAP ({self.IMAP_HOST})...") + + # 创建连接 + self._conn = imaplib.IMAP4_SSL( + self.IMAP_HOST, + self.IMAP_PORT, + timeout=self.config.timeout, + ) + + # XOAUTH2 认证 + if self._authenticate_xoauth2(): + self._connected = True + self.record_success() + logger.info(f"[{self.account.email}] 新版 IMAP 连接成功 (XOAUTH2)") + return True + + return False + + except Exception as e: + self.disconnect() + self.record_failure(str(e)) + logger.error(f"[{self.account.email}] 新版 IMAP 连接失败: {e}") + return False + + def _authenticate_xoauth2(self) -> bool: + """ + 使用 XOAUTH2 认证 + + Returns: + 是否认证成功 + """ + if not self._token_manager: + self._token_manager = TokenManager( + self.account, + ProviderType.IMAP_NEW, + self.config.proxy_url, + self.config.timeout, + ) + + # 获取 Access Token + token = self._token_manager.get_access_token() + if not token: + logger.error(f"[{self.account.email}] 获取 IMAP Token 失败") + return False + + try: + # 构建 XOAUTH2 认证字符串 + auth_string = f"user={self.account.email}\x01auth=Bearer {token}\x01\x01" + self._conn.authenticate("XOAUTH2", lambda _: auth_string.encode("utf-8")) + return True + except Exception as e: + logger.error(f"[{self.account.email}] XOAUTH2 认证异常: {e}") + # 清除缓存的 Token + self._token_manager.clear_cache() + return False + + def disconnect(self): + """断开 IMAP 连接""" + if self._conn: + try: + self._conn.close() + except Exception: + pass + try: + self._conn.logout() + except Exception: + pass + self._conn = None + + self._connected = False + + def get_recent_emails( + self, + count: int = 20, + only_unseen: bool = True, + ) -> List[EmailMessage]: + """ + 获取最近的邮件 + + Args: + count: 获取数量 + only_unseen: 是否只获取未读 + + Returns: + 邮件列表 + """ + if not self._connected: + if not self.connect(): + return [] + + try: + # 选择收件箱 + self._conn.select("INBOX", readonly=True) + + # 搜索邮件 + flag = "UNSEEN" if only_unseen else "ALL" + status, data = self._conn.search(None, flag) + + if status != "OK" or not data or not data[0]: + return [] + + # 获取最新的邮件 ID + ids = data[0].split() + recent_ids = ids[-count:][::-1] + + emails = [] + for msg_id in recent_ids: + try: + email_msg = self._fetch_email(msg_id) + if email_msg: + emails.append(email_msg) + except Exception as e: + logger.warning(f"[{self.account.email}] 解析邮件失败 (ID: {msg_id}): {e}") + + return emails + + except Exception as e: + self.record_failure(str(e)) + logger.error(f"[{self.account.email}] 获取邮件失败: {e}") + return [] + + def _fetch_email(self, msg_id: bytes) -> Optional[EmailMessage]: + """获取并解析单封邮件""" + status, data = self._conn.fetch(msg_id, "(RFC822)") + if status != "OK" or not data or not data[0]: + return None + + raw = b"" + for part in data: + if isinstance(part, tuple) and len(part) > 1: + raw = part[1] + break + + if not raw: + return None + + return self._parse_email(raw) + + @staticmethod + def _parse_email(raw: bytes) -> EmailMessage: + """解析原始邮件""" + # 使用旧版提供者的解析方法 + return IMAPOldProvider._parse_email(raw) + + def test_connection(self) -> bool: + """测试 IMAP 连接""" + try: + with self: + self._conn.select("INBOX", readonly=True) + self._conn.search(None, "ALL") + return True + except Exception as e: + logger.warning(f"[{self.account.email}] 新版 IMAP 连接测试失败: {e}") + return False diff --git a/src/services/outlook/providers/imap_old.py b/src/services/outlook/providers/imap_old.py new file mode 100644 index 0000000..e46f3ed --- /dev/null +++ b/src/services/outlook/providers/imap_old.py @@ -0,0 +1,345 @@ +""" +旧版 IMAP 提供者 +使用 outlook.office365.com 服务器和 login.live.com Token 端点 +""" + +import email +import imaplib +import logging +from email.header import decode_header +from email.utils import parsedate_to_datetime +from typing import List, Optional + +from ..base import ProviderType, EmailMessage +from ..account import OutlookAccount +from ..token_manager import TokenManager +from .base import OutlookProvider, ProviderConfig + + +logger = logging.getLogger(__name__) + + +class IMAPOldProvider(OutlookProvider): + """ + 旧版 IMAP 提供者 + 使用 outlook.office365.com:993 和 login.live.com Token 端点 + """ + + # IMAP 服务器配置 + IMAP_HOST = "outlook.office365.com" + IMAP_PORT = 993 + + @property + def provider_type(self) -> ProviderType: + return ProviderType.IMAP_OLD + + def __init__( + self, + account: OutlookAccount, + config: Optional[ProviderConfig] = None, + ): + super().__init__(account, config) + + # IMAP 连接 + self._conn: Optional[imaplib.IMAP4_SSL] = None + + # Token 管理器 + self._token_manager: Optional[TokenManager] = None + + def connect(self) -> bool: + """ + 连接到 IMAP 服务器 + + Returns: + 是否连接成功 + """ + if self._connected and self._conn: + # 检查现有连接 + try: + self._conn.noop() + return True + except Exception: + self.disconnect() + + try: + logger.debug(f"[{self.account.email}] 正在连接 IMAP ({self.IMAP_HOST})...") + + # 创建连接 + self._conn = imaplib.IMAP4_SSL( + self.IMAP_HOST, + self.IMAP_PORT, + timeout=self.config.timeout, + ) + + # 尝试 XOAUTH2 认证 + if self.account.has_oauth(): + if self._authenticate_xoauth2(): + self._connected = True + self.record_success() + logger.info(f"[{self.account.email}] IMAP 连接成功 (XOAUTH2)") + return True + else: + logger.warning(f"[{self.account.email}] XOAUTH2 认证失败,尝试密码认证") + + # 密码认证 + if self.account.password: + self._conn.login(self.account.email, self.account.password) + self._connected = True + self.record_success() + logger.info(f"[{self.account.email}] IMAP 连接成功 (密码认证)") + return True + + raise ValueError("没有可用的认证方式") + + except Exception as e: + self.disconnect() + self.record_failure(str(e)) + logger.error(f"[{self.account.email}] IMAP 连接失败: {e}") + return False + + def _authenticate_xoauth2(self) -> bool: + """ + 使用 XOAUTH2 认证 + + Returns: + 是否认证成功 + """ + if not self._token_manager: + self._token_manager = TokenManager( + self.account, + ProviderType.IMAP_OLD, + self.config.proxy_url, + self.config.timeout, + ) + + # 获取 Access Token + token = self._token_manager.get_access_token() + if not token: + return False + + try: + # 构建 XOAUTH2 认证字符串 + auth_string = f"user={self.account.email}\x01auth=Bearer {token}\x01\x01" + self._conn.authenticate("XOAUTH2", lambda _: auth_string.encode("utf-8")) + return True + except Exception as e: + logger.debug(f"[{self.account.email}] XOAUTH2 认证异常: {e}") + # 清除缓存的 Token + self._token_manager.clear_cache() + return False + + def disconnect(self): + """断开 IMAP 连接""" + if self._conn: + try: + self._conn.close() + except Exception: + pass + try: + self._conn.logout() + except Exception: + pass + self._conn = None + + self._connected = False + + def get_recent_emails( + self, + count: int = 20, + only_unseen: bool = True, + ) -> List[EmailMessage]: + """ + 获取最近的邮件 + + Args: + count: 获取数量 + only_unseen: 是否只获取未读 + + Returns: + 邮件列表 + """ + if not self._connected: + if not self.connect(): + return [] + + try: + # 选择收件箱 + self._conn.select("INBOX", readonly=True) + + # 搜索邮件 + flag = "UNSEEN" if only_unseen else "ALL" + status, data = self._conn.search(None, flag) + + if status != "OK" or not data or not data[0]: + return [] + + # 获取最新的邮件 ID + ids = data[0].split() + recent_ids = ids[-count:][::-1] # 倒序,最新的在前 + + emails = [] + for msg_id in recent_ids: + try: + email_msg = self._fetch_email(msg_id) + if email_msg: + emails.append(email_msg) + except Exception as e: + logger.warning(f"[{self.account.email}] 解析邮件失败 (ID: {msg_id}): {e}") + + return emails + + except Exception as e: + self.record_failure(str(e)) + logger.error(f"[{self.account.email}] 获取邮件失败: {e}") + return [] + + def _fetch_email(self, msg_id: bytes) -> Optional[EmailMessage]: + """ + 获取并解析单封邮件 + + Args: + msg_id: 邮件 ID + + Returns: + EmailMessage 对象,失败返回 None + """ + status, data = self._conn.fetch(msg_id, "(RFC822)") + if status != "OK" or not data or not data[0]: + return None + + # 获取原始邮件内容 + raw = b"" + for part in data: + if isinstance(part, tuple) and len(part) > 1: + raw = part[1] + break + + if not raw: + return None + + return self._parse_email(raw) + + @staticmethod + def _parse_email(raw: bytes) -> EmailMessage: + """ + 解析原始邮件 + + Args: + raw: 原始邮件数据 + + Returns: + EmailMessage 对象 + """ + # 移除 BOM + if raw.startswith(b"\xef\xbb\xbf"): + raw = raw[3:] + + msg = email.message_from_bytes(raw) + + # 解析邮件头 + subject = IMAPOldProvider._decode_header(msg.get("Subject", "")) + sender = IMAPOldProvider._decode_header(msg.get("From", "")) + to = IMAPOldProvider._decode_header(msg.get("To", "")) + delivered_to = IMAPOldProvider._decode_header(msg.get("Delivered-To", "")) + x_original_to = IMAPOldProvider._decode_header(msg.get("X-Original-To", "")) + date_str = IMAPOldProvider._decode_header(msg.get("Date", "")) + + # 提取正文 + body = IMAPOldProvider._extract_body(msg) + + # 解析日期 + received_timestamp = 0 + received_at = None + try: + if date_str: + received_at = parsedate_to_datetime(date_str) + received_timestamp = int(received_at.timestamp()) + except Exception: + pass + + # 构建收件人列表 + recipients = [r for r in [to, delivered_to, x_original_to] if r] + + return EmailMessage( + id=msg.get("Message-ID", ""), + subject=subject, + sender=sender, + recipients=recipients, + body=body, + received_at=received_at, + received_timestamp=received_timestamp, + is_read=False, # 搜索的是未读邮件 + raw_data=raw[:500] if len(raw) > 500 else raw, + ) + + @staticmethod + def _decode_header(header: str) -> str: + """解码邮件头""" + if not header: + return "" + + parts = [] + for chunk, encoding in decode_header(header): + if isinstance(chunk, bytes): + try: + decoded = chunk.decode(encoding or "utf-8", errors="replace") + parts.append(decoded) + except Exception: + parts.append(chunk.decode("utf-8", errors="replace")) + else: + parts.append(str(chunk)) + + return "".join(parts).strip() + + @staticmethod + def _extract_body(msg) -> str: + """提取邮件正文""" + import html as html_module + import re + + texts = [] + parts = msg.walk() if msg.is_multipart() else [msg] + + for part in parts: + content_type = part.get_content_type() + if content_type not in ("text/plain", "text/html"): + continue + + payload = part.get_payload(decode=True) + if not payload: + continue + + charset = part.get_content_charset() or "utf-8" + try: + text = payload.decode(charset, errors="replace") + except LookupError: + text = payload.decode("utf-8", errors="replace") + + # 如果是 HTML,移除标签 + if "]+>", " ", text) + + texts.append(text) + + # 合并并清理文本 + combined = " ".join(texts) + combined = html_module.unescape(combined) + combined = re.sub(r"\s+", " ", combined).strip() + + return combined + + def test_connection(self) -> bool: + """ + 测试 IMAP 连接 + + Returns: + 连接是否正常 + """ + try: + with self: + self._conn.select("INBOX", readonly=True) + self._conn.search(None, "ALL") + return True + except Exception as e: + logger.warning(f"[{self.account.email}] IMAP 连接测试失败: {e}") + return False diff --git a/src/services/outlook/service.py b/src/services/outlook/service.py new file mode 100644 index 0000000..b700a29 --- /dev/null +++ b/src/services/outlook/service.py @@ -0,0 +1,485 @@ +""" +Outlook 邮箱服务主类 +支持多种 IMAP/API 连接方式,自动故障切换 +""" + +import logging +import threading +import time +from typing import Optional, Dict, Any, List + +from ..base import BaseEmailService, EmailServiceError, EmailServiceStatus, EmailServiceType +from ...config.constants import EmailServiceType as ServiceType +from ...config.settings import get_settings +from .account import OutlookAccount +from .base import ProviderType, EmailMessage +from .email_parser import EmailParser, get_email_parser +from .health_checker import HealthChecker, FailoverManager +from .providers.base import OutlookProvider, ProviderConfig +from .providers.imap_old import IMAPOldProvider +from .providers.imap_new import IMAPNewProvider +from .providers.graph_api import GraphAPIProvider + + +logger = logging.getLogger(__name__) + + +# 默认提供者优先级 +DEFAULT_PROVIDER_PRIORITY = [ + ProviderType.GRAPH_API, + ProviderType.IMAP_NEW, + ProviderType.IMAP_OLD, +] + + +def get_email_code_settings() -> dict: + """获取验证码等待配置""" + settings = get_settings() + return { + "timeout": settings.email_code_timeout, + "poll_interval": settings.email_code_poll_interval, + } + + +class OutlookService(BaseEmailService): + """ + Outlook 邮箱服务 + 支持多种 IMAP/API 连接方式,自动故障切换 + """ + + def __init__(self, config: Dict[str, Any] = None, name: str = None): + """ + 初始化 Outlook 服务 + + Args: + config: 配置字典,支持以下键: + - accounts: Outlook 账户列表 + - provider_priority: 提供者优先级列表 + - health_failure_threshold: 连续失败次数阈值 + - health_disable_duration: 禁用时长(秒) + - timeout: 请求超时时间 + - proxy_url: 代理 URL + name: 服务名称 + """ + super().__init__(ServiceType.OUTLOOK, name) + + # 默认配置 + default_config = { + "accounts": [], + "provider_priority": [p.value for p in DEFAULT_PROVIDER_PRIORITY], + "health_failure_threshold": 5, + "health_disable_duration": 60, + "timeout": 30, + "proxy_url": None, + } + + self.config = {**default_config, **(config or {})} + + # 解析提供者优先级 + self.provider_priority = [ + ProviderType(p) for p in self.config.get("provider_priority", []) + ] + if not self.provider_priority: + self.provider_priority = DEFAULT_PROVIDER_PRIORITY + + # 提供者配置 + self.provider_config = ProviderConfig( + timeout=self.config.get("timeout", 30), + proxy_url=self.config.get("proxy_url"), + health_failure_threshold=self.config.get("health_failure_threshold", 3), + health_disable_duration=self.config.get("health_disable_duration", 300), + ) + + # 获取默认 client_id(供无 client_id 的账户使用) + try: + _default_client_id = get_settings().outlook_default_client_id + except Exception: + _default_client_id = "24d9a0ed-8787-4584-883c-2fd79308940a" + + # 解析账户 + self.accounts: List[OutlookAccount] = [] + self._current_account_index = 0 + self._account_lock = threading.Lock() + + # 支持两种配置格式 + if "email" in self.config and "password" in self.config: + account = OutlookAccount.from_config(self.config) + if not account.client_id and _default_client_id: + account.client_id = _default_client_id + if account.validate(): + self.accounts.append(account) + else: + for account_config in self.config.get("accounts", []): + account = OutlookAccount.from_config(account_config) + if not account.client_id and _default_client_id: + account.client_id = _default_client_id + if account.validate(): + self.accounts.append(account) + + if not self.accounts: + logger.warning("未配置有效的 Outlook 账户") + + # 健康检查器和故障切换管理器 + self.health_checker = HealthChecker( + failure_threshold=self.provider_config.health_failure_threshold, + disable_duration=self.provider_config.health_disable_duration, + ) + self.failover_manager = FailoverManager( + health_checker=self.health_checker, + priority_order=self.provider_priority, + ) + + # 邮件解析器 + self.email_parser = get_email_parser() + + # 提供者实例缓存: (email, provider_type) -> OutlookProvider + self._providers: Dict[tuple, OutlookProvider] = {} + self._provider_lock = threading.Lock() + + # IMAP 连接限制(防止限流) + self._imap_semaphore = threading.Semaphore(5) + + # 验证码去重机制 + self._used_codes: Dict[str, set] = {} + + def _get_provider( + self, + account: OutlookAccount, + provider_type: ProviderType, + ) -> OutlookProvider: + """ + 获取或创建提供者实例 + + Args: + account: Outlook 账户 + provider_type: 提供者类型 + + Returns: + 提供者实例 + """ + cache_key = (account.email.lower(), provider_type) + + with self._provider_lock: + if cache_key not in self._providers: + provider = self._create_provider(account, provider_type) + self._providers[cache_key] = provider + + return self._providers[cache_key] + + def _create_provider( + self, + account: OutlookAccount, + provider_type: ProviderType, + ) -> OutlookProvider: + """ + 创建提供者实例 + + Args: + account: Outlook 账户 + provider_type: 提供者类型 + + Returns: + 提供者实例 + """ + if provider_type == ProviderType.IMAP_OLD: + return IMAPOldProvider(account, self.provider_config) + elif provider_type == ProviderType.IMAP_NEW: + return IMAPNewProvider(account, self.provider_config) + elif provider_type == ProviderType.GRAPH_API: + return GraphAPIProvider(account, self.provider_config) + else: + raise ValueError(f"未知的提供者类型: {provider_type}") + + def _get_provider_priority_for_account(self, account: OutlookAccount) -> List[ProviderType]: + """根据账户是否有 OAuth,返回适合的提供者优先级列表""" + if account.has_oauth(): + return self.provider_priority + else: + # 无 OAuth,直接走旧版 IMAP(密码认证),跳过需要 OAuth 的提供者 + return [ProviderType.IMAP_OLD] + + def _try_providers_for_emails( + self, + account: OutlookAccount, + count: int = 20, + only_unseen: bool = True, + ) -> List[EmailMessage]: + """ + 尝试多个提供者获取邮件 + + Args: + account: Outlook 账户 + count: 获取数量 + only_unseen: 是否只获取未读 + + Returns: + 邮件列表 + """ + errors = [] + + # 根据账户类型选择合适的提供者优先级 + priority = self._get_provider_priority_for_account(account) + + # 按优先级尝试各提供者 + for provider_type in priority: + # 检查提供者是否可用 + if not self.health_checker.is_available(provider_type): + logger.debug( + f"[{account.email}] {provider_type.value} 不可用,跳过" + ) + continue + + try: + provider = self._get_provider(account, provider_type) + + with self._imap_semaphore: + with provider: + emails = provider.get_recent_emails(count, only_unseen) + + if emails: + # 成功获取邮件 + self.health_checker.record_success(provider_type) + logger.debug( + f"[{account.email}] {provider_type.value} 获取到 {len(emails)} 封邮件" + ) + return emails + + except Exception as e: + error_msg = str(e) + errors.append(f"{provider_type.value}: {error_msg}") + self.health_checker.record_failure(provider_type, error_msg) + logger.warning( + f"[{account.email}] {provider_type.value} 获取邮件失败: {e}" + ) + + logger.error( + f"[{account.email}] 所有提供者都失败: {'; '.join(errors)}" + ) + return [] + + def create_email(self, config: Dict[str, Any] = None) -> Dict[str, Any]: + """ + 选择可用的 Outlook 账户 + + Args: + config: 配置参数(未使用) + + Returns: + 包含邮箱信息的字典 + """ + if not self.accounts: + self.update_status(False, EmailServiceError("没有可用的 Outlook 账户")) + raise EmailServiceError("没有可用的 Outlook 账户") + + # 轮询选择账户 + with self._account_lock: + account = self.accounts[self._current_account_index] + self._current_account_index = (self._current_account_index + 1) % len(self.accounts) + + email_info = { + "email": account.email, + "service_id": account.email, + "account": { + "email": account.email, + "has_oauth": account.has_oauth() + } + } + + logger.info(f"选择 Outlook 账户: {account.email}") + self.update_status(True) + return email_info + + def get_verification_code( + self, + email: str, + email_id: str = None, + timeout: int = None, + pattern: str = None, + otp_sent_at: Optional[float] = None, + ) -> Optional[str]: + """ + 从 Outlook 邮箱获取验证码 + + Args: + email: 邮箱地址 + email_id: 未使用 + timeout: 超时时间(秒) + pattern: 验证码正则表达式(未使用) + otp_sent_at: OTP 发送时间戳 + + Returns: + 验证码字符串 + """ + # 查找对应的账户 + account = None + for acc in self.accounts: + if acc.email.lower() == email.lower(): + account = acc + break + + if not account: + self.update_status(False, EmailServiceError(f"未找到邮箱对应的账户: {email}")) + return None + + # 获取验证码等待配置 + code_settings = get_email_code_settings() + actual_timeout = timeout or code_settings["timeout"] + poll_interval = code_settings["poll_interval"] + + logger.info( + f"[{email}] 开始获取验证码,超时 {actual_timeout}s," + f"提供者优先级: {[p.value for p in self.provider_priority]}" + ) + + # 初始化验证码去重集合 + if email not in self._used_codes: + self._used_codes[email] = set() + used_codes = self._used_codes[email] + + # 计算最小时间戳(留出 60 秒时钟偏差) + min_timestamp = (otp_sent_at - 60) if otp_sent_at else 0 + + start_time = time.time() + poll_count = 0 + + while time.time() - start_time < actual_timeout: + poll_count += 1 + + # 渐进式邮件检查:前 3 次只检查未读 + only_unseen = poll_count <= 3 + + try: + # 尝试多个提供者获取邮件 + emails = self._try_providers_for_emails( + account, + count=15, + only_unseen=only_unseen, + ) + + if emails: + logger.debug( + f"[{email}] 第 {poll_count} 次轮询获取到 {len(emails)} 封邮件" + ) + + # 从邮件中查找验证码 + code = self.email_parser.find_verification_code_in_emails( + emails, + target_email=email, + min_timestamp=min_timestamp, + used_codes=used_codes, + ) + + if code: + used_codes.add(code) + elapsed = int(time.time() - start_time) + logger.info( + f"[{email}] 找到验证码: {code}," + f"总耗时 {elapsed}s,轮询 {poll_count} 次" + ) + self.update_status(True) + return code + + except Exception as e: + logger.warning(f"[{email}] 检查出错: {e}") + + # 等待下次轮询 + time.sleep(poll_interval) + + elapsed = int(time.time() - start_time) + logger.warning(f"[{email}] 验证码超时 ({actual_timeout}s),共轮询 {poll_count} 次") + return None + + def list_emails(self, **kwargs) -> List[Dict[str, Any]]: + """列出所有可用的 Outlook 账户""" + return [ + { + "email": account.email, + "id": account.email, + "has_oauth": account.has_oauth(), + "type": "outlook" + } + for account in self.accounts + ] + + def delete_email(self, email_id: str) -> bool: + """删除邮箱(Outlook 不支持删除账户)""" + logger.warning(f"Outlook 服务不支持删除账户: {email_id}") + return False + + def check_health(self) -> bool: + """检查 Outlook 服务是否可用""" + if not self.accounts: + self.update_status(False, EmailServiceError("没有配置的账户")) + return False + + # 测试第一个账户的连接 + test_account = self.accounts[0] + + # 尝试任一提供者连接 + for provider_type in self.provider_priority: + try: + provider = self._get_provider(test_account, provider_type) + if provider.test_connection(): + self.update_status(True) + return True + except Exception as e: + logger.warning( + f"Outlook 健康检查失败 ({test_account.email}, {provider_type.value}): {e}" + ) + + self.update_status(False, EmailServiceError("健康检查失败")) + return False + + def get_provider_status(self) -> Dict[str, Any]: + """获取提供者状态""" + return self.failover_manager.get_status() + + def get_account_stats(self) -> Dict[str, Any]: + """获取账户统计信息""" + total = len(self.accounts) + oauth_count = sum(1 for acc in self.accounts if acc.has_oauth()) + + return { + "total_accounts": total, + "oauth_accounts": oauth_count, + "password_accounts": total - oauth_count, + "accounts": [acc.to_dict() for acc in self.accounts], + "provider_status": self.get_provider_status(), + } + + def add_account(self, account_config: Dict[str, Any]) -> bool: + """添加新的 Outlook 账户""" + try: + account = OutlookAccount.from_config(account_config) + if not account.validate(): + return False + + self.accounts.append(account) + logger.info(f"添加 Outlook 账户: {account.email}") + return True + except Exception as e: + logger.error(f"添加 Outlook 账户失败: {e}") + return False + + def remove_account(self, email: str) -> bool: + """移除 Outlook 账户""" + for i, acc in enumerate(self.accounts): + if acc.email.lower() == email.lower(): + self.accounts.pop(i) + logger.info(f"移除 Outlook 账户: {email}") + return True + return False + + def reset_provider_health(self): + """重置所有提供者的健康状态""" + self.health_checker.reset_all() + logger.info("已重置所有提供者的健康状态") + + def force_provider(self, provider_type: ProviderType): + """强制使用指定的提供者""" + self.health_checker.force_enable(provider_type) + # 禁用其他提供者 + for pt in ProviderType: + if pt != provider_type: + self.health_checker.force_disable(pt, 60) + logger.info(f"已强制使用提供者: {provider_type.value}") diff --git a/src/services/outlook/token_manager.py b/src/services/outlook/token_manager.py new file mode 100644 index 0000000..77e54f2 --- /dev/null +++ b/src/services/outlook/token_manager.py @@ -0,0 +1,239 @@ +""" +Token 管理器 +支持多个 Microsoft Token 端点,自动选择合适的端点 +""" + +import json +import logging +import threading +import time +from typing import Dict, Optional, Any + +from curl_cffi import requests as _requests + +from .base import ProviderType, TokenEndpoint, TokenInfo +from .account import OutlookAccount + + +logger = logging.getLogger(__name__) + + +# 各提供者的 Scope 配置 +PROVIDER_SCOPES = { + ProviderType.IMAP_OLD: "", # 旧版 IMAP 不需要特定 scope + ProviderType.IMAP_NEW: "https://outlook.office.com/IMAP.AccessAsUser.All offline_access", + ProviderType.GRAPH_API: "https://graph.microsoft.com/.default", +} + +# 各提供者的 Token 端点 +PROVIDER_TOKEN_URLS = { + ProviderType.IMAP_OLD: TokenEndpoint.LIVE.value, + ProviderType.IMAP_NEW: TokenEndpoint.CONSUMERS.value, + ProviderType.GRAPH_API: TokenEndpoint.COMMON.value, +} + + +class TokenManager: + """ + Token 管理器 + 支持多端点 Token 获取和缓存 + """ + + # Token 缓存: key = (email, provider_type) -> TokenInfo + _token_cache: Dict[tuple, TokenInfo] = {} + _cache_lock = threading.Lock() + + # 默认超时时间 + DEFAULT_TIMEOUT = 30 + # Token 刷新提前时间(秒) + REFRESH_BUFFER = 120 + + def __init__( + self, + account: OutlookAccount, + provider_type: ProviderType, + proxy_url: Optional[str] = None, + timeout: int = DEFAULT_TIMEOUT, + ): + """ + 初始化 Token 管理器 + + Args: + account: Outlook 账户 + provider_type: 提供者类型 + proxy_url: 代理 URL(可选) + timeout: 请求超时时间 + """ + self.account = account + self.provider_type = provider_type + self.proxy_url = proxy_url + self.timeout = timeout + + # 获取端点和 Scope + self.token_url = PROVIDER_TOKEN_URLS.get(provider_type, TokenEndpoint.LIVE.value) + self.scope = PROVIDER_SCOPES.get(provider_type, "") + + def get_cached_token(self) -> Optional[TokenInfo]: + """获取缓存的 Token""" + cache_key = (self.account.email.lower(), self.provider_type) + with self._cache_lock: + token = self._token_cache.get(cache_key) + if token and not token.is_expired(self.REFRESH_BUFFER): + return token + return None + + def set_cached_token(self, token: TokenInfo): + """缓存 Token""" + cache_key = (self.account.email.lower(), self.provider_type) + with self._cache_lock: + self._token_cache[cache_key] = token + + def clear_cache(self): + """清除缓存""" + cache_key = (self.account.email.lower(), self.provider_type) + with self._cache_lock: + self._token_cache.pop(cache_key, None) + + def get_access_token(self, force_refresh: bool = False) -> Optional[str]: + """ + 获取 Access Token + + Args: + force_refresh: 是否强制刷新 + + Returns: + Access Token 字符串,失败返回 None + """ + # 检查缓存 + if not force_refresh: + cached = self.get_cached_token() + if cached: + logger.debug(f"[{self.account.email}] 使用缓存的 Token ({self.provider_type.value})") + return cached.access_token + + # 刷新 Token + try: + token = self._refresh_token() + if token: + self.set_cached_token(token) + return token.access_token + except Exception as e: + logger.error(f"[{self.account.email}] 获取 Token 失败 ({self.provider_type.value}): {e}") + + return None + + def _refresh_token(self) -> Optional[TokenInfo]: + """ + 刷新 Token + + Returns: + TokenInfo 对象,失败返回 None + """ + if not self.account.client_id or not self.account.refresh_token: + raise ValueError("缺少 client_id 或 refresh_token") + + logger.debug(f"[{self.account.email}] 正在刷新 Token ({self.provider_type.value})...") + logger.debug(f"[{self.account.email}] Token URL: {self.token_url}") + + # 构建请求体 + data = { + "client_id": self.account.client_id, + "refresh_token": self.account.refresh_token, + "grant_type": "refresh_token", + } + + # 添加 Scope(如果需要) + if self.scope: + data["scope"] = self.scope + + headers = { + "Content-Type": "application/x-www-form-urlencoded", + "Accept": "application/json", + } + + proxies = None + if self.proxy_url: + proxies = {"http": self.proxy_url, "https": self.proxy_url} + + try: + resp = _requests.post( + self.token_url, + data=data, + headers=headers, + proxies=proxies, + timeout=self.timeout, + impersonate="chrome110", + ) + + if resp.status_code != 200: + error_body = resp.text + logger.error(f"[{self.account.email}] Token 刷新失败: HTTP {resp.status_code}") + logger.debug(f"[{self.account.email}] 错误响应: {error_body[:500]}") + + if "service abuse" in error_body.lower(): + logger.warning(f"[{self.account.email}] 账号可能被封禁") + elif "invalid_grant" in error_body.lower(): + logger.warning(f"[{self.account.email}] Refresh Token 已失效") + + return None + + response_data = resp.json() + + # 解析响应 + token = TokenInfo.from_response(response_data, self.scope) + logger.info( + f"[{self.account.email}] Token 刷新成功 ({self.provider_type.value}), " + f"有效期 {int(token.expires_at - time.time())} 秒" + ) + return token + + except json.JSONDecodeError as e: + logger.error(f"[{self.account.email}] JSON 解析错误: {e}") + return None + + except Exception as e: + logger.error(f"[{self.account.email}] 未知错误: {e}") + return None + + @classmethod + def clear_all_cache(cls): + """清除所有 Token 缓存""" + with cls._cache_lock: + cls._token_cache.clear() + logger.info("已清除所有 Token 缓存") + + @classmethod + def get_cache_stats(cls) -> Dict[str, Any]: + """获取缓存统计""" + with cls._cache_lock: + return { + "cache_size": len(cls._token_cache), + "entries": [ + { + "email": key[0], + "provider": key[1].value, + } + for key in cls._token_cache.keys() + ], + } + + +def create_token_manager( + account: OutlookAccount, + provider_type: ProviderType, + proxy_url: Optional[str] = None, + timeout: int = TokenManager.DEFAULT_TIMEOUT, +) -> TokenManager: + """ + 创建 Token 管理器的工厂函数 + + Args: + account: Outlook 账户 + provider_type: 提供者类型 + proxy_url: 代理 URL + timeout: 超时时间 + + Returns: + TokenManager 实例 + """ + return TokenManager(account, provider_type, proxy_url, timeout) diff --git a/src/services/outlook.py b/src/services/outlook_legacy.py similarity index 100% rename from src/services/outlook.py rename to src/services/outlook_legacy.py diff --git a/src/web/routes/settings.py b/src/web/routes/settings.py index 3f8f6ea..c5fc99f 100644 --- a/src/web/routes/settings.py +++ b/src/web/routes/settings.py @@ -734,3 +734,37 @@ async def test_cpa_connection(request: CPATestRequest): "success": success, "message": message } + + +# ============== Outlook 设置 ============== + +class OutlookSettings(BaseModel): + """Outlook 设置""" + default_client_id: Optional[str] = None + + +@router.get("/outlook") +async def get_outlook_settings(): + """获取 Outlook 设置""" + settings = get_settings() + + return { + "default_client_id": settings.outlook_default_client_id, + "provider_priority": settings.outlook_provider_priority, + "health_failure_threshold": settings.outlook_health_failure_threshold, + "health_disable_duration": settings.outlook_health_disable_duration, + } + + +@router.post("/outlook") +async def update_outlook_settings(request: OutlookSettings): + """更新 Outlook 设置""" + update_dict = {} + + if request.default_client_id is not None: + update_dict["outlook_default_client_id"] = request.default_client_id + + if update_dict: + update_settings(**update_dict) + + return {"success": True, "message": "Outlook 设置已更新"} diff --git a/src/web/routes/websocket.py b/src/web/routes/websocket.py new file mode 100644 index 0000000..decea56 --- /dev/null +++ b/src/web/routes/websocket.py @@ -0,0 +1,170 @@ +""" +WebSocket 路由 +提供实时日志推送和任务状态更新 +""" + +import asyncio +import logging +from fastapi import APIRouter, WebSocket, WebSocketDisconnect + +from ..task_manager import task_manager + +logger = logging.getLogger(__name__) +router = APIRouter() + + +@router.websocket("/ws/task/{task_uuid}") +async def task_websocket(websocket: WebSocket, task_uuid: str): + """ + 任务日志 WebSocket + + 消息格式: + - 服务端发送: {"type": "log", "task_uuid": "xxx", "message": "...", "timestamp": "..."} + - 服务端发送: {"type": "status", "task_uuid": "xxx", "status": "running|completed|failed|cancelled", ...} + - 客户端发送: {"type": "ping"} - 心跳 + - 客户端发送: {"type": "cancel"} - 取消任务 + """ + await websocket.accept() + + # 注册连接(会记录当前日志数量,避免重复发送历史日志) + task_manager.register_websocket(task_uuid, websocket) + logger.info(f"WebSocket 连接已建立: {task_uuid}") + + try: + # 发送当前状态 + status = task_manager.get_status(task_uuid) + if status: + await websocket.send_json({ + "type": "status", + "task_uuid": task_uuid, + **status + }) + + # 发送历史日志(只发送注册时已存在的日志,避免与实时推送重复) + history_logs = task_manager.get_unsent_logs(task_uuid, websocket) + for log in history_logs: + await websocket.send_json({ + "type": "log", + "task_uuid": task_uuid, + "message": log + }) + + # 保持连接,等待客户端消息 + while True: + try: + # 使用 wait_for 实现超时,但不是断开连接 + # 而是发送心跳检测 + data = await asyncio.wait_for( + websocket.receive_json(), + timeout=30.0 # 30秒超时 + ) + + # 处理心跳 + if data.get("type") == "ping": + await websocket.send_json({"type": "pong"}) + + # 处理取消请求 + elif data.get("type") == "cancel": + task_manager.cancel_task(task_uuid) + await websocket.send_json({ + "type": "status", + "task_uuid": task_uuid, + "status": "cancelling", + "message": "取消请求已提交" + }) + + except asyncio.TimeoutError: + # 超时,发送心跳检测 + try: + await websocket.send_json({"type": "ping"}) + except Exception: + # 发送失败,可能是连接断开 + logger.info(f"WebSocket 心跳检测失败: {task_uuid}") + break + + except WebSocketDisconnect: + logger.info(f"WebSocket 断开: {task_uuid}") + + except Exception as e: + logger.error(f"WebSocket 错误: {e}") + + finally: + task_manager.unregister_websocket(task_uuid, websocket) + + +@router.websocket("/ws/batch/{batch_id}") +async def batch_websocket(websocket: WebSocket, batch_id: str): + """ + 批量任务 WebSocket + + 用于批量注册任务的实时状态更新 + + 消息格式: + - 服务端发送: {"type": "log", "batch_id": "xxx", "message": "...", "timestamp": "..."} + - 服务端发送: {"type": "status", "batch_id": "xxx", "status": "running|completed|cancelled", ...} + - 客户端发送: {"type": "ping"} - 心跳 + - 客户端发送: {"type": "cancel"} - 取消批量任务 + """ + await websocket.accept() + + # 注册连接(会记录当前日志数量,避免重复发送历史日志) + task_manager.register_batch_websocket(batch_id, websocket) + logger.info(f"批量任务 WebSocket 连接已建立: {batch_id}") + + try: + # 发送当前状态 + status = task_manager.get_batch_status(batch_id) + if status: + await websocket.send_json({ + "type": "status", + "batch_id": batch_id, + **status + }) + + # 发送历史日志(只发送注册时已存在的日志,避免与实时推送重复) + history_logs = task_manager.get_unsent_batch_logs(batch_id, websocket) + for log in history_logs: + await websocket.send_json({ + "type": "log", + "batch_id": batch_id, + "message": log + }) + + # 保持连接,等待客户端消息 + while True: + try: + data = await asyncio.wait_for( + websocket.receive_json(), + timeout=30.0 + ) + + # 处理心跳 + if data.get("type") == "ping": + await websocket.send_json({"type": "pong"}) + + # 处理取消请求 + elif data.get("type") == "cancel": + task_manager.cancel_batch(batch_id) + await websocket.send_json({ + "type": "status", + "batch_id": batch_id, + "status": "cancelling", + "message": "取消请求已提交" + }) + + except asyncio.TimeoutError: + # 超时,发送心跳检测 + try: + await websocket.send_json({"type": "ping"}) + except Exception: + logger.info(f"批量任务 WebSocket 心跳检测失败: {batch_id}") + break + + except WebSocketDisconnect: + logger.info(f"批量任务 WebSocket 断开: {batch_id}") + + except Exception as e: + logger.error(f"批量任务 WebSocket 错误: {e}") + + finally: + task_manager.unregister_batch_websocket(batch_id, websocket) diff --git a/src/web/task_manager.py b/src/web/task_manager.py new file mode 100644 index 0000000..8079600 --- /dev/null +++ b/src/web/task_manager.py @@ -0,0 +1,361 @@ +""" +任务管理器 +负责管理后台任务、日志队列和 WebSocket 推送 +""" + +import asyncio +import logging +import threading +from concurrent.futures import ThreadPoolExecutor +from typing import Dict, Optional, List, Callable, Any +from collections import defaultdict +from datetime import datetime + +logger = logging.getLogger(__name__) + +# 全局线程池 +_executor = ThreadPoolExecutor(max_workers=4, thread_name_prefix="reg_worker") + +# 任务日志队列 (task_uuid -> list of logs) +_log_queues: Dict[str, List[str]] = defaultdict(list) +_log_locks: Dict[str, threading.Lock] = defaultdict(threading.Lock) + +# WebSocket 连接管理 (task_uuid -> list of websockets) +_ws_connections: Dict[str, List] = defaultdict(list) +_ws_lock = threading.Lock() + +# WebSocket 已发送日志索引 (task_uuid -> {websocket: sent_count}) +_ws_sent_index: Dict[str, Dict] = defaultdict(dict) + +# 任务状态 +_task_status: Dict[str, dict] = {} + +# 任务取消标志 +_task_cancelled: Dict[str, bool] = {} + +# 批量任务状态 (batch_id -> dict) +_batch_status: Dict[str, dict] = {} +_batch_logs: Dict[str, List[str]] = defaultdict(list) +_batch_locks: Dict[str, threading.Lock] = defaultdict(threading.Lock) + + +class TaskManager: + """任务管理器""" + + def __init__(self): + self.executor = _executor + self._loop: Optional[asyncio.AbstractEventLoop] = None + + def set_loop(self, loop: asyncio.AbstractEventLoop): + """设置事件循环(在 FastAPI 启动时调用)""" + self._loop = loop + + def get_loop(self) -> Optional[asyncio.AbstractEventLoop]: + """获取事件循环""" + return self._loop + + def is_cancelled(self, task_uuid: str) -> bool: + """检查任务是否已取消""" + return _task_cancelled.get(task_uuid, False) + + def cancel_task(self, task_uuid: str): + """取消任务""" + _task_cancelled[task_uuid] = True + logger.info(f"任务 {task_uuid} 已标记为取消") + + def add_log(self, task_uuid: str, log_message: str): + """添加日志并推送到 WebSocket(线程安全)""" + # 先广播到 WebSocket,确保实时推送 + # 然后再添加到队列,这样 get_unsent_logs 不会获取到这条日志 + if self._loop and self._loop.is_running(): + try: + asyncio.run_coroutine_threadsafe( + self._broadcast_log(task_uuid, log_message), + self._loop + ) + except Exception as e: + logger.warning(f"推送日志到 WebSocket 失败: {e}") + + # 广播后再添加到队列 + with _log_locks[task_uuid]: + _log_queues[task_uuid].append(log_message) + + async def _broadcast_log(self, task_uuid: str, log_message: str): + """广播日志到所有 WebSocket 连接""" + with _ws_lock: + connections = _ws_connections.get(task_uuid, []).copy() + # 注意:不在这里更新 sent_index,因为日志已经通过 add_log 添加到队列 + # sent_index 应该只在 get_unsent_logs 或发送历史日志时更新 + # 这样可以避免竞态条件 + + for ws in connections: + try: + await ws.send_json({ + "type": "log", + "task_uuid": task_uuid, + "message": log_message, + "timestamp": datetime.utcnow().isoformat() + }) + # 发送成功后更新 sent_index + with _ws_lock: + ws_id = id(ws) + if task_uuid in _ws_sent_index and ws_id in _ws_sent_index[task_uuid]: + _ws_sent_index[task_uuid][ws_id] += 1 + except Exception as e: + logger.warning(f"WebSocket 发送失败: {e}") + + async def broadcast_status(self, task_uuid: str, status: str, **kwargs): + """广播任务状态更新""" + with _ws_lock: + connections = _ws_connections.get(task_uuid, []).copy() + + message = { + "type": "status", + "task_uuid": task_uuid, + "status": status, + "timestamp": datetime.utcnow().isoformat(), + **kwargs + } + + for ws in connections: + try: + await ws.send_json(message) + except Exception as e: + logger.warning(f"WebSocket 发送状态失败: {e}") + + def register_websocket(self, task_uuid: str, websocket): + """注册 WebSocket 连接""" + with _ws_lock: + if task_uuid not in _ws_connections: + _ws_connections[task_uuid] = [] + # 避免重复注册同一个连接 + if websocket not in _ws_connections[task_uuid]: + _ws_connections[task_uuid].append(websocket) + # 记录已发送的日志数量,用于发送历史日志时避免重复 + with _log_locks[task_uuid]: + _ws_sent_index[task_uuid][id(websocket)] = len(_log_queues.get(task_uuid, [])) + logger.info(f"WebSocket 连接已注册: {task_uuid}") + else: + logger.warning(f"WebSocket 连接已存在,跳过重复注册: {task_uuid}") + + def get_unsent_logs(self, task_uuid: str, websocket) -> List[str]: + """获取未发送给该 WebSocket 的日志""" + with _ws_lock: + ws_id = id(websocket) + sent_count = _ws_sent_index.get(task_uuid, {}).get(ws_id, 0) + + with _log_locks[task_uuid]: + all_logs = _log_queues.get(task_uuid, []) + unsent_logs = all_logs[sent_count:] + # 更新已发送索引 + _ws_sent_index[task_uuid][ws_id] = len(all_logs) + return unsent_logs + + def unregister_websocket(self, task_uuid: str, websocket): + """注销 WebSocket 连接""" + with _ws_lock: + if task_uuid in _ws_connections: + try: + _ws_connections[task_uuid].remove(websocket) + except ValueError: + pass + # 清理已发送索引 + if task_uuid in _ws_sent_index: + _ws_sent_index[task_uuid].pop(id(websocket), None) + logger.info(f"WebSocket 连接已注销: {task_uuid}") + + def get_logs(self, task_uuid: str) -> List[str]: + """获取任务的所有日志""" + with _log_locks[task_uuid]: + return _log_queues.get(task_uuid, []).copy() + + def update_status(self, task_uuid: str, status: str, **kwargs): + """更新任务状态""" + if task_uuid not in _task_status: + _task_status[task_uuid] = {} + + _task_status[task_uuid]["status"] = status + _task_status[task_uuid].update(kwargs) + + def get_status(self, task_uuid: str) -> Optional[dict]: + """获取任务状态""" + return _task_status.get(task_uuid) + + def cleanup_task(self, task_uuid: str): + """清理任务数据""" + # 保留日志队列一段时间,以便后续查询 + # 只清理取消标志 + if task_uuid in _task_cancelled: + del _task_cancelled[task_uuid] + + # ============== 批量任务管理 ============== + + def init_batch(self, batch_id: str, total: int): + """初始化批量任务""" + _batch_status[batch_id] = { + "status": "running", + "total": total, + "completed": 0, + "success": 0, + "failed": 0, + "skipped": 0, + "current_index": 0, + "finished": False + } + logger.info(f"批量任务 {batch_id} 已初始化,总数: {total}") + + def add_batch_log(self, batch_id: str, log_message: str): + """添加批量任务日志并推送""" + # 先广播到 WebSocket,确保实时推送 + if self._loop and self._loop.is_running(): + try: + asyncio.run_coroutine_threadsafe( + self._broadcast_batch_log(batch_id, log_message), + self._loop + ) + except Exception as e: + logger.warning(f"推送批量日志到 WebSocket 失败: {e}") + + # 广播后再添加到队列 + with _batch_locks[batch_id]: + _batch_logs[batch_id].append(log_message) + + async def _broadcast_batch_log(self, batch_id: str, log_message: str): + """广播批量任务日志""" + key = f"batch_{batch_id}" + with _ws_lock: + connections = _ws_connections.get(key, []).copy() + # 注意:不在这里更新 sent_index,避免竞态条件 + + for ws in connections: + try: + await ws.send_json({ + "type": "log", + "batch_id": batch_id, + "message": log_message, + "timestamp": datetime.utcnow().isoformat() + }) + # 发送成功后更新 sent_index + with _ws_lock: + ws_id = id(ws) + if key in _ws_sent_index and ws_id in _ws_sent_index[key]: + _ws_sent_index[key][ws_id] += 1 + except Exception as e: + logger.warning(f"WebSocket 发送批量日志失败: {e}") + + def update_batch_status(self, batch_id: str, **kwargs): + """更新批量任务状态""" + if batch_id not in _batch_status: + logger.warning(f"批量任务 {batch_id} 不存在") + return + + _batch_status[batch_id].update(kwargs) + + # 异步广播状态更新 + if self._loop and self._loop.is_running(): + try: + asyncio.run_coroutine_threadsafe( + self._broadcast_batch_status(batch_id), + self._loop + ) + except Exception as e: + logger.warning(f"广播批量状态失败: {e}") + + async def _broadcast_batch_status(self, batch_id: str): + """广播批量任务状态""" + with _ws_lock: + connections = _ws_connections.get(f"batch_{batch_id}", []).copy() + + status = _batch_status.get(batch_id, {}) + + for ws in connections: + try: + await ws.send_json({ + "type": "status", + "batch_id": batch_id, + "timestamp": datetime.utcnow().isoformat(), + **status + }) + except Exception as e: + logger.warning(f"WebSocket 发送批量状态失败: {e}") + + def get_batch_status(self, batch_id: str) -> Optional[dict]: + """获取批量任务状态""" + return _batch_status.get(batch_id) + + def get_batch_logs(self, batch_id: str) -> List[str]: + """获取批量任务日志""" + with _batch_locks[batch_id]: + return _batch_logs.get(batch_id, []).copy() + + def is_batch_cancelled(self, batch_id: str) -> bool: + """检查批量任务是否已取消""" + status = _batch_status.get(batch_id, {}) + return status.get("cancelled", False) + + def cancel_batch(self, batch_id: str): + """取消批量任务""" + if batch_id in _batch_status: + _batch_status[batch_id]["cancelled"] = True + _batch_status[batch_id]["status"] = "cancelling" + logger.info(f"批量任务 {batch_id} 已标记为取消") + + def register_batch_websocket(self, batch_id: str, websocket): + """注册批量任务 WebSocket 连接""" + key = f"batch_{batch_id}" + with _ws_lock: + if key not in _ws_connections: + _ws_connections[key] = [] + # 避免重复注册同一个连接 + if websocket not in _ws_connections[key]: + _ws_connections[key].append(websocket) + # 记录已发送的日志数量,用于发送历史日志时避免重复 + with _batch_locks[batch_id]: + _ws_sent_index[key][id(websocket)] = len(_batch_logs.get(batch_id, [])) + logger.info(f"批量任务 WebSocket 连接已注册: {batch_id}") + else: + logger.warning(f"批量任务 WebSocket 连接已存在,跳过重复注册: {batch_id}") + + def get_unsent_batch_logs(self, batch_id: str, websocket) -> List[str]: + """获取未发送给该 WebSocket 的批量任务日志""" + key = f"batch_{batch_id}" + with _ws_lock: + ws_id = id(websocket) + sent_count = _ws_sent_index.get(key, {}).get(ws_id, 0) + + with _batch_locks[batch_id]: + all_logs = _batch_logs.get(batch_id, []) + unsent_logs = all_logs[sent_count:] + # 更新已发送索引 + _ws_sent_index[key][ws_id] = len(all_logs) + return unsent_logs + + def unregister_batch_websocket(self, batch_id: str, websocket): + """注销批量任务 WebSocket 连接""" + key = f"batch_{batch_id}" + with _ws_lock: + if key in _ws_connections: + try: + _ws_connections[key].remove(websocket) + except ValueError: + pass + # 清理已发送索引 + if key in _ws_sent_index: + _ws_sent_index[key].pop(id(websocket), None) + logger.info(f"批量任务 WebSocket 连接已注销: {batch_id}") + + def create_log_callback(self, task_uuid: str) -> Callable[[str], None]: + """创建日志回调函数""" + def callback(msg: str): + self.add_log(task_uuid, msg) + return callback + + def create_check_cancelled_callback(self, task_uuid: str) -> Callable[[], bool]: + """创建检查取消的回调函数""" + def callback() -> bool: + return self.is_cancelled(task_uuid) + return callback + + +# 全局实例 +task_manager = TaskManager() diff --git a/static/js/settings.js b/static/js/settings.js index 2a545d0..648f798 100644 --- a/static/js/settings.js +++ b/static/js/settings.js @@ -42,7 +42,9 @@ const elements = { cpaForm: document.getElementById('cpa-form'), testCpaBtn: document.getElementById('test-cpa-btn'), // 验证码设置 - emailCodeForm: document.getElementById('email-code-form') + emailCodeForm: document.getElementById('email-code-form'), + // Outlook 设置 + outlookSettingsForm: document.getElementById('outlook-settings-form') }; // 选中的服务 ID @@ -213,6 +215,11 @@ function initEventListeners() { if (elements.emailCodeForm) { elements.emailCodeForm.addEventListener('submit', handleSaveEmailCode); } + + // Outlook 设置 + if (elements.outlookSettingsForm) { + elements.outlookSettingsForm.addEventListener('submit', handleSaveOutlookSettings); + } } // 加载设置 @@ -242,6 +249,8 @@ async function loadSettings() { // 加载 CPA 设置 loadCpaSettings(); + // 加载 Outlook 设置 + loadOutlookSettings(); } catch (error) { console.error('加载设置失败:', error); @@ -922,6 +931,35 @@ async function handleSaveCpa(e) { } } +// ============================================================================ +// Outlook 设置管理 +// ============================================================================ + +// 加载 Outlook 设置 +async function loadOutlookSettings() { + try { + const data = await api.get('/settings/outlook'); + const el = document.getElementById('outlook-default-client-id'); + if (el) el.value = data.default_client_id || ''; + } catch (error) { + console.error('加载 Outlook 设置失败:', error); + } +} + +// 保存 Outlook 设置 +async function handleSaveOutlookSettings(e) { + e.preventDefault(); + const data = { + default_client_id: document.getElementById('outlook-default-client-id').value + }; + try { + await api.post('/settings/outlook', data); + toast.success('Outlook 设置已保存'); + } catch (error) { + toast.error('保存失败: ' + error.message); + } +} + // 测试 CPA 连接 async function handleTestCpa() { const apiUrl = document.getElementById('cpa-api-url').value; diff --git a/templates/settings.html b/templates/settings.html index d2fbfd8..eec2a26 100644 --- a/templates/settings.html +++ b/templates/settings.html @@ -36,6 +36,7 @@