Files
codex-register/src/database/crud.py
Jay Hsueh 05e480a756 feat(newapi): 添加 NEWAPI 上传功能及服务管理接口
- 新增 `newapi_upload.py` 文件,包含上传到 NEWAPI 的功能。
- 在数据库模型中添加 `NewapiService` 表及相关字段。
- 更新 CRUD 操作以支持 NEWAPI 服务的创建、更新、查询和删除。
- 添加新的 API 路由以管理 NEWAPI 服务。
- 前端实现批量上传和单个账号上传到 NEWAPI 的功能。
- 更新相关页面以支持 NEWAPI 服务的选择和管理。
2026-03-24 17:46:33 +08:00

867 lines
24 KiB
Python

"""
数据库 CRUD 操作
"""
from typing import List, Optional, Dict, Any, Union, Iterable, Set
from datetime import datetime, timedelta
from sqlalchemy.orm import Session
from sqlalchemy.orm.attributes import flag_modified
from sqlalchemy import and_, or_, desc, asc, func
from .models import Account, EmailService, RegistrationTask, Setting, Proxy, CpaService, Sub2ApiService, NewapiService
TOKEN_FIELD_NAMES = ("access_token", "refresh_token", "id_token", "session_token")
def _default_token_sync_status(token_values: Dict[str, Any]) -> str:
"""根据当前持久化的 token 内容推导同步状态。"""
has_token = any(bool(token_values.get(field)) for field in TOKEN_FIELD_NAMES)
return "pending" if has_token else "not_ready"
# ============================================================================
# 账户 CRUD
# ============================================================================
def create_account(
db: Session,
email: str,
email_service: str,
password: Optional[str] = None,
client_id: Optional[str] = None,
session_token: Optional[str] = None,
email_service_id: Optional[str] = None,
account_id: Optional[str] = None,
workspace_id: Optional[str] = None,
access_token: Optional[str] = None,
refresh_token: Optional[str] = None,
id_token: Optional[str] = None,
cookies: Optional[str] = None,
proxy_used: Optional[str] = None,
expires_at: Optional['datetime'] = None,
extra_data: Optional[Dict[str, Any]] = None,
status: Optional[str] = None,
source: Optional[str] = None,
token_sync_status: Optional[str] = None,
) -> Account:
"""创建新账户"""
token_values = {
"access_token": access_token,
"refresh_token": refresh_token,
"id_token": id_token,
"session_token": session_token,
}
db_account = Account(
email=email,
password=password,
client_id=client_id,
session_token=session_token,
email_service=email_service,
email_service_id=email_service_id,
account_id=account_id,
workspace_id=workspace_id,
access_token=access_token,
refresh_token=refresh_token,
id_token=id_token,
cookies=cookies,
proxy_used=proxy_used,
expires_at=expires_at,
extra_data=extra_data or {},
status=status or 'active',
source=source or 'register',
registered_at=datetime.utcnow(),
token_sync_status=token_sync_status or _default_token_sync_status(token_values),
token_sync_updated_at=datetime.utcnow(),
)
db.add(db_account)
db.commit()
db.refresh(db_account)
return db_account
def get_account_by_id(db: Session, account_id: int) -> Optional[Account]:
"""根据 ID 获取账户"""
return db.query(Account).filter(Account.id == account_id).first()
def get_account_by_email(db: Session, email: str) -> Optional[Account]:
"""根据邮箱获取账户"""
return db.query(Account).filter(Account.email == email).first()
def get_accounts(
db: Session,
skip: int = 0,
limit: int = 100,
email_service: Optional[str] = None,
status: Optional[str] = None,
search: Optional[str] = None
) -> List[Account]:
"""获取账户列表(支持分页、筛选)"""
query = db.query(Account)
if email_service:
query = query.filter(Account.email_service == email_service)
if status:
query = query.filter(Account.status == status)
if search:
search_filter = or_(
Account.email.ilike(f"%{search}%"),
Account.account_id.ilike(f"%{search}%"),
Account.workspace_id.ilike(f"%{search}%")
)
query = query.filter(search_filter)
query = query.order_by(desc(Account.created_at)).offset(skip).limit(limit)
return query.all()
def update_account(
db: Session,
account_id: int,
**kwargs
) -> Optional[Account]:
"""更新账户信息"""
db_account = get_account_by_id(db, account_id)
if not db_account:
return None
touches_token = any(field in kwargs for field in TOKEN_FIELD_NAMES)
if touches_token:
persisted_token_values = {
field: kwargs.get(field, getattr(db_account, field))
for field in TOKEN_FIELD_NAMES
}
kwargs.setdefault("token_sync_status", _default_token_sync_status(persisted_token_values))
kwargs["token_sync_updated_at"] = datetime.utcnow()
for key, value in kwargs.items():
if hasattr(db_account, key) and value is not None:
setattr(db_account, key, value)
db.commit()
db.refresh(db_account)
return db_account
def delete_account(db: Session, account_id: int) -> bool:
"""删除账户"""
db_account = get_account_by_id(db, account_id)
if not db_account:
return False
db.delete(db_account)
db.commit()
return True
def delete_accounts_batch(db: Session, account_ids: List[int]) -> int:
"""批量删除账户"""
result = db.query(Account).filter(Account.id.in_(account_ids)).delete(synchronize_session=False)
db.commit()
return result
def get_accounts_count(
db: Session,
email_service: Optional[str] = None,
status: Optional[str] = None
) -> int:
"""获取账户数量"""
query = db.query(func.count(Account.id))
if email_service:
query = query.filter(Account.email_service == email_service)
if status:
query = query.filter(Account.status == status)
return query.scalar()
# ============================================================================
# 邮箱服务 CRUD
# ============================================================================
def create_email_service(
db: Session,
service_type: str,
name: str,
config: Dict[str, Any],
enabled: bool = True,
priority: int = 0
) -> EmailService:
"""创建邮箱服务配置"""
db_service = EmailService(
service_type=service_type,
name=name,
config=config,
enabled=enabled,
priority=priority
)
db.add(db_service)
db.commit()
db.refresh(db_service)
return db_service
def get_email_service_by_id(db: Session, service_id: int) -> Optional[EmailService]:
"""根据 ID 获取邮箱服务"""
return db.query(EmailService).filter(EmailService.id == service_id).first()
def get_email_services(
db: Session,
service_type: Optional[str] = None,
enabled: Optional[bool] = None,
skip: int = 0,
limit: int = 100
) -> List[EmailService]:
"""获取邮箱服务列表"""
query = db.query(EmailService)
if service_type:
query = query.filter(EmailService.service_type == service_type)
if enabled is not None:
query = query.filter(EmailService.enabled == enabled)
query = query.order_by(
asc(EmailService.priority),
desc(EmailService.last_used)
).offset(skip).limit(limit)
return query.all()
def update_email_service(
db: Session,
service_id: int,
**kwargs
) -> Optional[EmailService]:
"""更新邮箱服务配置"""
db_service = get_email_service_by_id(db, service_id)
if not db_service:
return None
for key, value in kwargs.items():
if hasattr(db_service, key) and value is not None:
setattr(db_service, key, value)
db.commit()
db.refresh(db_service)
return db_service
def delete_email_service(db: Session, service_id: int) -> bool:
"""删除邮箱服务配置"""
db_service = get_email_service_by_id(db, service_id)
if not db_service:
return False
db.delete(db_service)
db.commit()
return True
# ============================================================================
# 注册任务 CRUD
# ============================================================================
def create_registration_task(
db: Session,
task_uuid: str,
email_service_id: Optional[int] = None,
proxy: Optional[str] = None
) -> RegistrationTask:
"""创建注册任务"""
db_task = RegistrationTask(
task_uuid=task_uuid,
email_service_id=email_service_id,
proxy=proxy,
status='pending'
)
db.add(db_task)
db.commit()
db.refresh(db_task)
return db_task
def get_registration_task_by_uuid(db: Session, task_uuid: str) -> Optional[RegistrationTask]:
"""根据 UUID 获取注册任务"""
return db.query(RegistrationTask).filter(RegistrationTask.task_uuid == task_uuid).first()
def get_registration_tasks(
db: Session,
status: Optional[str] = None,
skip: int = 0,
limit: int = 100
) -> List[RegistrationTask]:
"""获取注册任务列表"""
query = db.query(RegistrationTask)
if status:
query = query.filter(RegistrationTask.status == status)
query = query.order_by(desc(RegistrationTask.created_at)).offset(skip).limit(limit)
return query.all()
def update_registration_task(
db: Session,
task_uuid: str,
**kwargs
) -> Optional[RegistrationTask]:
"""更新注册任务状态"""
db_task = get_registration_task_by_uuid(db, task_uuid)
if not db_task:
return None
for key, value in kwargs.items():
if hasattr(db_task, key):
setattr(db_task, key, value)
db.commit()
db.refresh(db_task)
return db_task
def append_task_log(db: Session, task_uuid: str, log_message: str) -> bool:
"""追加任务日志"""
db_task = get_registration_task_by_uuid(db, task_uuid)
if not db_task:
return False
if db_task.logs:
db_task.logs += f"\n{log_message}"
else:
db_task.logs = log_message
db.commit()
return True
def delete_registration_task(db: Session, task_uuid: str) -> bool:
"""删除注册任务"""
db_task = get_registration_task_by_uuid(db, task_uuid)
if not db_task:
return False
db.delete(db_task)
db.commit()
return True
def fail_incomplete_registration_tasks(db: Session, error_message: str) -> List[str]:
"""将服务重启后遗留的未完成任务标记为失败"""
tasks = db.query(RegistrationTask).filter(
RegistrationTask.status.in_(("pending", "running"))
).all()
if not tasks:
return []
now = datetime.utcnow()
cleaned_task_ids: List[str] = []
cleanup_log = f"[系统] {error_message}"
for task in tasks:
task.status = "failed"
task.error_message = error_message
task.completed_at = now
if task.logs:
if cleanup_log not in task.logs:
task.logs = f"{task.logs}\n{cleanup_log}"
else:
task.logs = cleanup_log
cleaned_task_ids.append(task.task_uuid)
db.commit()
return cleaned_task_ids
# 为 API 路由添加别名
get_account = get_account_by_id
get_registration_task = get_registration_task_by_uuid
# ============================================================================
# 设置 CRUD
# ============================================================================
def get_setting(db: Session, key: str) -> Optional[Setting]:
"""获取设置"""
return db.query(Setting).filter(Setting.key == key).first()
def get_settings_by_category(db: Session, category: str) -> List[Setting]:
"""根据分类获取设置"""
return db.query(Setting).filter(Setting.category == category).all()
def set_setting(
db: Session,
key: str,
value: str,
description: Optional[str] = None,
category: str = 'general'
) -> Setting:
"""设置或更新配置项"""
db_setting = get_setting(db, key)
if db_setting:
db_setting.value = value
db_setting.description = description or db_setting.description
db_setting.category = category
db_setting.updated_at = datetime.utcnow()
else:
db_setting = Setting(
key=key,
value=value,
description=description,
category=category
)
db.add(db_setting)
db.commit()
db.refresh(db_setting)
return db_setting
def delete_setting(db: Session, key: str) -> bool:
"""删除设置"""
db_setting = get_setting(db, key)
if not db_setting:
return False
db.delete(db_setting)
db.commit()
return True
# ============================================================================
# 代理 CRUD
# ============================================================================
def create_proxy(
db: Session,
name: str,
type: str,
host: str,
port: int,
username: Optional[str] = None,
password: Optional[str] = None,
enabled: bool = True,
priority: int = 0
) -> Proxy:
"""创建代理配置"""
db_proxy = Proxy(
name=name,
type=type,
host=host,
port=port,
username=username,
password=password,
enabled=enabled,
priority=priority
)
db.add(db_proxy)
db.commit()
db.refresh(db_proxy)
return db_proxy
def get_proxy_by_id(db: Session, proxy_id: int) -> Optional[Proxy]:
"""根据 ID 获取代理"""
return db.query(Proxy).filter(Proxy.id == proxy_id).first()
def get_proxies(
db: Session,
enabled: Optional[bool] = None,
skip: int = 0,
limit: int = 100
) -> List[Proxy]:
"""获取代理列表"""
query = db.query(Proxy)
if enabled is not None:
query = query.filter(Proxy.enabled == enabled)
query = query.order_by(desc(Proxy.created_at)).offset(skip).limit(limit)
return query.all()
def get_enabled_proxies(db: Session, exclude_ids: Optional[Iterable[int]] = None) -> List[Proxy]:
"""获取所有启用的代理"""
query = db.query(Proxy).filter(Proxy.enabled == True)
excluded: Set[int] = {int(proxy_id) for proxy_id in (exclude_ids or [])}
if excluded:
query = query.filter(~Proxy.id.in_(excluded))
return query.all()
def update_proxy(
db: Session,
proxy_id: int,
**kwargs
) -> Optional[Proxy]:
"""更新代理配置"""
db_proxy = get_proxy_by_id(db, proxy_id)
if not db_proxy:
return None
for key, value in kwargs.items():
if hasattr(db_proxy, key):
setattr(db_proxy, key, value)
db.commit()
db.refresh(db_proxy)
return db_proxy
def delete_proxy(db: Session, proxy_id: int) -> bool:
"""删除代理配置"""
db_proxy = get_proxy_by_id(db, proxy_id)
if not db_proxy:
return False
db.delete(db_proxy)
db.commit()
return True
def update_proxy_last_used(db: Session, proxy_id: int) -> bool:
"""更新代理最后使用时间"""
db_proxy = get_proxy_by_id(db, proxy_id)
if not db_proxy:
return False
db_proxy.last_used = datetime.utcnow()
db.commit()
return True
def get_random_proxy(db: Session, exclude_ids: Optional[Iterable[int]] = None) -> Optional[Proxy]:
"""随机获取一个启用的代理,优先返回 is_default=True 的代理"""
import random
excluded: Set[int] = {int(proxy_id) for proxy_id in (exclude_ids or [])}
# 优先返回默认代理
default_query = db.query(Proxy).filter(Proxy.enabled == True, Proxy.is_default == True)
if excluded:
default_query = default_query.filter(~Proxy.id.in_(excluded))
default_proxy = default_query.first()
if default_proxy:
return default_proxy
proxies = get_enabled_proxies(db, exclude_ids=excluded)
if not proxies:
return None
return random.choice(proxies)
def set_proxy_default(db: Session, proxy_id: int) -> Optional[Proxy]:
"""将指定代理设为默认,同时清除其他代理的默认标记"""
# 清除所有默认标记
db.query(Proxy).filter(Proxy.is_default == True).update({"is_default": False})
# 设置新的默认代理
proxy = db.query(Proxy).filter(Proxy.id == proxy_id).first()
if proxy:
proxy.is_default = True
db.commit()
db.refresh(proxy)
return proxy
def get_proxies_count(db: Session, enabled: Optional[bool] = None) -> int:
"""获取代理数量"""
query = db.query(func.count(Proxy.id))
if enabled is not None:
query = query.filter(Proxy.enabled == enabled)
return query.scalar()
# ============================================================================
# CPA 服务 CRUD
# ============================================================================
def create_cpa_service(
db: Session,
name: str,
api_url: str,
api_token: str,
enabled: bool = True,
include_proxy_url: bool = False,
priority: int = 0
) -> CpaService:
"""创建 CPA 服务配置"""
db_service = CpaService(
name=name,
api_url=api_url,
api_token=api_token,
enabled=enabled,
include_proxy_url=include_proxy_url,
priority=priority
)
db.add(db_service)
db.commit()
db.refresh(db_service)
return db_service
def get_cpa_service_by_id(db: Session, service_id: int) -> Optional[CpaService]:
"""根据 ID 获取 CPA 服务"""
return db.query(CpaService).filter(CpaService.id == service_id).first()
def get_cpa_services(
db: Session,
enabled: Optional[bool] = None
) -> List[CpaService]:
"""获取 CPA 服务列表"""
query = db.query(CpaService)
if enabled is not None:
query = query.filter(CpaService.enabled == enabled)
return query.order_by(asc(CpaService.priority), asc(CpaService.id)).all()
def update_cpa_service(
db: Session,
service_id: int,
**kwargs
) -> Optional[CpaService]:
"""更新 CPA 服务配置"""
db_service = get_cpa_service_by_id(db, service_id)
if not db_service:
return None
for key, value in kwargs.items():
if hasattr(db_service, key):
setattr(db_service, key, value)
db.commit()
db.refresh(db_service)
return db_service
def delete_cpa_service(db: Session, service_id: int) -> bool:
"""删除 CPA 服务配置"""
db_service = get_cpa_service_by_id(db, service_id)
if not db_service:
return False
db.delete(db_service)
db.commit()
return True
# ============================================================================
# Sub2API 服务 CRUD
# ============================================================================
def create_sub2api_service(
db: Session,
name: str,
api_url: str,
api_key: str,
enabled: bool = True,
priority: int = 0
) -> Sub2ApiService:
"""创建 Sub2API 服务配置"""
svc = Sub2ApiService(
name=name,
api_url=api_url,
api_key=api_key,
enabled=enabled,
priority=priority,
)
db.add(svc)
db.commit()
db.refresh(svc)
return svc
def get_sub2api_service_by_id(db: Session, service_id: int) -> Optional[Sub2ApiService]:
"""按 ID 获取 Sub2API 服务"""
return db.query(Sub2ApiService).filter(Sub2ApiService.id == service_id).first()
def get_sub2api_services(
db: Session,
enabled: Optional[bool] = None
) -> List[Sub2ApiService]:
"""获取 Sub2API 服务列表"""
query = db.query(Sub2ApiService)
if enabled is not None:
query = query.filter(Sub2ApiService.enabled == enabled)
return query.order_by(asc(Sub2ApiService.priority), asc(Sub2ApiService.id)).all()
def update_sub2api_service(db: Session, service_id: int, **kwargs) -> Optional[Sub2ApiService]:
"""更新 Sub2API 服务配置"""
svc = get_sub2api_service_by_id(db, service_id)
if not svc:
return None
for key, value in kwargs.items():
setattr(svc, key, value)
db.commit()
db.refresh(svc)
return svc
def delete_sub2api_service(db: Session, service_id: int) -> bool:
"""删除 Sub2API 服务配置"""
svc = get_sub2api_service_by_id(db, service_id)
if not svc:
return False
db.delete(svc)
db.commit()
return True
# ============================================================================
# Team Manager 服务 CRUD
# ============================================================================
def create_tm_service(
db: Session,
name: str,
api_url: str,
api_key: str,
enabled: bool = True,
priority: int = 0,
):
"""创建 Team Manager 服务配置"""
from .models import TeamManagerService
svc = TeamManagerService(
name=name,
api_url=api_url,
api_key=api_key,
enabled=enabled,
priority=priority,
)
db.add(svc)
db.commit()
db.refresh(svc)
return svc
def get_tm_service_by_id(db: Session, service_id: int):
"""按 ID 获取 Team Manager 服务"""
from .models import TeamManagerService
return db.query(TeamManagerService).filter(TeamManagerService.id == service_id).first()
def get_tm_services(db: Session, enabled=None):
"""获取 Team Manager 服务列表"""
from .models import TeamManagerService
q = db.query(TeamManagerService)
if enabled is not None:
q = q.filter(TeamManagerService.enabled == enabled)
return q.order_by(TeamManagerService.priority.asc(), TeamManagerService.id.asc()).all()
def update_tm_service(db: Session, service_id: int, **kwargs):
"""更新 Team Manager 服务配置"""
svc = get_tm_service_by_id(db, service_id)
if not svc:
return None
for k, v in kwargs.items():
setattr(svc, k, v)
db.commit()
db.refresh(svc)
return svc
def delete_tm_service(db: Session, service_id: int) -> bool:
"""删除 Team Manager 服务配置"""
svc = get_tm_service_by_id(db, service_id)
if not svc:
return False
db.delete(svc)
db.commit()
return True
def create_newapi_service(
db: Session,
name: str,
api_url: str,
api_key: str,
enabled: bool = True,
priority: int = 0,
) -> NewapiService:
svc = NewapiService(
name=name,
api_url=api_url,
api_key=api_key,
enabled=enabled,
priority=priority,
)
db.add(svc)
db.commit()
db.refresh(svc)
return svc
def get_newapi_service_by_id(db: Session, service_id: int) -> Optional[NewapiService]:
return db.query(NewapiService).filter(NewapiService.id == service_id).first()
def get_newapi_services(db: Session, enabled=None):
q = db.query(NewapiService)
if enabled is not None:
q = q.filter(NewapiService.enabled == enabled)
return q.order_by(NewapiService.priority.asc(), NewapiService.id.asc()).all()
def update_newapi_service(db: Session, service_id: int, **kwargs):
svc = get_newapi_service_by_id(db, service_id)
if not svc:
return None
for k, v in kwargs.items():
setattr(svc, k, v)
db.commit()
db.refresh(svc)
return svc
def delete_newapi_service(db: Session, service_id: int) -> bool:
svc = get_newapi_service_by_id(db, service_id)
if not svc:
return False
db.delete(svc)
db.commit()
return True
def update_outlook_refresh_token(db: Session, service_id: int, email: str, new_refresh_token: str):
"""更新 EmailService.config 中指定邮箱的 refresh_token"""
service = db.query(EmailService).filter(EmailService.id == service_id).first()
if not service or not isinstance(service.config, dict):
return
normalized_email = (email or "").strip().lower()
if not normalized_email or not isinstance(new_refresh_token, str) or not new_refresh_token:
return
config = dict(service.config)
updated = False
# 单账户格式
if str(config.get("email", "")).lower() == normalized_email:
config["refresh_token"] = new_refresh_token
updated = True
# 多账户列表格式
for acc in config.get("accounts", []):
if not isinstance(acc, dict):
continue
if str(acc.get("email", "")).lower() == normalized_email:
acc["refresh_token"] = new_refresh_token
updated = True
if not updated:
return
service.config = config
flag_modified(service, "config")
db.commit()