From 3d8a90cda9a1f74daa41c36beadb1650c33dcd10 Mon Sep 17 00:00:00 2001 From: cnlimiter Date: Sun, 15 Mar 2026 03:52:24 +0800 Subject: [PATCH] =?UTF-8?q?feat(webui):=20=E6=B7=BB=E5=8A=A0WebSocket?= =?UTF-8?q?=E6=94=AF=E6=8C=81=E5=AE=9E=E7=8E=B0=E5=AE=9E=E6=97=B6=E4=BB=BB?= =?UTF-8?q?=E5=8A=A1=E7=8A=B6=E6=80=81=E6=9B=B4=E6=96=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 在注册任务和批量任务中集成WebSocket连接 - 添加TaskManager管理任务状态和日志推送 - 前端app.js重构支持WebSocket与轮询降级机制 - 配置模块重构为完全基于数据库存储 --- cli.py | 2 +- src/config/__init__.py | 19 +- src/config/constants.py | 2 - src/config/settings.py | 560 +++++++++++++++++++++++++++++---- src/database/init_db.py | 65 +--- src/services/outlook.py | 27 +- src/web/app.py | 11 + src/web/routes/registration.py | 203 +++++++++--- src/web/routes/settings.py | 52 +-- static/js/app.js | 317 +++++++++++++++++-- webui.py | 18 +- 11 files changed, 1028 insertions(+), 248 deletions(-) diff --git a/cli.py b/cli.py index 106950a..c58bfa9 100644 --- a/cli.py +++ b/cli.py @@ -139,7 +139,7 @@ def main() -> None: sleep_max = max(sleep_min, args.sleep_max) count = 0 - print("[Info] Yasal's Seamless OpenAI Auto-Registrar Started for ZJH (重构版本)") + print("[Info] OpenAI Auto-Registrar") while True: count += 1 diff --git a/src/config/__init__.py b/src/config/__init__.py index 994ab9c..da2f93b 100644 --- a/src/config/__init__.py +++ b/src/config/__init__.py @@ -2,7 +2,18 @@ 配置模块 """ -from .settings import Settings, get_settings, update_settings, get_database_url +from .settings import ( + Settings, + get_settings, + update_settings, + get_database_url, + init_default_settings, + get_setting_definition, + get_all_setting_definitions, + SETTING_DEFINITIONS, + SettingCategory, + SettingDefinition, +) from .constants import ( AccountStatus, TaskStatus, @@ -22,6 +33,12 @@ __all__ = [ 'get_settings', 'update_settings', 'get_database_url', + 'init_default_settings', + 'get_setting_definition', + 'get_all_setting_definitions', + 'SETTING_DEFINITIONS', + 'SettingCategory', + 'SettingDefinition', 'AccountStatus', 'TaskStatus', 'EmailServiceType', diff --git a/src/config/constants.py b/src/config/constants.py index 8f97e16..ab795a4 100644 --- a/src/config/constants.py +++ b/src/config/constants.py @@ -114,8 +114,6 @@ EMAIL_SERVICE_DEFAULTS = { # 验证码相关 OTP_CODE_PATTERN = r"(? .env 文件 > 默认值 - """ +class SettingCategory(str, Enum): + """设置分类""" + GENERAL = "general" + DATABASE = "database" + WEBUI = "webui" + LOG = "log" + OPENAI = "openai" + PROXY = "proxy" + REGISTRATION = "registration" + EMAIL = "email" + TEMPMAIL = "tempmail" + CUSTOM_DOMAIN = "custom_domain" + SECURITY = "security" + CPA = "cpa" - model_config = SettingsConfigDict( - env_file=".env", - env_file_encoding="utf-8", - case_sensitive=False, - extra="ignore", - ) +@dataclass +class SettingDefinition: + """设置定义""" + db_key: str + default_value: Any + category: SettingCategory + description: str = "" + is_secret: bool = False + + +# 所有配置项定义(包含数据库键名、默认值、分类、描述) +SETTING_DEFINITIONS: Dict[str, SettingDefinition] = { # 应用信息 - app_name: str = Field(default="OpenAI/Codex CLI 自动注册系统") - app_version: str = Field(default="2.0.0") - debug: bool = Field(default=False) + "app_name": SettingDefinition( + db_key="app.name", + default_value="OpenAI/Codex CLI 自动注册系统", + category=SettingCategory.GENERAL, + description="应用名称" + ), + "app_version": SettingDefinition( + db_key="app.version", + default_value="2.0.0", + category=SettingCategory.GENERAL, + description="应用版本" + ), + "debug": SettingDefinition( + db_key="app.debug", + default_value=False, + category=SettingCategory.GENERAL, + description="调试模式" + ), # 数据库配置 - database_url: str = Field( - default=os.path.join( - os.path.dirname(os.path.dirname(os.path.dirname(__file__))), - 'data', - 'database.db' - ) - ) + "database_url": SettingDefinition( + db_key="database.url", + default_value="data/database.db", + category=SettingCategory.DATABASE, + description="数据库路径或连接字符串" + ), + + # Web UI 配置 + "webui_host": SettingDefinition( + db_key="webui.host", + default_value="0.0.0.0", + category=SettingCategory.WEBUI, + description="Web UI 监听地址" + ), + "webui_port": SettingDefinition( + db_key="webui.port", + default_value=8000, + category=SettingCategory.WEBUI, + description="Web UI 监听端口" + ), + "webui_secret_key": SettingDefinition( + db_key="webui.secret_key", + default_value="your-secret-key-change-in-production", + category=SettingCategory.WEBUI, + description="Web UI 密钥", + is_secret=True + ), + + # 日志配置 + "log_level": SettingDefinition( + db_key="log.level", + default_value="INFO", + category=SettingCategory.LOG, + description="日志级别" + ), + "log_file": SettingDefinition( + db_key="log.file", + default_value="logs/app.log", + category=SettingCategory.LOG, + description="日志文件路径" + ), + "log_retention_days": SettingDefinition( + db_key="log.retention_days", + default_value=30, + category=SettingCategory.LOG, + description="日志保留天数" + ), + + # OpenAI 配置 + "openai_client_id": SettingDefinition( + db_key="openai.client_id", + default_value="app_EMoamEEZ73f0CkXaXp7hrann", + category=SettingCategory.OPENAI, + description="OpenAI OAuth 客户端 ID" + ), + "openai_auth_url": SettingDefinition( + db_key="openai.auth_url", + default_value="https://auth.openai.com/oauth/authorize", + category=SettingCategory.OPENAI, + description="OpenAI OAuth 授权 URL" + ), + "openai_token_url": SettingDefinition( + db_key="openai.token_url", + default_value="https://auth.openai.com/oauth/token", + category=SettingCategory.OPENAI, + description="OpenAI OAuth Token URL" + ), + "openai_redirect_uri": SettingDefinition( + db_key="openai.redirect_uri", + default_value="http://localhost:1455/auth/callback", + category=SettingCategory.OPENAI, + description="OpenAI OAuth 回调 URI" + ), + "openai_scope": SettingDefinition( + db_key="openai.scope", + default_value="openid email profile offline_access", + category=SettingCategory.OPENAI, + description="OpenAI OAuth 权限范围" + ), + + # 代理配置 + "proxy_enabled": SettingDefinition( + db_key="proxy.enabled", + default_value=False, + category=SettingCategory.PROXY, + description="是否启用代理" + ), + "proxy_type": SettingDefinition( + db_key="proxy.type", + default_value="http", + category=SettingCategory.PROXY, + description="代理类型 (http/socks5)" + ), + "proxy_host": SettingDefinition( + db_key="proxy.host", + default_value="127.0.0.1", + category=SettingCategory.PROXY, + description="代理服务器地址" + ), + "proxy_port": SettingDefinition( + db_key="proxy.port", + default_value=7890, + category=SettingCategory.PROXY, + description="代理服务器端口" + ), + "proxy_username": SettingDefinition( + db_key="proxy.username", + default_value="", + category=SettingCategory.PROXY, + description="代理用户名" + ), + "proxy_password": SettingDefinition( + db_key="proxy.password", + default_value="", + category=SettingCategory.PROXY, + description="代理密码", + is_secret=True + ), + + # 注册配置 + "registration_max_retries": SettingDefinition( + db_key="registration.max_retries", + default_value=3, + category=SettingCategory.REGISTRATION, + description="注册最大重试次数" + ), + "registration_timeout": SettingDefinition( + db_key="registration.timeout", + default_value=120, + category=SettingCategory.REGISTRATION, + description="注册超时时间(秒)" + ), + "registration_default_password_length": SettingDefinition( + db_key="registration.default_password_length", + default_value=12, + category=SettingCategory.REGISTRATION, + description="默认密码长度" + ), + "registration_sleep_min": SettingDefinition( + db_key="registration.sleep_min", + default_value=5, + category=SettingCategory.REGISTRATION, + description="注册间隔最小值(秒)" + ), + "registration_sleep_max": SettingDefinition( + db_key="registration.sleep_max", + default_value=30, + category=SettingCategory.REGISTRATION, + description="注册间隔最大值(秒)" + ), + + # 邮箱服务配置 + "email_service_priority": SettingDefinition( + db_key="email.service_priority", + default_value={"tempmail": 0, "outlook": 1, "custom_domain": 2}, + category=SettingCategory.EMAIL, + description="邮箱服务优先级" + ), + + # Tempmail.lol 配置 + "tempmail_base_url": SettingDefinition( + db_key="tempmail.base_url", + default_value="https://api.tempmail.lol/v2", + category=SettingCategory.TEMPMAIL, + description="Tempmail API 地址" + ), + "tempmail_timeout": SettingDefinition( + db_key="tempmail.timeout", + default_value=30, + category=SettingCategory.TEMPMAIL, + description="Tempmail 超时时间(秒)" + ), + "tempmail_max_retries": SettingDefinition( + db_key="tempmail.max_retries", + default_value=3, + category=SettingCategory.TEMPMAIL, + description="Tempmail 最大重试次数" + ), + + # 自定义域名邮箱配置 + "custom_domain_base_url": SettingDefinition( + db_key="custom_domain.base_url", + default_value="", + category=SettingCategory.CUSTOM_DOMAIN, + description="自定义域名 API 地址" + ), + "custom_domain_api_key": SettingDefinition( + db_key="custom_domain.api_key", + default_value="", + category=SettingCategory.CUSTOM_DOMAIN, + description="自定义域名 API 密钥", + is_secret=True + ), + + # 安全配置 + "encryption_key": SettingDefinition( + db_key="security.encryption_key", + default_value="your-encryption-key-change-in-production", + category=SettingCategory.SECURITY, + description="加密密钥", + is_secret=True + ), + + # CPA 上传配置 + "cpa_enabled": SettingDefinition( + db_key="cpa.enabled", + default_value=False, + category=SettingCategory.CPA, + description="是否启用 CPA 上传" + ), + "cpa_api_url": SettingDefinition( + db_key="cpa.api_url", + default_value="", + category=SettingCategory.CPA, + description="CPA API 地址" + ), + "cpa_api_token": SettingDefinition( + db_key="cpa.api_token", + default_value="", + category=SettingCategory.CPA, + description="CPA API Token", + is_secret=True + ), + + # 验证码配置 + "email_code_timeout": SettingDefinition( + db_key="email_code.timeout", + default_value=120, + category=SettingCategory.EMAIL, + description="验证码等待超时时间(秒)" + ), + "email_code_poll_interval": SettingDefinition( + db_key="email_code.poll_interval", + default_value=3, + category=SettingCategory.EMAIL, + description="验证码轮询间隔(秒)" + ), +} + +# 属性名到数据库键名的映射(用于向后兼容) +DB_SETTING_KEYS = {name: defn.db_key for name, defn in SETTING_DEFINITIONS.items()} + +# 类型定义映射 +SETTING_TYPES: Dict[str, Type] = { + "debug": bool, + "webui_port": int, + "log_retention_days": int, + "proxy_enabled": bool, + "proxy_port": int, + "registration_max_retries": int, + "registration_timeout": int, + "registration_default_password_length": int, + "registration_sleep_min": int, + "registration_sleep_max": int, + "email_service_priority": dict, + "tempmail_timeout": int, + "tempmail_max_retries": int, + "cpa_enabled": bool, + "email_code_timeout": int, + "email_code_poll_interval": int, +} + +# 需要作为 SecretStr 处理的字段 +SECRET_FIELDS = {name for name, defn in SETTING_DEFINITIONS.items() if defn.is_secret} + + +def _convert_value(attr_name: str, value: str) -> Any: + """将数据库字符串值转换为正确的类型""" + if attr_name in SECRET_FIELDS: + return SecretStr(value) if value else SecretStr("") + + target_type = SETTING_TYPES.get(attr_name, str) + + if target_type == bool: + if isinstance(value, bool): + return value + return str(value).lower() in ("true", "1", "yes", "on") + elif target_type == int: + if isinstance(value, int): + return value + return int(value) if value else 0 + elif target_type == dict: + if isinstance(value, dict): + return value + import json + return json.loads(value) if value else {} + else: + return value + + +def _value_to_string(value: Any) -> str: + """将值转换为数据库存储的字符串""" + if isinstance(value, SecretStr): + return value.get_secret_value() + elif isinstance(value, bool): + return "true" if value else "false" + elif isinstance(value, dict): + import json + return json.dumps(value) + elif value is None: + return "" + else: + return str(value) + + +def init_default_settings() -> None: + """ + 初始化数据库中的默认设置 + 如果设置项不存在,则创建并设置默认值 + """ + try: + from ..database.session import get_db + from ..database.crud import get_setting, set_setting + + with get_db() as db: + for attr_name, defn in SETTING_DEFINITIONS.items(): + existing = get_setting(db, defn.db_key) + if not existing: + default_value = _value_to_string(defn.default_value) + set_setting( + db, + defn.db_key, + default_value, + category=defn.category.value, + description=defn.description + ) + print(f"[Settings] 初始化默认设置: {defn.db_key} = {default_value if not defn.is_secret else '***'}") + except Exception as e: + print(f"[Settings] 初始化默认设置失败: {e}") + + +def _load_settings_from_db() -> Dict[str, Any]: + """从数据库加载所有设置""" + try: + from ..database.session import get_db + from ..database.crud import get_setting + + settings_dict = {} + with get_db() as db: + for attr_name, defn in SETTING_DEFINITIONS.items(): + db_setting = get_setting(db, defn.db_key) + if db_setting: + settings_dict[attr_name] = _convert_value(attr_name, db_setting.value) + else: + # 数据库中没有此设置,使用默认值 + settings_dict[attr_name] = _convert_value(attr_name, _value_to_string(defn.default_value)) + return settings_dict + except Exception as e: + print(f"[Settings] 从数据库加载设置失败: {e},使用默认值") + return {name: defn.default_value for name, defn in SETTING_DEFINITIONS.items()} + + +def _save_settings_to_db(**kwargs) -> None: + """保存设置到数据库""" + try: + from ..database.session import get_db + from ..database.crud import set_setting + + with get_db() as db: + for attr_name, value in kwargs.items(): + if attr_name in SETTING_DEFINITIONS: + defn = SETTING_DEFINITIONS[attr_name] + str_value = _value_to_string(value) + set_setting( + db, + defn.db_key, + str_value, + category=defn.category.value, + description=defn.description + ) + except Exception as e: + print(f"[Settings] 保存设置到数据库失败: {e}") + + +class Settings(BaseModel): + """ + 应用配置 - 完全基于数据库存储 + """ + + # 应用信息 + app_name: str = "OpenAI/Codex CLI 自动注册系统" + app_version: str = "2.0.0" + debug: bool = False + + # 数据库配置 + database_url: str = "data/database.db" @field_validator('database_url', mode='before') @classmethod @@ -48,31 +459,29 @@ class Settings(BaseSettings): return v # Web UI 配置 - webui_host: str = Field(default="0.0.0.0") - webui_port: int = Field(default=8000) - webui_secret_key: SecretStr = Field( - default=SecretStr("your-secret-key-change-in-production") - ) + webui_host: str = "0.0.0.0" + webui_port: int = 8000 + webui_secret_key: SecretStr = SecretStr("your-secret-key-change-in-production") # 日志配置 - log_level: str = Field(default="INFO") - log_file: str = Field(default="logs/app.log") - log_retention_days: int = Field(default=30) + log_level: str = "INFO" + log_file: str = "logs/app.log" + log_retention_days: int = 30 # OpenAI 配置 - openai_client_id: str = Field(default="app_EMoamEEZ73f0CkXaXp7hrann") - openai_auth_url: str = Field(default="https://auth.openai.com/oauth/authorize") - openai_token_url: str = Field(default="https://auth.openai.com/oauth/token") - openai_redirect_uri: str = Field(default="http://localhost:1455/auth/callback") - openai_scope: str = Field(default="openid email profile offline_access") + openai_client_id: str = "app_EMoamEEZ73f0CkXaXp7hrann" + openai_auth_url: str = "https://auth.openai.com/oauth/authorize" + openai_token_url: str = "https://auth.openai.com/oauth/token" + openai_redirect_uri: str = "http://localhost:1455/auth/callback" + openai_scope: str = "openid email profile offline_access" # 代理配置 - proxy_enabled: bool = Field(default=False) - proxy_type: str = Field(default="http") # http, socks5 - proxy_host: str = Field(default="127.0.0.1") - proxy_port: int = Field(default=7890) - proxy_username: Optional[str] = Field(default=None) - proxy_password: Optional[SecretStr] = Field(default=None) + proxy_enabled: bool = False + proxy_type: str = "http" + proxy_host: str = "127.0.0.1" + proxy_port: int = 7890 + proxy_username: Optional[str] = None + proxy_password: Optional[SecretStr] = None @property def proxy_url(self) -> Optional[str]: @@ -94,35 +503,35 @@ class Settings(BaseSettings): return f"{scheme}://{auth}{self.proxy_host}:{self.proxy_port}" # 注册配置 - registration_max_retries: int = Field(default=3) - registration_timeout: int = Field(default=120) # 秒 - registration_default_password_length: int = Field(default=12) - registration_sleep_min: int = Field(default=5) - registration_sleep_max: int = Field(default=30) + registration_max_retries: int = 3 + registration_timeout: int = 120 + registration_default_password_length: int = 12 + registration_sleep_min: int = 5 + registration_sleep_max: int = 30 # 邮箱服务配置 - email_service_priority: Dict[str, int] = Field( - default={"tempmail": 0, "outlook": 1, "custom_domain": 2} - ) + email_service_priority: Dict[str, int] = {"tempmail": 0, "outlook": 1, "custom_domain": 2} # Tempmail.lol 配置 - tempmail_base_url: str = Field(default="https://api.tempmail.lol/v2") - tempmail_timeout: int = Field(default=30) - tempmail_max_retries: int = Field(default=3) + tempmail_base_url: str = "https://api.tempmail.lol/v2" + tempmail_timeout: int = 30 + tempmail_max_retries: int = 3 # 自定义域名邮箱配置 - custom_domain_base_url: str = Field(default="") - custom_domain_api_key: Optional[SecretStr] = Field(default=None) + custom_domain_base_url: str = "" + custom_domain_api_key: Optional[SecretStr] = None # 安全配置 - encryption_key: SecretStr = Field( - default=SecretStr("your-encryption-key-change-in-production") - ) + encryption_key: SecretStr = SecretStr("your-encryption-key-change-in-production") # CPA 上传配置 - cpa_enabled: bool = Field(default=False) - cpa_api_url: str = Field(default="") # 例如: https://cpa.example.com - cpa_api_token: SecretStr = Field(default=SecretStr("")) + cpa_enabled: bool = False + cpa_api_url: str = "" + cpa_api_token: SecretStr = SecretStr("") + + # 验证码配置 + email_code_timeout: int = 120 + email_code_poll_interval: int = 3 # 全局配置实例 @@ -132,25 +541,34 @@ _settings: Optional[Settings] = None def get_settings() -> Settings: """ 获取全局配置实例(单例模式) + 完全从数据库加载配置 """ global _settings if _settings is None: - _settings = Settings() + # 先初始化默认设置(如果数据库中没有的话) + init_default_settings() + # 从数据库加载所有设置 + settings_dict = _load_settings_from_db() + _settings = Settings(**settings_dict) return _settings def update_settings(**kwargs) -> Settings: """ - 更新配置(用于测试或运行时配置更改) + 更新配置并保存到数据库 """ global _settings if _settings is None: - _settings = Settings() + _settings = get_settings() # 创建新的配置实例 updated_data = _settings.model_dump() updated_data.update(kwargs) _settings = Settings(**updated_data) + + # 保存到数据库 + _save_settings_to_db(**kwargs) + return _settings @@ -171,3 +589,13 @@ def get_database_url() -> str: return f"sqlite:///{abs_path}" return url + + +def get_setting_definition(attr_name: str) -> Optional[SettingDefinition]: + """获取设置项的定义信息""" + return SETTING_DEFINITIONS.get(attr_name) + + +def get_all_setting_definitions() -> Dict[str, SettingDefinition]: + """获取所有设置项的定义""" + return SETTING_DEFINITIONS.copy() diff --git a/src/database/init_db.py b/src/database/init_db.py index 38f9601..58ea4b0 100644 --- a/src/database/init_db.py +++ b/src/database/init_db.py @@ -2,55 +2,10 @@ 数据库初始化和初始化数据 """ -import json -from datetime import datetime from .session import init_database -from .crud import set_setting from .models import Base -def init_default_settings(db): - """初始化默认设置""" - # 通用设置 - default_settings = [ - ("system.name", "OpenAI/Codex CLI 自动注册系统", "系统名称", "general"), - ("system.version", "2.0.0", "系统版本", "general"), - ("logs.retention_days", "30", "日志保留天数", "general"), - - # OpenAI 配置 - ("openai.client_id", "app_EMoamEEZ73f0CkXaXp7hrann", "OpenAI OAuth Client ID", "openai"), - ("openai.auth_url", "https://auth.openai.com/oauth/authorize", "OpenAI 认证地址", "openai"), - ("openai.token_url", "https://auth.openai.com/oauth/token", "OpenAI Token 地址", "openai"), - ("openai.redirect_uri", "http://localhost:1455/auth/callback", "OpenAI 回调地址", "openai"), - ("openai.scope", "openid email profile offline_access", "OpenAI 权限范围", "openai"), - - # 代理设置 - ("proxy.enabled", "false", "是否启用代理", "proxy"), - ("proxy.type", "http", "代理类型 (http/socks5)", "proxy"), - ("proxy.host", "127.0.0.1", "代理主机", "proxy"), - ("proxy.port", "7890", "代理端口", "proxy"), - - # 注册设置 - ("registration.max_retries", "3", "最大重试次数", "registration"), - ("registration.timeout", "120", "超时时间(秒)", "registration"), - ("registration.default_password_length", "12", "默认密码长度", "registration"), - - # Web UI 设置 - ("webui.host", "0.0.0.0", "Web UI 监听主机", "webui"), - ("webui.port", "8000", "Web UI 监听端口", "webui"), - ("webui.debug", "true", "调试模式", "webui"), - ] - - for key, value, description, category in default_settings: - set_setting(db, key, value, description, category) - - -def init_default_email_services(db): - """初始化默认邮箱服务(仅模板,需要用户配置)""" - # 这里只创建模板配置,实际配置需要用户通过 Web UI 设置 - pass - - def initialize_database(database_url: str = None): """ 初始化数据库 @@ -59,15 +14,13 @@ def initialize_database(database_url: str = None): # 初始化数据库连接和表 db_manager = init_database(database_url) - # 在事务中设置默认配置 - with db_manager.session_scope() as session: - # 初始化默认设置 - init_default_settings(session) + # 创建表 + db_manager.create_tables() - # 初始化默认邮箱服务 - init_default_email_services(session) + # 初始化默认设置(从 settings 模块导入以避免循环导入) + from ..config.settings import init_default_settings + init_default_settings() - print("数据库初始化完成") return db_manager @@ -86,9 +39,9 @@ def reset_database(database_url: str = None): db_manager.create_tables() print("已重新创建所有表") - # 初始化数据 - with db_manager.session_scope() as session: - init_default_settings(session) + # 初始化默认设置 + from ..config.settings import init_default_settings + init_default_settings() print("数据库重置完成") return db_manager @@ -130,4 +83,4 @@ if __name__ == "__main__": else: print("操作已取消") else: - initialize_database(args.url) \ No newline at end of file + initialize_database(args.url) diff --git a/src/services/outlook.py b/src/services/outlook.py index 51c5571..3fd6a7d 100644 --- a/src/services/outlook.py +++ b/src/services/outlook.py @@ -27,35 +27,22 @@ from ..config.constants import ( OTP_CODE_SEMANTIC_PATTERN, OPENAI_EMAIL_SENDERS, OPENAI_VERIFICATION_KEYWORDS, - OTP_WAIT_TIMEOUT, - OTP_POLL_INTERVAL, ) -from ..database import crud -from ..database.session import get_db +from ..config.settings import get_settings def get_email_code_settings() -> dict: """ - 从数据库获取验证码等待配置 + 获取验证码等待配置 Returns: dict: 包含 timeout 和 poll_interval 的字典 """ - try: - with get_db() as db: - timeout_setting = crud.get_setting(db, "email_code.timeout") - poll_interval_setting = crud.get_setting(db, "email_code.poll_interval") - - return { - "timeout": int(timeout_setting.value) if timeout_setting else OTP_WAIT_TIMEOUT, - "poll_interval": int(poll_interval_setting.value) if poll_interval_setting else OTP_POLL_INTERVAL, - } - except Exception as e: - logger.warning(f"获取验证码配置失败,使用默认值: {e}") - return { - "timeout": OTP_WAIT_TIMEOUT, - "poll_interval": OTP_POLL_INTERVAL, - } + settings = get_settings() + return { + "timeout": settings.email_code_timeout, + "poll_interval": settings.email_code_poll_interval, + } logger = logging.getLogger(__name__) diff --git a/src/web/app.py b/src/web/app.py index b99e6bd..bfed89e 100644 --- a/src/web/app.py +++ b/src/web/app.py @@ -15,6 +15,8 @@ from fastapi.responses import HTMLResponse from ..config.settings import get_settings from .routes import api_router +from .routes.websocket import router as ws_router +from .task_manager import task_manager logger = logging.getLogger(__name__) @@ -65,6 +67,9 @@ def create_app() -> FastAPI: # 注册 API 路由 app.include_router(api_router, prefix="/api") + # 注册 WebSocket 路由 + app.include_router(ws_router, prefix="/api") + # 模板引擎 templates = Jinja2Templates(directory=str(TEMPLATES_DIR)) @@ -91,6 +96,12 @@ def create_app() -> FastAPI: @app.on_event("startup") async def startup_event(): """应用启动事件""" + import asyncio + + # 设置 TaskManager 的事件循环 + loop = asyncio.get_event_loop() + task_manager.set_loop(loop) + logger.info("=" * 50) logger.info(f"{settings.app_name} v{settings.app_version} 启动中...") logger.info(f"调试模式: {settings.debug}") diff --git a/src/web/routes/registration.py b/src/web/routes/registration.py index ff6ced7..0e7e3d0 100644 --- a/src/web/routes/registration.py +++ b/src/web/routes/registration.py @@ -18,6 +18,7 @@ from ...database.models import RegistrationTask, Proxy from ...core.register import RegistrationEngine, RegistrationResult from ...services import EmailServiceFactory, EmailServiceType from ...config.settings import get_settings +from ..task_manager import task_manager logger = logging.getLogger(__name__) router = APIRouter() @@ -169,10 +170,19 @@ def task_to_response(task: RegistrationTask) -> RegistrationTaskResponse: ) -async def run_registration_task(task_uuid: str, email_service_type: str, proxy: Optional[str], email_service_config: Optional[dict], email_service_id: Optional[int] = None): - """异步执行注册任务""" +def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy: Optional[str], email_service_config: Optional[dict], email_service_id: Optional[int] = None): + """ + 在线程池中执行的同步注册任务 + + 这个函数会被 run_in_executor 调用,运行在独立线程中 + """ with get_db() as db: try: + # 检查是否已取消 + if task_manager.is_cancelled(task_uuid): + logger.info(f"任务 {task_uuid} 已取消,跳过执行") + return + # 更新任务状态为运行中 task = crud.update_registration_task( db, task_uuid, @@ -184,6 +194,9 @@ async def run_registration_task(task_uuid: str, email_service_type: str, proxy: logger.error(f"任务不存在: {task_uuid}") return + # 更新 TaskManager 状态 + task_manager.update_status(task_uuid, "running") + # 确定使用的代理 # 如果前端传入了代理参数,使用传入的 # 否则从代理列表或系统设置中获取 @@ -284,10 +297,8 @@ async def run_registration_task(task_uuid: str, email_service_type: str, proxy: email_service = EmailServiceFactory.create(service_type, config) - # 创建注册引擎 - def log_callback(msg): - with get_db() as db_inner: - crud.append_task_log(db_inner, task_uuid, msg) + # 创建注册引擎 - 使用 TaskManager 的日志回调 + log_callback = task_manager.create_log_callback(task_uuid) engine = RegistrationEngine( email_service=email_service, @@ -314,6 +325,9 @@ async def run_registration_task(task_uuid: str, email_service_type: str, proxy: result=result.to_dict() ) + # 更新 TaskManager 状态 + task_manager.update_status(task_uuid, "completed", email=result.email) + logger.info(f"注册任务完成: {task_uuid}, 邮箱: {result.email}") else: # 更新任务状态为失败 @@ -324,6 +338,9 @@ async def run_registration_task(task_uuid: str, email_service_type: str, proxy: error_message=result.error_message ) + # 更新 TaskManager 状态 + task_manager.update_status(task_uuid, "failed", error=result.error_message) + logger.warning(f"注册任务失败: {task_uuid}, 原因: {result.error_message}") except Exception as e: @@ -337,10 +354,45 @@ async def run_registration_task(task_uuid: str, email_service_type: str, proxy: completed_at=datetime.utcnow(), error_message=str(e) ) + + # 更新 TaskManager 状态 + task_manager.update_status(task_uuid, "failed", error=str(e)) except: pass +async def run_registration_task(task_uuid: str, email_service_type: str, proxy: Optional[str], email_service_config: Optional[dict], email_service_id: Optional[int] = None): + """ + 异步执行注册任务 + + 使用 run_in_executor 将同步任务放入线程池执行,避免阻塞主事件循环 + """ + loop = task_manager.get_loop() + if loop is None: + loop = asyncio.get_event_loop() + task_manager.set_loop(loop) + + # 初始化 TaskManager 状态 + task_manager.update_status(task_uuid, "pending") + task_manager.add_log(task_uuid, f"[系统] 任务 {task_uuid[:8]} 已加入队列") + + try: + # 在线程池中执行同步任务 + await loop.run_in_executor( + task_manager.executor, + _run_sync_registration_task, + task_uuid, + email_service_type, + proxy, + email_service_config, + email_service_id + ) + except Exception as e: + logger.error(f"线程池执行异常: {task_uuid}, 错误: {e}") + task_manager.add_log(task_uuid, f"[错误] 线程池执行异常: {str(e)}") + task_manager.update_status(task_uuid, "failed", error=str(e)) + + async def run_batch_registration( batch_id: str, task_uuids: List[str], @@ -351,7 +403,11 @@ async def run_batch_registration( interval_min: int, interval_max: int ): - """异步执行批量注册任务""" + """ + 异步执行批量注册任务 + + 使用线程池执行每个注册任务,避免阻塞主事件循环 + """ batch_tasks[batch_id] = { "total": len(task_uuids), "completed": 0, @@ -375,7 +431,7 @@ async def run_batch_registration( batch_tasks[batch_id]["current_index"] = i - # 运行单个注册任务 + # 运行单个注册任务(使用线程池) await run_registration_task( task_uuid, email_service_type, proxy, email_service_config, email_service_id ) @@ -802,7 +858,7 @@ async def get_outlook_accounts_for_registration(): ) -async def run_outlook_batch_registration( +def _run_sync_outlook_batch_registration( batch_id: str, service_ids: List[int], skip_registered: bool, @@ -811,13 +867,15 @@ async def run_outlook_batch_registration( interval_max: int ): """ - 异步执行 Outlook 批量注册任务 - - 遍历选中的 Outlook 服务,检查邮箱是否已注册,执行注册任务 + 在线程池中执行的同步 Outlook 批量注册任务 """ from ...database.models import EmailService as EmailServiceModel from ...database.models import Account + # 初始化 TaskManager 批量任务 + task_manager.init_batch(batch_id, len(service_ids)) + + # 兼容旧版 batch_tasks(用于 REST API 轮询降级) batch_tasks[batch_id] = { "total": len(service_ids), "completed": 0, @@ -830,14 +888,28 @@ async def run_outlook_batch_registration( "logs": [] } + def add_batch_log(msg: str): + """同时添加日志到两个系统""" + batch_tasks[batch_id]["logs"].append(msg) + task_manager.add_batch_log(batch_id, msg) + + def update_batch_status(**kwargs): + """同时更新两个系统的状态""" + for key, value in kwargs.items(): + if key in batch_tasks[batch_id]: + batch_tasks[batch_id][key] = value + task_manager.update_batch_status(batch_id, **kwargs) + try: for i, service_id in enumerate(service_ids): # 检查是否已取消 - if batch_tasks[batch_id]["cancelled"]: + if task_manager.is_batch_cancelled(batch_id): + add_batch_log(f"[取消] 批量任务已取消") + update_batch_status(finished=True, status="cancelled") logger.info(f"Outlook 批量任务 {batch_id} 已取消") break - batch_tasks[batch_id]["current_index"] = i + update_batch_status(current_index=i) with get_db() as db: # 获取邮箱服务 @@ -846,9 +918,9 @@ async def run_outlook_batch_registration( ).first() if not service: - batch_tasks[batch_id]["logs"].append(f"[跳过] 服务 ID {service_id} 不存在") - batch_tasks[batch_id]["skipped"] += 1 - batch_tasks[batch_id]["completed"] += 1 + add_batch_log(f"[跳过] 服务 ID {service_id} 不存在") + update_batch_status(skipped=batch_tasks[batch_id]["skipped"] + 1, + completed=batch_tasks[batch_id]["completed"] + 1) continue config = service.config or {} @@ -861,9 +933,9 @@ async def run_outlook_batch_registration( ).first() if existing_account: - batch_tasks[batch_id]["logs"].append(f"[跳过] {email} 已注册 (账号 ID: {existing_account.id})") - batch_tasks[batch_id]["skipped"] += 1 - batch_tasks[batch_id]["completed"] += 1 + add_batch_log(f"[跳过] {email} 已注册 (账号 ID: {existing_account.id})") + update_batch_status(skipped=batch_tasks[batch_id]["skipped"] + 1, + completed=batch_tasks[batch_id]["completed"] + 1) continue # 创建注册任务 @@ -875,38 +947,80 @@ async def run_outlook_batch_registration( email_service_id=service_id ) - batch_tasks[batch_id]["logs"].append(f"[注册] 开始注册 {email}...") + add_batch_log(f"[注册] 开始注册 {email}...") - # 运行单个注册任务 - await run_registration_task( - task_uuid, "outlook", proxy, None, service_id - ) + # 同步执行注册任务 + _run_sync_registration_task(task_uuid, "outlook", proxy, None, service_id) # 更新统计 with get_db() as db: task = crud.get_registration_task(db, task_uuid) if task: - batch_tasks[batch_id]["completed"] += 1 + new_completed = batch_tasks[batch_id]["completed"] + 1 + new_success = batch_tasks[batch_id]["success"] + new_failed = batch_tasks[batch_id]["failed"] + if task.status == "completed": - batch_tasks[batch_id]["success"] += 1 - batch_tasks[batch_id]["logs"].append(f"[成功] {email} 注册成功") + new_success += 1 + add_batch_log(f"[成功] {email} 注册成功") elif task.status == "failed": - batch_tasks[batch_id]["failed"] += 1 - batch_tasks[batch_id]["logs"].append(f"[失败] {email} 注册失败: {task.error_message}") + new_failed += 1 + add_batch_log(f"[失败] {email} 注册失败: {task.error_message}") + + update_batch_status( + completed=new_completed, + success=new_success, + failed=new_failed + ) # 如果不是最后一个任务,等待随机间隔 - if i < len(service_ids) - 1 and not batch_tasks[batch_id]["cancelled"]: + if i < len(service_ids) - 1 and not task_manager.is_batch_cancelled(batch_id): wait_time = random.randint(interval_min, interval_max) logger.info(f"Outlook 批量任务 {batch_id}: 等待 {wait_time} 秒后继续下一个任务") - await asyncio.sleep(wait_time) + import time + time.sleep(wait_time) - logger.info(f"Outlook 批量任务 {batch_id} 完成: 成功 {batch_tasks[batch_id]['success']}, 失败 {batch_tasks[batch_id]['failed']}, 跳过 {batch_tasks[batch_id]['skipped']}") + # 完成批量任务 + if not task_manager.is_batch_cancelled(batch_id): + add_batch_log(f"[完成] 批量任务完成!成功: {batch_tasks[batch_id]['success']}, 失败: {batch_tasks[batch_id]['failed']}, 跳过: {batch_tasks[batch_id]['skipped']}") + update_batch_status(finished=True, status="completed") + logger.info(f"Outlook 批量任务 {batch_id} 完成: 成功 {batch_tasks[batch_id]['success']}, 失败 {batch_tasks[batch_id]['failed']}, 跳过 {batch_tasks[batch_id]['skipped']}") except Exception as e: logger.error(f"Outlook 批量任务 {batch_id} 异常: {e}") - batch_tasks[batch_id]["logs"].append(f"[错误] 批量任务异常: {str(e)}") - finally: - batch_tasks[batch_id]["finished"] = True + add_batch_log(f"[错误] 批量任务异常: {str(e)}") + update_batch_status(finished=True, status="failed") + + +async def run_outlook_batch_registration( + batch_id: str, + service_ids: List[int], + skip_registered: bool, + proxy: Optional[str], + interval_min: int, + interval_max: int +): + """ + 异步执行 Outlook 批量注册任务 + + 使用线程池执行,避免阻塞主事件循环 + """ + loop = task_manager.get_loop() + if loop is None: + loop = asyncio.get_event_loop() + task_manager.set_loop(loop) + + # 在线程池中执行 + await loop.run_in_executor( + task_manager.executor, + _run_sync_outlook_batch_registration, + batch_id, + service_ids, + skip_registered, + proxy, + interval_min, + interval_max + ) @router.post("/outlook-batch", response_model=OutlookBatchRegistrationResponse) @@ -1027,3 +1141,20 @@ async def get_outlook_batch_status(batch_id: str): "logs": batch.get("logs", []), "progress": f"{batch['completed']}/{batch['total']}" } + + +@router.post("/outlook-batch/{batch_id}/cancel") +async def cancel_outlook_batch(batch_id: str): + """取消 Outlook 批量任务""" + if batch_id not in batch_tasks: + raise HTTPException(status_code=404, detail="批量任务不存在") + + batch = batch_tasks[batch_id] + if batch.get("finished"): + raise HTTPException(status_code=400, detail="批量任务已完成") + + # 同时更新两个系统的取消状态 + batch["cancelled"] = True + task_manager.cancel_batch(batch_id) + + return {"success": True, "message": "批量任务取消请求已提交"} diff --git a/src/web/routes/settings.py b/src/web/routes/settings.py index d378922..3f8f6ea 100644 --- a/src/web/routes/settings.py +++ b/src/web/routes/settings.py @@ -11,7 +11,6 @@ from pydantic import BaseModel from ...database import crud from ...database.session import get_db from ...config.settings import get_settings, update_settings -from ...config.constants import OTP_WAIT_TIMEOUT, OTP_POLL_INTERVAL logger = logging.getLogger(__name__) router = APIRouter() @@ -72,11 +71,6 @@ async def get_all_settings(): """获取所有设置""" settings = get_settings() - # 从数据库获取验证码设置 - with get_db() as db: - timeout_setting = crud.get_setting(db, "email_code.timeout") - poll_interval_setting = crud.get_setting(db, "email_code.poll_interval") - return { "proxy": { "enabled": settings.proxy_enabled, @@ -104,8 +98,8 @@ async def get_all_settings(): "max_retries": settings.tempmail_max_retries, }, "email_code": { - "timeout": int(timeout_setting.value) if timeout_setting else OTP_WAIT_TIMEOUT, - "poll_interval": int(poll_interval_setting.value) if poll_interval_setting else OTP_POLL_INTERVAL, + "timeout": settings.email_code_timeout, + "poll_interval": settings.email_code_poll_interval, }, } @@ -409,40 +403,26 @@ async def update_tempmail_settings(request: TempmailSettings): @router.get("/email-code") async def get_email_code_settings(): """获取验证码等待设置""" - with get_db() as db: - timeout_setting = crud.get_setting(db, "email_code.timeout") - poll_interval_setting = crud.get_setting(db, "email_code.poll_interval") - - return { - "timeout": int(timeout_setting.value) if timeout_setting else OTP_WAIT_TIMEOUT, - "poll_interval": int(poll_interval_setting.value) if poll_interval_setting else OTP_POLL_INTERVAL, - } + settings = get_settings() + return { + "timeout": settings.email_code_timeout, + "poll_interval": settings.email_code_poll_interval, + } @router.post("/email-code") async def update_email_code_settings(request: EmailCodeSettings): """更新验证码等待设置""" - with get_db() as db: - # 验证参数范围 - if request.timeout < 30 or request.timeout > 600: - raise HTTPException(status_code=400, detail="超时时间必须在 30-600 秒之间") - if request.poll_interval < 1 or request.poll_interval > 30: - raise HTTPException(status_code=400, detail="轮询间隔必须在 1-30 秒之间") + # 验证参数范围 + if request.timeout < 30 or request.timeout > 600: + raise HTTPException(status_code=400, detail="超时时间必须在 30-600 秒之间") + if request.poll_interval < 1 or request.poll_interval > 30: + raise HTTPException(status_code=400, detail="轮询间隔必须在 1-30 秒之间") - crud.set_setting( - db, - "email_code.timeout", - str(request.timeout), - description="验证码等待超时(秒)", - category="email" - ) - crud.set_setting( - db, - "email_code.poll_interval", - str(request.poll_interval), - description="验证码轮询间隔(秒)", - category="email" - ) + update_settings( + email_code_timeout=request.timeout, + email_code_poll_interval=request.poll_interval, + ) return {"success": True, "message": "验证码等待设置已更新"} diff --git a/static/js/app.js b/static/js/app.js index 6dcd544..909c8d2 100644 --- a/static/js/app.js +++ b/static/js/app.js @@ -18,6 +18,13 @@ let availableServices = { custom_domain: { available: false, services: [] } }; +// WebSocket 相关变量 +let webSocket = null; +let batchWebSocket = null; // 批量任务 WebSocket +let useWebSocket = true; // 是否使用 WebSocket +let wsHeartbeatInterval = null; // 心跳定时器 +let batchWsHeartbeatInterval = null; // 批量任务心跳定时器 + // DOM 元素 const elements = { form: document.getElementById('registration-form'), @@ -297,8 +304,8 @@ async function handleSingleRegistration(requestData) { showTaskStatus(data); updateTaskStatus('running'); - // 开始轮询日志 - startLogPolling(data.task_uuid); + // 优先使用 WebSocket + connectWebSocket(data.task_uuid); } catch (error) { addLog('error', `[错误] 启动失败: ${error.message}`); @@ -307,6 +314,118 @@ async function handleSingleRegistration(requestData) { } } + +// ============== WebSocket 功能 ============== + +// 连接 WebSocket +function connectWebSocket(taskUuid) { + const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:'; + const wsUrl = `${protocol}//${window.location.host}/api/ws/task/${taskUuid}`; + + try { + webSocket = new WebSocket(wsUrl); + + webSocket.onopen = () => { + console.log('WebSocket 连接成功'); + useWebSocket = true; + // 停止轮询(如果有) + stopLogPolling(); + // 开始心跳 + startWebSocketHeartbeat(); + }; + + webSocket.onmessage = (event) => { + const data = JSON.parse(event.data); + + if (data.type === 'log') { + const logType = getLogType(data.message); + addLog(logType, data.message); + } else if (data.type === 'status') { + updateTaskStatus(data.status); + + // 检查是否完成 + if (['completed', 'failed', 'cancelled', 'cancelling'].includes(data.status)) { + disconnectWebSocket(); + resetButtons(); + + if (data.status === 'completed') { + addLog('success', '[成功] 注册成功!'); + toast.success('注册成功!'); + // 刷新账号列表 + loadRecentAccounts(); + } else if (data.status === 'failed') { + addLog('error', '[错误] 注册失败'); + toast.error('注册失败'); + } else if (data.status === 'cancelled' || data.status === 'cancelling') { + addLog('warning', '[警告] 任务已取消'); + } + } + } else if (data.type === 'pong') { + // 心跳响应,忽略 + } + }; + + webSocket.onclose = (event) => { + console.log('WebSocket 连接关闭:', event.code); + stopWebSocketHeartbeat(); + + // 如果任务仍在运行,切换到轮询 + if (currentTask && ['pending', 'running'].includes(currentTask.status)) { + console.log('切换到轮询模式'); + useWebSocket = false; + startLogPolling(currentTask.task_uuid); + } + }; + + webSocket.onerror = (error) => { + console.error('WebSocket 错误:', error); + // 切换到轮询 + useWebSocket = false; + stopWebSocketHeartbeat(); + startLogPolling(taskUuid); + }; + + } catch (error) { + console.error('WebSocket 连接失败:', error); + useWebSocket = false; + startLogPolling(taskUuid); + } +} + +// 断开 WebSocket +function disconnectWebSocket() { + stopWebSocketHeartbeat(); + if (webSocket) { + webSocket.close(); + webSocket = null; + } +} + +// 开始心跳 +function startWebSocketHeartbeat() { + stopWebSocketHeartbeat(); + wsHeartbeatInterval = setInterval(() => { + if (webSocket && webSocket.readyState === WebSocket.OPEN) { + webSocket.send(JSON.stringify({ type: 'ping' })); + } + }, 25000); // 每 25 秒发送一次心跳 +} + +// 停止心跳 +function stopWebSocketHeartbeat() { + if (wsHeartbeatInterval) { + clearInterval(wsHeartbeatInterval); + wsHeartbeatInterval = null; + } +} + +// 发送取消请求 +function cancelViaWebSocket() { + if (webSocket && webSocket.readyState === WebSocket.OPEN) { + webSocket.send(JSON.stringify({ type: 'cancel' })); + } +} + // 批量注册 async function handleBatchRegistration(requestData) { const count = parseInt(elements.batchCount.value) || 5; @@ -340,26 +459,61 @@ async function handleBatchRegistration(requestData) { // 取消任务 async function handleCancelTask() { if (isBatchMode && currentBatch) { - try { - await api.post(`/registration/batch/${currentBatch.batch_id}/cancel`); + // 优先通过 WebSocket 取消批量任务 + if (batchWebSocket && batchWebSocket.readyState === WebSocket.OPEN) { + cancelBatchViaWebSocket(); addLog('warning', '[警告] 批量任务取消请求已提交'); toast.info('任务取消请求已提交'); - stopBatchPolling(); - resetButtons(); - } catch (error) { - addLog('error', `[错误] 取消失败: ${error.message}`); - toast.error(error.message); + } else { + // 降级到 REST API + try { + await api.post(`/registration/batch/${currentBatch.batch_id}/cancel`); + addLog('warning', '[警告] 批量任务取消请求已提交'); + toast.info('任务取消请求已提交'); + stopBatchPolling(); + resetButtons(); + } catch (error) { + addLog('error', `[错误] 取消失败: ${error.message}`); + toast.error(error.message); + } + } + } else if (isOutlookBatchMode && currentBatch) { + // Outlook 批量任务取消 + if (batchWebSocket && batchWebSocket.readyState === WebSocket.OPEN) { + cancelBatchViaWebSocket(); + addLog('warning', '[警告] Outlook 批量任务取消请求已提交'); + toast.info('任务取消请求已提交'); + } else { + // 降级到 REST API + try { + await api.post(`/registration/outlook-batch/${currentBatch.batch_id}/cancel`); + addLog('warning', '[警告] Outlook 批量任务取消请求已提交'); + toast.info('任务取消请求已提交'); + stopBatchPolling(); + resetButtons(); + } catch (error) { + addLog('error', `[错误] 取消失败: ${error.message}`); + toast.error(error.message); + } } } else if (currentTask) { - try { - await api.post(`/registration/tasks/${currentTask.task_uuid}/cancel`); - addLog('warning', '[警告] 任务已取消'); - toast.info('任务已取消'); - stopLogPolling(); - resetButtons(); - } catch (error) { - addLog('error', `[错误] 取消失败: ${error.message}`); - toast.error(error.message); + // 优先通过 WebSocket 取消 + if (useWebSocket && webSocket && webSocket.readyState === WebSocket.OPEN) { + cancelViaWebSocket(); + addLog('warning', '[警告] 任务取消请求已提交'); + toast.info('任务取消请求已提交'); + } else { + // 降级到 REST API + try { + await api.post(`/registration/tasks/${currentTask.task_uuid}/cancel`); + addLog('warning', '[警告] 任务已取消'); + toast.info('任务已取消'); + stopLogPolling(); + resetButtons(); + } catch (error) { + addLog('error', `[错误] 取消失败: ${error.message}`); + toast.error(error.message); + } } } } @@ -634,6 +788,9 @@ function resetButtons() { currentTask = null; currentBatch = null; isBatchMode = false; + // 断开 WebSocket + disconnectWebSocket(); + disconnectBatchWebSocket(); // 注意:不重置 isOutlookBatchMode,因为用户可能想继续使用 Outlook 批量模式 } @@ -765,8 +922,8 @@ async function handleOutlookBatchRegistration() { // 初始化批量状态显示 showBatchStatus({ count: data.to_register }); - // 开始轮询批量状态 - startOutlookBatchPolling(data.batch_id); + // 优先使用 WebSocket + connectBatchWebSocket(data.batch_id); } catch (error) { addLog('error', `[错误] 启动失败: ${error.message}`); @@ -775,7 +932,125 @@ async function handleOutlookBatchRegistration() { } } -// 开始轮询 Outlook 批量状态 +// ============== 批量任务 WebSocket 功能 ============== + +// 连接批量任务 WebSocket +function connectBatchWebSocket(batchId) { + const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:'; + const wsUrl = `${protocol}//${window.location.host}/api/ws/batch/${batchId}`; + + try { + batchWebSocket = new WebSocket(wsUrl); + + batchWebSocket.onopen = () => { + console.log('批量任务 WebSocket 连接成功'); + // 停止轮询(如果有) + stopBatchPolling(); + // 开始心跳 + startBatchWebSocketHeartbeat(); + }; + + batchWebSocket.onmessage = (event) => { + const data = JSON.parse(event.data); + + if (data.type === 'log') { + const logType = getLogType(data.message); + addLog(logType, data.message); + } else if (data.type === 'status') { + // 更新进度 + if (data.total !== undefined) { + updateBatchProgress({ + total: data.total, + completed: data.completed || 0, + success: data.success || 0, + failed: data.failed || 0 + }); + } + + // 检查是否完成 + if (['completed', 'failed', 'cancelled', 'cancelling'].includes(data.status)) { + disconnectBatchWebSocket(); + resetButtons(); + + if (data.status === 'completed') { + addLog('success', `[完成] Outlook 批量任务完成!成功: ${data.success}, 失败: ${data.failed}, 跳过: ${data.skipped || 0}`); + if (data.success > 0) { + toast.success(`Outlook 批量注册完成,成功 ${data.success} 个`); + loadRecentAccounts(); + } else { + toast.warning('Outlook 批量注册完成,但没有成功注册任何账号'); + } + } else if (data.status === 'failed') { + addLog('error', '[错误] 批量任务执行失败'); + toast.error('批量任务执行失败'); + } else if (data.status === 'cancelled' || data.status === 'cancelling') { + addLog('warning', '[警告] 批量任务已取消'); + } + } + } else if (data.type === 'pong') { + // 心跳响应,忽略 + } + }; + + batchWebSocket.onclose = (event) => { + console.log('批量任务 WebSocket 连接关闭:', event.code); + stopBatchWebSocketHeartbeat(); + + // 如果任务仍在运行,切换到轮询 + if (currentBatch && !['completed', 'failed', 'cancelled'].includes(currentBatch.status)) { + console.log('切换到轮询模式'); + startOutlookBatchPolling(currentBatch.batch_id); + } + }; + + batchWebSocket.onerror = (error) => { + console.error('批量任务 WebSocket 错误:', error); + stopBatchWebSocketHeartbeat(); + // 切换到轮询 + startOutlookBatchPolling(batchId); + }; + + } catch (error) { + console.error('批量任务 WebSocket 连接失败:', error); + startOutlookBatchPolling(batchId); + } +} + +// 断开批量任务 WebSocket +function disconnectBatchWebSocket() { + stopBatchWebSocketHeartbeat(); + if (batchWebSocket) { + batchWebSocket.close(); + batchWebSocket = null; + } +} + +// 开始批量任务心跳 +function startBatchWebSocketHeartbeat() { + stopBatchWebSocketHeartbeat(); + batchWsHeartbeatInterval = setInterval(() => { + if (batchWebSocket && batchWebSocket.readyState === WebSocket.OPEN) { + batchWebSocket.send(JSON.stringify({ type: 'ping' })); + } + }, 25000); // 每 25 秒发送一次心跳 +} + +// 停止批量任务心跳 +function stopBatchWebSocketHeartbeat() { + if (batchWsHeartbeatInterval) { + clearInterval(batchWsHeartbeatInterval); + batchWsHeartbeatInterval = null; + } +} + +// 发送批量任务取消请求 +function cancelBatchViaWebSocket() { + if (batchWebSocket && batchWebSocket.readyState === WebSocket.OPEN) { + batchWebSocket.send(JSON.stringify({ type: 'cancel' })); + } +} + +// 开始轮询 Outlook 批量状态(降级方案) function startOutlookBatchPolling(batchId) { batchPollingInterval = setInterval(async () => { try { diff --git a/webui.py b/webui.py index 452a991..5335b8d 100644 --- a/webui.py +++ b/webui.py @@ -18,7 +18,14 @@ from src.config.settings import get_settings def setup_application(): """设置应用程序""" - # 获取配置 + # 初始化数据库(必须先于获取设置) + try: + initialize_database() + except Exception as e: + print(f"数据库初始化失败: {e}") + raise + + # 获取配置(需要数据库已初始化) settings = get_settings() # 配置日志 @@ -28,14 +35,7 @@ def setup_application(): ) logger = logging.getLogger(__name__) - - # 初始化数据库 - try: - initialize_database() - logger.info("数据库初始化完成") - except Exception as e: - logger.error(f"数据库初始化失败: {e}") - raise + logger.info("数据库初始化完成") # 检查数据目录 data_dir = project_root / "data"