mirror of
https://github.com/cnlimiter/codex-register.git
synced 2026-06-27 02:01:38 +08:00
2
This commit is contained in:
20
src/database/__init__.py
Normal file
20
src/database/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""
|
||||
数据库模块
|
||||
"""
|
||||
|
||||
from .models import Base, Account, EmailService, RegistrationTask, Setting
|
||||
from .session import get_db, init_database, get_session_manager, DatabaseSessionManager
|
||||
from . import crud
|
||||
|
||||
__all__ = [
|
||||
'Base',
|
||||
'Account',
|
||||
'EmailService',
|
||||
'RegistrationTask',
|
||||
'Setting',
|
||||
'get_db',
|
||||
'init_database',
|
||||
'get_session_manager',
|
||||
'DatabaseSessionManager',
|
||||
'crud',
|
||||
]
|
||||
372
src/database/crud.py
Normal file
372
src/database/crud.py
Normal file
@@ -0,0 +1,372 @@
|
||||
"""
|
||||
数据库 CRUD 操作
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Dict, Any, Union
|
||||
from datetime import datetime, timedelta
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import and_, or_, desc, asc, func
|
||||
|
||||
from .models import Account, EmailService, RegistrationTask, Setting
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 账户 CRUD
|
||||
# ============================================================================
|
||||
|
||||
def create_account(
|
||||
db: Session,
|
||||
email: str,
|
||||
email_service: str,
|
||||
email_service_id: Optional[str] = None,
|
||||
account_id: Optional[str] = None,
|
||||
workspace_id: Optional[str] = None,
|
||||
access_token: Optional[str] = None,
|
||||
refresh_token: Optional[str] = None,
|
||||
id_token: Optional[str] = None,
|
||||
proxy_used: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> Account:
|
||||
"""创建新账户"""
|
||||
db_account = Account(
|
||||
email=email,
|
||||
email_service=email_service,
|
||||
email_service_id=email_service_id,
|
||||
account_id=account_id,
|
||||
workspace_id=workspace_id,
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
id_token=id_token,
|
||||
proxy_used=proxy_used,
|
||||
metadata=metadata or {},
|
||||
registered_at=datetime.utcnow()
|
||||
)
|
||||
db.add(db_account)
|
||||
db.commit()
|
||||
db.refresh(db_account)
|
||||
return db_account
|
||||
|
||||
|
||||
def get_account_by_id(db: Session, account_id: int) -> Optional[Account]:
|
||||
"""根据 ID 获取账户"""
|
||||
return db.query(Account).filter(Account.id == account_id).first()
|
||||
|
||||
|
||||
def get_account_by_email(db: Session, email: str) -> Optional[Account]:
|
||||
"""根据邮箱获取账户"""
|
||||
return db.query(Account).filter(Account.email == email).first()
|
||||
|
||||
|
||||
def get_accounts(
|
||||
db: Session,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
email_service: Optional[str] = None,
|
||||
status: Optional[str] = None,
|
||||
search: Optional[str] = None
|
||||
) -> List[Account]:
|
||||
"""获取账户列表(支持分页、筛选)"""
|
||||
query = db.query(Account)
|
||||
|
||||
if email_service:
|
||||
query = query.filter(Account.email_service == email_service)
|
||||
|
||||
if status:
|
||||
query = query.filter(Account.status == status)
|
||||
|
||||
if search:
|
||||
search_filter = or_(
|
||||
Account.email.ilike(f"%{search}%"),
|
||||
Account.account_id.ilike(f"%{search}%"),
|
||||
Account.workspace_id.ilike(f"%{search}%")
|
||||
)
|
||||
query = query.filter(search_filter)
|
||||
|
||||
query = query.order_by(desc(Account.created_at)).offset(skip).limit(limit)
|
||||
return query.all()
|
||||
|
||||
|
||||
def update_account(
|
||||
db: Session,
|
||||
account_id: int,
|
||||
**kwargs
|
||||
) -> Optional[Account]:
|
||||
"""更新账户信息"""
|
||||
db_account = get_account_by_id(db, account_id)
|
||||
if not db_account:
|
||||
return None
|
||||
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(db_account, key) and value is not None:
|
||||
setattr(db_account, key, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(db_account)
|
||||
return db_account
|
||||
|
||||
|
||||
def delete_account(db: Session, account_id: int) -> bool:
|
||||
"""删除账户"""
|
||||
db_account = get_account_by_id(db, account_id)
|
||||
if not db_account:
|
||||
return False
|
||||
|
||||
db.delete(db_account)
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
|
||||
def delete_accounts_batch(db: Session, account_ids: List[int]) -> int:
|
||||
"""批量删除账户"""
|
||||
result = db.query(Account).filter(Account.id.in_(account_ids)).delete(synchronize_session=False)
|
||||
db.commit()
|
||||
return result
|
||||
|
||||
|
||||
def get_accounts_count(
|
||||
db: Session,
|
||||
email_service: Optional[str] = None,
|
||||
status: Optional[str] = None
|
||||
) -> int:
|
||||
"""获取账户数量"""
|
||||
query = db.query(func.count(Account.id))
|
||||
|
||||
if email_service:
|
||||
query = query.filter(Account.email_service == email_service)
|
||||
|
||||
if status:
|
||||
query = query.filter(Account.status == status)
|
||||
|
||||
return query.scalar()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 邮箱服务 CRUD
|
||||
# ============================================================================
|
||||
|
||||
def create_email_service(
|
||||
db: Session,
|
||||
service_type: str,
|
||||
name: str,
|
||||
config: Dict[str, Any],
|
||||
enabled: bool = True,
|
||||
priority: int = 0
|
||||
) -> EmailService:
|
||||
"""创建邮箱服务配置"""
|
||||
db_service = EmailService(
|
||||
service_type=service_type,
|
||||
name=name,
|
||||
config=config,
|
||||
enabled=enabled,
|
||||
priority=priority
|
||||
)
|
||||
db.add(db_service)
|
||||
db.commit()
|
||||
db.refresh(db_service)
|
||||
return db_service
|
||||
|
||||
|
||||
def get_email_service_by_id(db: Session, service_id: int) -> Optional[EmailService]:
|
||||
"""根据 ID 获取邮箱服务"""
|
||||
return db.query(EmailService).filter(EmailService.id == service_id).first()
|
||||
|
||||
|
||||
def get_email_services(
|
||||
db: Session,
|
||||
service_type: Optional[str] = None,
|
||||
enabled: Optional[bool] = None,
|
||||
skip: int = 0,
|
||||
limit: int = 100
|
||||
) -> List[EmailService]:
|
||||
"""获取邮箱服务列表"""
|
||||
query = db.query(EmailService)
|
||||
|
||||
if service_type:
|
||||
query = query.filter(EmailService.service_type == service_type)
|
||||
|
||||
if enabled is not None:
|
||||
query = query.filter(EmailService.enabled == enabled)
|
||||
|
||||
query = query.order_by(
|
||||
asc(EmailService.priority),
|
||||
desc(EmailService.last_used)
|
||||
).offset(skip).limit(limit)
|
||||
|
||||
return query.all()
|
||||
|
||||
|
||||
def update_email_service(
|
||||
db: Session,
|
||||
service_id: int,
|
||||
**kwargs
|
||||
) -> Optional[EmailService]:
|
||||
"""更新邮箱服务配置"""
|
||||
db_service = get_email_service_by_id(db, service_id)
|
||||
if not db_service:
|
||||
return None
|
||||
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(db_service, key) and value is not None:
|
||||
setattr(db_service, key, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(db_service)
|
||||
return db_service
|
||||
|
||||
|
||||
def delete_email_service(db: Session, service_id: int) -> bool:
|
||||
"""删除邮箱服务配置"""
|
||||
db_service = get_email_service_by_id(db, service_id)
|
||||
if not db_service:
|
||||
return False
|
||||
|
||||
db.delete(db_service)
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 注册任务 CRUD
|
||||
# ============================================================================
|
||||
|
||||
def create_registration_task(
|
||||
db: Session,
|
||||
task_uuid: str,
|
||||
email_service_id: Optional[int] = None,
|
||||
proxy: Optional[str] = None
|
||||
) -> RegistrationTask:
|
||||
"""创建注册任务"""
|
||||
db_task = RegistrationTask(
|
||||
task_uuid=task_uuid,
|
||||
email_service_id=email_service_id,
|
||||
proxy=proxy,
|
||||
status='pending'
|
||||
)
|
||||
db.add(db_task)
|
||||
db.commit()
|
||||
db.refresh(db_task)
|
||||
return db_task
|
||||
|
||||
|
||||
def get_registration_task_by_uuid(db: Session, task_uuid: str) -> Optional[RegistrationTask]:
|
||||
"""根据 UUID 获取注册任务"""
|
||||
return db.query(RegistrationTask).filter(RegistrationTask.task_uuid == task_uuid).first()
|
||||
|
||||
|
||||
def get_registration_tasks(
|
||||
db: Session,
|
||||
status: Optional[str] = None,
|
||||
skip: int = 0,
|
||||
limit: int = 100
|
||||
) -> List[RegistrationTask]:
|
||||
"""获取注册任务列表"""
|
||||
query = db.query(RegistrationTask)
|
||||
|
||||
if status:
|
||||
query = query.filter(RegistrationTask.status == status)
|
||||
|
||||
query = query.order_by(desc(RegistrationTask.created_at)).offset(skip).limit(limit)
|
||||
return query.all()
|
||||
|
||||
|
||||
def update_registration_task(
|
||||
db: Session,
|
||||
task_uuid: str,
|
||||
**kwargs
|
||||
) -> Optional[RegistrationTask]:
|
||||
"""更新注册任务状态"""
|
||||
db_task = get_registration_task_by_uuid(db, task_uuid)
|
||||
if not db_task:
|
||||
return None
|
||||
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(db_task, key):
|
||||
setattr(db_task, key, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(db_task)
|
||||
return db_task
|
||||
|
||||
|
||||
def append_task_log(db: Session, task_uuid: str, log_message: str) -> bool:
|
||||
"""追加任务日志"""
|
||||
db_task = get_registration_task_by_uuid(db, task_uuid)
|
||||
if not db_task:
|
||||
return False
|
||||
|
||||
if db_task.logs:
|
||||
db_task.logs += f"\n{log_message}"
|
||||
else:
|
||||
db_task.logs = log_message
|
||||
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
|
||||
def delete_registration_task(db: Session, task_uuid: str) -> bool:
|
||||
"""删除注册任务"""
|
||||
db_task = get_registration_task_by_uuid(db, task_uuid)
|
||||
if not db_task:
|
||||
return False
|
||||
|
||||
db.delete(db_task)
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
|
||||
# 为 API 路由添加别名
|
||||
get_account = get_account_by_id
|
||||
get_registration_task = get_registration_task_by_uuid
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 设置 CRUD
|
||||
# ============================================================================
|
||||
|
||||
def get_setting(db: Session, key: str) -> Optional[Setting]:
|
||||
"""获取设置"""
|
||||
return db.query(Setting).filter(Setting.key == key).first()
|
||||
|
||||
|
||||
def get_settings_by_category(db: Session, category: str) -> List[Setting]:
|
||||
"""根据分类获取设置"""
|
||||
return db.query(Setting).filter(Setting.category == category).all()
|
||||
|
||||
|
||||
def set_setting(
|
||||
db: Session,
|
||||
key: str,
|
||||
value: str,
|
||||
description: Optional[str] = None,
|
||||
category: str = 'general'
|
||||
) -> Setting:
|
||||
"""设置或更新配置项"""
|
||||
db_setting = get_setting(db, key)
|
||||
if db_setting:
|
||||
db_setting.value = value
|
||||
db_setting.description = description or db_setting.description
|
||||
db_setting.category = category
|
||||
db_setting.updated_at = datetime.utcnow()
|
||||
else:
|
||||
db_setting = Setting(
|
||||
key=key,
|
||||
value=value,
|
||||
description=description,
|
||||
category=category
|
||||
)
|
||||
db.add(db_setting)
|
||||
|
||||
db.commit()
|
||||
db.refresh(db_setting)
|
||||
return db_setting
|
||||
|
||||
|
||||
def delete_setting(db: Session, key: str) -> bool:
|
||||
"""删除设置"""
|
||||
db_setting = get_setting(db, key)
|
||||
if not db_setting:
|
||||
return False
|
||||
|
||||
db.delete(db_setting)
|
||||
db.commit()
|
||||
return True
|
||||
133
src/database/init_db.py
Normal file
133
src/database/init_db.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""
|
||||
数据库初始化和初始化数据
|
||||
"""
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from .session import init_database
|
||||
from .crud import set_setting
|
||||
from .models import Base
|
||||
|
||||
|
||||
def init_default_settings(db):
|
||||
"""初始化默认设置"""
|
||||
# 通用设置
|
||||
default_settings = [
|
||||
("system.name", "OpenAI/Codex CLI 自动注册系统", "系统名称", "general"),
|
||||
("system.version", "2.0.0", "系统版本", "general"),
|
||||
("logs.retention_days", "30", "日志保留天数", "general"),
|
||||
|
||||
# OpenAI 配置
|
||||
("openai.client_id", "app_EMoamEEZ73f0CkXaXp7hrann", "OpenAI OAuth Client ID", "openai"),
|
||||
("openai.auth_url", "https://auth.openai.com/oauth/authorize", "OpenAI 认证地址", "openai"),
|
||||
("openai.token_url", "https://auth.openai.com/oauth/token", "OpenAI Token 地址", "openai"),
|
||||
("openai.redirect_uri", "http://localhost:1455/auth/callback", "OpenAI 回调地址", "openai"),
|
||||
("openai.scope", "openid email profile offline_access", "OpenAI 权限范围", "openai"),
|
||||
|
||||
# 代理设置
|
||||
("proxy.enabled", "false", "是否启用代理", "proxy"),
|
||||
("proxy.type", "http", "代理类型 (http/socks5)", "proxy"),
|
||||
("proxy.host", "127.0.0.1", "代理主机", "proxy"),
|
||||
("proxy.port", "7890", "代理端口", "proxy"),
|
||||
|
||||
# 注册设置
|
||||
("registration.max_retries", "3", "最大重试次数", "registration"),
|
||||
("registration.timeout", "120", "超时时间(秒)", "registration"),
|
||||
("registration.default_password_length", "12", "默认密码长度", "registration"),
|
||||
|
||||
# Web UI 设置
|
||||
("webui.host", "0.0.0.0", "Web UI 监听主机", "webui"),
|
||||
("webui.port", "8000", "Web UI 监听端口", "webui"),
|
||||
("webui.debug", "true", "调试模式", "webui"),
|
||||
]
|
||||
|
||||
for key, value, description, category in default_settings:
|
||||
set_setting(db, key, value, description, category)
|
||||
|
||||
|
||||
def init_default_email_services(db):
|
||||
"""初始化默认邮箱服务(仅模板,需要用户配置)"""
|
||||
# 这里只创建模板配置,实际配置需要用户通过 Web UI 设置
|
||||
pass
|
||||
|
||||
|
||||
def initialize_database(database_url: str = None):
|
||||
"""
|
||||
初始化数据库
|
||||
创建所有表并设置默认配置
|
||||
"""
|
||||
# 初始化数据库连接和表
|
||||
db_manager = init_database(database_url)
|
||||
|
||||
# 在事务中设置默认配置
|
||||
with db_manager.session_scope() as session:
|
||||
# 初始化默认设置
|
||||
init_default_settings(session)
|
||||
|
||||
# 初始化默认邮箱服务
|
||||
init_default_email_services(session)
|
||||
|
||||
print("数据库初始化完成")
|
||||
return db_manager
|
||||
|
||||
|
||||
def reset_database(database_url: str = None):
|
||||
"""
|
||||
重置数据库(删除所有表并重新创建)
|
||||
警告:会丢失所有数据!
|
||||
"""
|
||||
db_manager = init_database(database_url)
|
||||
|
||||
# 删除所有表
|
||||
db_manager.drop_tables()
|
||||
print("已删除所有表")
|
||||
|
||||
# 重新创建所有表
|
||||
db_manager.create_tables()
|
||||
print("已重新创建所有表")
|
||||
|
||||
# 初始化数据
|
||||
with db_manager.session_scope() as session:
|
||||
init_default_settings(session)
|
||||
|
||||
print("数据库重置完成")
|
||||
return db_manager
|
||||
|
||||
|
||||
def check_database_connection(database_url: str = None) -> bool:
|
||||
"""
|
||||
检查数据库连接是否正常
|
||||
"""
|
||||
try:
|
||||
db_manager = init_database(database_url)
|
||||
with db_manager.get_db() as db:
|
||||
# 尝试执行一个简单的查询
|
||||
db.execute("SELECT 1")
|
||||
print("数据库连接正常")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"数据库连接失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 当直接运行此脚本时,初始化数据库
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="数据库初始化脚本")
|
||||
parser.add_argument("--reset", action="store_true", help="重置数据库(删除所有数据)")
|
||||
parser.add_argument("--check", action="store_true", help="检查数据库连接")
|
||||
parser.add_argument("--url", help="数据库连接字符串")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.check:
|
||||
check_database_connection(args.url)
|
||||
elif args.reset:
|
||||
confirm = input("警告:这将删除所有数据!确认重置?(y/N): ")
|
||||
if confirm.lower() == 'y':
|
||||
reset_database(args.url)
|
||||
else:
|
||||
print("操作已取消")
|
||||
else:
|
||||
initialize_database(args.url)
|
||||
113
src/database/models.py
Normal file
113
src/database/models.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""
|
||||
SQLAlchemy ORM 模型定义
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any
|
||||
import json
|
||||
from sqlalchemy import Column, Integer, String, Text, Boolean, DateTime, ForeignKey
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.types import TypeDecorator
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class JSONEncodedDict(TypeDecorator):
|
||||
"""JSON 编码字典类型"""
|
||||
impl = Text
|
||||
|
||||
def process_bind_param(self, value: Optional[Dict[str, Any]], dialect):
|
||||
if value is None:
|
||||
return None
|
||||
return json.dumps(value, ensure_ascii=False)
|
||||
|
||||
def process_result_value(self, value: Optional[str], dialect):
|
||||
if value is None:
|
||||
return None
|
||||
return json.loads(value)
|
||||
|
||||
|
||||
class Account(Base):
|
||||
"""已注册账号表"""
|
||||
__tablename__ = 'accounts'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
email = Column(String(255), nullable=False, unique=True, index=True)
|
||||
password_hash = Column(String(255))
|
||||
access_token = Column(Text)
|
||||
refresh_token = Column(Text)
|
||||
id_token = Column(Text)
|
||||
account_id = Column(String(255))
|
||||
workspace_id = Column(String(255))
|
||||
email_service = Column(String(50), nullable=False) # 'tempmail', 'outlook', 'custom_domain'
|
||||
email_service_id = Column(String(255)) # 邮箱服务中的ID
|
||||
proxy_used = Column(String(255))
|
||||
registered_at = Column(DateTime, default=datetime.utcnow)
|
||||
last_refresh = Column(DateTime)
|
||||
expires_at = Column(DateTime)
|
||||
status = Column(String(20), default='active') # 'active', 'expired', 'banned', 'failed'
|
||||
extra_data = Column(JSONEncodedDict) # 额外信息存储
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典"""
|
||||
return {
|
||||
'id': self.id,
|
||||
'email': self.email,
|
||||
'email_service': self.email_service,
|
||||
'account_id': self.account_id,
|
||||
'workspace_id': self.workspace_id,
|
||||
'registered_at': self.registered_at.isoformat() if self.registered_at else None,
|
||||
'status': self.status,
|
||||
'proxy_used': self.proxy_used,
|
||||
'created_at': self.created_at.isoformat() if self.created_at else None,
|
||||
'updated_at': self.updated_at.isoformat() if self.updated_at else None
|
||||
}
|
||||
|
||||
|
||||
class EmailService(Base):
|
||||
"""邮箱服务配置表"""
|
||||
__tablename__ = 'email_services'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
service_type = Column(String(50), nullable=False) # 'outlook', 'custom_domain'
|
||||
name = Column(String(100), nullable=False)
|
||||
config = Column(JSONEncodedDict, nullable=False) # 服务配置(加密存储)
|
||||
enabled = Column(Boolean, default=True)
|
||||
priority = Column(Integer, default=0) # 使用优先级
|
||||
last_used = Column(DateTime)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
|
||||
class RegistrationTask(Base):
|
||||
"""注册任务表"""
|
||||
__tablename__ = 'registration_tasks'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
task_uuid = Column(String(36), unique=True, nullable=False, index=True) # 任务唯一标识
|
||||
status = Column(String(20), default='pending') # 'pending', 'running', 'completed', 'failed', 'cancelled'
|
||||
email_service_id = Column(Integer, ForeignKey('email_services.id'), index=True) # 使用的邮箱服务
|
||||
proxy = Column(String(255)) # 使用的代理
|
||||
logs = Column(Text) # 注册过程日志
|
||||
result = Column(JSONEncodedDict) # 注册结果
|
||||
error_message = Column(Text)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
started_at = Column(DateTime)
|
||||
completed_at = Column(DateTime)
|
||||
|
||||
# 关系
|
||||
email_service = relationship('EmailService')
|
||||
|
||||
|
||||
class Setting(Base):
|
||||
"""系统设置表"""
|
||||
__tablename__ = 'settings'
|
||||
|
||||
key = Column(String(100), primary_key=True)
|
||||
value = Column(Text)
|
||||
description = Column(Text)
|
||||
category = Column(String(50), default='general') # 'general', 'email', 'proxy', 'openai'
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
115
src/database/session.py
Normal file
115
src/database/session.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""
|
||||
数据库会话管理
|
||||
"""
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Generator
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker, Session
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
import os
|
||||
|
||||
from .models import Base
|
||||
|
||||
|
||||
class DatabaseSessionManager:
|
||||
"""数据库会话管理器"""
|
||||
|
||||
def __init__(self, database_url: str = None):
|
||||
if database_url is None:
|
||||
# 默认使用项目根目录下的 SQLite 数据库
|
||||
db_path = os.path.join(
|
||||
os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
|
||||
'data',
|
||||
'database.db'
|
||||
)
|
||||
# 确保目录存在
|
||||
os.makedirs(os.path.dirname(db_path), exist_ok=True)
|
||||
database_url = f"sqlite:///{db_path}"
|
||||
|
||||
self.database_url = database_url
|
||||
self.engine = create_engine(
|
||||
database_url,
|
||||
connect_args={"check_same_thread": False} if database_url.startswith("sqlite") else {},
|
||||
echo=False, # 设置为 True 可以查看所有 SQL 语句
|
||||
pool_pre_ping=True # 连接池预检查
|
||||
)
|
||||
self.SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=self.engine)
|
||||
|
||||
def get_db(self) -> Generator[Session, None, None]:
|
||||
"""
|
||||
获取数据库会话的上下文管理器
|
||||
使用示例:
|
||||
with get_db() as db:
|
||||
# 使用 db 进行数据库操作
|
||||
pass
|
||||
"""
|
||||
db = self.SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@contextmanager
|
||||
def session_scope(self) -> Generator[Session, None, None]:
|
||||
"""
|
||||
事务作用域上下文管理器
|
||||
使用示例:
|
||||
with session_scope() as session:
|
||||
# 数据库操作
|
||||
pass
|
||||
"""
|
||||
session = self.SessionLocal()
|
||||
try:
|
||||
yield session
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
raise e
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
def create_tables(self):
|
||||
"""创建所有表"""
|
||||
Base.metadata.create_all(bind=self.engine)
|
||||
|
||||
def drop_tables(self):
|
||||
"""删除所有表(谨慎使用)"""
|
||||
Base.metadata.drop_all(bind=self.engine)
|
||||
|
||||
|
||||
# 全局数据库会话管理器实例
|
||||
_db_manager: DatabaseSessionManager = None
|
||||
|
||||
|
||||
def init_database(database_url: str = None) -> DatabaseSessionManager:
|
||||
"""
|
||||
初始化数据库会话管理器
|
||||
"""
|
||||
global _db_manager
|
||||
if _db_manager is None:
|
||||
_db_manager = DatabaseSessionManager(database_url)
|
||||
_db_manager.create_tables()
|
||||
return _db_manager
|
||||
|
||||
|
||||
def get_session_manager() -> DatabaseSessionManager:
|
||||
"""
|
||||
获取数据库会话管理器
|
||||
"""
|
||||
if _db_manager is None:
|
||||
raise RuntimeError("数据库未初始化,请先调用 init_database()")
|
||||
return _db_manager
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_db() -> Generator[Session, None, None]:
|
||||
"""
|
||||
获取数据库会话的快捷函数
|
||||
"""
|
||||
manager = get_session_manager()
|
||||
db = manager.SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
Reference in New Issue
Block a user