feat(webui): 添加WebSocket支持实现实时任务状态更新

- 在注册任务和批量任务中集成WebSocket连接
- 添加TaskManager管理任务状态和日志推送
- 前端app.js重构支持WebSocket与轮询降级机制
- 配置模块重构为完全基于数据库存储
This commit is contained in:
cnlimiter
2026-03-15 03:52:24 +08:00
parent 76efc047b3
commit 3d8a90cda9
11 changed files with 1028 additions and 248 deletions

2
cli.py
View File

@@ -139,7 +139,7 @@ def main() -> None:
sleep_max = max(sleep_min, args.sleep_max) sleep_max = max(sleep_min, args.sleep_max)
count = 0 count = 0
print("[Info] Yasal's Seamless OpenAI Auto-Registrar Started for ZJH (重构版本)") print("[Info] OpenAI Auto-Registrar")
while True: while True:
count += 1 count += 1

View File

@@ -2,7 +2,18 @@
配置模块 配置模块
""" """
from .settings import Settings, get_settings, update_settings, get_database_url from .settings import (
Settings,
get_settings,
update_settings,
get_database_url,
init_default_settings,
get_setting_definition,
get_all_setting_definitions,
SETTING_DEFINITIONS,
SettingCategory,
SettingDefinition,
)
from .constants import ( from .constants import (
AccountStatus, AccountStatus,
TaskStatus, TaskStatus,
@@ -22,6 +33,12 @@ __all__ = [
'get_settings', 'get_settings',
'update_settings', 'update_settings',
'get_database_url', 'get_database_url',
'init_default_settings',
'get_setting_definition',
'get_all_setting_definitions',
'SETTING_DEFINITIONS',
'SettingCategory',
'SettingDefinition',
'AccountStatus', 'AccountStatus',
'TaskStatus', 'TaskStatus',
'EmailServiceType', 'EmailServiceType',

View File

@@ -114,8 +114,6 @@ EMAIL_SERVICE_DEFAULTS = {
# 验证码相关 # 验证码相关
OTP_CODE_PATTERN = r"(?<!\d)(\d{6})(?!\d)" OTP_CODE_PATTERN = r"(?<!\d)(\d{6})(?!\d)"
OTP_WAIT_TIMEOUT = 120 # 秒
OTP_POLL_INTERVAL = 3 # 秒
OTP_MAX_ATTEMPTS = 40 # 最大轮询次数 OTP_MAX_ATTEMPTS = 40 # 最大轮询次数
# 验证码提取正则(增强版) # 验证码提取正则(增强版)

View File

@@ -1,40 +1,451 @@
""" """
配置管理 - Pydantic 设置模型 配置管理 - 完全基于数据库存储
所有配置都从数据库读取,不再使用环境变量或 .env 文件
""" """
import os import os
from typing import Optional, Dict, Any from typing import Optional, Dict, Any, Type
from pydantic import Field, field_validator from enum import Enum
from pydantic import BaseModel, field_validator
from pydantic.types import SecretStr from pydantic.types import SecretStr
from pydantic_settings import BaseSettings, SettingsConfigDict from dataclasses import dataclass
class Settings(BaseSettings): class SettingCategory(str, Enum):
""" """设置分类"""
应用配置 GENERAL = "general"
优先级:环境变量 > .env 文件 > 默认值 DATABASE = "database"
""" WEBUI = "webui"
LOG = "log"
OPENAI = "openai"
PROXY = "proxy"
REGISTRATION = "registration"
EMAIL = "email"
TEMPMAIL = "tempmail"
CUSTOM_DOMAIN = "custom_domain"
SECURITY = "security"
CPA = "cpa"
model_config = SettingsConfigDict(
env_file=".env",
env_file_encoding="utf-8",
case_sensitive=False,
extra="ignore",
)
@dataclass
class SettingDefinition:
"""设置定义"""
db_key: str
default_value: Any
category: SettingCategory
description: str = ""
is_secret: bool = False
# 所有配置项定义(包含数据库键名、默认值、分类、描述)
SETTING_DEFINITIONS: Dict[str, SettingDefinition] = {
# 应用信息 # 应用信息
app_name: str = Field(default="OpenAI/Codex CLI 自动注册系统") "app_name": SettingDefinition(
app_version: str = Field(default="2.0.0") db_key="app.name",
debug: bool = Field(default=False) default_value="OpenAI/Codex CLI 自动注册系统",
category=SettingCategory.GENERAL,
description="应用名称"
),
"app_version": SettingDefinition(
db_key="app.version",
default_value="2.0.0",
category=SettingCategory.GENERAL,
description="应用版本"
),
"debug": SettingDefinition(
db_key="app.debug",
default_value=False,
category=SettingCategory.GENERAL,
description="调试模式"
),
# 数据库配置 # 数据库配置
database_url: str = Field( "database_url": SettingDefinition(
default=os.path.join( db_key="database.url",
os.path.dirname(os.path.dirname(os.path.dirname(__file__))), default_value="data/database.db",
'data', category=SettingCategory.DATABASE,
'database.db' description="数据库路径或连接字符串"
) ),
)
# Web UI 配置
"webui_host": SettingDefinition(
db_key="webui.host",
default_value="0.0.0.0",
category=SettingCategory.WEBUI,
description="Web UI 监听地址"
),
"webui_port": SettingDefinition(
db_key="webui.port",
default_value=8000,
category=SettingCategory.WEBUI,
description="Web UI 监听端口"
),
"webui_secret_key": SettingDefinition(
db_key="webui.secret_key",
default_value="your-secret-key-change-in-production",
category=SettingCategory.WEBUI,
description="Web UI 密钥",
is_secret=True
),
# 日志配置
"log_level": SettingDefinition(
db_key="log.level",
default_value="INFO",
category=SettingCategory.LOG,
description="日志级别"
),
"log_file": SettingDefinition(
db_key="log.file",
default_value="logs/app.log",
category=SettingCategory.LOG,
description="日志文件路径"
),
"log_retention_days": SettingDefinition(
db_key="log.retention_days",
default_value=30,
category=SettingCategory.LOG,
description="日志保留天数"
),
# OpenAI 配置
"openai_client_id": SettingDefinition(
db_key="openai.client_id",
default_value="app_EMoamEEZ73f0CkXaXp7hrann",
category=SettingCategory.OPENAI,
description="OpenAI OAuth 客户端 ID"
),
"openai_auth_url": SettingDefinition(
db_key="openai.auth_url",
default_value="https://auth.openai.com/oauth/authorize",
category=SettingCategory.OPENAI,
description="OpenAI OAuth 授权 URL"
),
"openai_token_url": SettingDefinition(
db_key="openai.token_url",
default_value="https://auth.openai.com/oauth/token",
category=SettingCategory.OPENAI,
description="OpenAI OAuth Token URL"
),
"openai_redirect_uri": SettingDefinition(
db_key="openai.redirect_uri",
default_value="http://localhost:1455/auth/callback",
category=SettingCategory.OPENAI,
description="OpenAI OAuth 回调 URI"
),
"openai_scope": SettingDefinition(
db_key="openai.scope",
default_value="openid email profile offline_access",
category=SettingCategory.OPENAI,
description="OpenAI OAuth 权限范围"
),
# 代理配置
"proxy_enabled": SettingDefinition(
db_key="proxy.enabled",
default_value=False,
category=SettingCategory.PROXY,
description="是否启用代理"
),
"proxy_type": SettingDefinition(
db_key="proxy.type",
default_value="http",
category=SettingCategory.PROXY,
description="代理类型 (http/socks5)"
),
"proxy_host": SettingDefinition(
db_key="proxy.host",
default_value="127.0.0.1",
category=SettingCategory.PROXY,
description="代理服务器地址"
),
"proxy_port": SettingDefinition(
db_key="proxy.port",
default_value=7890,
category=SettingCategory.PROXY,
description="代理服务器端口"
),
"proxy_username": SettingDefinition(
db_key="proxy.username",
default_value="",
category=SettingCategory.PROXY,
description="代理用户名"
),
"proxy_password": SettingDefinition(
db_key="proxy.password",
default_value="",
category=SettingCategory.PROXY,
description="代理密码",
is_secret=True
),
# 注册配置
"registration_max_retries": SettingDefinition(
db_key="registration.max_retries",
default_value=3,
category=SettingCategory.REGISTRATION,
description="注册最大重试次数"
),
"registration_timeout": SettingDefinition(
db_key="registration.timeout",
default_value=120,
category=SettingCategory.REGISTRATION,
description="注册超时时间(秒)"
),
"registration_default_password_length": SettingDefinition(
db_key="registration.default_password_length",
default_value=12,
category=SettingCategory.REGISTRATION,
description="默认密码长度"
),
"registration_sleep_min": SettingDefinition(
db_key="registration.sleep_min",
default_value=5,
category=SettingCategory.REGISTRATION,
description="注册间隔最小值(秒)"
),
"registration_sleep_max": SettingDefinition(
db_key="registration.sleep_max",
default_value=30,
category=SettingCategory.REGISTRATION,
description="注册间隔最大值(秒)"
),
# 邮箱服务配置
"email_service_priority": SettingDefinition(
db_key="email.service_priority",
default_value={"tempmail": 0, "outlook": 1, "custom_domain": 2},
category=SettingCategory.EMAIL,
description="邮箱服务优先级"
),
# Tempmail.lol 配置
"tempmail_base_url": SettingDefinition(
db_key="tempmail.base_url",
default_value="https://api.tempmail.lol/v2",
category=SettingCategory.TEMPMAIL,
description="Tempmail API 地址"
),
"tempmail_timeout": SettingDefinition(
db_key="tempmail.timeout",
default_value=30,
category=SettingCategory.TEMPMAIL,
description="Tempmail 超时时间(秒)"
),
"tempmail_max_retries": SettingDefinition(
db_key="tempmail.max_retries",
default_value=3,
category=SettingCategory.TEMPMAIL,
description="Tempmail 最大重试次数"
),
# 自定义域名邮箱配置
"custom_domain_base_url": SettingDefinition(
db_key="custom_domain.base_url",
default_value="",
category=SettingCategory.CUSTOM_DOMAIN,
description="自定义域名 API 地址"
),
"custom_domain_api_key": SettingDefinition(
db_key="custom_domain.api_key",
default_value="",
category=SettingCategory.CUSTOM_DOMAIN,
description="自定义域名 API 密钥",
is_secret=True
),
# 安全配置
"encryption_key": SettingDefinition(
db_key="security.encryption_key",
default_value="your-encryption-key-change-in-production",
category=SettingCategory.SECURITY,
description="加密密钥",
is_secret=True
),
# CPA 上传配置
"cpa_enabled": SettingDefinition(
db_key="cpa.enabled",
default_value=False,
category=SettingCategory.CPA,
description="是否启用 CPA 上传"
),
"cpa_api_url": SettingDefinition(
db_key="cpa.api_url",
default_value="",
category=SettingCategory.CPA,
description="CPA API 地址"
),
"cpa_api_token": SettingDefinition(
db_key="cpa.api_token",
default_value="",
category=SettingCategory.CPA,
description="CPA API Token",
is_secret=True
),
# 验证码配置
"email_code_timeout": SettingDefinition(
db_key="email_code.timeout",
default_value=120,
category=SettingCategory.EMAIL,
description="验证码等待超时时间(秒)"
),
"email_code_poll_interval": SettingDefinition(
db_key="email_code.poll_interval",
default_value=3,
category=SettingCategory.EMAIL,
description="验证码轮询间隔(秒)"
),
}
# 属性名到数据库键名的映射(用于向后兼容)
DB_SETTING_KEYS = {name: defn.db_key for name, defn in SETTING_DEFINITIONS.items()}
# 类型定义映射
SETTING_TYPES: Dict[str, Type] = {
"debug": bool,
"webui_port": int,
"log_retention_days": int,
"proxy_enabled": bool,
"proxy_port": int,
"registration_max_retries": int,
"registration_timeout": int,
"registration_default_password_length": int,
"registration_sleep_min": int,
"registration_sleep_max": int,
"email_service_priority": dict,
"tempmail_timeout": int,
"tempmail_max_retries": int,
"cpa_enabled": bool,
"email_code_timeout": int,
"email_code_poll_interval": int,
}
# 需要作为 SecretStr 处理的字段
SECRET_FIELDS = {name for name, defn in SETTING_DEFINITIONS.items() if defn.is_secret}
def _convert_value(attr_name: str, value: str) -> Any:
"""将数据库字符串值转换为正确的类型"""
if attr_name in SECRET_FIELDS:
return SecretStr(value) if value else SecretStr("")
target_type = SETTING_TYPES.get(attr_name, str)
if target_type == bool:
if isinstance(value, bool):
return value
return str(value).lower() in ("true", "1", "yes", "on")
elif target_type == int:
if isinstance(value, int):
return value
return int(value) if value else 0
elif target_type == dict:
if isinstance(value, dict):
return value
import json
return json.loads(value) if value else {}
else:
return value
def _value_to_string(value: Any) -> str:
"""将值转换为数据库存储的字符串"""
if isinstance(value, SecretStr):
return value.get_secret_value()
elif isinstance(value, bool):
return "true" if value else "false"
elif isinstance(value, dict):
import json
return json.dumps(value)
elif value is None:
return ""
else:
return str(value)
def init_default_settings() -> None:
"""
初始化数据库中的默认设置
如果设置项不存在,则创建并设置默认值
"""
try:
from ..database.session import get_db
from ..database.crud import get_setting, set_setting
with get_db() as db:
for attr_name, defn in SETTING_DEFINITIONS.items():
existing = get_setting(db, defn.db_key)
if not existing:
default_value = _value_to_string(defn.default_value)
set_setting(
db,
defn.db_key,
default_value,
category=defn.category.value,
description=defn.description
)
print(f"[Settings] 初始化默认设置: {defn.db_key} = {default_value if not defn.is_secret else '***'}")
except Exception as e:
print(f"[Settings] 初始化默认设置失败: {e}")
def _load_settings_from_db() -> Dict[str, Any]:
"""从数据库加载所有设置"""
try:
from ..database.session import get_db
from ..database.crud import get_setting
settings_dict = {}
with get_db() as db:
for attr_name, defn in SETTING_DEFINITIONS.items():
db_setting = get_setting(db, defn.db_key)
if db_setting:
settings_dict[attr_name] = _convert_value(attr_name, db_setting.value)
else:
# 数据库中没有此设置,使用默认值
settings_dict[attr_name] = _convert_value(attr_name, _value_to_string(defn.default_value))
return settings_dict
except Exception as e:
print(f"[Settings] 从数据库加载设置失败: {e},使用默认值")
return {name: defn.default_value for name, defn in SETTING_DEFINITIONS.items()}
def _save_settings_to_db(**kwargs) -> None:
"""保存设置到数据库"""
try:
from ..database.session import get_db
from ..database.crud import set_setting
with get_db() as db:
for attr_name, value in kwargs.items():
if attr_name in SETTING_DEFINITIONS:
defn = SETTING_DEFINITIONS[attr_name]
str_value = _value_to_string(value)
set_setting(
db,
defn.db_key,
str_value,
category=defn.category.value,
description=defn.description
)
except Exception as e:
print(f"[Settings] 保存设置到数据库失败: {e}")
class Settings(BaseModel):
"""
应用配置 - 完全基于数据库存储
"""
# 应用信息
app_name: str = "OpenAI/Codex CLI 自动注册系统"
app_version: str = "2.0.0"
debug: bool = False
# 数据库配置
database_url: str = "data/database.db"
@field_validator('database_url', mode='before') @field_validator('database_url', mode='before')
@classmethod @classmethod
@@ -48,31 +459,29 @@ class Settings(BaseSettings):
return v return v
# Web UI 配置 # Web UI 配置
webui_host: str = Field(default="0.0.0.0") webui_host: str = "0.0.0.0"
webui_port: int = Field(default=8000) webui_port: int = 8000
webui_secret_key: SecretStr = Field( webui_secret_key: SecretStr = SecretStr("your-secret-key-change-in-production")
default=SecretStr("your-secret-key-change-in-production")
)
# 日志配置 # 日志配置
log_level: str = Field(default="INFO") log_level: str = "INFO"
log_file: str = Field(default="logs/app.log") log_file: str = "logs/app.log"
log_retention_days: int = Field(default=30) log_retention_days: int = 30
# OpenAI 配置 # OpenAI 配置
openai_client_id: str = Field(default="app_EMoamEEZ73f0CkXaXp7hrann") openai_client_id: str = "app_EMoamEEZ73f0CkXaXp7hrann"
openai_auth_url: str = Field(default="https://auth.openai.com/oauth/authorize") openai_auth_url: str = "https://auth.openai.com/oauth/authorize"
openai_token_url: str = Field(default="https://auth.openai.com/oauth/token") openai_token_url: str = "https://auth.openai.com/oauth/token"
openai_redirect_uri: str = Field(default="http://localhost:1455/auth/callback") openai_redirect_uri: str = "http://localhost:1455/auth/callback"
openai_scope: str = Field(default="openid email profile offline_access") openai_scope: str = "openid email profile offline_access"
# 代理配置 # 代理配置
proxy_enabled: bool = Field(default=False) proxy_enabled: bool = False
proxy_type: str = Field(default="http") # http, socks5 proxy_type: str = "http"
proxy_host: str = Field(default="127.0.0.1") proxy_host: str = "127.0.0.1"
proxy_port: int = Field(default=7890) proxy_port: int = 7890
proxy_username: Optional[str] = Field(default=None) proxy_username: Optional[str] = None
proxy_password: Optional[SecretStr] = Field(default=None) proxy_password: Optional[SecretStr] = None
@property @property
def proxy_url(self) -> Optional[str]: def proxy_url(self) -> Optional[str]:
@@ -94,35 +503,35 @@ class Settings(BaseSettings):
return f"{scheme}://{auth}{self.proxy_host}:{self.proxy_port}" return f"{scheme}://{auth}{self.proxy_host}:{self.proxy_port}"
# 注册配置 # 注册配置
registration_max_retries: int = Field(default=3) registration_max_retries: int = 3
registration_timeout: int = Field(default=120) # 秒 registration_timeout: int = 120
registration_default_password_length: int = Field(default=12) registration_default_password_length: int = 12
registration_sleep_min: int = Field(default=5) registration_sleep_min: int = 5
registration_sleep_max: int = Field(default=30) registration_sleep_max: int = 30
# 邮箱服务配置 # 邮箱服务配置
email_service_priority: Dict[str, int] = Field( email_service_priority: Dict[str, int] = {"tempmail": 0, "outlook": 1, "custom_domain": 2}
default={"tempmail": 0, "outlook": 1, "custom_domain": 2}
)
# Tempmail.lol 配置 # Tempmail.lol 配置
tempmail_base_url: str = Field(default="https://api.tempmail.lol/v2") tempmail_base_url: str = "https://api.tempmail.lol/v2"
tempmail_timeout: int = Field(default=30) tempmail_timeout: int = 30
tempmail_max_retries: int = Field(default=3) tempmail_max_retries: int = 3
# 自定义域名邮箱配置 # 自定义域名邮箱配置
custom_domain_base_url: str = Field(default="") custom_domain_base_url: str = ""
custom_domain_api_key: Optional[SecretStr] = Field(default=None) custom_domain_api_key: Optional[SecretStr] = None
# 安全配置 # 安全配置
encryption_key: SecretStr = Field( encryption_key: SecretStr = SecretStr("your-encryption-key-change-in-production")
default=SecretStr("your-encryption-key-change-in-production")
)
# CPA 上传配置 # CPA 上传配置
cpa_enabled: bool = Field(default=False) cpa_enabled: bool = False
cpa_api_url: str = Field(default="") # 例如: https://cpa.example.com cpa_api_url: str = ""
cpa_api_token: SecretStr = Field(default=SecretStr("")) cpa_api_token: SecretStr = SecretStr("")
# 验证码配置
email_code_timeout: int = 120
email_code_poll_interval: int = 3
# 全局配置实例 # 全局配置实例
@@ -132,25 +541,34 @@ _settings: Optional[Settings] = None
def get_settings() -> Settings: def get_settings() -> Settings:
""" """
获取全局配置实例(单例模式) 获取全局配置实例(单例模式)
完全从数据库加载配置
""" """
global _settings global _settings
if _settings is None: if _settings is None:
_settings = Settings() # 先初始化默认设置(如果数据库中没有的话)
init_default_settings()
# 从数据库加载所有设置
settings_dict = _load_settings_from_db()
_settings = Settings(**settings_dict)
return _settings return _settings
def update_settings(**kwargs) -> Settings: def update_settings(**kwargs) -> Settings:
""" """
更新配置(用于测试或运行时配置更改) 更新配置并保存到数据库
""" """
global _settings global _settings
if _settings is None: if _settings is None:
_settings = Settings() _settings = get_settings()
# 创建新的配置实例 # 创建新的配置实例
updated_data = _settings.model_dump() updated_data = _settings.model_dump()
updated_data.update(kwargs) updated_data.update(kwargs)
_settings = Settings(**updated_data) _settings = Settings(**updated_data)
# 保存到数据库
_save_settings_to_db(**kwargs)
return _settings return _settings
@@ -171,3 +589,13 @@ def get_database_url() -> str:
return f"sqlite:///{abs_path}" return f"sqlite:///{abs_path}"
return url return url
def get_setting_definition(attr_name: str) -> Optional[SettingDefinition]:
"""获取设置项的定义信息"""
return SETTING_DEFINITIONS.get(attr_name)
def get_all_setting_definitions() -> Dict[str, SettingDefinition]:
"""获取所有设置项的定义"""
return SETTING_DEFINITIONS.copy()

View File

@@ -2,55 +2,10 @@
数据库初始化和初始化数据 数据库初始化和初始化数据
""" """
import json
from datetime import datetime
from .session import init_database from .session import init_database
from .crud import set_setting
from .models import Base from .models import Base
def init_default_settings(db):
"""初始化默认设置"""
# 通用设置
default_settings = [
("system.name", "OpenAI/Codex CLI 自动注册系统", "系统名称", "general"),
("system.version", "2.0.0", "系统版本", "general"),
("logs.retention_days", "30", "日志保留天数", "general"),
# OpenAI 配置
("openai.client_id", "app_EMoamEEZ73f0CkXaXp7hrann", "OpenAI OAuth Client ID", "openai"),
("openai.auth_url", "https://auth.openai.com/oauth/authorize", "OpenAI 认证地址", "openai"),
("openai.token_url", "https://auth.openai.com/oauth/token", "OpenAI Token 地址", "openai"),
("openai.redirect_uri", "http://localhost:1455/auth/callback", "OpenAI 回调地址", "openai"),
("openai.scope", "openid email profile offline_access", "OpenAI 权限范围", "openai"),
# 代理设置
("proxy.enabled", "false", "是否启用代理", "proxy"),
("proxy.type", "http", "代理类型 (http/socks5)", "proxy"),
("proxy.host", "127.0.0.1", "代理主机", "proxy"),
("proxy.port", "7890", "代理端口", "proxy"),
# 注册设置
("registration.max_retries", "3", "最大重试次数", "registration"),
("registration.timeout", "120", "超时时间(秒)", "registration"),
("registration.default_password_length", "12", "默认密码长度", "registration"),
# Web UI 设置
("webui.host", "0.0.0.0", "Web UI 监听主机", "webui"),
("webui.port", "8000", "Web UI 监听端口", "webui"),
("webui.debug", "true", "调试模式", "webui"),
]
for key, value, description, category in default_settings:
set_setting(db, key, value, description, category)
def init_default_email_services(db):
"""初始化默认邮箱服务(仅模板,需要用户配置)"""
# 这里只创建模板配置,实际配置需要用户通过 Web UI 设置
pass
def initialize_database(database_url: str = None): def initialize_database(database_url: str = None):
""" """
初始化数据库 初始化数据库
@@ -59,15 +14,13 @@ def initialize_database(database_url: str = None):
# 初始化数据库连接和表 # 初始化数据库连接和表
db_manager = init_database(database_url) db_manager = init_database(database_url)
# 在事务中设置默认配置 # 创建表
with db_manager.session_scope() as session: db_manager.create_tables()
# 初始化默认设置
init_default_settings(session)
# 初始化默认邮箱服务 # 初始化默认设置(从 settings 模块导入以避免循环导入)
init_default_email_services(session) from ..config.settings import init_default_settings
init_default_settings()
print("数据库初始化完成")
return db_manager return db_manager
@@ -86,9 +39,9 @@ def reset_database(database_url: str = None):
db_manager.create_tables() db_manager.create_tables()
print("已重新创建所有表") print("已重新创建所有表")
# 初始化数据 # 初始化默认设置
with db_manager.session_scope() as session: from ..config.settings import init_default_settings
init_default_settings(session) init_default_settings()
print("数据库重置完成") print("数据库重置完成")
return db_manager return db_manager
@@ -130,4 +83,4 @@ if __name__ == "__main__":
else: else:
print("操作已取消") print("操作已取消")
else: else:
initialize_database(args.url) initialize_database(args.url)

View File

@@ -27,35 +27,22 @@ from ..config.constants import (
OTP_CODE_SEMANTIC_PATTERN, OTP_CODE_SEMANTIC_PATTERN,
OPENAI_EMAIL_SENDERS, OPENAI_EMAIL_SENDERS,
OPENAI_VERIFICATION_KEYWORDS, OPENAI_VERIFICATION_KEYWORDS,
OTP_WAIT_TIMEOUT,
OTP_POLL_INTERVAL,
) )
from ..database import crud from ..config.settings import get_settings
from ..database.session import get_db
def get_email_code_settings() -> dict: def get_email_code_settings() -> dict:
""" """
从数据库获取验证码等待配置 获取验证码等待配置
Returns: Returns:
dict: 包含 timeout 和 poll_interval 的字典 dict: 包含 timeout 和 poll_interval 的字典
""" """
try: settings = get_settings()
with get_db() as db: return {
timeout_setting = crud.get_setting(db, "email_code.timeout") "timeout": settings.email_code_timeout,
poll_interval_setting = crud.get_setting(db, "email_code.poll_interval") "poll_interval": settings.email_code_poll_interval,
}
return {
"timeout": int(timeout_setting.value) if timeout_setting else OTP_WAIT_TIMEOUT,
"poll_interval": int(poll_interval_setting.value) if poll_interval_setting else OTP_POLL_INTERVAL,
}
except Exception as e:
logger.warning(f"获取验证码配置失败,使用默认值: {e}")
return {
"timeout": OTP_WAIT_TIMEOUT,
"poll_interval": OTP_POLL_INTERVAL,
}
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -15,6 +15,8 @@ from fastapi.responses import HTMLResponse
from ..config.settings import get_settings from ..config.settings import get_settings
from .routes import api_router from .routes import api_router
from .routes.websocket import router as ws_router
from .task_manager import task_manager
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -65,6 +67,9 @@ def create_app() -> FastAPI:
# 注册 API 路由 # 注册 API 路由
app.include_router(api_router, prefix="/api") app.include_router(api_router, prefix="/api")
# 注册 WebSocket 路由
app.include_router(ws_router, prefix="/api")
# 模板引擎 # 模板引擎
templates = Jinja2Templates(directory=str(TEMPLATES_DIR)) templates = Jinja2Templates(directory=str(TEMPLATES_DIR))
@@ -91,6 +96,12 @@ def create_app() -> FastAPI:
@app.on_event("startup") @app.on_event("startup")
async def startup_event(): async def startup_event():
"""应用启动事件""" """应用启动事件"""
import asyncio
# 设置 TaskManager 的事件循环
loop = asyncio.get_event_loop()
task_manager.set_loop(loop)
logger.info("=" * 50) logger.info("=" * 50)
logger.info(f"{settings.app_name} v{settings.app_version} 启动中...") logger.info(f"{settings.app_name} v{settings.app_version} 启动中...")
logger.info(f"调试模式: {settings.debug}") logger.info(f"调试模式: {settings.debug}")

View File

@@ -18,6 +18,7 @@ from ...database.models import RegistrationTask, Proxy
from ...core.register import RegistrationEngine, RegistrationResult from ...core.register import RegistrationEngine, RegistrationResult
from ...services import EmailServiceFactory, EmailServiceType from ...services import EmailServiceFactory, EmailServiceType
from ...config.settings import get_settings from ...config.settings import get_settings
from ..task_manager import task_manager
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
router = APIRouter() router = APIRouter()
@@ -169,10 +170,19 @@ def task_to_response(task: RegistrationTask) -> RegistrationTaskResponse:
) )
async def run_registration_task(task_uuid: str, email_service_type: str, proxy: Optional[str], email_service_config: Optional[dict], email_service_id: Optional[int] = None): def _run_sync_registration_task(task_uuid: str, email_service_type: str, proxy: Optional[str], email_service_config: Optional[dict], email_service_id: Optional[int] = None):
"""异步执行注册任务""" """
在线程池中执行的同步注册任务
这个函数会被 run_in_executor 调用,运行在独立线程中
"""
with get_db() as db: with get_db() as db:
try: try:
# 检查是否已取消
if task_manager.is_cancelled(task_uuid):
logger.info(f"任务 {task_uuid} 已取消,跳过执行")
return
# 更新任务状态为运行中 # 更新任务状态为运行中
task = crud.update_registration_task( task = crud.update_registration_task(
db, task_uuid, db, task_uuid,
@@ -184,6 +194,9 @@ async def run_registration_task(task_uuid: str, email_service_type: str, proxy:
logger.error(f"任务不存在: {task_uuid}") logger.error(f"任务不存在: {task_uuid}")
return return
# 更新 TaskManager 状态
task_manager.update_status(task_uuid, "running")
# 确定使用的代理 # 确定使用的代理
# 如果前端传入了代理参数,使用传入的 # 如果前端传入了代理参数,使用传入的
# 否则从代理列表或系统设置中获取 # 否则从代理列表或系统设置中获取
@@ -284,10 +297,8 @@ async def run_registration_task(task_uuid: str, email_service_type: str, proxy:
email_service = EmailServiceFactory.create(service_type, config) email_service = EmailServiceFactory.create(service_type, config)
# 创建注册引擎 # 创建注册引擎 - 使用 TaskManager 的日志回调
def log_callback(msg): log_callback = task_manager.create_log_callback(task_uuid)
with get_db() as db_inner:
crud.append_task_log(db_inner, task_uuid, msg)
engine = RegistrationEngine( engine = RegistrationEngine(
email_service=email_service, email_service=email_service,
@@ -314,6 +325,9 @@ async def run_registration_task(task_uuid: str, email_service_type: str, proxy:
result=result.to_dict() result=result.to_dict()
) )
# 更新 TaskManager 状态
task_manager.update_status(task_uuid, "completed", email=result.email)
logger.info(f"注册任务完成: {task_uuid}, 邮箱: {result.email}") logger.info(f"注册任务完成: {task_uuid}, 邮箱: {result.email}")
else: else:
# 更新任务状态为失败 # 更新任务状态为失败
@@ -324,6 +338,9 @@ async def run_registration_task(task_uuid: str, email_service_type: str, proxy:
error_message=result.error_message error_message=result.error_message
) )
# 更新 TaskManager 状态
task_manager.update_status(task_uuid, "failed", error=result.error_message)
logger.warning(f"注册任务失败: {task_uuid}, 原因: {result.error_message}") logger.warning(f"注册任务失败: {task_uuid}, 原因: {result.error_message}")
except Exception as e: except Exception as e:
@@ -337,10 +354,45 @@ async def run_registration_task(task_uuid: str, email_service_type: str, proxy:
completed_at=datetime.utcnow(), completed_at=datetime.utcnow(),
error_message=str(e) error_message=str(e)
) )
# 更新 TaskManager 状态
task_manager.update_status(task_uuid, "failed", error=str(e))
except: except:
pass pass
async def run_registration_task(task_uuid: str, email_service_type: str, proxy: Optional[str], email_service_config: Optional[dict], email_service_id: Optional[int] = None):
"""
异步执行注册任务
使用 run_in_executor 将同步任务放入线程池执行,避免阻塞主事件循环
"""
loop = task_manager.get_loop()
if loop is None:
loop = asyncio.get_event_loop()
task_manager.set_loop(loop)
# 初始化 TaskManager 状态
task_manager.update_status(task_uuid, "pending")
task_manager.add_log(task_uuid, f"[系统] 任务 {task_uuid[:8]} 已加入队列")
try:
# 在线程池中执行同步任务
await loop.run_in_executor(
task_manager.executor,
_run_sync_registration_task,
task_uuid,
email_service_type,
proxy,
email_service_config,
email_service_id
)
except Exception as e:
logger.error(f"线程池执行异常: {task_uuid}, 错误: {e}")
task_manager.add_log(task_uuid, f"[错误] 线程池执行异常: {str(e)}")
task_manager.update_status(task_uuid, "failed", error=str(e))
async def run_batch_registration( async def run_batch_registration(
batch_id: str, batch_id: str,
task_uuids: List[str], task_uuids: List[str],
@@ -351,7 +403,11 @@ async def run_batch_registration(
interval_min: int, interval_min: int,
interval_max: int interval_max: int
): ):
"""异步执行批量注册任务""" """
异步执行批量注册任务
使用线程池执行每个注册任务,避免阻塞主事件循环
"""
batch_tasks[batch_id] = { batch_tasks[batch_id] = {
"total": len(task_uuids), "total": len(task_uuids),
"completed": 0, "completed": 0,
@@ -375,7 +431,7 @@ async def run_batch_registration(
batch_tasks[batch_id]["current_index"] = i batch_tasks[batch_id]["current_index"] = i
# 运行单个注册任务 # 运行单个注册任务(使用线程池)
await run_registration_task( await run_registration_task(
task_uuid, email_service_type, proxy, email_service_config, email_service_id task_uuid, email_service_type, proxy, email_service_config, email_service_id
) )
@@ -802,7 +858,7 @@ async def get_outlook_accounts_for_registration():
) )
async def run_outlook_batch_registration( def _run_sync_outlook_batch_registration(
batch_id: str, batch_id: str,
service_ids: List[int], service_ids: List[int],
skip_registered: bool, skip_registered: bool,
@@ -811,13 +867,15 @@ async def run_outlook_batch_registration(
interval_max: int interval_max: int
): ):
""" """
异步执行 Outlook 批量注册任务 在线程池中执行的同步 Outlook 批量注册任务
遍历选中的 Outlook 服务,检查邮箱是否已注册,执行注册任务
""" """
from ...database.models import EmailService as EmailServiceModel from ...database.models import EmailService as EmailServiceModel
from ...database.models import Account from ...database.models import Account
# 初始化 TaskManager 批量任务
task_manager.init_batch(batch_id, len(service_ids))
# 兼容旧版 batch_tasks用于 REST API 轮询降级)
batch_tasks[batch_id] = { batch_tasks[batch_id] = {
"total": len(service_ids), "total": len(service_ids),
"completed": 0, "completed": 0,
@@ -830,14 +888,28 @@ async def run_outlook_batch_registration(
"logs": [] "logs": []
} }
def add_batch_log(msg: str):
"""同时添加日志到两个系统"""
batch_tasks[batch_id]["logs"].append(msg)
task_manager.add_batch_log(batch_id, msg)
def update_batch_status(**kwargs):
"""同时更新两个系统的状态"""
for key, value in kwargs.items():
if key in batch_tasks[batch_id]:
batch_tasks[batch_id][key] = value
task_manager.update_batch_status(batch_id, **kwargs)
try: try:
for i, service_id in enumerate(service_ids): for i, service_id in enumerate(service_ids):
# 检查是否已取消 # 检查是否已取消
if batch_tasks[batch_id]["cancelled"]: if task_manager.is_batch_cancelled(batch_id):
add_batch_log(f"[取消] 批量任务已取消")
update_batch_status(finished=True, status="cancelled")
logger.info(f"Outlook 批量任务 {batch_id} 已取消") logger.info(f"Outlook 批量任务 {batch_id} 已取消")
break break
batch_tasks[batch_id]["current_index"] = i update_batch_status(current_index=i)
with get_db() as db: with get_db() as db:
# 获取邮箱服务 # 获取邮箱服务
@@ -846,9 +918,9 @@ async def run_outlook_batch_registration(
).first() ).first()
if not service: if not service:
batch_tasks[batch_id]["logs"].append(f"[跳过] 服务 ID {service_id} 不存在") add_batch_log(f"[跳过] 服务 ID {service_id} 不存在")
batch_tasks[batch_id]["skipped"] += 1 update_batch_status(skipped=batch_tasks[batch_id]["skipped"] + 1,
batch_tasks[batch_id]["completed"] += 1 completed=batch_tasks[batch_id]["completed"] + 1)
continue continue
config = service.config or {} config = service.config or {}
@@ -861,9 +933,9 @@ async def run_outlook_batch_registration(
).first() ).first()
if existing_account: if existing_account:
batch_tasks[batch_id]["logs"].append(f"[跳过] {email} 已注册 (账号 ID: {existing_account.id})") add_batch_log(f"[跳过] {email} 已注册 (账号 ID: {existing_account.id})")
batch_tasks[batch_id]["skipped"] += 1 update_batch_status(skipped=batch_tasks[batch_id]["skipped"] + 1,
batch_tasks[batch_id]["completed"] += 1 completed=batch_tasks[batch_id]["completed"] + 1)
continue continue
# 创建注册任务 # 创建注册任务
@@ -875,38 +947,80 @@ async def run_outlook_batch_registration(
email_service_id=service_id email_service_id=service_id
) )
batch_tasks[batch_id]["logs"].append(f"[注册] 开始注册 {email}...") add_batch_log(f"[注册] 开始注册 {email}...")
# 运行单个注册任务 # 同步执行注册任务
await run_registration_task( _run_sync_registration_task(task_uuid, "outlook", proxy, None, service_id)
task_uuid, "outlook", proxy, None, service_id
)
# 更新统计 # 更新统计
with get_db() as db: with get_db() as db:
task = crud.get_registration_task(db, task_uuid) task = crud.get_registration_task(db, task_uuid)
if task: if task:
batch_tasks[batch_id]["completed"] += 1 new_completed = batch_tasks[batch_id]["completed"] + 1
new_success = batch_tasks[batch_id]["success"]
new_failed = batch_tasks[batch_id]["failed"]
if task.status == "completed": if task.status == "completed":
batch_tasks[batch_id]["success"] += 1 new_success += 1
batch_tasks[batch_id]["logs"].append(f"[成功] {email} 注册成功") add_batch_log(f"[成功] {email} 注册成功")
elif task.status == "failed": elif task.status == "failed":
batch_tasks[batch_id]["failed"] += 1 new_failed += 1
batch_tasks[batch_id]["logs"].append(f"[失败] {email} 注册失败: {task.error_message}") add_batch_log(f"[失败] {email} 注册失败: {task.error_message}")
update_batch_status(
completed=new_completed,
success=new_success,
failed=new_failed
)
# 如果不是最后一个任务,等待随机间隔 # 如果不是最后一个任务,等待随机间隔
if i < len(service_ids) - 1 and not batch_tasks[batch_id]["cancelled"]: if i < len(service_ids) - 1 and not task_manager.is_batch_cancelled(batch_id):
wait_time = random.randint(interval_min, interval_max) wait_time = random.randint(interval_min, interval_max)
logger.info(f"Outlook 批量任务 {batch_id}: 等待 {wait_time} 秒后继续下一个任务") logger.info(f"Outlook 批量任务 {batch_id}: 等待 {wait_time} 秒后继续下一个任务")
await asyncio.sleep(wait_time) import time
time.sleep(wait_time)
logger.info(f"Outlook 批量任务 {batch_id} 完成: 成功 {batch_tasks[batch_id]['success']}, 失败 {batch_tasks[batch_id]['failed']}, 跳过 {batch_tasks[batch_id]['skipped']}") # 完成批量任务
if not task_manager.is_batch_cancelled(batch_id):
add_batch_log(f"[完成] 批量任务完成!成功: {batch_tasks[batch_id]['success']}, 失败: {batch_tasks[batch_id]['failed']}, 跳过: {batch_tasks[batch_id]['skipped']}")
update_batch_status(finished=True, status="completed")
logger.info(f"Outlook 批量任务 {batch_id} 完成: 成功 {batch_tasks[batch_id]['success']}, 失败 {batch_tasks[batch_id]['failed']}, 跳过 {batch_tasks[batch_id]['skipped']}")
except Exception as e: except Exception as e:
logger.error(f"Outlook 批量任务 {batch_id} 异常: {e}") logger.error(f"Outlook 批量任务 {batch_id} 异常: {e}")
batch_tasks[batch_id]["logs"].append(f"[错误] 批量任务异常: {str(e)}") add_batch_log(f"[错误] 批量任务异常: {str(e)}")
finally: update_batch_status(finished=True, status="failed")
batch_tasks[batch_id]["finished"] = True
async def run_outlook_batch_registration(
batch_id: str,
service_ids: List[int],
skip_registered: bool,
proxy: Optional[str],
interval_min: int,
interval_max: int
):
"""
异步执行 Outlook 批量注册任务
使用线程池执行,避免阻塞主事件循环
"""
loop = task_manager.get_loop()
if loop is None:
loop = asyncio.get_event_loop()
task_manager.set_loop(loop)
# 在线程池中执行
await loop.run_in_executor(
task_manager.executor,
_run_sync_outlook_batch_registration,
batch_id,
service_ids,
skip_registered,
proxy,
interval_min,
interval_max
)
@router.post("/outlook-batch", response_model=OutlookBatchRegistrationResponse) @router.post("/outlook-batch", response_model=OutlookBatchRegistrationResponse)
@@ -1027,3 +1141,20 @@ async def get_outlook_batch_status(batch_id: str):
"logs": batch.get("logs", []), "logs": batch.get("logs", []),
"progress": f"{batch['completed']}/{batch['total']}" "progress": f"{batch['completed']}/{batch['total']}"
} }
@router.post("/outlook-batch/{batch_id}/cancel")
async def cancel_outlook_batch(batch_id: str):
"""取消 Outlook 批量任务"""
if batch_id not in batch_tasks:
raise HTTPException(status_code=404, detail="批量任务不存在")
batch = batch_tasks[batch_id]
if batch.get("finished"):
raise HTTPException(status_code=400, detail="批量任务已完成")
# 同时更新两个系统的取消状态
batch["cancelled"] = True
task_manager.cancel_batch(batch_id)
return {"success": True, "message": "批量任务取消请求已提交"}

View File

@@ -11,7 +11,6 @@ from pydantic import BaseModel
from ...database import crud from ...database import crud
from ...database.session import get_db from ...database.session import get_db
from ...config.settings import get_settings, update_settings from ...config.settings import get_settings, update_settings
from ...config.constants import OTP_WAIT_TIMEOUT, OTP_POLL_INTERVAL
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
router = APIRouter() router = APIRouter()
@@ -72,11 +71,6 @@ async def get_all_settings():
"""获取所有设置""" """获取所有设置"""
settings = get_settings() settings = get_settings()
# 从数据库获取验证码设置
with get_db() as db:
timeout_setting = crud.get_setting(db, "email_code.timeout")
poll_interval_setting = crud.get_setting(db, "email_code.poll_interval")
return { return {
"proxy": { "proxy": {
"enabled": settings.proxy_enabled, "enabled": settings.proxy_enabled,
@@ -104,8 +98,8 @@ async def get_all_settings():
"max_retries": settings.tempmail_max_retries, "max_retries": settings.tempmail_max_retries,
}, },
"email_code": { "email_code": {
"timeout": int(timeout_setting.value) if timeout_setting else OTP_WAIT_TIMEOUT, "timeout": settings.email_code_timeout,
"poll_interval": int(poll_interval_setting.value) if poll_interval_setting else OTP_POLL_INTERVAL, "poll_interval": settings.email_code_poll_interval,
}, },
} }
@@ -409,40 +403,26 @@ async def update_tempmail_settings(request: TempmailSettings):
@router.get("/email-code") @router.get("/email-code")
async def get_email_code_settings(): async def get_email_code_settings():
"""获取验证码等待设置""" """获取验证码等待设置"""
with get_db() as db: settings = get_settings()
timeout_setting = crud.get_setting(db, "email_code.timeout") return {
poll_interval_setting = crud.get_setting(db, "email_code.poll_interval") "timeout": settings.email_code_timeout,
"poll_interval": settings.email_code_poll_interval,
return { }
"timeout": int(timeout_setting.value) if timeout_setting else OTP_WAIT_TIMEOUT,
"poll_interval": int(poll_interval_setting.value) if poll_interval_setting else OTP_POLL_INTERVAL,
}
@router.post("/email-code") @router.post("/email-code")
async def update_email_code_settings(request: EmailCodeSettings): async def update_email_code_settings(request: EmailCodeSettings):
"""更新验证码等待设置""" """更新验证码等待设置"""
with get_db() as db: # 验证参数范围
# 验证参数范围 if request.timeout < 30 or request.timeout > 600:
if request.timeout < 30 or request.timeout > 600: raise HTTPException(status_code=400, detail="超时时间必须在 30-600 秒之间")
raise HTTPException(status_code=400, detail="超时时间必须在 30-600 秒之间") if request.poll_interval < 1 or request.poll_interval > 30:
if request.poll_interval < 1 or request.poll_interval > 30: raise HTTPException(status_code=400, detail="轮询间隔必须在 1-30 秒之间")
raise HTTPException(status_code=400, detail="轮询间隔必须在 1-30 秒之间")
crud.set_setting( update_settings(
db, email_code_timeout=request.timeout,
"email_code.timeout", email_code_poll_interval=request.poll_interval,
str(request.timeout), )
description="验证码等待超时(秒)",
category="email"
)
crud.set_setting(
db,
"email_code.poll_interval",
str(request.poll_interval),
description="验证码轮询间隔(秒)",
category="email"
)
return {"success": True, "message": "验证码等待设置已更新"} return {"success": True, "message": "验证码等待设置已更新"}

View File

@@ -18,6 +18,13 @@ let availableServices = {
custom_domain: { available: false, services: [] } custom_domain: { available: false, services: [] }
}; };
// WebSocket 相关变量
let webSocket = null;
let batchWebSocket = null; // 批量任务 WebSocket
let useWebSocket = true; // 是否使用 WebSocket
let wsHeartbeatInterval = null; // 心跳定时器
let batchWsHeartbeatInterval = null; // 批量任务心跳定时器
// DOM 元素 // DOM 元素
const elements = { const elements = {
form: document.getElementById('registration-form'), form: document.getElementById('registration-form'),
@@ -297,8 +304,8 @@ async function handleSingleRegistration(requestData) {
showTaskStatus(data); showTaskStatus(data);
updateTaskStatus('running'); updateTaskStatus('running');
// 开始轮询日志 // 优先使用 WebSocket
startLogPolling(data.task_uuid); connectWebSocket(data.task_uuid);
} catch (error) { } catch (error) {
addLog('error', `[错误] 启动失败: ${error.message}`); addLog('error', `[错误] 启动失败: ${error.message}`);
@@ -307,6 +314,118 @@ async function handleSingleRegistration(requestData) {
} }
} }
// ============== WebSocket 功能 ==============
// 连接 WebSocket
function connectWebSocket(taskUuid) {
const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
const wsUrl = `${protocol}//${window.location.host}/api/ws/task/${taskUuid}`;
try {
webSocket = new WebSocket(wsUrl);
webSocket.onopen = () => {
console.log('WebSocket 连接成功');
useWebSocket = true;
// 停止轮询(如果有)
stopLogPolling();
// 开始心跳
startWebSocketHeartbeat();
};
webSocket.onmessage = (event) => {
const data = JSON.parse(event.data);
if (data.type === 'log') {
const logType = getLogType(data.message);
addLog(logType, data.message);
} else if (data.type === 'status') {
updateTaskStatus(data.status);
// 检查是否完成
if (['completed', 'failed', 'cancelled', 'cancelling'].includes(data.status)) {
disconnectWebSocket();
resetButtons();
if (data.status === 'completed') {
addLog('success', '[成功] 注册成功!');
toast.success('注册成功!');
// 刷新账号列表
loadRecentAccounts();
} else if (data.status === 'failed') {
addLog('error', '[错误] 注册失败');
toast.error('注册失败');
} else if (data.status === 'cancelled' || data.status === 'cancelling') {
addLog('warning', '[警告] 任务已取消');
}
}
} else if (data.type === 'pong') {
// 心跳响应,忽略
}
};
webSocket.onclose = (event) => {
console.log('WebSocket 连接关闭:', event.code);
stopWebSocketHeartbeat();
// 如果任务仍在运行,切换到轮询
if (currentTask && ['pending', 'running'].includes(currentTask.status)) {
console.log('切换到轮询模式');
useWebSocket = false;
startLogPolling(currentTask.task_uuid);
}
};
webSocket.onerror = (error) => {
console.error('WebSocket 错误:', error);
// 切换到轮询
useWebSocket = false;
stopWebSocketHeartbeat();
startLogPolling(taskUuid);
};
} catch (error) {
console.error('WebSocket 连接失败:', error);
useWebSocket = false;
startLogPolling(taskUuid);
}
}
// 断开 WebSocket
function disconnectWebSocket() {
stopWebSocketHeartbeat();
if (webSocket) {
webSocket.close();
webSocket = null;
}
}
// 开始心跳
function startWebSocketHeartbeat() {
stopWebSocketHeartbeat();
wsHeartbeatInterval = setInterval(() => {
if (webSocket && webSocket.readyState === WebSocket.OPEN) {
webSocket.send(JSON.stringify({ type: 'ping' }));
}
}, 25000); // 每 25 秒发送一次心跳
}
// 停止心跳
function stopWebSocketHeartbeat() {
if (wsHeartbeatInterval) {
clearInterval(wsHeartbeatInterval);
wsHeartbeatInterval = null;
}
}
// 发送取消请求
function cancelViaWebSocket() {
if (webSocket && webSocket.readyState === WebSocket.OPEN) {
webSocket.send(JSON.stringify({ type: 'cancel' }));
}
}
// 批量注册 // 批量注册
async function handleBatchRegistration(requestData) { async function handleBatchRegistration(requestData) {
const count = parseInt(elements.batchCount.value) || 5; const count = parseInt(elements.batchCount.value) || 5;
@@ -340,26 +459,61 @@ async function handleBatchRegistration(requestData) {
// 取消任务 // 取消任务
async function handleCancelTask() { async function handleCancelTask() {
if (isBatchMode && currentBatch) { if (isBatchMode && currentBatch) {
try { // 优先通过 WebSocket 取消批量任务
await api.post(`/registration/batch/${currentBatch.batch_id}/cancel`); if (batchWebSocket && batchWebSocket.readyState === WebSocket.OPEN) {
cancelBatchViaWebSocket();
addLog('warning', '[警告] 批量任务取消请求已提交'); addLog('warning', '[警告] 批量任务取消请求已提交');
toast.info('任务取消请求已提交'); toast.info('任务取消请求已提交');
stopBatchPolling(); } else {
resetButtons(); // 降级到 REST API
} catch (error) { try {
addLog('error', `[错误] 取消失败: ${error.message}`); await api.post(`/registration/batch/${currentBatch.batch_id}/cancel`);
toast.error(error.message); addLog('warning', '[警告] 批量任务取消请求已提交');
toast.info('任务取消请求已提交');
stopBatchPolling();
resetButtons();
} catch (error) {
addLog('error', `[错误] 取消失败: ${error.message}`);
toast.error(error.message);
}
}
} else if (isOutlookBatchMode && currentBatch) {
// Outlook 批量任务取消
if (batchWebSocket && batchWebSocket.readyState === WebSocket.OPEN) {
cancelBatchViaWebSocket();
addLog('warning', '[警告] Outlook 批量任务取消请求已提交');
toast.info('任务取消请求已提交');
} else {
// 降级到 REST API
try {
await api.post(`/registration/outlook-batch/${currentBatch.batch_id}/cancel`);
addLog('warning', '[警告] Outlook 批量任务取消请求已提交');
toast.info('任务取消请求已提交');
stopBatchPolling();
resetButtons();
} catch (error) {
addLog('error', `[错误] 取消失败: ${error.message}`);
toast.error(error.message);
}
} }
} else if (currentTask) { } else if (currentTask) {
try { // 优先通过 WebSocket 取消
await api.post(`/registration/tasks/${currentTask.task_uuid}/cancel`); if (useWebSocket && webSocket && webSocket.readyState === WebSocket.OPEN) {
addLog('warning', '[警告] 任务已取消'); cancelViaWebSocket();
toast.info('任务取消'); addLog('warning', '[警告] 任务取消请求已提交');
stopLogPolling(); toast.info('任务取消请求已提交');
resetButtons(); } else {
} catch (error) { // 降级到 REST API
addLog('error', `[错误] 取消失败: ${error.message}`); try {
toast.error(error.message); await api.post(`/registration/tasks/${currentTask.task_uuid}/cancel`);
addLog('warning', '[警告] 任务已取消');
toast.info('任务已取消');
stopLogPolling();
resetButtons();
} catch (error) {
addLog('error', `[错误] 取消失败: ${error.message}`);
toast.error(error.message);
}
} }
} }
} }
@@ -634,6 +788,9 @@ function resetButtons() {
currentTask = null; currentTask = null;
currentBatch = null; currentBatch = null;
isBatchMode = false; isBatchMode = false;
// 断开 WebSocket
disconnectWebSocket();
disconnectBatchWebSocket();
// 注意:不重置 isOutlookBatchMode因为用户可能想继续使用 Outlook 批量模式 // 注意:不重置 isOutlookBatchMode因为用户可能想继续使用 Outlook 批量模式
} }
@@ -765,8 +922,8 @@ async function handleOutlookBatchRegistration() {
// 初始化批量状态显示 // 初始化批量状态显示
showBatchStatus({ count: data.to_register }); showBatchStatus({ count: data.to_register });
// 开始轮询批量状态 // 优先使用 WebSocket
startOutlookBatchPolling(data.batch_id); connectBatchWebSocket(data.batch_id);
} catch (error) { } catch (error) {
addLog('error', `[错误] 启动失败: ${error.message}`); addLog('error', `[错误] 启动失败: ${error.message}`);
@@ -775,7 +932,125 @@ async function handleOutlookBatchRegistration() {
} }
} }
// 开始轮询 Outlook 批量状态 // ============== 批量任务 WebSocket 功能 ==============
// 连接批量任务 WebSocket
function connectBatchWebSocket(batchId) {
const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
const wsUrl = `${protocol}//${window.location.host}/api/ws/batch/${batchId}`;
try {
batchWebSocket = new WebSocket(wsUrl);
batchWebSocket.onopen = () => {
console.log('批量任务 WebSocket 连接成功');
// 停止轮询(如果有)
stopBatchPolling();
// 开始心跳
startBatchWebSocketHeartbeat();
};
batchWebSocket.onmessage = (event) => {
const data = JSON.parse(event.data);
if (data.type === 'log') {
const logType = getLogType(data.message);
addLog(logType, data.message);
} else if (data.type === 'status') {
// 更新进度
if (data.total !== undefined) {
updateBatchProgress({
total: data.total,
completed: data.completed || 0,
success: data.success || 0,
failed: data.failed || 0
});
}
// 检查是否完成
if (['completed', 'failed', 'cancelled', 'cancelling'].includes(data.status)) {
disconnectBatchWebSocket();
resetButtons();
if (data.status === 'completed') {
addLog('success', `[完成] Outlook 批量任务完成!成功: ${data.success}, 失败: ${data.failed}, 跳过: ${data.skipped || 0}`);
if (data.success > 0) {
toast.success(`Outlook 批量注册完成,成功 ${data.success}`);
loadRecentAccounts();
} else {
toast.warning('Outlook 批量注册完成,但没有成功注册任何账号');
}
} else if (data.status === 'failed') {
addLog('error', '[错误] 批量任务执行失败');
toast.error('批量任务执行失败');
} else if (data.status === 'cancelled' || data.status === 'cancelling') {
addLog('warning', '[警告] 批量任务已取消');
}
}
} else if (data.type === 'pong') {
// 心跳响应,忽略
}
};
batchWebSocket.onclose = (event) => {
console.log('批量任务 WebSocket 连接关闭:', event.code);
stopBatchWebSocketHeartbeat();
// 如果任务仍在运行,切换到轮询
if (currentBatch && !['completed', 'failed', 'cancelled'].includes(currentBatch.status)) {
console.log('切换到轮询模式');
startOutlookBatchPolling(currentBatch.batch_id);
}
};
batchWebSocket.onerror = (error) => {
console.error('批量任务 WebSocket 错误:', error);
stopBatchWebSocketHeartbeat();
// 切换到轮询
startOutlookBatchPolling(batchId);
};
} catch (error) {
console.error('批量任务 WebSocket 连接失败:', error);
startOutlookBatchPolling(batchId);
}
}
// 断开批量任务 WebSocket
function disconnectBatchWebSocket() {
stopBatchWebSocketHeartbeat();
if (batchWebSocket) {
batchWebSocket.close();
batchWebSocket = null;
}
}
// 开始批量任务心跳
function startBatchWebSocketHeartbeat() {
stopBatchWebSocketHeartbeat();
batchWsHeartbeatInterval = setInterval(() => {
if (batchWebSocket && batchWebSocket.readyState === WebSocket.OPEN) {
batchWebSocket.send(JSON.stringify({ type: 'ping' }));
}
}, 25000); // 每 25 秒发送一次心跳
}
// 停止批量任务心跳
function stopBatchWebSocketHeartbeat() {
if (batchWsHeartbeatInterval) {
clearInterval(batchWsHeartbeatInterval);
batchWsHeartbeatInterval = null;
}
}
// 发送批量任务取消请求
function cancelBatchViaWebSocket() {
if (batchWebSocket && batchWebSocket.readyState === WebSocket.OPEN) {
batchWebSocket.send(JSON.stringify({ type: 'cancel' }));
}
}
// 开始轮询 Outlook 批量状态(降级方案)
function startOutlookBatchPolling(batchId) { function startOutlookBatchPolling(batchId) {
batchPollingInterval = setInterval(async () => { batchPollingInterval = setInterval(async () => {
try { try {

View File

@@ -18,7 +18,14 @@ from src.config.settings import get_settings
def setup_application(): def setup_application():
"""设置应用程序""" """设置应用程序"""
# 获取配置 # 初始化数据库(必须先于获取设置)
try:
initialize_database()
except Exception as e:
print(f"数据库初始化失败: {e}")
raise
# 获取配置(需要数据库已初始化)
settings = get_settings() settings = get_settings()
# 配置日志 # 配置日志
@@ -28,14 +35,7 @@ def setup_application():
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.info("数据库初始化完成")
# 初始化数据库
try:
initialize_database()
logger.info("数据库初始化完成")
except Exception as e:
logger.error(f"数据库初始化失败: {e}")
raise
# 检查数据目录 # 检查数据目录
data_dir = project_root / "data" data_dir = project_root / "data"