mirror of
https://github.com/cnlimiter/codex-register.git
synced 2026-05-15 04:18:51 +08:00
feat(app): 重构outlook邮箱服务
This commit is contained in:
@@ -332,4 +332,38 @@ TIME_CONSTANTS = {
|
||||
"HOUR": 3600,
|
||||
"DAY": 86400,
|
||||
"WEEK": 604800,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Microsoft/Outlook 相关常量
|
||||
# ============================================================================
|
||||
|
||||
# Microsoft OAuth2 Token 端点
|
||||
MICROSOFT_TOKEN_ENDPOINTS = {
|
||||
# 旧版 IMAP 使用的端点
|
||||
"LIVE": "https://login.live.com/oauth20_token.srf",
|
||||
# 新版 IMAP 使用的端点(需要特定 scope)
|
||||
"CONSUMERS": "https://login.microsoftonline.com/consumers/oauth2/v2.0/token",
|
||||
# Graph API 使用的端点
|
||||
"COMMON": "https://login.microsoftonline.com/common/oauth2/v2.0/token",
|
||||
}
|
||||
|
||||
# IMAP 服务器配置
|
||||
OUTLOOK_IMAP_SERVERS = {
|
||||
"OLD": "outlook.office365.com", # 旧版 IMAP
|
||||
"NEW": "outlook.live.com", # 新版 IMAP
|
||||
}
|
||||
|
||||
# Microsoft OAuth2 Scopes
|
||||
MICROSOFT_SCOPES = {
|
||||
# 旧版 IMAP 不需要特定 scope
|
||||
"IMAP_OLD": "",
|
||||
# 新版 IMAP 需要的 scope
|
||||
"IMAP_NEW": "https://outlook.office.com/IMAP.AccessAsUser.All offline_access",
|
||||
# Graph API 需要的 scope
|
||||
"GRAPH_API": "https://graph.microsoft.com/.default",
|
||||
}
|
||||
|
||||
# Outlook 提供者默认优先级
|
||||
OUTLOOK_PROVIDER_PRIORITY = ["imap_new", "imap_old", "graph_api"]
|
||||
@@ -4,7 +4,7 @@
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Optional, Dict, Any, Type
|
||||
from typing import Optional, Dict, Any, Type, List
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel, field_validator
|
||||
from pydantic.types import SecretStr
|
||||
@@ -297,6 +297,32 @@ SETTING_DEFINITIONS: Dict[str, SettingDefinition] = {
|
||||
category=SettingCategory.EMAIL,
|
||||
description="验证码轮询间隔(秒)"
|
||||
),
|
||||
|
||||
# Outlook 配置
|
||||
"outlook_provider_priority": SettingDefinition(
|
||||
db_key="outlook.provider_priority",
|
||||
default_value=["graph_api", "imap_new", "imap_old"],
|
||||
category=SettingCategory.EMAIL,
|
||||
description="Outlook 提供者优先级"
|
||||
),
|
||||
"outlook_health_failure_threshold": SettingDefinition(
|
||||
db_key="outlook.health_failure_threshold",
|
||||
default_value=5,
|
||||
category=SettingCategory.EMAIL,
|
||||
description="Outlook 提供者连续失败次数阈值"
|
||||
),
|
||||
"outlook_health_disable_duration": SettingDefinition(
|
||||
db_key="outlook.health_disable_duration",
|
||||
default_value=60,
|
||||
category=SettingCategory.EMAIL,
|
||||
description="Outlook 提供者禁用时长(秒)"
|
||||
),
|
||||
"outlook_default_client_id": SettingDefinition(
|
||||
db_key="outlook.default_client_id",
|
||||
default_value="24d9a0ed-8787-4584-883c-2fd79308940a",
|
||||
category=SettingCategory.EMAIL,
|
||||
description="Outlook OAuth 默认 Client ID"
|
||||
),
|
||||
}
|
||||
|
||||
# 属性名到数据库键名的映射(用于向后兼容)
|
||||
@@ -320,6 +346,9 @@ SETTING_TYPES: Dict[str, Type] = {
|
||||
"cpa_enabled": bool,
|
||||
"email_code_timeout": int,
|
||||
"email_code_poll_interval": int,
|
||||
"outlook_provider_priority": list,
|
||||
"outlook_health_failure_threshold": int,
|
||||
"outlook_health_disable_duration": int,
|
||||
}
|
||||
|
||||
# 需要作为 SecretStr 处理的字段
|
||||
@@ -346,6 +375,11 @@ def _convert_value(attr_name: str, value: str) -> Any:
|
||||
return value
|
||||
import json
|
||||
return json.loads(value) if value else {}
|
||||
elif target_type == list:
|
||||
if isinstance(value, list):
|
||||
return value
|
||||
import json
|
||||
return json.loads(value) if value else []
|
||||
else:
|
||||
return value
|
||||
|
||||
@@ -356,7 +390,7 @@ def _value_to_string(value: Any) -> str:
|
||||
return value.get_secret_value()
|
||||
elif isinstance(value, bool):
|
||||
return "true" if value else "false"
|
||||
elif isinstance(value, dict):
|
||||
elif isinstance(value, (dict, list)):
|
||||
import json
|
||||
return json.dumps(value)
|
||||
elif value is None:
|
||||
@@ -533,6 +567,12 @@ class Settings(BaseModel):
|
||||
email_code_timeout: int = 120
|
||||
email_code_poll_interval: int = 3
|
||||
|
||||
# Outlook 配置
|
||||
outlook_provider_priority: List[str] = ["graph_api", "imap_new", "imap_old"]
|
||||
outlook_health_failure_threshold: int = 5
|
||||
outlook_health_disable_duration: int = 60
|
||||
outlook_default_client_id: str = "24d9a0ed-8787-4584-883c-2fd79308940a"
|
||||
|
||||
|
||||
# 全局配置实例
|
||||
_settings: Optional[Settings] = None
|
||||
|
||||
@@ -19,14 +19,43 @@ EmailServiceFactory.register(EmailServiceType.TEMPMAIL, TempmailService)
|
||||
EmailServiceFactory.register(EmailServiceType.OUTLOOK, OutlookService)
|
||||
EmailServiceFactory.register(EmailServiceType.CUSTOM_DOMAIN, CustomDomainEmailService)
|
||||
|
||||
# 导出 Outlook 模块的额外内容
|
||||
from .outlook.base import (
|
||||
ProviderType,
|
||||
EmailMessage,
|
||||
TokenInfo,
|
||||
ProviderHealth,
|
||||
ProviderStatus,
|
||||
)
|
||||
from .outlook.account import OutlookAccount
|
||||
from .outlook.providers import (
|
||||
OutlookProvider,
|
||||
IMAPOldProvider,
|
||||
IMAPNewProvider,
|
||||
GraphAPIProvider,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# 基类
|
||||
'BaseEmailService',
|
||||
'EmailServiceError',
|
||||
'EmailServiceStatus',
|
||||
'EmailServiceFactory',
|
||||
'create_email_service',
|
||||
'EmailServiceType',
|
||||
# 服务类
|
||||
'TempmailService',
|
||||
'OutlookService',
|
||||
'CustomDomainEmailService',
|
||||
# Outlook 模块
|
||||
'ProviderType',
|
||||
'EmailMessage',
|
||||
'TokenInfo',
|
||||
'ProviderHealth',
|
||||
'ProviderStatus',
|
||||
'OutlookAccount',
|
||||
'OutlookProvider',
|
||||
'IMAPOldProvider',
|
||||
'IMAPNewProvider',
|
||||
'GraphAPIProvider',
|
||||
]
|
||||
8
src/services/outlook/__init__.py
Normal file
8
src/services/outlook/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""
|
||||
Outlook 邮箱服务模块
|
||||
支持多种 IMAP/API 连接方式,自动故障切换
|
||||
"""
|
||||
|
||||
from .service import OutlookService
|
||||
|
||||
__all__ = ['OutlookService']
|
||||
51
src/services/outlook/account.py
Normal file
51
src/services/outlook/account.py
Normal file
@@ -0,0 +1,51 @@
|
||||
"""
|
||||
Outlook 账户数据类
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutlookAccount:
|
||||
"""Outlook 账户信息"""
|
||||
email: str
|
||||
password: str = ""
|
||||
client_id: str = ""
|
||||
refresh_token: str = ""
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "OutlookAccount":
|
||||
"""从配置创建账户"""
|
||||
return cls(
|
||||
email=config.get("email", ""),
|
||||
password=config.get("password", ""),
|
||||
client_id=config.get("client_id", ""),
|
||||
refresh_token=config.get("refresh_token", "")
|
||||
)
|
||||
|
||||
def has_oauth(self) -> bool:
|
||||
"""是否支持 OAuth2"""
|
||||
return bool(self.client_id and self.refresh_token)
|
||||
|
||||
def validate(self) -> bool:
|
||||
"""验证账户信息是否有效"""
|
||||
return bool(self.email and self.password) or self.has_oauth()
|
||||
|
||||
def to_dict(self, include_sensitive: bool = False) -> Dict[str, Any]:
|
||||
"""转换为字典"""
|
||||
result = {
|
||||
"email": self.email,
|
||||
"has_oauth": self.has_oauth(),
|
||||
}
|
||||
if include_sensitive:
|
||||
result.update({
|
||||
"password": self.password,
|
||||
"client_id": self.client_id,
|
||||
"refresh_token": self.refresh_token[:20] + "..." if self.refresh_token else "",
|
||||
})
|
||||
return result
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""字符串表示"""
|
||||
return f"OutlookAccount({self.email})"
|
||||
153
src/services/outlook/base.py
Normal file
153
src/services/outlook/base.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""
|
||||
Outlook 服务基础定义
|
||||
包含枚举类型和数据类
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Optional, Dict, Any, List
|
||||
|
||||
|
||||
class ProviderType(str, Enum):
|
||||
"""Outlook 提供者类型"""
|
||||
IMAP_OLD = "imap_old" # 旧版 IMAP (outlook.office365.com)
|
||||
IMAP_NEW = "imap_new" # 新版 IMAP (outlook.live.com)
|
||||
GRAPH_API = "graph_api" # Microsoft Graph API
|
||||
|
||||
|
||||
class TokenEndpoint(str, Enum):
|
||||
"""Token 端点"""
|
||||
LIVE = "https://login.live.com/oauth20_token.srf"
|
||||
CONSUMERS = "https://login.microsoftonline.com/consumers/oauth2/v2.0/token"
|
||||
COMMON = "https://login.microsoftonline.com/common/oauth2/v2.0/token"
|
||||
|
||||
|
||||
class IMAPServer(str, Enum):
|
||||
"""IMAP 服务器"""
|
||||
OLD = "outlook.office365.com"
|
||||
NEW = "outlook.live.com"
|
||||
|
||||
|
||||
class ProviderStatus(str, Enum):
|
||||
"""提供者状态"""
|
||||
HEALTHY = "healthy" # 健康
|
||||
DEGRADED = "degraded" # 降级
|
||||
DISABLED = "disabled" # 禁用
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmailMessage:
|
||||
"""邮件消息数据类"""
|
||||
id: str # 消息 ID
|
||||
subject: str # 主题
|
||||
sender: str # 发件人
|
||||
recipients: List[str] = field(default_factory=list) # 收件人列表
|
||||
body: str = "" # 正文内容
|
||||
body_preview: str = "" # 正文预览
|
||||
received_at: Optional[datetime] = None # 接收时间
|
||||
received_timestamp: int = 0 # 接收时间戳
|
||||
is_read: bool = False # 是否已读
|
||||
has_attachments: bool = False # 是否有附件
|
||||
raw_data: Optional[bytes] = None # 原始数据(用于调试)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典"""
|
||||
return {
|
||||
"id": self.id,
|
||||
"subject": self.subject,
|
||||
"sender": self.sender,
|
||||
"recipients": self.recipients,
|
||||
"body": self.body,
|
||||
"body_preview": self.body_preview,
|
||||
"received_at": self.received_at.isoformat() if self.received_at else None,
|
||||
"received_timestamp": self.received_timestamp,
|
||||
"is_read": self.is_read,
|
||||
"has_attachments": self.has_attachments,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenInfo:
|
||||
"""Token 信息数据类"""
|
||||
access_token: str
|
||||
expires_at: float # 过期时间戳
|
||||
token_type: str = "Bearer"
|
||||
scope: str = ""
|
||||
refresh_token: Optional[str] = None
|
||||
|
||||
def is_expired(self, buffer_seconds: int = 120) -> bool:
|
||||
"""检查 Token 是否已过期"""
|
||||
import time
|
||||
return time.time() >= (self.expires_at - buffer_seconds)
|
||||
|
||||
@classmethod
|
||||
def from_response(cls, data: Dict[str, Any], scope: str = "") -> "TokenInfo":
|
||||
"""从 API 响应创建"""
|
||||
import time
|
||||
return cls(
|
||||
access_token=data.get("access_token", ""),
|
||||
expires_at=time.time() + data.get("expires_in", 3600),
|
||||
token_type=data.get("token_type", "Bearer"),
|
||||
scope=scope or data.get("scope", ""),
|
||||
refresh_token=data.get("refresh_token"),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProviderHealth:
|
||||
"""提供者健康状态"""
|
||||
provider_type: ProviderType
|
||||
status: ProviderStatus = ProviderStatus.HEALTHY
|
||||
failure_count: int = 0 # 连续失败次数
|
||||
last_success: Optional[datetime] = None # 最后成功时间
|
||||
last_failure: Optional[datetime] = None # 最后失败时间
|
||||
last_error: str = "" # 最后错误信息
|
||||
disabled_until: Optional[datetime] = None # 禁用截止时间
|
||||
|
||||
def record_success(self):
|
||||
"""记录成功"""
|
||||
self.status = ProviderStatus.HEALTHY
|
||||
self.failure_count = 0
|
||||
self.last_success = datetime.now()
|
||||
self.disabled_until = None
|
||||
|
||||
def record_failure(self, error: str):
|
||||
"""记录失败"""
|
||||
self.failure_count += 1
|
||||
self.last_failure = datetime.now()
|
||||
self.last_error = error
|
||||
|
||||
def should_disable(self, threshold: int = 3) -> bool:
|
||||
"""判断是否应该禁用"""
|
||||
return self.failure_count >= threshold
|
||||
|
||||
def is_disabled(self) -> bool:
|
||||
"""检查是否被禁用"""
|
||||
if self.disabled_until and datetime.now() < self.disabled_until:
|
||||
return True
|
||||
return False
|
||||
|
||||
def disable(self, duration_seconds: int = 300):
|
||||
"""禁用提供者"""
|
||||
from datetime import timedelta
|
||||
self.status = ProviderStatus.DISABLED
|
||||
self.disabled_until = datetime.now() + timedelta(seconds=duration_seconds)
|
||||
|
||||
def enable(self):
|
||||
"""启用提供者"""
|
||||
self.status = ProviderStatus.HEALTHY
|
||||
self.disabled_until = None
|
||||
self.failure_count = 0
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典"""
|
||||
return {
|
||||
"provider_type": self.provider_type.value,
|
||||
"status": self.status.value,
|
||||
"failure_count": self.failure_count,
|
||||
"last_success": self.last_success.isoformat() if self.last_success else None,
|
||||
"last_failure": self.last_failure.isoformat() if self.last_failure else None,
|
||||
"last_error": self.last_error,
|
||||
"disabled_until": self.disabled_until.isoformat() if self.disabled_until else None,
|
||||
}
|
||||
228
src/services/outlook/email_parser.py
Normal file
228
src/services/outlook/email_parser.py
Normal file
@@ -0,0 +1,228 @@
|
||||
"""
|
||||
邮件解析和验证码提取
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Optional, List, Dict, Any
|
||||
|
||||
from ...config.constants import (
|
||||
OTP_CODE_SIMPLE_PATTERN,
|
||||
OTP_CODE_SEMANTIC_PATTERN,
|
||||
OPENAI_EMAIL_SENDERS,
|
||||
OPENAI_VERIFICATION_KEYWORDS,
|
||||
)
|
||||
from .base import EmailMessage
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmailParser:
|
||||
"""
|
||||
邮件解析器
|
||||
用于识别 OpenAI 验证邮件并提取验证码
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# 编译正则表达式
|
||||
self._simple_pattern = re.compile(OTP_CODE_SIMPLE_PATTERN)
|
||||
self._semantic_pattern = re.compile(OTP_CODE_SEMANTIC_PATTERN, re.IGNORECASE)
|
||||
|
||||
def is_openai_verification_email(
|
||||
self,
|
||||
email: EmailMessage,
|
||||
target_email: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
判断是否为 OpenAI 验证邮件
|
||||
|
||||
Args:
|
||||
email: 邮件对象
|
||||
target_email: 目标邮箱地址(用于验证收件人)
|
||||
|
||||
Returns:
|
||||
是否为 OpenAI 验证邮件
|
||||
"""
|
||||
sender = email.sender.lower()
|
||||
|
||||
# 1. 发件人必须是 OpenAI
|
||||
if not any(s in sender for s in OPENAI_EMAIL_SENDERS):
|
||||
logger.debug(f"邮件发件人非 OpenAI: {sender}")
|
||||
return False
|
||||
|
||||
# 2. 主题或正文包含验证关键词
|
||||
subject = email.subject.lower()
|
||||
body = email.body.lower()
|
||||
combined = f"{subject} {body}"
|
||||
|
||||
if not any(kw in combined for kw in OPENAI_VERIFICATION_KEYWORDS):
|
||||
logger.debug(f"邮件未包含验证关键词: {subject[:50]}")
|
||||
return False
|
||||
|
||||
# 3. 收件人检查已移除:别名邮件的 IMAP 头中收件人可能不匹配,只靠发件人+关键词判断
|
||||
logger.debug(f"识别为 OpenAI 验证邮件: {subject[:50]}")
|
||||
return True
|
||||
|
||||
def extract_verification_code(
|
||||
self,
|
||||
email: EmailMessage,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
从邮件中提取验证码
|
||||
|
||||
优先级:
|
||||
1. 从主题提取(6位数字)
|
||||
2. 从正文用语义正则提取(如 "code is 123456")
|
||||
3. 兜底:任意 6 位数字
|
||||
|
||||
Args:
|
||||
email: 邮件对象
|
||||
|
||||
Returns:
|
||||
验证码字符串,如果未找到返回 None
|
||||
"""
|
||||
# 1. 主题优先
|
||||
code = self._extract_from_subject(email.subject)
|
||||
if code:
|
||||
logger.debug(f"从主题提取验证码: {code}")
|
||||
return code
|
||||
|
||||
# 2. 正文语义匹配
|
||||
code = self._extract_semantic(email.body)
|
||||
if code:
|
||||
logger.debug(f"从正文语义提取验证码: {code}")
|
||||
return code
|
||||
|
||||
# 3. 兜底:正文任意 6 位数字
|
||||
code = self._extract_simple(email.body)
|
||||
if code:
|
||||
logger.debug(f"从正文兜底提取验证码: {code}")
|
||||
return code
|
||||
|
||||
return None
|
||||
|
||||
def _extract_from_subject(self, subject: str) -> Optional[str]:
|
||||
"""从主题提取验证码"""
|
||||
match = self._simple_pattern.search(subject)
|
||||
if match:
|
||||
return match.group(1)
|
||||
return None
|
||||
|
||||
def _extract_semantic(self, body: str) -> Optional[str]:
|
||||
"""语义匹配提取验证码"""
|
||||
match = self._semantic_pattern.search(body)
|
||||
if match:
|
||||
return match.group(1)
|
||||
return None
|
||||
|
||||
def _extract_simple(self, body: str) -> Optional[str]:
|
||||
"""简单匹配提取验证码"""
|
||||
match = self._simple_pattern.search(body)
|
||||
if match:
|
||||
return match.group(1)
|
||||
return None
|
||||
|
||||
def find_verification_code_in_emails(
|
||||
self,
|
||||
emails: List[EmailMessage],
|
||||
target_email: Optional[str] = None,
|
||||
min_timestamp: int = 0,
|
||||
used_codes: Optional[set] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
从邮件列表中查找验证码
|
||||
|
||||
Args:
|
||||
emails: 邮件列表
|
||||
target_email: 目标邮箱地址
|
||||
min_timestamp: 最小时间戳(用于过滤旧邮件)
|
||||
used_codes: 已使用的验证码集合(用于去重)
|
||||
|
||||
Returns:
|
||||
验证码字符串,如果未找到返回 None
|
||||
"""
|
||||
used_codes = used_codes or set()
|
||||
|
||||
for email in emails:
|
||||
# 时间戳过滤
|
||||
if min_timestamp > 0 and email.received_timestamp > 0:
|
||||
if email.received_timestamp < min_timestamp:
|
||||
logger.debug(f"跳过旧邮件: {email.subject[:50]}")
|
||||
continue
|
||||
|
||||
# 检查是否是 OpenAI 验证邮件
|
||||
if not self.is_openai_verification_email(email, target_email):
|
||||
continue
|
||||
|
||||
# 提取验证码
|
||||
code = self.extract_verification_code(email)
|
||||
if code:
|
||||
# 去重检查
|
||||
if code in used_codes:
|
||||
logger.debug(f"跳过已使用的验证码: {code}")
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
f"[{target_email or 'unknown'}] 找到验证码: {code}, "
|
||||
f"邮件主题: {email.subject[:30]}"
|
||||
)
|
||||
return code
|
||||
|
||||
return None
|
||||
|
||||
def filter_emails_by_sender(
|
||||
self,
|
||||
emails: List[EmailMessage],
|
||||
sender_patterns: List[str],
|
||||
) -> List[EmailMessage]:
|
||||
"""
|
||||
按发件人过滤邮件
|
||||
|
||||
Args:
|
||||
emails: 邮件列表
|
||||
sender_patterns: 发件人匹配模式列表
|
||||
|
||||
Returns:
|
||||
过滤后的邮件列表
|
||||
"""
|
||||
filtered = []
|
||||
for email in emails:
|
||||
sender = email.sender.lower()
|
||||
if any(pattern.lower() in sender for pattern in sender_patterns):
|
||||
filtered.append(email)
|
||||
return filtered
|
||||
|
||||
def filter_emails_by_subject(
|
||||
self,
|
||||
emails: List[EmailMessage],
|
||||
keywords: List[str],
|
||||
) -> List[EmailMessage]:
|
||||
"""
|
||||
按主题关键词过滤邮件
|
||||
|
||||
Args:
|
||||
emails: 邮件列表
|
||||
keywords: 关键词列表
|
||||
|
||||
Returns:
|
||||
过滤后的邮件列表
|
||||
"""
|
||||
filtered = []
|
||||
for email in emails:
|
||||
subject = email.subject.lower()
|
||||
if any(kw.lower() in subject for kw in keywords):
|
||||
filtered.append(email)
|
||||
return filtered
|
||||
|
||||
|
||||
# 全局解析器实例
|
||||
_parser: Optional[EmailParser] = None
|
||||
|
||||
|
||||
def get_email_parser() -> EmailParser:
|
||||
"""获取全局邮件解析器实例"""
|
||||
global _parser
|
||||
if _parser is None:
|
||||
_parser = EmailParser()
|
||||
return _parser
|
||||
312
src/services/outlook/health_checker.py
Normal file
312
src/services/outlook/health_checker.py
Normal file
@@ -0,0 +1,312 @@
|
||||
"""
|
||||
健康检查和故障切换管理
|
||||
"""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any
|
||||
|
||||
from .base import ProviderType, ProviderHealth, ProviderStatus
|
||||
from .providers.base import OutlookProvider
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HealthChecker:
|
||||
"""
|
||||
健康检查管理器
|
||||
跟踪各提供者的健康状态,管理故障切换
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
failure_threshold: int = 3,
|
||||
disable_duration: int = 300,
|
||||
recovery_check_interval: int = 60,
|
||||
):
|
||||
"""
|
||||
初始化健康检查器
|
||||
|
||||
Args:
|
||||
failure_threshold: 连续失败次数阈值,超过后禁用
|
||||
disable_duration: 禁用时长(秒)
|
||||
recovery_check_interval: 恢复检查间隔(秒)
|
||||
"""
|
||||
self.failure_threshold = failure_threshold
|
||||
self.disable_duration = disable_duration
|
||||
self.recovery_check_interval = recovery_check_interval
|
||||
|
||||
# 提供者健康状态: ProviderType -> ProviderHealth
|
||||
self._health_status: Dict[ProviderType, ProviderHealth] = {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
# 初始化所有提供者的健康状态
|
||||
for provider_type in ProviderType:
|
||||
self._health_status[provider_type] = ProviderHealth(
|
||||
provider_type=provider_type
|
||||
)
|
||||
|
||||
def get_health(self, provider_type: ProviderType) -> ProviderHealth:
|
||||
"""获取提供者的健康状态"""
|
||||
with self._lock:
|
||||
return self._health_status.get(provider_type, ProviderHealth(provider_type=provider_type))
|
||||
|
||||
def record_success(self, provider_type: ProviderType):
|
||||
"""记录成功操作"""
|
||||
with self._lock:
|
||||
health = self._health_status.get(provider_type)
|
||||
if health:
|
||||
health.record_success()
|
||||
logger.debug(f"{provider_type.value} 记录成功")
|
||||
|
||||
def record_failure(self, provider_type: ProviderType, error: str):
|
||||
"""记录失败操作"""
|
||||
with self._lock:
|
||||
health = self._health_status.get(provider_type)
|
||||
if health:
|
||||
health.record_failure(error)
|
||||
|
||||
# 检查是否需要禁用
|
||||
if health.should_disable(self.failure_threshold):
|
||||
health.disable(self.disable_duration)
|
||||
logger.warning(
|
||||
f"{provider_type.value} 已禁用 {self.disable_duration} 秒,"
|
||||
f"原因: {error}"
|
||||
)
|
||||
|
||||
def is_available(self, provider_type: ProviderType) -> bool:
|
||||
"""
|
||||
检查提供者是否可用
|
||||
|
||||
Args:
|
||||
provider_type: 提供者类型
|
||||
|
||||
Returns:
|
||||
是否可用
|
||||
"""
|
||||
health = self.get_health(provider_type)
|
||||
|
||||
# 检查是否被禁用
|
||||
if health.is_disabled():
|
||||
remaining = (health.disabled_until - datetime.now()).total_seconds()
|
||||
logger.debug(
|
||||
f"{provider_type.value} 已被禁用,剩余 {int(remaining)} 秒"
|
||||
)
|
||||
return False
|
||||
|
||||
return health.status != ProviderStatus.DISABLED
|
||||
|
||||
def get_available_providers(
|
||||
self,
|
||||
priority_order: Optional[List[ProviderType]] = None,
|
||||
) -> List[ProviderType]:
|
||||
"""
|
||||
获取可用的提供者列表
|
||||
|
||||
Args:
|
||||
priority_order: 优先级顺序,默认为 [IMAP_NEW, IMAP_OLD, GRAPH_API]
|
||||
|
||||
Returns:
|
||||
可用的提供者列表
|
||||
"""
|
||||
if priority_order is None:
|
||||
priority_order = [
|
||||
ProviderType.IMAP_NEW,
|
||||
ProviderType.IMAP_OLD,
|
||||
ProviderType.GRAPH_API,
|
||||
]
|
||||
|
||||
available = []
|
||||
for provider_type in priority_order:
|
||||
if self.is_available(provider_type):
|
||||
available.append(provider_type)
|
||||
|
||||
return available
|
||||
|
||||
def get_next_available_provider(
|
||||
self,
|
||||
priority_order: Optional[List[ProviderType]] = None,
|
||||
) -> Optional[ProviderType]:
|
||||
"""
|
||||
获取下一个可用的提供者
|
||||
|
||||
Args:
|
||||
priority_order: 优先级顺序
|
||||
|
||||
Returns:
|
||||
可用的提供者类型,如果没有返回 None
|
||||
"""
|
||||
available = self.get_available_providers(priority_order)
|
||||
return available[0] if available else None
|
||||
|
||||
def force_disable(self, provider_type: ProviderType, duration: Optional[int] = None):
|
||||
"""
|
||||
强制禁用提供者
|
||||
|
||||
Args:
|
||||
provider_type: 提供者类型
|
||||
duration: 禁用时长(秒),默认使用配置值
|
||||
"""
|
||||
with self._lock:
|
||||
health = self._health_status.get(provider_type)
|
||||
if health:
|
||||
health.disable(duration or self.disable_duration)
|
||||
logger.warning(f"{provider_type.value} 已强制禁用")
|
||||
|
||||
def force_enable(self, provider_type: ProviderType):
|
||||
"""
|
||||
强制启用提供者
|
||||
|
||||
Args:
|
||||
provider_type: 提供者类型
|
||||
"""
|
||||
with self._lock:
|
||||
health = self._health_status.get(provider_type)
|
||||
if health:
|
||||
health.enable()
|
||||
logger.info(f"{provider_type.value} 已启用")
|
||||
|
||||
def get_all_health_status(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取所有提供者的健康状态
|
||||
|
||||
Returns:
|
||||
健康状态字典
|
||||
"""
|
||||
with self._lock:
|
||||
return {
|
||||
provider_type.value: health.to_dict()
|
||||
for provider_type, health in self._health_status.items()
|
||||
}
|
||||
|
||||
def check_and_recover(self):
|
||||
"""
|
||||
检查并恢复被禁用的提供者
|
||||
|
||||
如果禁用时间已过,自动恢复提供者
|
||||
"""
|
||||
with self._lock:
|
||||
for provider_type, health in self._health_status.items():
|
||||
if health.is_disabled():
|
||||
# 检查是否可以恢复
|
||||
if health.disabled_until and datetime.now() >= health.disabled_until:
|
||||
health.enable()
|
||||
logger.info(f"{provider_type.value} 已自动恢复")
|
||||
|
||||
def reset_all(self):
|
||||
"""重置所有提供者的健康状态"""
|
||||
with self._lock:
|
||||
for provider_type in ProviderType:
|
||||
self._health_status[provider_type] = ProviderHealth(
|
||||
provider_type=provider_type
|
||||
)
|
||||
logger.info("已重置所有提供者的健康状态")
|
||||
|
||||
|
||||
class FailoverManager:
|
||||
"""
|
||||
故障切换管理器
|
||||
管理提供者之间的自动切换
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
health_checker: HealthChecker,
|
||||
priority_order: Optional[List[ProviderType]] = None,
|
||||
):
|
||||
"""
|
||||
初始化故障切换管理器
|
||||
|
||||
Args:
|
||||
health_checker: 健康检查器
|
||||
priority_order: 提供者优先级顺序
|
||||
"""
|
||||
self.health_checker = health_checker
|
||||
self.priority_order = priority_order or [
|
||||
ProviderType.IMAP_NEW,
|
||||
ProviderType.IMAP_OLD,
|
||||
ProviderType.GRAPH_API,
|
||||
]
|
||||
|
||||
# 当前使用的提供者索引
|
||||
self._current_index = 0
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def get_current_provider(self) -> Optional[ProviderType]:
|
||||
"""
|
||||
获取当前提供者
|
||||
|
||||
Returns:
|
||||
当前提供者类型,如果没有可用的返回 None
|
||||
"""
|
||||
available = self.health_checker.get_available_providers(self.priority_order)
|
||||
if not available:
|
||||
return None
|
||||
|
||||
with self._lock:
|
||||
# 尝试使用当前索引
|
||||
if self._current_index < len(available):
|
||||
return available[self._current_index]
|
||||
return available[0]
|
||||
|
||||
def switch_to_next(self) -> Optional[ProviderType]:
|
||||
"""
|
||||
切换到下一个提供者
|
||||
|
||||
Returns:
|
||||
下一个提供者类型,如果没有可用的返回 None
|
||||
"""
|
||||
available = self.health_checker.get_available_providers(self.priority_order)
|
||||
if not available:
|
||||
return None
|
||||
|
||||
with self._lock:
|
||||
self._current_index = (self._current_index + 1) % len(available)
|
||||
next_provider = available[self._current_index]
|
||||
logger.info(f"切换到提供者: {next_provider.value}")
|
||||
return next_provider
|
||||
|
||||
def on_provider_success(self, provider_type: ProviderType):
|
||||
"""
|
||||
提供者成功时调用
|
||||
|
||||
Args:
|
||||
provider_type: 提供者类型
|
||||
"""
|
||||
self.health_checker.record_success(provider_type)
|
||||
|
||||
# 重置索引到成功的提供者
|
||||
with self._lock:
|
||||
available = self.health_checker.get_available_providers(self.priority_order)
|
||||
if provider_type in available:
|
||||
self._current_index = available.index(provider_type)
|
||||
|
||||
def on_provider_failure(self, provider_type: ProviderType, error: str):
|
||||
"""
|
||||
提供者失败时调用
|
||||
|
||||
Args:
|
||||
provider_type: 提供者类型
|
||||
error: 错误信息
|
||||
"""
|
||||
self.health_checker.record_failure(provider_type, error)
|
||||
|
||||
def get_status(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取故障切换状态
|
||||
|
||||
Returns:
|
||||
状态字典
|
||||
"""
|
||||
current = self.get_current_provider()
|
||||
return {
|
||||
"current_provider": current.value if current else None,
|
||||
"priority_order": [p.value for p in self.priority_order],
|
||||
"available_providers": [
|
||||
p.value for p in self.health_checker.get_available_providers(self.priority_order)
|
||||
],
|
||||
"health_status": self.health_checker.get_all_health_status(),
|
||||
}
|
||||
29
src/services/outlook/providers/__init__.py
Normal file
29
src/services/outlook/providers/__init__.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""
|
||||
Outlook 提供者模块
|
||||
"""
|
||||
|
||||
from .base import OutlookProvider, ProviderConfig
|
||||
from .imap_old import IMAPOldProvider
|
||||
from .imap_new import IMAPNewProvider
|
||||
from .graph_api import GraphAPIProvider
|
||||
|
||||
__all__ = [
|
||||
'OutlookProvider',
|
||||
'ProviderConfig',
|
||||
'IMAPOldProvider',
|
||||
'IMAPNewProvider',
|
||||
'GraphAPIProvider',
|
||||
]
|
||||
|
||||
|
||||
# 提供者注册表
|
||||
PROVIDER_REGISTRY = {
|
||||
'imap_old': IMAPOldProvider,
|
||||
'imap_new': IMAPNewProvider,
|
||||
'graph_api': GraphAPIProvider,
|
||||
}
|
||||
|
||||
|
||||
def get_provider_class(provider_type: str):
|
||||
"""获取提供者类"""
|
||||
return PROVIDER_REGISTRY.get(provider_type)
|
||||
180
src/services/outlook/providers/base.py
Normal file
180
src/services/outlook/providers/base.py
Normal file
@@ -0,0 +1,180 @@
|
||||
"""
|
||||
Outlook 提供者抽象基类
|
||||
"""
|
||||
|
||||
import abc
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
from ..base import ProviderType, EmailMessage, ProviderHealth, ProviderStatus
|
||||
from ..account import OutlookAccount
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProviderConfig:
|
||||
"""提供者配置"""
|
||||
timeout: int = 30
|
||||
max_retries: int = 3
|
||||
proxy_url: Optional[str] = None
|
||||
|
||||
# 健康检查配置
|
||||
health_failure_threshold: int = 3
|
||||
health_disable_duration: int = 300 # 秒
|
||||
|
||||
|
||||
class OutlookProvider(abc.ABC):
|
||||
"""
|
||||
Outlook 提供者抽象基类
|
||||
定义所有提供者必须实现的接口
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
account: OutlookAccount,
|
||||
config: Optional[ProviderConfig] = None,
|
||||
):
|
||||
"""
|
||||
初始化提供者
|
||||
|
||||
Args:
|
||||
account: Outlook 账户
|
||||
config: 提供者配置
|
||||
"""
|
||||
self.account = account
|
||||
self.config = config or ProviderConfig()
|
||||
|
||||
# 健康状态
|
||||
self._health = ProviderHealth(provider_type=self.provider_type)
|
||||
|
||||
# 连接状态
|
||||
self._connected = False
|
||||
self._last_error: Optional[str] = None
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def provider_type(self) -> ProviderType:
|
||||
"""获取提供者类型"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def health(self) -> ProviderHealth:
|
||||
"""获取健康状态"""
|
||||
return self._health
|
||||
|
||||
@property
|
||||
def is_healthy(self) -> bool:
|
||||
"""检查是否健康"""
|
||||
return (
|
||||
self._health.status == ProviderStatus.HEALTHY
|
||||
and not self._health.is_disabled()
|
||||
)
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""检查是否已连接"""
|
||||
return self._connected
|
||||
|
||||
@abc.abstractmethod
|
||||
def connect(self) -> bool:
|
||||
"""
|
||||
连接到服务
|
||||
|
||||
Returns:
|
||||
是否连接成功
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def disconnect(self):
|
||||
"""断开连接"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_recent_emails(
|
||||
self,
|
||||
count: int = 20,
|
||||
only_unseen: bool = True,
|
||||
) -> List[EmailMessage]:
|
||||
"""
|
||||
获取最近的邮件
|
||||
|
||||
Args:
|
||||
count: 获取数量
|
||||
only_unseen: 是否只获取未读
|
||||
|
||||
Returns:
|
||||
邮件列表
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def test_connection(self) -> bool:
|
||||
"""
|
||||
测试连接是否正常
|
||||
|
||||
Returns:
|
||||
连接是否正常
|
||||
"""
|
||||
pass
|
||||
|
||||
def record_success(self):
|
||||
"""记录成功操作"""
|
||||
self._health.record_success()
|
||||
self._last_error = None
|
||||
logger.debug(f"[{self.account.email}] {self.provider_type.value} 操作成功")
|
||||
|
||||
def record_failure(self, error: str):
|
||||
"""记录失败操作"""
|
||||
self._health.record_failure(error)
|
||||
self._last_error = error
|
||||
|
||||
# 检查是否需要禁用
|
||||
if self._health.should_disable(self.config.health_failure_threshold):
|
||||
self._health.disable(self.config.health_disable_duration)
|
||||
logger.warning(
|
||||
f"[{self.account.email}] {self.provider_type.value} 已禁用 "
|
||||
f"{self.config.health_disable_duration} 秒,原因: {error}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"[{self.account.email}] {self.provider_type.value} 操作失败 "
|
||||
f"({self._health.failure_count}/{self.config.health_failure_threshold}): {error}"
|
||||
)
|
||||
|
||||
def check_health(self) -> bool:
|
||||
"""
|
||||
检查健康状态
|
||||
|
||||
Returns:
|
||||
是否健康可用
|
||||
"""
|
||||
# 检查是否被禁用
|
||||
if self._health.is_disabled():
|
||||
logger.debug(
|
||||
f"[{self.account.email}] {self.provider_type.value} 已被禁用,"
|
||||
f"将在 {self._health.disabled_until} 后恢复"
|
||||
)
|
||||
return False
|
||||
|
||||
return self._health.status in (ProviderStatus.HEALTHY, ProviderStatus.DEGRADED)
|
||||
|
||||
def __enter__(self):
|
||||
"""上下文管理器入口"""
|
||||
self.connect()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""上下文管理器出口"""
|
||||
self.disconnect()
|
||||
return False
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""字符串表示"""
|
||||
return f"{self.__class__.__name__}({self.account.email})"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.__str__()
|
||||
249
src/services/outlook/providers/graph_api.py
Normal file
249
src/services/outlook/providers/graph_api.py
Normal file
@@ -0,0 +1,249 @@
|
||||
"""
|
||||
Graph API 提供者
|
||||
使用 Microsoft Graph REST API
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from curl_cffi import requests as _requests
|
||||
|
||||
from ..base import ProviderType, EmailMessage
|
||||
from ..account import OutlookAccount
|
||||
from ..token_manager import TokenManager
|
||||
from .base import OutlookProvider, ProviderConfig
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GraphAPIProvider(OutlookProvider):
|
||||
"""
|
||||
Graph API 提供者
|
||||
使用 Microsoft Graph REST API 获取邮件
|
||||
需要 graph.microsoft.com/.default scope
|
||||
"""
|
||||
|
||||
# Graph API 端点
|
||||
GRAPH_API_BASE = "https://graph.microsoft.com/v1.0"
|
||||
MESSAGES_ENDPOINT = "/me/mailFolders/inbox/messages"
|
||||
|
||||
@property
|
||||
def provider_type(self) -> ProviderType:
|
||||
return ProviderType.GRAPH_API
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
account: OutlookAccount,
|
||||
config: Optional[ProviderConfig] = None,
|
||||
):
|
||||
super().__init__(account, config)
|
||||
|
||||
# Token 管理器
|
||||
self._token_manager: Optional[TokenManager] = None
|
||||
|
||||
# 注意:Graph API 必须使用 OAuth2
|
||||
if not account.has_oauth():
|
||||
logger.warning(
|
||||
f"[{self.account.email}] Graph API 提供者需要 OAuth2 配置 "
|
||||
f"(client_id + refresh_token)"
|
||||
)
|
||||
|
||||
def connect(self) -> bool:
|
||||
"""
|
||||
验证连接(获取 Token)
|
||||
|
||||
Returns:
|
||||
是否连接成功
|
||||
"""
|
||||
if not self.account.has_oauth():
|
||||
error = "Graph API 需要 OAuth2 配置"
|
||||
self.record_failure(error)
|
||||
logger.error(f"[{self.account.email}] {error}")
|
||||
return False
|
||||
|
||||
if not self._token_manager:
|
||||
self._token_manager = TokenManager(
|
||||
self.account,
|
||||
ProviderType.GRAPH_API,
|
||||
self.config.proxy_url,
|
||||
self.config.timeout,
|
||||
)
|
||||
|
||||
# 尝试获取 Token
|
||||
token = self._token_manager.get_access_token()
|
||||
if token:
|
||||
self._connected = True
|
||||
self.record_success()
|
||||
logger.info(f"[{self.account.email}] Graph API 连接成功")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def disconnect(self):
|
||||
"""断开连接(清除状态)"""
|
||||
self._connected = False
|
||||
|
||||
def get_recent_emails(
|
||||
self,
|
||||
count: int = 20,
|
||||
only_unseen: bool = True,
|
||||
) -> List[EmailMessage]:
|
||||
"""
|
||||
获取最近的邮件
|
||||
|
||||
Args:
|
||||
count: 获取数量
|
||||
only_unseen: 是否只获取未读
|
||||
|
||||
Returns:
|
||||
邮件列表
|
||||
"""
|
||||
if not self._connected:
|
||||
if not self.connect():
|
||||
return []
|
||||
|
||||
try:
|
||||
# 获取 Access Token
|
||||
token = self._token_manager.get_access_token()
|
||||
if not token:
|
||||
self.record_failure("无法获取 Access Token")
|
||||
return []
|
||||
|
||||
# 构建 API 请求
|
||||
url = f"{self.GRAPH_API_BASE}{self.MESSAGES_ENDPOINT}"
|
||||
|
||||
params = {
|
||||
"$top": count,
|
||||
"$select": "id,subject,from,toRecipients,receivedDateTime,isRead,hasAttachments,bodyPreview,body",
|
||||
"$orderby": "receivedDateTime desc",
|
||||
}
|
||||
|
||||
# 只获取未读邮件
|
||||
if only_unseen:
|
||||
params["$filter"] = "isRead eq false"
|
||||
|
||||
# 构建代理配置
|
||||
proxies = None
|
||||
if self.config.proxy_url:
|
||||
proxies = {"http": self.config.proxy_url, "https": self.config.proxy_url}
|
||||
|
||||
# 发送请求(curl_cffi 自动对 params 进行 URL 编码)
|
||||
resp = _requests.get(
|
||||
url,
|
||||
params=params,
|
||||
headers={
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Accept": "application/json",
|
||||
"Prefer": "outlook.body-content-type='text'",
|
||||
},
|
||||
proxies=proxies,
|
||||
timeout=self.config.timeout,
|
||||
impersonate="chrome110",
|
||||
)
|
||||
|
||||
if resp.status_code == 401:
|
||||
# Token 失效,清除缓存
|
||||
if self._token_manager:
|
||||
self._token_manager.clear_cache()
|
||||
self.record_failure(f"HTTP 401: Token 失效")
|
||||
logger.error(f"[{self.account.email}] Graph API Token 失效")
|
||||
return []
|
||||
|
||||
if resp.status_code != 200:
|
||||
error_body = resp.text[:200]
|
||||
self.record_failure(f"HTTP {resp.status_code}: {error_body}")
|
||||
logger.error(f"[{self.account.email}] Graph API 请求失败: HTTP {resp.status_code}")
|
||||
return []
|
||||
|
||||
data = resp.json()
|
||||
|
||||
# 解析邮件
|
||||
messages = data.get("value", [])
|
||||
emails = []
|
||||
|
||||
for msg in messages:
|
||||
try:
|
||||
email_msg = self._parse_graph_message(msg)
|
||||
if email_msg:
|
||||
emails.append(email_msg)
|
||||
except Exception as e:
|
||||
logger.warning(f"[{self.account.email}] 解析 Graph API 邮件失败: {e}")
|
||||
|
||||
self.record_success()
|
||||
return emails
|
||||
|
||||
except Exception as e:
|
||||
self.record_failure(str(e))
|
||||
logger.error(f"[{self.account.email}] Graph API 获取邮件失败: {e}")
|
||||
return []
|
||||
|
||||
def _parse_graph_message(self, msg: dict) -> Optional[EmailMessage]:
|
||||
"""
|
||||
解析 Graph API 消息
|
||||
|
||||
Args:
|
||||
msg: Graph API 消息对象
|
||||
|
||||
Returns:
|
||||
EmailMessage 对象
|
||||
"""
|
||||
# 解析发件人
|
||||
from_info = msg.get("from", {})
|
||||
sender_info = from_info.get("emailAddress", {})
|
||||
sender = sender_info.get("address", "")
|
||||
|
||||
# 解析收件人
|
||||
recipients = []
|
||||
for recipient in msg.get("toRecipients", []):
|
||||
addr_info = recipient.get("emailAddress", {})
|
||||
addr = addr_info.get("address", "")
|
||||
if addr:
|
||||
recipients.append(addr)
|
||||
|
||||
# 解析日期
|
||||
received_at = None
|
||||
received_timestamp = 0
|
||||
try:
|
||||
date_str = msg.get("receivedDateTime", "")
|
||||
if date_str:
|
||||
# ISO 8601 格式
|
||||
received_at = datetime.fromisoformat(date_str.replace("Z", "+00:00"))
|
||||
received_timestamp = int(received_at.timestamp())
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 获取正文
|
||||
body_info = msg.get("body", {})
|
||||
body = body_info.get("content", "")
|
||||
body_preview = msg.get("bodyPreview", "")
|
||||
|
||||
return EmailMessage(
|
||||
id=msg.get("id", ""),
|
||||
subject=msg.get("subject", ""),
|
||||
sender=sender,
|
||||
recipients=recipients,
|
||||
body=body,
|
||||
body_preview=body_preview,
|
||||
received_at=received_at,
|
||||
received_timestamp=received_timestamp,
|
||||
is_read=msg.get("isRead", False),
|
||||
has_attachments=msg.get("hasAttachments", False),
|
||||
)
|
||||
|
||||
def test_connection(self) -> bool:
|
||||
"""
|
||||
测试 Graph API 连接
|
||||
|
||||
Returns:
|
||||
连接是否正常
|
||||
"""
|
||||
try:
|
||||
# 尝试获取一封邮件来测试连接
|
||||
emails = self.get_recent_emails(count=1, only_unseen=False)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"[{self.account.email}] Graph API 连接测试失败: {e}")
|
||||
return False
|
||||
231
src/services/outlook/providers/imap_new.py
Normal file
231
src/services/outlook/providers/imap_new.py
Normal file
@@ -0,0 +1,231 @@
|
||||
"""
|
||||
新版 IMAP 提供者
|
||||
使用 outlook.live.com 服务器和 login.microsoftonline.com/consumers Token 端点
|
||||
"""
|
||||
|
||||
import email
|
||||
import imaplib
|
||||
import logging
|
||||
from email.header import decode_header
|
||||
from email.utils import parsedate_to_datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from ..base import ProviderType, EmailMessage
|
||||
from ..account import OutlookAccount
|
||||
from ..token_manager import TokenManager
|
||||
from .base import OutlookProvider, ProviderConfig
|
||||
from .imap_old import IMAPOldProvider
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IMAPNewProvider(OutlookProvider):
|
||||
"""
|
||||
新版 IMAP 提供者
|
||||
使用 outlook.live.com:993 和 login.microsoftonline.com/consumers Token 端点
|
||||
需要 IMAP.AccessAsUser.All scope
|
||||
"""
|
||||
|
||||
# IMAP 服务器配置
|
||||
IMAP_HOST = "outlook.live.com"
|
||||
IMAP_PORT = 993
|
||||
|
||||
@property
|
||||
def provider_type(self) -> ProviderType:
|
||||
return ProviderType.IMAP_NEW
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
account: OutlookAccount,
|
||||
config: Optional[ProviderConfig] = None,
|
||||
):
|
||||
super().__init__(account, config)
|
||||
|
||||
# IMAP 连接
|
||||
self._conn: Optional[imaplib.IMAP4_SSL] = None
|
||||
|
||||
# Token 管理器
|
||||
self._token_manager: Optional[TokenManager] = None
|
||||
|
||||
# 注意:新版 IMAP 必须使用 OAuth2
|
||||
if not account.has_oauth():
|
||||
logger.warning(
|
||||
f"[{self.account.email}] 新版 IMAP 提供者需要 OAuth2 配置 "
|
||||
f"(client_id + refresh_token)"
|
||||
)
|
||||
|
||||
def connect(self) -> bool:
|
||||
"""
|
||||
连接到 IMAP 服务器
|
||||
|
||||
Returns:
|
||||
是否连接成功
|
||||
"""
|
||||
if self._connected and self._conn:
|
||||
try:
|
||||
self._conn.noop()
|
||||
return True
|
||||
except Exception:
|
||||
self.disconnect()
|
||||
|
||||
# 新版 IMAP 必须使用 OAuth2,无 OAuth 时静默跳过,不记录健康失败
|
||||
if not self.account.has_oauth():
|
||||
logger.debug(f"[{self.account.email}] 跳过 IMAP_NEW(无 OAuth)")
|
||||
return False
|
||||
|
||||
try:
|
||||
logger.debug(f"[{self.account.email}] 正在连接 IMAP ({self.IMAP_HOST})...")
|
||||
|
||||
# 创建连接
|
||||
self._conn = imaplib.IMAP4_SSL(
|
||||
self.IMAP_HOST,
|
||||
self.IMAP_PORT,
|
||||
timeout=self.config.timeout,
|
||||
)
|
||||
|
||||
# XOAUTH2 认证
|
||||
if self._authenticate_xoauth2():
|
||||
self._connected = True
|
||||
self.record_success()
|
||||
logger.info(f"[{self.account.email}] 新版 IMAP 连接成功 (XOAUTH2)")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.disconnect()
|
||||
self.record_failure(str(e))
|
||||
logger.error(f"[{self.account.email}] 新版 IMAP 连接失败: {e}")
|
||||
return False
|
||||
|
||||
def _authenticate_xoauth2(self) -> bool:
|
||||
"""
|
||||
使用 XOAUTH2 认证
|
||||
|
||||
Returns:
|
||||
是否认证成功
|
||||
"""
|
||||
if not self._token_manager:
|
||||
self._token_manager = TokenManager(
|
||||
self.account,
|
||||
ProviderType.IMAP_NEW,
|
||||
self.config.proxy_url,
|
||||
self.config.timeout,
|
||||
)
|
||||
|
||||
# 获取 Access Token
|
||||
token = self._token_manager.get_access_token()
|
||||
if not token:
|
||||
logger.error(f"[{self.account.email}] 获取 IMAP Token 失败")
|
||||
return False
|
||||
|
||||
try:
|
||||
# 构建 XOAUTH2 认证字符串
|
||||
auth_string = f"user={self.account.email}\x01auth=Bearer {token}\x01\x01"
|
||||
self._conn.authenticate("XOAUTH2", lambda _: auth_string.encode("utf-8"))
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.account.email}] XOAUTH2 认证异常: {e}")
|
||||
# 清除缓存的 Token
|
||||
self._token_manager.clear_cache()
|
||||
return False
|
||||
|
||||
def disconnect(self):
|
||||
"""断开 IMAP 连接"""
|
||||
if self._conn:
|
||||
try:
|
||||
self._conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
self._conn.logout()
|
||||
except Exception:
|
||||
pass
|
||||
self._conn = None
|
||||
|
||||
self._connected = False
|
||||
|
||||
def get_recent_emails(
|
||||
self,
|
||||
count: int = 20,
|
||||
only_unseen: bool = True,
|
||||
) -> List[EmailMessage]:
|
||||
"""
|
||||
获取最近的邮件
|
||||
|
||||
Args:
|
||||
count: 获取数量
|
||||
only_unseen: 是否只获取未读
|
||||
|
||||
Returns:
|
||||
邮件列表
|
||||
"""
|
||||
if not self._connected:
|
||||
if not self.connect():
|
||||
return []
|
||||
|
||||
try:
|
||||
# 选择收件箱
|
||||
self._conn.select("INBOX", readonly=True)
|
||||
|
||||
# 搜索邮件
|
||||
flag = "UNSEEN" if only_unseen else "ALL"
|
||||
status, data = self._conn.search(None, flag)
|
||||
|
||||
if status != "OK" or not data or not data[0]:
|
||||
return []
|
||||
|
||||
# 获取最新的邮件 ID
|
||||
ids = data[0].split()
|
||||
recent_ids = ids[-count:][::-1]
|
||||
|
||||
emails = []
|
||||
for msg_id in recent_ids:
|
||||
try:
|
||||
email_msg = self._fetch_email(msg_id)
|
||||
if email_msg:
|
||||
emails.append(email_msg)
|
||||
except Exception as e:
|
||||
logger.warning(f"[{self.account.email}] 解析邮件失败 (ID: {msg_id}): {e}")
|
||||
|
||||
return emails
|
||||
|
||||
except Exception as e:
|
||||
self.record_failure(str(e))
|
||||
logger.error(f"[{self.account.email}] 获取邮件失败: {e}")
|
||||
return []
|
||||
|
||||
def _fetch_email(self, msg_id: bytes) -> Optional[EmailMessage]:
|
||||
"""获取并解析单封邮件"""
|
||||
status, data = self._conn.fetch(msg_id, "(RFC822)")
|
||||
if status != "OK" or not data or not data[0]:
|
||||
return None
|
||||
|
||||
raw = b""
|
||||
for part in data:
|
||||
if isinstance(part, tuple) and len(part) > 1:
|
||||
raw = part[1]
|
||||
break
|
||||
|
||||
if not raw:
|
||||
return None
|
||||
|
||||
return self._parse_email(raw)
|
||||
|
||||
@staticmethod
|
||||
def _parse_email(raw: bytes) -> EmailMessage:
|
||||
"""解析原始邮件"""
|
||||
# 使用旧版提供者的解析方法
|
||||
return IMAPOldProvider._parse_email(raw)
|
||||
|
||||
def test_connection(self) -> bool:
|
||||
"""测试 IMAP 连接"""
|
||||
try:
|
||||
with self:
|
||||
self._conn.select("INBOX", readonly=True)
|
||||
self._conn.search(None, "ALL")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"[{self.account.email}] 新版 IMAP 连接测试失败: {e}")
|
||||
return False
|
||||
345
src/services/outlook/providers/imap_old.py
Normal file
345
src/services/outlook/providers/imap_old.py
Normal file
@@ -0,0 +1,345 @@
|
||||
"""
|
||||
旧版 IMAP 提供者
|
||||
使用 outlook.office365.com 服务器和 login.live.com Token 端点
|
||||
"""
|
||||
|
||||
import email
|
||||
import imaplib
|
||||
import logging
|
||||
from email.header import decode_header
|
||||
from email.utils import parsedate_to_datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from ..base import ProviderType, EmailMessage
|
||||
from ..account import OutlookAccount
|
||||
from ..token_manager import TokenManager
|
||||
from .base import OutlookProvider, ProviderConfig
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IMAPOldProvider(OutlookProvider):
|
||||
"""
|
||||
旧版 IMAP 提供者
|
||||
使用 outlook.office365.com:993 和 login.live.com Token 端点
|
||||
"""
|
||||
|
||||
# IMAP 服务器配置
|
||||
IMAP_HOST = "outlook.office365.com"
|
||||
IMAP_PORT = 993
|
||||
|
||||
@property
|
||||
def provider_type(self) -> ProviderType:
|
||||
return ProviderType.IMAP_OLD
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
account: OutlookAccount,
|
||||
config: Optional[ProviderConfig] = None,
|
||||
):
|
||||
super().__init__(account, config)
|
||||
|
||||
# IMAP 连接
|
||||
self._conn: Optional[imaplib.IMAP4_SSL] = None
|
||||
|
||||
# Token 管理器
|
||||
self._token_manager: Optional[TokenManager] = None
|
||||
|
||||
def connect(self) -> bool:
|
||||
"""
|
||||
连接到 IMAP 服务器
|
||||
|
||||
Returns:
|
||||
是否连接成功
|
||||
"""
|
||||
if self._connected and self._conn:
|
||||
# 检查现有连接
|
||||
try:
|
||||
self._conn.noop()
|
||||
return True
|
||||
except Exception:
|
||||
self.disconnect()
|
||||
|
||||
try:
|
||||
logger.debug(f"[{self.account.email}] 正在连接 IMAP ({self.IMAP_HOST})...")
|
||||
|
||||
# 创建连接
|
||||
self._conn = imaplib.IMAP4_SSL(
|
||||
self.IMAP_HOST,
|
||||
self.IMAP_PORT,
|
||||
timeout=self.config.timeout,
|
||||
)
|
||||
|
||||
# 尝试 XOAUTH2 认证
|
||||
if self.account.has_oauth():
|
||||
if self._authenticate_xoauth2():
|
||||
self._connected = True
|
||||
self.record_success()
|
||||
logger.info(f"[{self.account.email}] IMAP 连接成功 (XOAUTH2)")
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"[{self.account.email}] XOAUTH2 认证失败,尝试密码认证")
|
||||
|
||||
# 密码认证
|
||||
if self.account.password:
|
||||
self._conn.login(self.account.email, self.account.password)
|
||||
self._connected = True
|
||||
self.record_success()
|
||||
logger.info(f"[{self.account.email}] IMAP 连接成功 (密码认证)")
|
||||
return True
|
||||
|
||||
raise ValueError("没有可用的认证方式")
|
||||
|
||||
except Exception as e:
|
||||
self.disconnect()
|
||||
self.record_failure(str(e))
|
||||
logger.error(f"[{self.account.email}] IMAP 连接失败: {e}")
|
||||
return False
|
||||
|
||||
def _authenticate_xoauth2(self) -> bool:
|
||||
"""
|
||||
使用 XOAUTH2 认证
|
||||
|
||||
Returns:
|
||||
是否认证成功
|
||||
"""
|
||||
if not self._token_manager:
|
||||
self._token_manager = TokenManager(
|
||||
self.account,
|
||||
ProviderType.IMAP_OLD,
|
||||
self.config.proxy_url,
|
||||
self.config.timeout,
|
||||
)
|
||||
|
||||
# 获取 Access Token
|
||||
token = self._token_manager.get_access_token()
|
||||
if not token:
|
||||
return False
|
||||
|
||||
try:
|
||||
# 构建 XOAUTH2 认证字符串
|
||||
auth_string = f"user={self.account.email}\x01auth=Bearer {token}\x01\x01"
|
||||
self._conn.authenticate("XOAUTH2", lambda _: auth_string.encode("utf-8"))
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.debug(f"[{self.account.email}] XOAUTH2 认证异常: {e}")
|
||||
# 清除缓存的 Token
|
||||
self._token_manager.clear_cache()
|
||||
return False
|
||||
|
||||
def disconnect(self):
|
||||
"""断开 IMAP 连接"""
|
||||
if self._conn:
|
||||
try:
|
||||
self._conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
self._conn.logout()
|
||||
except Exception:
|
||||
pass
|
||||
self._conn = None
|
||||
|
||||
self._connected = False
|
||||
|
||||
def get_recent_emails(
|
||||
self,
|
||||
count: int = 20,
|
||||
only_unseen: bool = True,
|
||||
) -> List[EmailMessage]:
|
||||
"""
|
||||
获取最近的邮件
|
||||
|
||||
Args:
|
||||
count: 获取数量
|
||||
only_unseen: 是否只获取未读
|
||||
|
||||
Returns:
|
||||
邮件列表
|
||||
"""
|
||||
if not self._connected:
|
||||
if not self.connect():
|
||||
return []
|
||||
|
||||
try:
|
||||
# 选择收件箱
|
||||
self._conn.select("INBOX", readonly=True)
|
||||
|
||||
# 搜索邮件
|
||||
flag = "UNSEEN" if only_unseen else "ALL"
|
||||
status, data = self._conn.search(None, flag)
|
||||
|
||||
if status != "OK" or not data or not data[0]:
|
||||
return []
|
||||
|
||||
# 获取最新的邮件 ID
|
||||
ids = data[0].split()
|
||||
recent_ids = ids[-count:][::-1] # 倒序,最新的在前
|
||||
|
||||
emails = []
|
||||
for msg_id in recent_ids:
|
||||
try:
|
||||
email_msg = self._fetch_email(msg_id)
|
||||
if email_msg:
|
||||
emails.append(email_msg)
|
||||
except Exception as e:
|
||||
logger.warning(f"[{self.account.email}] 解析邮件失败 (ID: {msg_id}): {e}")
|
||||
|
||||
return emails
|
||||
|
||||
except Exception as e:
|
||||
self.record_failure(str(e))
|
||||
logger.error(f"[{self.account.email}] 获取邮件失败: {e}")
|
||||
return []
|
||||
|
||||
def _fetch_email(self, msg_id: bytes) -> Optional[EmailMessage]:
|
||||
"""
|
||||
获取并解析单封邮件
|
||||
|
||||
Args:
|
||||
msg_id: 邮件 ID
|
||||
|
||||
Returns:
|
||||
EmailMessage 对象,失败返回 None
|
||||
"""
|
||||
status, data = self._conn.fetch(msg_id, "(RFC822)")
|
||||
if status != "OK" or not data or not data[0]:
|
||||
return None
|
||||
|
||||
# 获取原始邮件内容
|
||||
raw = b""
|
||||
for part in data:
|
||||
if isinstance(part, tuple) and len(part) > 1:
|
||||
raw = part[1]
|
||||
break
|
||||
|
||||
if not raw:
|
||||
return None
|
||||
|
||||
return self._parse_email(raw)
|
||||
|
||||
@staticmethod
|
||||
def _parse_email(raw: bytes) -> EmailMessage:
|
||||
"""
|
||||
解析原始邮件
|
||||
|
||||
Args:
|
||||
raw: 原始邮件数据
|
||||
|
||||
Returns:
|
||||
EmailMessage 对象
|
||||
"""
|
||||
# 移除 BOM
|
||||
if raw.startswith(b"\xef\xbb\xbf"):
|
||||
raw = raw[3:]
|
||||
|
||||
msg = email.message_from_bytes(raw)
|
||||
|
||||
# 解析邮件头
|
||||
subject = IMAPOldProvider._decode_header(msg.get("Subject", ""))
|
||||
sender = IMAPOldProvider._decode_header(msg.get("From", ""))
|
||||
to = IMAPOldProvider._decode_header(msg.get("To", ""))
|
||||
delivered_to = IMAPOldProvider._decode_header(msg.get("Delivered-To", ""))
|
||||
x_original_to = IMAPOldProvider._decode_header(msg.get("X-Original-To", ""))
|
||||
date_str = IMAPOldProvider._decode_header(msg.get("Date", ""))
|
||||
|
||||
# 提取正文
|
||||
body = IMAPOldProvider._extract_body(msg)
|
||||
|
||||
# 解析日期
|
||||
received_timestamp = 0
|
||||
received_at = None
|
||||
try:
|
||||
if date_str:
|
||||
received_at = parsedate_to_datetime(date_str)
|
||||
received_timestamp = int(received_at.timestamp())
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 构建收件人列表
|
||||
recipients = [r for r in [to, delivered_to, x_original_to] if r]
|
||||
|
||||
return EmailMessage(
|
||||
id=msg.get("Message-ID", ""),
|
||||
subject=subject,
|
||||
sender=sender,
|
||||
recipients=recipients,
|
||||
body=body,
|
||||
received_at=received_at,
|
||||
received_timestamp=received_timestamp,
|
||||
is_read=False, # 搜索的是未读邮件
|
||||
raw_data=raw[:500] if len(raw) > 500 else raw,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _decode_header(header: str) -> str:
|
||||
"""解码邮件头"""
|
||||
if not header:
|
||||
return ""
|
||||
|
||||
parts = []
|
||||
for chunk, encoding in decode_header(header):
|
||||
if isinstance(chunk, bytes):
|
||||
try:
|
||||
decoded = chunk.decode(encoding or "utf-8", errors="replace")
|
||||
parts.append(decoded)
|
||||
except Exception:
|
||||
parts.append(chunk.decode("utf-8", errors="replace"))
|
||||
else:
|
||||
parts.append(str(chunk))
|
||||
|
||||
return "".join(parts).strip()
|
||||
|
||||
@staticmethod
|
||||
def _extract_body(msg) -> str:
|
||||
"""提取邮件正文"""
|
||||
import html as html_module
|
||||
import re
|
||||
|
||||
texts = []
|
||||
parts = msg.walk() if msg.is_multipart() else [msg]
|
||||
|
||||
for part in parts:
|
||||
content_type = part.get_content_type()
|
||||
if content_type not in ("text/plain", "text/html"):
|
||||
continue
|
||||
|
||||
payload = part.get_payload(decode=True)
|
||||
if not payload:
|
||||
continue
|
||||
|
||||
charset = part.get_content_charset() or "utf-8"
|
||||
try:
|
||||
text = payload.decode(charset, errors="replace")
|
||||
except LookupError:
|
||||
text = payload.decode("utf-8", errors="replace")
|
||||
|
||||
# 如果是 HTML,移除标签
|
||||
if "<html" in text.lower():
|
||||
text = re.sub(r"<[^>]+>", " ", text)
|
||||
|
||||
texts.append(text)
|
||||
|
||||
# 合并并清理文本
|
||||
combined = " ".join(texts)
|
||||
combined = html_module.unescape(combined)
|
||||
combined = re.sub(r"\s+", " ", combined).strip()
|
||||
|
||||
return combined
|
||||
|
||||
def test_connection(self) -> bool:
|
||||
"""
|
||||
测试 IMAP 连接
|
||||
|
||||
Returns:
|
||||
连接是否正常
|
||||
"""
|
||||
try:
|
||||
with self:
|
||||
self._conn.select("INBOX", readonly=True)
|
||||
self._conn.search(None, "ALL")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"[{self.account.email}] IMAP 连接测试失败: {e}")
|
||||
return False
|
||||
485
src/services/outlook/service.py
Normal file
485
src/services/outlook/service.py
Normal file
@@ -0,0 +1,485 @@
|
||||
"""
|
||||
Outlook 邮箱服务主类
|
||||
支持多种 IMAP/API 连接方式,自动故障切换
|
||||
"""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from typing import Optional, Dict, Any, List
|
||||
|
||||
from ..base import BaseEmailService, EmailServiceError, EmailServiceStatus, EmailServiceType
|
||||
from ...config.constants import EmailServiceType as ServiceType
|
||||
from ...config.settings import get_settings
|
||||
from .account import OutlookAccount
|
||||
from .base import ProviderType, EmailMessage
|
||||
from .email_parser import EmailParser, get_email_parser
|
||||
from .health_checker import HealthChecker, FailoverManager
|
||||
from .providers.base import OutlookProvider, ProviderConfig
|
||||
from .providers.imap_old import IMAPOldProvider
|
||||
from .providers.imap_new import IMAPNewProvider
|
||||
from .providers.graph_api import GraphAPIProvider
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# 默认提供者优先级
|
||||
DEFAULT_PROVIDER_PRIORITY = [
|
||||
ProviderType.GRAPH_API,
|
||||
ProviderType.IMAP_NEW,
|
||||
ProviderType.IMAP_OLD,
|
||||
]
|
||||
|
||||
|
||||
def get_email_code_settings() -> dict:
|
||||
"""获取验证码等待配置"""
|
||||
settings = get_settings()
|
||||
return {
|
||||
"timeout": settings.email_code_timeout,
|
||||
"poll_interval": settings.email_code_poll_interval,
|
||||
}
|
||||
|
||||
|
||||
class OutlookService(BaseEmailService):
|
||||
"""
|
||||
Outlook 邮箱服务
|
||||
支持多种 IMAP/API 连接方式,自动故障切换
|
||||
"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any] = None, name: str = None):
|
||||
"""
|
||||
初始化 Outlook 服务
|
||||
|
||||
Args:
|
||||
config: 配置字典,支持以下键:
|
||||
- accounts: Outlook 账户列表
|
||||
- provider_priority: 提供者优先级列表
|
||||
- health_failure_threshold: 连续失败次数阈值
|
||||
- health_disable_duration: 禁用时长(秒)
|
||||
- timeout: 请求超时时间
|
||||
- proxy_url: 代理 URL
|
||||
name: 服务名称
|
||||
"""
|
||||
super().__init__(ServiceType.OUTLOOK, name)
|
||||
|
||||
# 默认配置
|
||||
default_config = {
|
||||
"accounts": [],
|
||||
"provider_priority": [p.value for p in DEFAULT_PROVIDER_PRIORITY],
|
||||
"health_failure_threshold": 5,
|
||||
"health_disable_duration": 60,
|
||||
"timeout": 30,
|
||||
"proxy_url": None,
|
||||
}
|
||||
|
||||
self.config = {**default_config, **(config or {})}
|
||||
|
||||
# 解析提供者优先级
|
||||
self.provider_priority = [
|
||||
ProviderType(p) for p in self.config.get("provider_priority", [])
|
||||
]
|
||||
if not self.provider_priority:
|
||||
self.provider_priority = DEFAULT_PROVIDER_PRIORITY
|
||||
|
||||
# 提供者配置
|
||||
self.provider_config = ProviderConfig(
|
||||
timeout=self.config.get("timeout", 30),
|
||||
proxy_url=self.config.get("proxy_url"),
|
||||
health_failure_threshold=self.config.get("health_failure_threshold", 3),
|
||||
health_disable_duration=self.config.get("health_disable_duration", 300),
|
||||
)
|
||||
|
||||
# 获取默认 client_id(供无 client_id 的账户使用)
|
||||
try:
|
||||
_default_client_id = get_settings().outlook_default_client_id
|
||||
except Exception:
|
||||
_default_client_id = "24d9a0ed-8787-4584-883c-2fd79308940a"
|
||||
|
||||
# 解析账户
|
||||
self.accounts: List[OutlookAccount] = []
|
||||
self._current_account_index = 0
|
||||
self._account_lock = threading.Lock()
|
||||
|
||||
# 支持两种配置格式
|
||||
if "email" in self.config and "password" in self.config:
|
||||
account = OutlookAccount.from_config(self.config)
|
||||
if not account.client_id and _default_client_id:
|
||||
account.client_id = _default_client_id
|
||||
if account.validate():
|
||||
self.accounts.append(account)
|
||||
else:
|
||||
for account_config in self.config.get("accounts", []):
|
||||
account = OutlookAccount.from_config(account_config)
|
||||
if not account.client_id and _default_client_id:
|
||||
account.client_id = _default_client_id
|
||||
if account.validate():
|
||||
self.accounts.append(account)
|
||||
|
||||
if not self.accounts:
|
||||
logger.warning("未配置有效的 Outlook 账户")
|
||||
|
||||
# 健康检查器和故障切换管理器
|
||||
self.health_checker = HealthChecker(
|
||||
failure_threshold=self.provider_config.health_failure_threshold,
|
||||
disable_duration=self.provider_config.health_disable_duration,
|
||||
)
|
||||
self.failover_manager = FailoverManager(
|
||||
health_checker=self.health_checker,
|
||||
priority_order=self.provider_priority,
|
||||
)
|
||||
|
||||
# 邮件解析器
|
||||
self.email_parser = get_email_parser()
|
||||
|
||||
# 提供者实例缓存: (email, provider_type) -> OutlookProvider
|
||||
self._providers: Dict[tuple, OutlookProvider] = {}
|
||||
self._provider_lock = threading.Lock()
|
||||
|
||||
# IMAP 连接限制(防止限流)
|
||||
self._imap_semaphore = threading.Semaphore(5)
|
||||
|
||||
# 验证码去重机制
|
||||
self._used_codes: Dict[str, set] = {}
|
||||
|
||||
def _get_provider(
|
||||
self,
|
||||
account: OutlookAccount,
|
||||
provider_type: ProviderType,
|
||||
) -> OutlookProvider:
|
||||
"""
|
||||
获取或创建提供者实例
|
||||
|
||||
Args:
|
||||
account: Outlook 账户
|
||||
provider_type: 提供者类型
|
||||
|
||||
Returns:
|
||||
提供者实例
|
||||
"""
|
||||
cache_key = (account.email.lower(), provider_type)
|
||||
|
||||
with self._provider_lock:
|
||||
if cache_key not in self._providers:
|
||||
provider = self._create_provider(account, provider_type)
|
||||
self._providers[cache_key] = provider
|
||||
|
||||
return self._providers[cache_key]
|
||||
|
||||
def _create_provider(
|
||||
self,
|
||||
account: OutlookAccount,
|
||||
provider_type: ProviderType,
|
||||
) -> OutlookProvider:
|
||||
"""
|
||||
创建提供者实例
|
||||
|
||||
Args:
|
||||
account: Outlook 账户
|
||||
provider_type: 提供者类型
|
||||
|
||||
Returns:
|
||||
提供者实例
|
||||
"""
|
||||
if provider_type == ProviderType.IMAP_OLD:
|
||||
return IMAPOldProvider(account, self.provider_config)
|
||||
elif provider_type == ProviderType.IMAP_NEW:
|
||||
return IMAPNewProvider(account, self.provider_config)
|
||||
elif provider_type == ProviderType.GRAPH_API:
|
||||
return GraphAPIProvider(account, self.provider_config)
|
||||
else:
|
||||
raise ValueError(f"未知的提供者类型: {provider_type}")
|
||||
|
||||
def _get_provider_priority_for_account(self, account: OutlookAccount) -> List[ProviderType]:
|
||||
"""根据账户是否有 OAuth,返回适合的提供者优先级列表"""
|
||||
if account.has_oauth():
|
||||
return self.provider_priority
|
||||
else:
|
||||
# 无 OAuth,直接走旧版 IMAP(密码认证),跳过需要 OAuth 的提供者
|
||||
return [ProviderType.IMAP_OLD]
|
||||
|
||||
def _try_providers_for_emails(
|
||||
self,
|
||||
account: OutlookAccount,
|
||||
count: int = 20,
|
||||
only_unseen: bool = True,
|
||||
) -> List[EmailMessage]:
|
||||
"""
|
||||
尝试多个提供者获取邮件
|
||||
|
||||
Args:
|
||||
account: Outlook 账户
|
||||
count: 获取数量
|
||||
only_unseen: 是否只获取未读
|
||||
|
||||
Returns:
|
||||
邮件列表
|
||||
"""
|
||||
errors = []
|
||||
|
||||
# 根据账户类型选择合适的提供者优先级
|
||||
priority = self._get_provider_priority_for_account(account)
|
||||
|
||||
# 按优先级尝试各提供者
|
||||
for provider_type in priority:
|
||||
# 检查提供者是否可用
|
||||
if not self.health_checker.is_available(provider_type):
|
||||
logger.debug(
|
||||
f"[{account.email}] {provider_type.value} 不可用,跳过"
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
provider = self._get_provider(account, provider_type)
|
||||
|
||||
with self._imap_semaphore:
|
||||
with provider:
|
||||
emails = provider.get_recent_emails(count, only_unseen)
|
||||
|
||||
if emails:
|
||||
# 成功获取邮件
|
||||
self.health_checker.record_success(provider_type)
|
||||
logger.debug(
|
||||
f"[{account.email}] {provider_type.value} 获取到 {len(emails)} 封邮件"
|
||||
)
|
||||
return emails
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
errors.append(f"{provider_type.value}: {error_msg}")
|
||||
self.health_checker.record_failure(provider_type, error_msg)
|
||||
logger.warning(
|
||||
f"[{account.email}] {provider_type.value} 获取邮件失败: {e}"
|
||||
)
|
||||
|
||||
logger.error(
|
||||
f"[{account.email}] 所有提供者都失败: {'; '.join(errors)}"
|
||||
)
|
||||
return []
|
||||
|
||||
def create_email(self, config: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
选择可用的 Outlook 账户
|
||||
|
||||
Args:
|
||||
config: 配置参数(未使用)
|
||||
|
||||
Returns:
|
||||
包含邮箱信息的字典
|
||||
"""
|
||||
if not self.accounts:
|
||||
self.update_status(False, EmailServiceError("没有可用的 Outlook 账户"))
|
||||
raise EmailServiceError("没有可用的 Outlook 账户")
|
||||
|
||||
# 轮询选择账户
|
||||
with self._account_lock:
|
||||
account = self.accounts[self._current_account_index]
|
||||
self._current_account_index = (self._current_account_index + 1) % len(self.accounts)
|
||||
|
||||
email_info = {
|
||||
"email": account.email,
|
||||
"service_id": account.email,
|
||||
"account": {
|
||||
"email": account.email,
|
||||
"has_oauth": account.has_oauth()
|
||||
}
|
||||
}
|
||||
|
||||
logger.info(f"选择 Outlook 账户: {account.email}")
|
||||
self.update_status(True)
|
||||
return email_info
|
||||
|
||||
def get_verification_code(
|
||||
self,
|
||||
email: str,
|
||||
email_id: str = None,
|
||||
timeout: int = None,
|
||||
pattern: str = None,
|
||||
otp_sent_at: Optional[float] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
从 Outlook 邮箱获取验证码
|
||||
|
||||
Args:
|
||||
email: 邮箱地址
|
||||
email_id: 未使用
|
||||
timeout: 超时时间(秒)
|
||||
pattern: 验证码正则表达式(未使用)
|
||||
otp_sent_at: OTP 发送时间戳
|
||||
|
||||
Returns:
|
||||
验证码字符串
|
||||
"""
|
||||
# 查找对应的账户
|
||||
account = None
|
||||
for acc in self.accounts:
|
||||
if acc.email.lower() == email.lower():
|
||||
account = acc
|
||||
break
|
||||
|
||||
if not account:
|
||||
self.update_status(False, EmailServiceError(f"未找到邮箱对应的账户: {email}"))
|
||||
return None
|
||||
|
||||
# 获取验证码等待配置
|
||||
code_settings = get_email_code_settings()
|
||||
actual_timeout = timeout or code_settings["timeout"]
|
||||
poll_interval = code_settings["poll_interval"]
|
||||
|
||||
logger.info(
|
||||
f"[{email}] 开始获取验证码,超时 {actual_timeout}s,"
|
||||
f"提供者优先级: {[p.value for p in self.provider_priority]}"
|
||||
)
|
||||
|
||||
# 初始化验证码去重集合
|
||||
if email not in self._used_codes:
|
||||
self._used_codes[email] = set()
|
||||
used_codes = self._used_codes[email]
|
||||
|
||||
# 计算最小时间戳(留出 60 秒时钟偏差)
|
||||
min_timestamp = (otp_sent_at - 60) if otp_sent_at else 0
|
||||
|
||||
start_time = time.time()
|
||||
poll_count = 0
|
||||
|
||||
while time.time() - start_time < actual_timeout:
|
||||
poll_count += 1
|
||||
|
||||
# 渐进式邮件检查:前 3 次只检查未读
|
||||
only_unseen = poll_count <= 3
|
||||
|
||||
try:
|
||||
# 尝试多个提供者获取邮件
|
||||
emails = self._try_providers_for_emails(
|
||||
account,
|
||||
count=15,
|
||||
only_unseen=only_unseen,
|
||||
)
|
||||
|
||||
if emails:
|
||||
logger.debug(
|
||||
f"[{email}] 第 {poll_count} 次轮询获取到 {len(emails)} 封邮件"
|
||||
)
|
||||
|
||||
# 从邮件中查找验证码
|
||||
code = self.email_parser.find_verification_code_in_emails(
|
||||
emails,
|
||||
target_email=email,
|
||||
min_timestamp=min_timestamp,
|
||||
used_codes=used_codes,
|
||||
)
|
||||
|
||||
if code:
|
||||
used_codes.add(code)
|
||||
elapsed = int(time.time() - start_time)
|
||||
logger.info(
|
||||
f"[{email}] 找到验证码: {code},"
|
||||
f"总耗时 {elapsed}s,轮询 {poll_count} 次"
|
||||
)
|
||||
self.update_status(True)
|
||||
return code
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[{email}] 检查出错: {e}")
|
||||
|
||||
# 等待下次轮询
|
||||
time.sleep(poll_interval)
|
||||
|
||||
elapsed = int(time.time() - start_time)
|
||||
logger.warning(f"[{email}] 验证码超时 ({actual_timeout}s),共轮询 {poll_count} 次")
|
||||
return None
|
||||
|
||||
def list_emails(self, **kwargs) -> List[Dict[str, Any]]:
|
||||
"""列出所有可用的 Outlook 账户"""
|
||||
return [
|
||||
{
|
||||
"email": account.email,
|
||||
"id": account.email,
|
||||
"has_oauth": account.has_oauth(),
|
||||
"type": "outlook"
|
||||
}
|
||||
for account in self.accounts
|
||||
]
|
||||
|
||||
def delete_email(self, email_id: str) -> bool:
|
||||
"""删除邮箱(Outlook 不支持删除账户)"""
|
||||
logger.warning(f"Outlook 服务不支持删除账户: {email_id}")
|
||||
return False
|
||||
|
||||
def check_health(self) -> bool:
|
||||
"""检查 Outlook 服务是否可用"""
|
||||
if not self.accounts:
|
||||
self.update_status(False, EmailServiceError("没有配置的账户"))
|
||||
return False
|
||||
|
||||
# 测试第一个账户的连接
|
||||
test_account = self.accounts[0]
|
||||
|
||||
# 尝试任一提供者连接
|
||||
for provider_type in self.provider_priority:
|
||||
try:
|
||||
provider = self._get_provider(test_account, provider_type)
|
||||
if provider.test_connection():
|
||||
self.update_status(True)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Outlook 健康检查失败 ({test_account.email}, {provider_type.value}): {e}"
|
||||
)
|
||||
|
||||
self.update_status(False, EmailServiceError("健康检查失败"))
|
||||
return False
|
||||
|
||||
def get_provider_status(self) -> Dict[str, Any]:
|
||||
"""获取提供者状态"""
|
||||
return self.failover_manager.get_status()
|
||||
|
||||
def get_account_stats(self) -> Dict[str, Any]:
|
||||
"""获取账户统计信息"""
|
||||
total = len(self.accounts)
|
||||
oauth_count = sum(1 for acc in self.accounts if acc.has_oauth())
|
||||
|
||||
return {
|
||||
"total_accounts": total,
|
||||
"oauth_accounts": oauth_count,
|
||||
"password_accounts": total - oauth_count,
|
||||
"accounts": [acc.to_dict() for acc in self.accounts],
|
||||
"provider_status": self.get_provider_status(),
|
||||
}
|
||||
|
||||
def add_account(self, account_config: Dict[str, Any]) -> bool:
|
||||
"""添加新的 Outlook 账户"""
|
||||
try:
|
||||
account = OutlookAccount.from_config(account_config)
|
||||
if not account.validate():
|
||||
return False
|
||||
|
||||
self.accounts.append(account)
|
||||
logger.info(f"添加 Outlook 账户: {account.email}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"添加 Outlook 账户失败: {e}")
|
||||
return False
|
||||
|
||||
def remove_account(self, email: str) -> bool:
|
||||
"""移除 Outlook 账户"""
|
||||
for i, acc in enumerate(self.accounts):
|
||||
if acc.email.lower() == email.lower():
|
||||
self.accounts.pop(i)
|
||||
logger.info(f"移除 Outlook 账户: {email}")
|
||||
return True
|
||||
return False
|
||||
|
||||
def reset_provider_health(self):
|
||||
"""重置所有提供者的健康状态"""
|
||||
self.health_checker.reset_all()
|
||||
logger.info("已重置所有提供者的健康状态")
|
||||
|
||||
def force_provider(self, provider_type: ProviderType):
|
||||
"""强制使用指定的提供者"""
|
||||
self.health_checker.force_enable(provider_type)
|
||||
# 禁用其他提供者
|
||||
for pt in ProviderType:
|
||||
if pt != provider_type:
|
||||
self.health_checker.force_disable(pt, 60)
|
||||
logger.info(f"已强制使用提供者: {provider_type.value}")
|
||||
239
src/services/outlook/token_manager.py
Normal file
239
src/services/outlook/token_manager.py
Normal file
@@ -0,0 +1,239 @@
|
||||
"""
|
||||
Token 管理器
|
||||
支持多个 Microsoft Token 端点,自动选择合适的端点
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from typing import Dict, Optional, Any
|
||||
|
||||
from curl_cffi import requests as _requests
|
||||
|
||||
from .base import ProviderType, TokenEndpoint, TokenInfo
|
||||
from .account import OutlookAccount
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# 各提供者的 Scope 配置
|
||||
PROVIDER_SCOPES = {
|
||||
ProviderType.IMAP_OLD: "", # 旧版 IMAP 不需要特定 scope
|
||||
ProviderType.IMAP_NEW: "https://outlook.office.com/IMAP.AccessAsUser.All offline_access",
|
||||
ProviderType.GRAPH_API: "https://graph.microsoft.com/.default",
|
||||
}
|
||||
|
||||
# 各提供者的 Token 端点
|
||||
PROVIDER_TOKEN_URLS = {
|
||||
ProviderType.IMAP_OLD: TokenEndpoint.LIVE.value,
|
||||
ProviderType.IMAP_NEW: TokenEndpoint.CONSUMERS.value,
|
||||
ProviderType.GRAPH_API: TokenEndpoint.COMMON.value,
|
||||
}
|
||||
|
||||
|
||||
class TokenManager:
|
||||
"""
|
||||
Token 管理器
|
||||
支持多端点 Token 获取和缓存
|
||||
"""
|
||||
|
||||
# Token 缓存: key = (email, provider_type) -> TokenInfo
|
||||
_token_cache: Dict[tuple, TokenInfo] = {}
|
||||
_cache_lock = threading.Lock()
|
||||
|
||||
# 默认超时时间
|
||||
DEFAULT_TIMEOUT = 30
|
||||
# Token 刷新提前时间(秒)
|
||||
REFRESH_BUFFER = 120
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
account: OutlookAccount,
|
||||
provider_type: ProviderType,
|
||||
proxy_url: Optional[str] = None,
|
||||
timeout: int = DEFAULT_TIMEOUT,
|
||||
):
|
||||
"""
|
||||
初始化 Token 管理器
|
||||
|
||||
Args:
|
||||
account: Outlook 账户
|
||||
provider_type: 提供者类型
|
||||
proxy_url: 代理 URL(可选)
|
||||
timeout: 请求超时时间
|
||||
"""
|
||||
self.account = account
|
||||
self.provider_type = provider_type
|
||||
self.proxy_url = proxy_url
|
||||
self.timeout = timeout
|
||||
|
||||
# 获取端点和 Scope
|
||||
self.token_url = PROVIDER_TOKEN_URLS.get(provider_type, TokenEndpoint.LIVE.value)
|
||||
self.scope = PROVIDER_SCOPES.get(provider_type, "")
|
||||
|
||||
def get_cached_token(self) -> Optional[TokenInfo]:
|
||||
"""获取缓存的 Token"""
|
||||
cache_key = (self.account.email.lower(), self.provider_type)
|
||||
with self._cache_lock:
|
||||
token = self._token_cache.get(cache_key)
|
||||
if token and not token.is_expired(self.REFRESH_BUFFER):
|
||||
return token
|
||||
return None
|
||||
|
||||
def set_cached_token(self, token: TokenInfo):
|
||||
"""缓存 Token"""
|
||||
cache_key = (self.account.email.lower(), self.provider_type)
|
||||
with self._cache_lock:
|
||||
self._token_cache[cache_key] = token
|
||||
|
||||
def clear_cache(self):
|
||||
"""清除缓存"""
|
||||
cache_key = (self.account.email.lower(), self.provider_type)
|
||||
with self._cache_lock:
|
||||
self._token_cache.pop(cache_key, None)
|
||||
|
||||
def get_access_token(self, force_refresh: bool = False) -> Optional[str]:
|
||||
"""
|
||||
获取 Access Token
|
||||
|
||||
Args:
|
||||
force_refresh: 是否强制刷新
|
||||
|
||||
Returns:
|
||||
Access Token 字符串,失败返回 None
|
||||
"""
|
||||
# 检查缓存
|
||||
if not force_refresh:
|
||||
cached = self.get_cached_token()
|
||||
if cached:
|
||||
logger.debug(f"[{self.account.email}] 使用缓存的 Token ({self.provider_type.value})")
|
||||
return cached.access_token
|
||||
|
||||
# 刷新 Token
|
||||
try:
|
||||
token = self._refresh_token()
|
||||
if token:
|
||||
self.set_cached_token(token)
|
||||
return token.access_token
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.account.email}] 获取 Token 失败 ({self.provider_type.value}): {e}")
|
||||
|
||||
return None
|
||||
|
||||
def _refresh_token(self) -> Optional[TokenInfo]:
|
||||
"""
|
||||
刷新 Token
|
||||
|
||||
Returns:
|
||||
TokenInfo 对象,失败返回 None
|
||||
"""
|
||||
if not self.account.client_id or not self.account.refresh_token:
|
||||
raise ValueError("缺少 client_id 或 refresh_token")
|
||||
|
||||
logger.debug(f"[{self.account.email}] 正在刷新 Token ({self.provider_type.value})...")
|
||||
logger.debug(f"[{self.account.email}] Token URL: {self.token_url}")
|
||||
|
||||
# 构建请求体
|
||||
data = {
|
||||
"client_id": self.account.client_id,
|
||||
"refresh_token": self.account.refresh_token,
|
||||
"grant_type": "refresh_token",
|
||||
}
|
||||
|
||||
# 添加 Scope(如果需要)
|
||||
if self.scope:
|
||||
data["scope"] = self.scope
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
proxies = None
|
||||
if self.proxy_url:
|
||||
proxies = {"http": self.proxy_url, "https": self.proxy_url}
|
||||
|
||||
try:
|
||||
resp = _requests.post(
|
||||
self.token_url,
|
||||
data=data,
|
||||
headers=headers,
|
||||
proxies=proxies,
|
||||
timeout=self.timeout,
|
||||
impersonate="chrome110",
|
||||
)
|
||||
|
||||
if resp.status_code != 200:
|
||||
error_body = resp.text
|
||||
logger.error(f"[{self.account.email}] Token 刷新失败: HTTP {resp.status_code}")
|
||||
logger.debug(f"[{self.account.email}] 错误响应: {error_body[:500]}")
|
||||
|
||||
if "service abuse" in error_body.lower():
|
||||
logger.warning(f"[{self.account.email}] 账号可能被封禁")
|
||||
elif "invalid_grant" in error_body.lower():
|
||||
logger.warning(f"[{self.account.email}] Refresh Token 已失效")
|
||||
|
||||
return None
|
||||
|
||||
response_data = resp.json()
|
||||
|
||||
# 解析响应
|
||||
token = TokenInfo.from_response(response_data, self.scope)
|
||||
logger.info(
|
||||
f"[{self.account.email}] Token 刷新成功 ({self.provider_type.value}), "
|
||||
f"有效期 {int(token.expires_at - time.time())} 秒"
|
||||
)
|
||||
return token
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"[{self.account.email}] JSON 解析错误: {e}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.account.email}] 未知错误: {e}")
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def clear_all_cache(cls):
|
||||
"""清除所有 Token 缓存"""
|
||||
with cls._cache_lock:
|
||||
cls._token_cache.clear()
|
||||
logger.info("已清除所有 Token 缓存")
|
||||
|
||||
@classmethod
|
||||
def get_cache_stats(cls) -> Dict[str, Any]:
|
||||
"""获取缓存统计"""
|
||||
with cls._cache_lock:
|
||||
return {
|
||||
"cache_size": len(cls._token_cache),
|
||||
"entries": [
|
||||
{
|
||||
"email": key[0],
|
||||
"provider": key[1].value,
|
||||
}
|
||||
for key in cls._token_cache.keys()
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def create_token_manager(
|
||||
account: OutlookAccount,
|
||||
provider_type: ProviderType,
|
||||
proxy_url: Optional[str] = None,
|
||||
timeout: int = TokenManager.DEFAULT_TIMEOUT,
|
||||
) -> TokenManager:
|
||||
"""
|
||||
创建 Token 管理器的工厂函数
|
||||
|
||||
Args:
|
||||
account: Outlook 账户
|
||||
provider_type: 提供者类型
|
||||
proxy_url: 代理 URL
|
||||
timeout: 超时时间
|
||||
|
||||
Returns:
|
||||
TokenManager 实例
|
||||
"""
|
||||
return TokenManager(account, provider_type, proxy_url, timeout)
|
||||
@@ -734,3 +734,37 @@ async def test_cpa_connection(request: CPATestRequest):
|
||||
"success": success,
|
||||
"message": message
|
||||
}
|
||||
|
||||
|
||||
# ============== Outlook 设置 ==============
|
||||
|
||||
class OutlookSettings(BaseModel):
|
||||
"""Outlook 设置"""
|
||||
default_client_id: Optional[str] = None
|
||||
|
||||
|
||||
@router.get("/outlook")
|
||||
async def get_outlook_settings():
|
||||
"""获取 Outlook 设置"""
|
||||
settings = get_settings()
|
||||
|
||||
return {
|
||||
"default_client_id": settings.outlook_default_client_id,
|
||||
"provider_priority": settings.outlook_provider_priority,
|
||||
"health_failure_threshold": settings.outlook_health_failure_threshold,
|
||||
"health_disable_duration": settings.outlook_health_disable_duration,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/outlook")
|
||||
async def update_outlook_settings(request: OutlookSettings):
|
||||
"""更新 Outlook 设置"""
|
||||
update_dict = {}
|
||||
|
||||
if request.default_client_id is not None:
|
||||
update_dict["outlook_default_client_id"] = request.default_client_id
|
||||
|
||||
if update_dict:
|
||||
update_settings(**update_dict)
|
||||
|
||||
return {"success": True, "message": "Outlook 设置已更新"}
|
||||
|
||||
170
src/web/routes/websocket.py
Normal file
170
src/web/routes/websocket.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""
|
||||
WebSocket 路由
|
||||
提供实时日志推送和任务状态更新
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
||||
|
||||
from ..task_manager import task_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.websocket("/ws/task/{task_uuid}")
|
||||
async def task_websocket(websocket: WebSocket, task_uuid: str):
|
||||
"""
|
||||
任务日志 WebSocket
|
||||
|
||||
消息格式:
|
||||
- 服务端发送: {"type": "log", "task_uuid": "xxx", "message": "...", "timestamp": "..."}
|
||||
- 服务端发送: {"type": "status", "task_uuid": "xxx", "status": "running|completed|failed|cancelled", ...}
|
||||
- 客户端发送: {"type": "ping"} - 心跳
|
||||
- 客户端发送: {"type": "cancel"} - 取消任务
|
||||
"""
|
||||
await websocket.accept()
|
||||
|
||||
# 注册连接(会记录当前日志数量,避免重复发送历史日志)
|
||||
task_manager.register_websocket(task_uuid, websocket)
|
||||
logger.info(f"WebSocket 连接已建立: {task_uuid}")
|
||||
|
||||
try:
|
||||
# 发送当前状态
|
||||
status = task_manager.get_status(task_uuid)
|
||||
if status:
|
||||
await websocket.send_json({
|
||||
"type": "status",
|
||||
"task_uuid": task_uuid,
|
||||
**status
|
||||
})
|
||||
|
||||
# 发送历史日志(只发送注册时已存在的日志,避免与实时推送重复)
|
||||
history_logs = task_manager.get_unsent_logs(task_uuid, websocket)
|
||||
for log in history_logs:
|
||||
await websocket.send_json({
|
||||
"type": "log",
|
||||
"task_uuid": task_uuid,
|
||||
"message": log
|
||||
})
|
||||
|
||||
# 保持连接,等待客户端消息
|
||||
while True:
|
||||
try:
|
||||
# 使用 wait_for 实现超时,但不是断开连接
|
||||
# 而是发送心跳检测
|
||||
data = await asyncio.wait_for(
|
||||
websocket.receive_json(),
|
||||
timeout=30.0 # 30秒超时
|
||||
)
|
||||
|
||||
# 处理心跳
|
||||
if data.get("type") == "ping":
|
||||
await websocket.send_json({"type": "pong"})
|
||||
|
||||
# 处理取消请求
|
||||
elif data.get("type") == "cancel":
|
||||
task_manager.cancel_task(task_uuid)
|
||||
await websocket.send_json({
|
||||
"type": "status",
|
||||
"task_uuid": task_uuid,
|
||||
"status": "cancelling",
|
||||
"message": "取消请求已提交"
|
||||
})
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# 超时,发送心跳检测
|
||||
try:
|
||||
await websocket.send_json({"type": "ping"})
|
||||
except Exception:
|
||||
# 发送失败,可能是连接断开
|
||||
logger.info(f"WebSocket 心跳检测失败: {task_uuid}")
|
||||
break
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"WebSocket 断开: {task_uuid}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket 错误: {e}")
|
||||
|
||||
finally:
|
||||
task_manager.unregister_websocket(task_uuid, websocket)
|
||||
|
||||
|
||||
@router.websocket("/ws/batch/{batch_id}")
|
||||
async def batch_websocket(websocket: WebSocket, batch_id: str):
|
||||
"""
|
||||
批量任务 WebSocket
|
||||
|
||||
用于批量注册任务的实时状态更新
|
||||
|
||||
消息格式:
|
||||
- 服务端发送: {"type": "log", "batch_id": "xxx", "message": "...", "timestamp": "..."}
|
||||
- 服务端发送: {"type": "status", "batch_id": "xxx", "status": "running|completed|cancelled", ...}
|
||||
- 客户端发送: {"type": "ping"} - 心跳
|
||||
- 客户端发送: {"type": "cancel"} - 取消批量任务
|
||||
"""
|
||||
await websocket.accept()
|
||||
|
||||
# 注册连接(会记录当前日志数量,避免重复发送历史日志)
|
||||
task_manager.register_batch_websocket(batch_id, websocket)
|
||||
logger.info(f"批量任务 WebSocket 连接已建立: {batch_id}")
|
||||
|
||||
try:
|
||||
# 发送当前状态
|
||||
status = task_manager.get_batch_status(batch_id)
|
||||
if status:
|
||||
await websocket.send_json({
|
||||
"type": "status",
|
||||
"batch_id": batch_id,
|
||||
**status
|
||||
})
|
||||
|
||||
# 发送历史日志(只发送注册时已存在的日志,避免与实时推送重复)
|
||||
history_logs = task_manager.get_unsent_batch_logs(batch_id, websocket)
|
||||
for log in history_logs:
|
||||
await websocket.send_json({
|
||||
"type": "log",
|
||||
"batch_id": batch_id,
|
||||
"message": log
|
||||
})
|
||||
|
||||
# 保持连接,等待客户端消息
|
||||
while True:
|
||||
try:
|
||||
data = await asyncio.wait_for(
|
||||
websocket.receive_json(),
|
||||
timeout=30.0
|
||||
)
|
||||
|
||||
# 处理心跳
|
||||
if data.get("type") == "ping":
|
||||
await websocket.send_json({"type": "pong"})
|
||||
|
||||
# 处理取消请求
|
||||
elif data.get("type") == "cancel":
|
||||
task_manager.cancel_batch(batch_id)
|
||||
await websocket.send_json({
|
||||
"type": "status",
|
||||
"batch_id": batch_id,
|
||||
"status": "cancelling",
|
||||
"message": "取消请求已提交"
|
||||
})
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# 超时,发送心跳检测
|
||||
try:
|
||||
await websocket.send_json({"type": "ping"})
|
||||
except Exception:
|
||||
logger.info(f"批量任务 WebSocket 心跳检测失败: {batch_id}")
|
||||
break
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"批量任务 WebSocket 断开: {batch_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"批量任务 WebSocket 错误: {e}")
|
||||
|
||||
finally:
|
||||
task_manager.unregister_batch_websocket(batch_id, websocket)
|
||||
361
src/web/task_manager.py
Normal file
361
src/web/task_manager.py
Normal file
@@ -0,0 +1,361 @@
|
||||
"""
|
||||
任务管理器
|
||||
负责管理后台任务、日志队列和 WebSocket 推送
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Dict, Optional, List, Callable, Any
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 全局线程池
|
||||
_executor = ThreadPoolExecutor(max_workers=4, thread_name_prefix="reg_worker")
|
||||
|
||||
# 任务日志队列 (task_uuid -> list of logs)
|
||||
_log_queues: Dict[str, List[str]] = defaultdict(list)
|
||||
_log_locks: Dict[str, threading.Lock] = defaultdict(threading.Lock)
|
||||
|
||||
# WebSocket 连接管理 (task_uuid -> list of websockets)
|
||||
_ws_connections: Dict[str, List] = defaultdict(list)
|
||||
_ws_lock = threading.Lock()
|
||||
|
||||
# WebSocket 已发送日志索引 (task_uuid -> {websocket: sent_count})
|
||||
_ws_sent_index: Dict[str, Dict] = defaultdict(dict)
|
||||
|
||||
# 任务状态
|
||||
_task_status: Dict[str, dict] = {}
|
||||
|
||||
# 任务取消标志
|
||||
_task_cancelled: Dict[str, bool] = {}
|
||||
|
||||
# 批量任务状态 (batch_id -> dict)
|
||||
_batch_status: Dict[str, dict] = {}
|
||||
_batch_logs: Dict[str, List[str]] = defaultdict(list)
|
||||
_batch_locks: Dict[str, threading.Lock] = defaultdict(threading.Lock)
|
||||
|
||||
|
||||
class TaskManager:
|
||||
"""任务管理器"""
|
||||
|
||||
def __init__(self):
|
||||
self.executor = _executor
|
||||
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
|
||||
def set_loop(self, loop: asyncio.AbstractEventLoop):
|
||||
"""设置事件循环(在 FastAPI 启动时调用)"""
|
||||
self._loop = loop
|
||||
|
||||
def get_loop(self) -> Optional[asyncio.AbstractEventLoop]:
|
||||
"""获取事件循环"""
|
||||
return self._loop
|
||||
|
||||
def is_cancelled(self, task_uuid: str) -> bool:
|
||||
"""检查任务是否已取消"""
|
||||
return _task_cancelled.get(task_uuid, False)
|
||||
|
||||
def cancel_task(self, task_uuid: str):
|
||||
"""取消任务"""
|
||||
_task_cancelled[task_uuid] = True
|
||||
logger.info(f"任务 {task_uuid} 已标记为取消")
|
||||
|
||||
def add_log(self, task_uuid: str, log_message: str):
|
||||
"""添加日志并推送到 WebSocket(线程安全)"""
|
||||
# 先广播到 WebSocket,确保实时推送
|
||||
# 然后再添加到队列,这样 get_unsent_logs 不会获取到这条日志
|
||||
if self._loop and self._loop.is_running():
|
||||
try:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._broadcast_log(task_uuid, log_message),
|
||||
self._loop
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"推送日志到 WebSocket 失败: {e}")
|
||||
|
||||
# 广播后再添加到队列
|
||||
with _log_locks[task_uuid]:
|
||||
_log_queues[task_uuid].append(log_message)
|
||||
|
||||
async def _broadcast_log(self, task_uuid: str, log_message: str):
|
||||
"""广播日志到所有 WebSocket 连接"""
|
||||
with _ws_lock:
|
||||
connections = _ws_connections.get(task_uuid, []).copy()
|
||||
# 注意:不在这里更新 sent_index,因为日志已经通过 add_log 添加到队列
|
||||
# sent_index 应该只在 get_unsent_logs 或发送历史日志时更新
|
||||
# 这样可以避免竞态条件
|
||||
|
||||
for ws in connections:
|
||||
try:
|
||||
await ws.send_json({
|
||||
"type": "log",
|
||||
"task_uuid": task_uuid,
|
||||
"message": log_message,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
})
|
||||
# 发送成功后更新 sent_index
|
||||
with _ws_lock:
|
||||
ws_id = id(ws)
|
||||
if task_uuid in _ws_sent_index and ws_id in _ws_sent_index[task_uuid]:
|
||||
_ws_sent_index[task_uuid][ws_id] += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"WebSocket 发送失败: {e}")
|
||||
|
||||
async def broadcast_status(self, task_uuid: str, status: str, **kwargs):
|
||||
"""广播任务状态更新"""
|
||||
with _ws_lock:
|
||||
connections = _ws_connections.get(task_uuid, []).copy()
|
||||
|
||||
message = {
|
||||
"type": "status",
|
||||
"task_uuid": task_uuid,
|
||||
"status": status,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
**kwargs
|
||||
}
|
||||
|
||||
for ws in connections:
|
||||
try:
|
||||
await ws.send_json(message)
|
||||
except Exception as e:
|
||||
logger.warning(f"WebSocket 发送状态失败: {e}")
|
||||
|
||||
def register_websocket(self, task_uuid: str, websocket):
|
||||
"""注册 WebSocket 连接"""
|
||||
with _ws_lock:
|
||||
if task_uuid not in _ws_connections:
|
||||
_ws_connections[task_uuid] = []
|
||||
# 避免重复注册同一个连接
|
||||
if websocket not in _ws_connections[task_uuid]:
|
||||
_ws_connections[task_uuid].append(websocket)
|
||||
# 记录已发送的日志数量,用于发送历史日志时避免重复
|
||||
with _log_locks[task_uuid]:
|
||||
_ws_sent_index[task_uuid][id(websocket)] = len(_log_queues.get(task_uuid, []))
|
||||
logger.info(f"WebSocket 连接已注册: {task_uuid}")
|
||||
else:
|
||||
logger.warning(f"WebSocket 连接已存在,跳过重复注册: {task_uuid}")
|
||||
|
||||
def get_unsent_logs(self, task_uuid: str, websocket) -> List[str]:
|
||||
"""获取未发送给该 WebSocket 的日志"""
|
||||
with _ws_lock:
|
||||
ws_id = id(websocket)
|
||||
sent_count = _ws_sent_index.get(task_uuid, {}).get(ws_id, 0)
|
||||
|
||||
with _log_locks[task_uuid]:
|
||||
all_logs = _log_queues.get(task_uuid, [])
|
||||
unsent_logs = all_logs[sent_count:]
|
||||
# 更新已发送索引
|
||||
_ws_sent_index[task_uuid][ws_id] = len(all_logs)
|
||||
return unsent_logs
|
||||
|
||||
def unregister_websocket(self, task_uuid: str, websocket):
|
||||
"""注销 WebSocket 连接"""
|
||||
with _ws_lock:
|
||||
if task_uuid in _ws_connections:
|
||||
try:
|
||||
_ws_connections[task_uuid].remove(websocket)
|
||||
except ValueError:
|
||||
pass
|
||||
# 清理已发送索引
|
||||
if task_uuid in _ws_sent_index:
|
||||
_ws_sent_index[task_uuid].pop(id(websocket), None)
|
||||
logger.info(f"WebSocket 连接已注销: {task_uuid}")
|
||||
|
||||
def get_logs(self, task_uuid: str) -> List[str]:
|
||||
"""获取任务的所有日志"""
|
||||
with _log_locks[task_uuid]:
|
||||
return _log_queues.get(task_uuid, []).copy()
|
||||
|
||||
def update_status(self, task_uuid: str, status: str, **kwargs):
|
||||
"""更新任务状态"""
|
||||
if task_uuid not in _task_status:
|
||||
_task_status[task_uuid] = {}
|
||||
|
||||
_task_status[task_uuid]["status"] = status
|
||||
_task_status[task_uuid].update(kwargs)
|
||||
|
||||
def get_status(self, task_uuid: str) -> Optional[dict]:
|
||||
"""获取任务状态"""
|
||||
return _task_status.get(task_uuid)
|
||||
|
||||
def cleanup_task(self, task_uuid: str):
|
||||
"""清理任务数据"""
|
||||
# 保留日志队列一段时间,以便后续查询
|
||||
# 只清理取消标志
|
||||
if task_uuid in _task_cancelled:
|
||||
del _task_cancelled[task_uuid]
|
||||
|
||||
# ============== 批量任务管理 ==============
|
||||
|
||||
def init_batch(self, batch_id: str, total: int):
|
||||
"""初始化批量任务"""
|
||||
_batch_status[batch_id] = {
|
||||
"status": "running",
|
||||
"total": total,
|
||||
"completed": 0,
|
||||
"success": 0,
|
||||
"failed": 0,
|
||||
"skipped": 0,
|
||||
"current_index": 0,
|
||||
"finished": False
|
||||
}
|
||||
logger.info(f"批量任务 {batch_id} 已初始化,总数: {total}")
|
||||
|
||||
def add_batch_log(self, batch_id: str, log_message: str):
|
||||
"""添加批量任务日志并推送"""
|
||||
# 先广播到 WebSocket,确保实时推送
|
||||
if self._loop and self._loop.is_running():
|
||||
try:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._broadcast_batch_log(batch_id, log_message),
|
||||
self._loop
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"推送批量日志到 WebSocket 失败: {e}")
|
||||
|
||||
# 广播后再添加到队列
|
||||
with _batch_locks[batch_id]:
|
||||
_batch_logs[batch_id].append(log_message)
|
||||
|
||||
async def _broadcast_batch_log(self, batch_id: str, log_message: str):
|
||||
"""广播批量任务日志"""
|
||||
key = f"batch_{batch_id}"
|
||||
with _ws_lock:
|
||||
connections = _ws_connections.get(key, []).copy()
|
||||
# 注意:不在这里更新 sent_index,避免竞态条件
|
||||
|
||||
for ws in connections:
|
||||
try:
|
||||
await ws.send_json({
|
||||
"type": "log",
|
||||
"batch_id": batch_id,
|
||||
"message": log_message,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
})
|
||||
# 发送成功后更新 sent_index
|
||||
with _ws_lock:
|
||||
ws_id = id(ws)
|
||||
if key in _ws_sent_index and ws_id in _ws_sent_index[key]:
|
||||
_ws_sent_index[key][ws_id] += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"WebSocket 发送批量日志失败: {e}")
|
||||
|
||||
def update_batch_status(self, batch_id: str, **kwargs):
|
||||
"""更新批量任务状态"""
|
||||
if batch_id not in _batch_status:
|
||||
logger.warning(f"批量任务 {batch_id} 不存在")
|
||||
return
|
||||
|
||||
_batch_status[batch_id].update(kwargs)
|
||||
|
||||
# 异步广播状态更新
|
||||
if self._loop and self._loop.is_running():
|
||||
try:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._broadcast_batch_status(batch_id),
|
||||
self._loop
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"广播批量状态失败: {e}")
|
||||
|
||||
async def _broadcast_batch_status(self, batch_id: str):
|
||||
"""广播批量任务状态"""
|
||||
with _ws_lock:
|
||||
connections = _ws_connections.get(f"batch_{batch_id}", []).copy()
|
||||
|
||||
status = _batch_status.get(batch_id, {})
|
||||
|
||||
for ws in connections:
|
||||
try:
|
||||
await ws.send_json({
|
||||
"type": "status",
|
||||
"batch_id": batch_id,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
**status
|
||||
})
|
||||
except Exception as e:
|
||||
logger.warning(f"WebSocket 发送批量状态失败: {e}")
|
||||
|
||||
def get_batch_status(self, batch_id: str) -> Optional[dict]:
|
||||
"""获取批量任务状态"""
|
||||
return _batch_status.get(batch_id)
|
||||
|
||||
def get_batch_logs(self, batch_id: str) -> List[str]:
|
||||
"""获取批量任务日志"""
|
||||
with _batch_locks[batch_id]:
|
||||
return _batch_logs.get(batch_id, []).copy()
|
||||
|
||||
def is_batch_cancelled(self, batch_id: str) -> bool:
|
||||
"""检查批量任务是否已取消"""
|
||||
status = _batch_status.get(batch_id, {})
|
||||
return status.get("cancelled", False)
|
||||
|
||||
def cancel_batch(self, batch_id: str):
|
||||
"""取消批量任务"""
|
||||
if batch_id in _batch_status:
|
||||
_batch_status[batch_id]["cancelled"] = True
|
||||
_batch_status[batch_id]["status"] = "cancelling"
|
||||
logger.info(f"批量任务 {batch_id} 已标记为取消")
|
||||
|
||||
def register_batch_websocket(self, batch_id: str, websocket):
|
||||
"""注册批量任务 WebSocket 连接"""
|
||||
key = f"batch_{batch_id}"
|
||||
with _ws_lock:
|
||||
if key not in _ws_connections:
|
||||
_ws_connections[key] = []
|
||||
# 避免重复注册同一个连接
|
||||
if websocket not in _ws_connections[key]:
|
||||
_ws_connections[key].append(websocket)
|
||||
# 记录已发送的日志数量,用于发送历史日志时避免重复
|
||||
with _batch_locks[batch_id]:
|
||||
_ws_sent_index[key][id(websocket)] = len(_batch_logs.get(batch_id, []))
|
||||
logger.info(f"批量任务 WebSocket 连接已注册: {batch_id}")
|
||||
else:
|
||||
logger.warning(f"批量任务 WebSocket 连接已存在,跳过重复注册: {batch_id}")
|
||||
|
||||
def get_unsent_batch_logs(self, batch_id: str, websocket) -> List[str]:
|
||||
"""获取未发送给该 WebSocket 的批量任务日志"""
|
||||
key = f"batch_{batch_id}"
|
||||
with _ws_lock:
|
||||
ws_id = id(websocket)
|
||||
sent_count = _ws_sent_index.get(key, {}).get(ws_id, 0)
|
||||
|
||||
with _batch_locks[batch_id]:
|
||||
all_logs = _batch_logs.get(batch_id, [])
|
||||
unsent_logs = all_logs[sent_count:]
|
||||
# 更新已发送索引
|
||||
_ws_sent_index[key][ws_id] = len(all_logs)
|
||||
return unsent_logs
|
||||
|
||||
def unregister_batch_websocket(self, batch_id: str, websocket):
|
||||
"""注销批量任务 WebSocket 连接"""
|
||||
key = f"batch_{batch_id}"
|
||||
with _ws_lock:
|
||||
if key in _ws_connections:
|
||||
try:
|
||||
_ws_connections[key].remove(websocket)
|
||||
except ValueError:
|
||||
pass
|
||||
# 清理已发送索引
|
||||
if key in _ws_sent_index:
|
||||
_ws_sent_index[key].pop(id(websocket), None)
|
||||
logger.info(f"批量任务 WebSocket 连接已注销: {batch_id}")
|
||||
|
||||
def create_log_callback(self, task_uuid: str) -> Callable[[str], None]:
|
||||
"""创建日志回调函数"""
|
||||
def callback(msg: str):
|
||||
self.add_log(task_uuid, msg)
|
||||
return callback
|
||||
|
||||
def create_check_cancelled_callback(self, task_uuid: str) -> Callable[[], bool]:
|
||||
"""创建检查取消的回调函数"""
|
||||
def callback() -> bool:
|
||||
return self.is_cancelled(task_uuid)
|
||||
return callback
|
||||
|
||||
|
||||
# 全局实例
|
||||
task_manager = TaskManager()
|
||||
Reference in New Issue
Block a user