mirror of
https://github.com/cnlimiter/codex-register.git
synced 2026-05-15 04:18:51 +08:00
增加为需要输入密码才能访问,同时支持远程PGSQL
This commit is contained in:
@@ -87,6 +87,13 @@ SETTING_DEFINITIONS: Dict[str, SettingDefinition] = {
|
||||
description="Web UI 密钥",
|
||||
is_secret=True
|
||||
),
|
||||
"webui_access_password": SettingDefinition(
|
||||
db_key="webui.access_password",
|
||||
default_value="admin123",
|
||||
category=SettingCategory.WEBUI,
|
||||
description="Web UI 访问密码",
|
||||
is_secret=True
|
||||
),
|
||||
|
||||
# 日志配置
|
||||
"log_level": SettingDefinition(
|
||||
@@ -434,6 +441,14 @@ def _convert_value(attr_name: str, value: str) -> Any:
|
||||
return value
|
||||
|
||||
|
||||
def _normalize_database_url(url: str) -> str:
|
||||
if url.startswith("postgres://"):
|
||||
return "postgresql+psycopg://" + url[len("postgres://"):]
|
||||
if url.startswith("postgresql://"):
|
||||
return "postgresql+psycopg://" + url[len("postgresql://"):]
|
||||
return url
|
||||
|
||||
|
||||
def _value_to_string(value: Any) -> str:
|
||||
"""将值转换为数据库存储的字符串"""
|
||||
if isinstance(value, SecretStr):
|
||||
@@ -462,7 +477,12 @@ def init_default_settings() -> None:
|
||||
for attr_name, defn in SETTING_DEFINITIONS.items():
|
||||
existing = get_setting(db, defn.db_key)
|
||||
if not existing:
|
||||
default_value = _value_to_string(defn.default_value)
|
||||
default_value = defn.default_value
|
||||
if attr_name == "database_url":
|
||||
env_url = os.environ.get("APP_DATABASE_URL") or os.environ.get("DATABASE_URL")
|
||||
if env_url:
|
||||
default_value = _normalize_database_url(env_url)
|
||||
default_value = _value_to_string(default_value)
|
||||
set_setting(
|
||||
db,
|
||||
defn.db_key,
|
||||
@@ -490,6 +510,9 @@ def _load_settings_from_db() -> Dict[str, Any]:
|
||||
else:
|
||||
# 数据库中没有此设置,使用默认值
|
||||
settings_dict[attr_name] = _convert_value(attr_name, _value_to_string(defn.default_value))
|
||||
env_url = os.environ.get("APP_DATABASE_URL") or os.environ.get("DATABASE_URL")
|
||||
if env_url:
|
||||
settings_dict["database_url"] = _normalize_database_url(env_url)
|
||||
return settings_dict
|
||||
except Exception as e:
|
||||
print(f"[Settings] 从数据库加载设置失败: {e},使用默认值")
|
||||
@@ -534,9 +557,14 @@ class Settings(BaseModel):
|
||||
@field_validator('database_url', mode='before')
|
||||
@classmethod
|
||||
def validate_database_url(cls, v):
|
||||
if isinstance(v, str):
|
||||
if v.startswith(("postgres://", "postgresql://")):
|
||||
return _normalize_database_url(v)
|
||||
if v.startswith(("postgresql+psycopg://", "postgresql+psycopg2://")):
|
||||
return v
|
||||
if isinstance(v, str) and v.startswith("sqlite:///"):
|
||||
return v
|
||||
if isinstance(v, str) and not v.startswith(("sqlite:///", "postgresql://", "mysql://")):
|
||||
if isinstance(v, str) and not v.startswith(("sqlite:///", "postgresql://", "postgresql+psycopg://", "postgresql+psycopg2://", "mysql://")):
|
||||
# 如果是文件路径,转换为 SQLite URL
|
||||
if os.path.isabs(v) or ":/" not in v:
|
||||
return f"sqlite:///{v}"
|
||||
@@ -546,6 +574,7 @@ class Settings(BaseModel):
|
||||
webui_host: str = "0.0.0.0"
|
||||
webui_port: int = 8000
|
||||
webui_secret_key: SecretStr = SecretStr("your-secret-key-change-in-production")
|
||||
webui_access_password: SecretStr = SecretStr("admin123")
|
||||
|
||||
# 日志配置
|
||||
log_level: str = "INFO"
|
||||
|
||||
@@ -393,6 +393,10 @@ def get_data_dir() -> Path:
|
||||
数据目录 Path 对象
|
||||
"""
|
||||
settings = get_settings()
|
||||
if not settings.database_url.startswith("sqlite"):
|
||||
data_dir = Path(os.environ.get("APP_DATA_DIR", "data"))
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
return data_dir
|
||||
data_dir = Path(settings.database_url).parent
|
||||
|
||||
# 如果 database_url 是 SQLite URL,提取路径
|
||||
@@ -563,4 +567,4 @@ class Timer:
|
||||
return self.elapsed
|
||||
if self.start_time is not None:
|
||||
return time.time() - self.start_time
|
||||
return 0.0
|
||||
return 0.0
|
||||
|
||||
@@ -15,25 +15,37 @@ from .models import Base
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _build_sqlalchemy_url(database_url: str) -> str:
|
||||
if database_url.startswith("postgresql://"):
|
||||
return "postgresql+psycopg://" + database_url[len("postgresql://"):]
|
||||
if database_url.startswith("postgres://"):
|
||||
return "postgresql+psycopg://" + database_url[len("postgres://"):]
|
||||
return database_url
|
||||
|
||||
|
||||
class DatabaseSessionManager:
|
||||
"""数据库会话管理器"""
|
||||
|
||||
def __init__(self, database_url: str = None):
|
||||
if database_url is None:
|
||||
# 优先使用 APP_DATA_DIR 环境变量(PyInstaller 打包后由 webui.py 设置)
|
||||
data_dir = os.environ.get('APP_DATA_DIR') or os.path.join(
|
||||
os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
|
||||
'data'
|
||||
)
|
||||
db_path = os.path.join(data_dir, 'database.db')
|
||||
# 确保目录存在
|
||||
os.makedirs(data_dir, exist_ok=True)
|
||||
database_url = f"sqlite:///{db_path}"
|
||||
env_url = os.environ.get("APP_DATABASE_URL") or os.environ.get("DATABASE_URL")
|
||||
if env_url:
|
||||
database_url = env_url
|
||||
else:
|
||||
# 优先使用 APP_DATA_DIR 环境变量(PyInstaller 打包后由 webui.py 设置)
|
||||
data_dir = os.environ.get('APP_DATA_DIR') or os.path.join(
|
||||
os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
|
||||
'data'
|
||||
)
|
||||
db_path = os.path.join(data_dir, 'database.db')
|
||||
# 确保目录存在
|
||||
os.makedirs(data_dir, exist_ok=True)
|
||||
database_url = f"sqlite:///{db_path}"
|
||||
|
||||
self.database_url = database_url
|
||||
self.database_url = _build_sqlalchemy_url(database_url)
|
||||
self.engine = create_engine(
|
||||
database_url,
|
||||
connect_args={"check_same_thread": False} if database_url.startswith("sqlite") else {},
|
||||
self.database_url,
|
||||
connect_args={"check_same_thread": False} if self.database_url.startswith("sqlite") else {},
|
||||
echo=False, # 设置为 True 可以查看所有 SQL 语句
|
||||
pool_pre_ping=True # 连接池预检查
|
||||
)
|
||||
@@ -152,4 +164,4 @@ def get_db() -> Generator[Session, None, None]:
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
db.close()
|
||||
|
||||
@@ -5,14 +5,17 @@ FastAPI 应用主文件
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import secrets
|
||||
import hmac
|
||||
import hashlib
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi import FastAPI, Request, Form
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import HTMLResponse
|
||||
from fastapi.responses import HTMLResponse, RedirectResponse
|
||||
|
||||
from ..config.settings import get_settings
|
||||
from .routes import api_router
|
||||
@@ -78,24 +81,74 @@ def create_app() -> FastAPI:
|
||||
# 模板引擎
|
||||
templates = Jinja2Templates(directory=str(TEMPLATES_DIR))
|
||||
|
||||
def _auth_token(password: str) -> str:
|
||||
secret = get_settings().webui_secret_key.get_secret_value().encode("utf-8")
|
||||
return hmac.new(secret, password.encode("utf-8"), hashlib.sha256).hexdigest()
|
||||
|
||||
def _is_authenticated(request: Request) -> bool:
|
||||
cookie = request.cookies.get("webui_auth")
|
||||
expected = _auth_token(get_settings().webui_access_password.get_secret_value())
|
||||
return bool(cookie) and secrets.compare_digest(cookie, expected)
|
||||
|
||||
def _redirect_to_login(request: Request) -> RedirectResponse:
|
||||
return RedirectResponse(url=f"/login?next={request.url.path}", status_code=302)
|
||||
|
||||
@app.get("/login", response_class=HTMLResponse)
|
||||
async def login_page(request: Request, next: Optional[str] = "/"):
|
||||
"""登录页面"""
|
||||
return templates.TemplateResponse(
|
||||
"login.html",
|
||||
{"request": request, "error": "", "next": next or "/"}
|
||||
)
|
||||
|
||||
@app.post("/login")
|
||||
async def login_submit(request: Request, password: str = Form(...), next: Optional[str] = "/"):
|
||||
"""处理登录提交"""
|
||||
expected = get_settings().webui_access_password.get_secret_value()
|
||||
if not secrets.compare_digest(password, expected):
|
||||
return templates.TemplateResponse(
|
||||
"login.html",
|
||||
{"request": request, "error": "密码错误", "next": next or "/"},
|
||||
status_code=401
|
||||
)
|
||||
|
||||
response = RedirectResponse(url=next or "/", status_code=302)
|
||||
response.set_cookie("webui_auth", _auth_token(expected), httponly=True, samesite="lax")
|
||||
return response
|
||||
|
||||
@app.get("/logout")
|
||||
async def logout(request: Request, next: Optional[str] = "/login"):
|
||||
"""退出登录"""
|
||||
response = RedirectResponse(url=next or "/login", status_code=302)
|
||||
response.delete_cookie("webui_auth")
|
||||
return response
|
||||
|
||||
@app.get("/", response_class=HTMLResponse)
|
||||
async def index(request: Request):
|
||||
"""首页 - 注册页面"""
|
||||
if not _is_authenticated(request):
|
||||
return _redirect_to_login(request)
|
||||
return templates.TemplateResponse("index.html", {"request": request})
|
||||
|
||||
@app.get("/accounts", response_class=HTMLResponse)
|
||||
async def accounts_page(request: Request):
|
||||
"""账号管理页面"""
|
||||
if not _is_authenticated(request):
|
||||
return _redirect_to_login(request)
|
||||
return templates.TemplateResponse("accounts.html", {"request": request})
|
||||
|
||||
@app.get("/email-services", response_class=HTMLResponse)
|
||||
async def email_services_page(request: Request):
|
||||
"""邮箱服务管理页面"""
|
||||
if not _is_authenticated(request):
|
||||
return _redirect_to_login(request)
|
||||
return templates.TemplateResponse("email_services.html", {"request": request})
|
||||
|
||||
@app.get("/settings", response_class=HTMLResponse)
|
||||
async def settings_page(request: Request):
|
||||
"""设置页面"""
|
||||
if not _is_authenticated(request):
|
||||
return _redirect_to_login(request)
|
||||
return templates.TemplateResponse("settings.html", {"request": request})
|
||||
|
||||
@app.on_event("startup")
|
||||
|
||||
@@ -52,9 +52,10 @@ class RegistrationSettings(BaseModel):
|
||||
|
||||
class WebUISettings(BaseModel):
|
||||
"""Web UI 设置"""
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8000
|
||||
debug: bool = False
|
||||
host: Optional[str] = None
|
||||
port: Optional[int] = None
|
||||
debug: Optional[bool] = None
|
||||
access_password: Optional[str] = None
|
||||
|
||||
|
||||
class AllSettings(BaseModel):
|
||||
@@ -96,6 +97,7 @@ async def get_all_settings():
|
||||
"host": settings.webui_host,
|
||||
"port": settings.webui_port,
|
||||
"debug": settings.debug,
|
||||
"has_access_password": bool(settings.webui_access_password and settings.webui_access_password.get_secret_value()),
|
||||
},
|
||||
"tempmail": {
|
||||
"base_url": settings.tempmail_base_url,
|
||||
@@ -317,6 +319,23 @@ async def update_registration_settings(request: RegistrationSettings):
|
||||
return {"success": True, "message": "注册设置已更新"}
|
||||
|
||||
|
||||
@router.post("/webui")
|
||||
async def update_webui_settings(request: WebUISettings):
|
||||
"""更新 Web UI 设置"""
|
||||
update_dict = {}
|
||||
if request.host is not None:
|
||||
update_dict["webui_host"] = request.host
|
||||
if request.port is not None:
|
||||
update_dict["webui_port"] = request.port
|
||||
if request.debug is not None:
|
||||
update_dict["debug"] = request.debug
|
||||
if request.access_password:
|
||||
update_dict["webui_access_password"] = request.access_password
|
||||
|
||||
update_settings(**update_dict)
|
||||
return {"success": True, "message": "Web UI 设置已更新"}
|
||||
|
||||
|
||||
@router.get("/database")
|
||||
async def get_database_info():
|
||||
"""获取数据库信息"""
|
||||
|
||||
Reference in New Issue
Block a user