mirror of
https://github.com/DrizzleTime/Foxel.git
synced 2026-05-07 07:22:58 +08:00
refactor: optimize backend module
This commit is contained in:
@@ -137,8 +137,8 @@ Install the following tooling first:
|
||||
|
||||
Storage adapters integrate new storage providers (for example S3, FTP, or Alist).
|
||||
|
||||
1. Create a new module under [`services/adapters/`](services/adapters/) (for example `my_new_adapter.py`).
|
||||
2. Implement a class that inherits from [`services.adapters.base.BaseAdapter`](services/adapters/base.py) and provide concrete implementations for the abstract methods such as `list_dir`, `get_meta`, `upload`, and `download`.
|
||||
1. Create a new module under [`domain/adapters/providers/`](domain/adapters/providers/) (for example `my_new_adapter.py`).
|
||||
2. Implement a class that inherits from [`domain.adapters.providers.base.BaseAdapter`](domain/adapters/providers/base.py) and provide concrete implementations for the abstract methods such as `list_dir`, `get_meta`, `upload`, and `download`.
|
||||
|
||||
### Frontend Apps
|
||||
|
||||
|
||||
@@ -143,9 +143,9 @@
|
||||
|
||||
存储适配器是 Foxel 的核心扩展点,用于接入不同的存储后端 (如 S3, FTP, Alist 等)。
|
||||
|
||||
1. **创建适配器文件**: 在 [`services/adapters/`](services/adapters/) 目录下,创建一个新文件,例如 `my_new_adapter.py`。
|
||||
1. **创建适配器文件**: 在 [`domain/adapters/providers/`](domain/adapters/providers/) 目录下,创建一个新文件,例如 `my_new_adapter.py`。
|
||||
2. **实现适配器类**:
|
||||
- 创建一个类,继承自 [`services.adapters.base.BaseAdapter`](services/adapters/base.py)。
|
||||
- 创建一个类,继承自 [`domain.adapters.providers.base.BaseAdapter`](domain/adapters/providers/base.py)。
|
||||
- 实现 `BaseAdapter` 中定义的所有抽象方法,如 `list_dir`, `get_meta`, `upload`, `download` 等。请仔细阅读基类中的文档注释以理解每个方法的作用和参数。
|
||||
|
||||
### 贡献前端应用 (App)
|
||||
|
||||
@@ -1,26 +1,37 @@
|
||||
from fastapi import FastAPI
|
||||
|
||||
from .routes import adapters, virtual_fs, auth, config, processors, tasks, logs, share, backup, search, vector_db, offline_downloads, ai_providers, email
|
||||
from .routes import webdav, s3
|
||||
from .routes import plugins
|
||||
from domain.adapters import api as adapters
|
||||
from domain.auth import api as auth
|
||||
from domain.backup import api as backup
|
||||
from domain.config import api as config
|
||||
from domain.email import api as email
|
||||
from domain.offline_downloads import api as offline_downloads
|
||||
from domain.plugins import api as plugins
|
||||
from domain.processors import api as processors
|
||||
from domain.share import api as share
|
||||
from domain.tasks import api as tasks
|
||||
from domain.ai import api as ai
|
||||
from domain.virtual_fs import api as virtual_fs
|
||||
from domain.virtual_fs import s3_api, search_api, webdav_api
|
||||
from domain.audit import router as audit
|
||||
|
||||
|
||||
def include_routers(app: FastAPI):
|
||||
app.include_router(adapters.router)
|
||||
app.include_router(virtual_fs.router)
|
||||
app.include_router(search.router)
|
||||
app.include_router(search_api.router)
|
||||
app.include_router(auth.router)
|
||||
app.include_router(config.router)
|
||||
app.include_router(processors.router)
|
||||
app.include_router(tasks.router)
|
||||
app.include_router(logs.router)
|
||||
app.include_router(share.router)
|
||||
app.include_router(share.public_router)
|
||||
app.include_router(backup.router)
|
||||
app.include_router(vector_db.router)
|
||||
app.include_router(ai_providers.router)
|
||||
app.include_router(ai.router_vector_db)
|
||||
app.include_router(ai.router_ai)
|
||||
app.include_router(plugins.router)
|
||||
app.include_router(webdav.router)
|
||||
app.include_router(s3.router)
|
||||
app.include_router(webdav_api.router)
|
||||
app.include_router(s3_api.router)
|
||||
app.include_router(offline_downloads.router)
|
||||
app.include_router(email.router)
|
||||
app.include_router(audit)
|
||||
|
||||
@@ -1,151 +0,0 @@
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from tortoise.transactions import in_transaction
|
||||
from typing import Annotated
|
||||
|
||||
from models import StorageAdapter
|
||||
from schemas import AdapterCreate, AdapterOut
|
||||
from services.auth import get_current_active_user, User
|
||||
from services.adapters.registry import runtime_registry, get_config_schemas, normalize_adapter_type
|
||||
from api.response import success
|
||||
from services.logging import LogService
|
||||
|
||||
router = APIRouter(prefix="/api/adapters", tags=["adapters"])
|
||||
|
||||
|
||||
def validate_and_normalize_config(adapter_type: str, cfg):
|
||||
schemas = get_config_schemas()
|
||||
adapter_type = normalize_adapter_type(adapter_type)
|
||||
if not adapter_type:
|
||||
raise HTTPException(400, detail="不支持的适配器类型")
|
||||
if not isinstance(cfg, dict):
|
||||
raise HTTPException(400, detail="config 必须是对象")
|
||||
schema = schemas.get(adapter_type)
|
||||
if not schema:
|
||||
raise HTTPException(400, detail=f"不支持的适配器类型: {adapter_type}")
|
||||
out = {}
|
||||
missing = []
|
||||
for f in schema:
|
||||
k = f["key"]
|
||||
if k in cfg and cfg[k] not in (None, ""):
|
||||
out[k] = cfg[k]
|
||||
elif "default" in f:
|
||||
out[k] = f["default"]
|
||||
elif f.get("required"):
|
||||
missing.append(k)
|
||||
if missing:
|
||||
raise HTTPException(400, detail="缺少必填配置字段: " + ", ".join(missing))
|
||||
return out
|
||||
|
||||
|
||||
@router.post("")
|
||||
async def create_adapter(
|
||||
data: AdapterCreate,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
norm_path = AdapterCreate.normalize_mount_path(data.path)
|
||||
exists = await StorageAdapter.get_or_none(path=norm_path)
|
||||
if exists:
|
||||
raise HTTPException(400, detail="Mount path already exists")
|
||||
|
||||
adapter_fields = {
|
||||
"name": data.name,
|
||||
"type": data.type,
|
||||
"config": validate_and_normalize_config(data.type, data.config or {}),
|
||||
"enabled": data.enabled,
|
||||
"path": norm_path,
|
||||
"sub_path": data.sub_path,
|
||||
}
|
||||
|
||||
rec = await StorageAdapter.create(**adapter_fields)
|
||||
await runtime_registry.upsert(rec)
|
||||
await LogService.action(
|
||||
"route:adapters",
|
||||
f"Created adapter {rec.name}",
|
||||
details=adapter_fields,
|
||||
user_id=current_user.id if hasattr(current_user, "id") else None,
|
||||
)
|
||||
return success(rec)
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_adapters(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
adapters = await StorageAdapter.all()
|
||||
out = [AdapterOut.model_validate(a) for a in adapters]
|
||||
return success(out)
|
||||
|
||||
|
||||
@router.get("/available")
|
||||
async def available_adapter_types(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
data = []
|
||||
for t, fields in get_config_schemas().items():
|
||||
data.append({
|
||||
"type": t,
|
||||
"config_schema": fields,
|
||||
})
|
||||
return success(data)
|
||||
|
||||
|
||||
@router.get("/{adapter_id}")
|
||||
async def get_adapter(
|
||||
adapter_id: int,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
rec = await StorageAdapter.get_or_none(id=adapter_id)
|
||||
if not rec:
|
||||
raise HTTPException(404, detail="Not found")
|
||||
return success(AdapterOut.model_validate(rec))
|
||||
|
||||
|
||||
@router.put("/{adapter_id}")
|
||||
async def update_adapter(
|
||||
adapter_id: int,
|
||||
data: AdapterCreate,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
rec = await StorageAdapter.get_or_none(id=adapter_id)
|
||||
if not rec:
|
||||
raise HTTPException(404, detail="Not found")
|
||||
|
||||
norm_path = AdapterCreate.normalize_mount_path(data.path)
|
||||
existing = await StorageAdapter.get_or_none(path=norm_path)
|
||||
if existing and existing.id != adapter_id:
|
||||
raise HTTPException(400, detail="Mount path already exists")
|
||||
|
||||
rec.name = data.name
|
||||
rec.type = data.type
|
||||
rec.config = validate_and_normalize_config(data.type, data.config or {})
|
||||
rec.enabled = data.enabled
|
||||
rec.path = norm_path
|
||||
rec.sub_path = data.sub_path
|
||||
await rec.save()
|
||||
|
||||
await runtime_registry.upsert(rec)
|
||||
await LogService.action(
|
||||
"route:adapters",
|
||||
f"Updated adapter {rec.name}",
|
||||
details=data.model_dump(),
|
||||
user_id=current_user.id if hasattr(current_user, "id") else None,
|
||||
)
|
||||
return success(rec)
|
||||
|
||||
|
||||
@router.delete("/{adapter_id}")
|
||||
async def delete_adapter(
|
||||
adapter_id: int,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
deleted = await StorageAdapter.filter(id=adapter_id).delete()
|
||||
if not deleted:
|
||||
raise HTTPException(404, detail="Not found")
|
||||
runtime_registry.remove(adapter_id)
|
||||
await LogService.action(
|
||||
"route:adapters",
|
||||
f"Deleted adapter {adapter_id}",
|
||||
details={"adapter_id": adapter_id},
|
||||
user_id=current_user.id if hasattr(current_user, "id") else None,
|
||||
)
|
||||
return success({"deleted": True})
|
||||
@@ -1,177 +0,0 @@
|
||||
from typing import Annotated, Dict, Optional
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Depends, HTTPException, Path
|
||||
|
||||
from api.response import success
|
||||
from schemas.ai import (
|
||||
AIDefaultsUpdate,
|
||||
AIModelCreate,
|
||||
AIModelUpdate,
|
||||
AIProviderCreate,
|
||||
AIProviderUpdate,
|
||||
)
|
||||
from services.ai_providers import AIProviderService
|
||||
from services.auth import User, get_current_active_user
|
||||
from services.vector_db import VectorDBService
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/ai", tags=["ai"])
|
||||
service = AIProviderService()
|
||||
|
||||
|
||||
@router.get("/providers")
|
||||
async def list_providers(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
providers = await service.list_providers()
|
||||
return success({"providers": providers})
|
||||
|
||||
|
||||
@router.post("/providers")
|
||||
async def create_provider(
|
||||
payload: AIProviderCreate,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
provider = await service.create_provider(payload.dict())
|
||||
return success(provider)
|
||||
|
||||
|
||||
@router.get("/providers/{provider_id}")
|
||||
async def get_provider(
|
||||
provider_id: Annotated[int, Path(..., gt=0)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
provider = await service.get_provider(provider_id, with_models=True)
|
||||
return success(provider)
|
||||
|
||||
|
||||
@router.put("/providers/{provider_id}")
|
||||
async def update_provider(
|
||||
provider_id: Annotated[int, Path(..., gt=0)],
|
||||
payload: AIProviderUpdate,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
data = {k: v for k, v in payload.dict().items() if v is not None}
|
||||
if not data:
|
||||
raise HTTPException(status_code=400, detail="No fields to update")
|
||||
provider = await service.update_provider(provider_id, data)
|
||||
return success(provider)
|
||||
|
||||
|
||||
@router.delete("/providers/{provider_id}")
|
||||
async def delete_provider(
|
||||
provider_id: Annotated[int, Path(..., gt=0)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
await service.delete_provider(provider_id)
|
||||
return success({"id": provider_id})
|
||||
|
||||
|
||||
@router.post("/providers/{provider_id}/sync-models")
|
||||
async def sync_models(
|
||||
provider_id: Annotated[int, Path(..., gt=0)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
try:
|
||||
result = await service.sync_models(provider_id)
|
||||
except (httpx.RequestError, httpx.HTTPStatusError) as exc:
|
||||
raise HTTPException(status_code=502, detail=f"Failed to synchronize models: {exc}") from exc
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
return success(result)
|
||||
|
||||
|
||||
@router.get("/providers/{provider_id}/remote-models")
|
||||
async def fetch_remote_models(
|
||||
provider_id: Annotated[int, Path(..., gt=0)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
try:
|
||||
models = await service.fetch_remote_models(provider_id)
|
||||
except (httpx.RequestError, httpx.HTTPStatusError) as exc:
|
||||
raise HTTPException(status_code=502, detail=f"Failed to pull models: {exc}") from exc
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
return success({"models": models})
|
||||
|
||||
|
||||
@router.get("/providers/{provider_id}/models")
|
||||
async def list_models(
|
||||
provider_id: Annotated[int, Path(..., gt=0)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
models = await service.list_models(provider_id)
|
||||
return success({"models": models})
|
||||
|
||||
|
||||
@router.post("/providers/{provider_id}/models")
|
||||
async def create_model(
|
||||
provider_id: Annotated[int, Path(..., gt=0)],
|
||||
payload: AIModelCreate,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
model = await service.create_model(provider_id, payload.dict())
|
||||
return success(model)
|
||||
|
||||
|
||||
@router.put("/models/{model_id}")
|
||||
async def update_model(
|
||||
model_id: Annotated[int, Path(..., gt=0)],
|
||||
payload: AIModelUpdate,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
data = {k: v for k, v in payload.dict().items() if v is not None}
|
||||
if not data:
|
||||
raise HTTPException(status_code=400, detail="No fields to update")
|
||||
model = await service.update_model(model_id, data)
|
||||
return success(model)
|
||||
|
||||
|
||||
@router.delete("/models/{model_id}")
|
||||
async def delete_model(
|
||||
model_id: Annotated[int, Path(..., gt=0)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
await service.delete_model(model_id)
|
||||
return success({"id": model_id})
|
||||
|
||||
|
||||
def _get_embedding_dimension(entry: Optional[Dict]) -> Optional[int]:
|
||||
if not entry:
|
||||
return None
|
||||
value = entry.get("embedding_dimensions")
|
||||
return int(value) if value is not None else None
|
||||
|
||||
|
||||
@router.get("/defaults")
|
||||
async def get_defaults(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
defaults = await service.get_default_models()
|
||||
return success(defaults)
|
||||
|
||||
|
||||
@router.put("/defaults")
|
||||
async def update_defaults(
|
||||
payload: AIDefaultsUpdate,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
previous = await service.get_default_models()
|
||||
try:
|
||||
updated = await service.set_default_models(payload.as_mapping())
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
prev_dim = _get_embedding_dimension(previous.get("embedding"))
|
||||
next_dim = _get_embedding_dimension(updated.get("embedding"))
|
||||
|
||||
if prev_dim and next_dim and prev_dim != next_dim:
|
||||
try:
|
||||
await VectorDBService().clear_all_data()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
raise HTTPException(status_code=500, detail=f"Failed to clear vector database: {exc}") from exc
|
||||
|
||||
return success(updated)
|
||||
@@ -1,155 +0,0 @@
|
||||
from typing import Annotated
|
||||
from fastapi import APIRouter, HTTPException, Depends, Form
|
||||
import hashlib
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from services.auth import (
|
||||
authenticate_user_db,
|
||||
create_access_token,
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES,
|
||||
register_user,
|
||||
Token,
|
||||
get_current_active_user,
|
||||
User,
|
||||
request_password_reset,
|
||||
verify_password_reset_token,
|
||||
reset_password_with_token,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
from datetime import timedelta
|
||||
from api.response import success
|
||||
from models.database import UserAccount
|
||||
from services.auth import verify_password, get_password_hash
|
||||
|
||||
router = APIRouter(prefix="/api/auth", tags=["auth"])
|
||||
|
||||
|
||||
class RegisterRequest(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
email: str | None = None
|
||||
full_name: str | None = None
|
||||
|
||||
|
||||
@router.post("/register", summary="注册第一个管理员用户")
|
||||
async def register(data: RegisterRequest):
|
||||
"""
|
||||
仅当系统中没有用户时,才允许注册。
|
||||
"""
|
||||
user = await register_user(
|
||||
username=data.username,
|
||||
password=data.password,
|
||||
email=data.email,
|
||||
full_name=data.full_name,
|
||||
)
|
||||
return success({"username": user.username}, msg="初始用户注册成功")
|
||||
|
||||
|
||||
@router.post("/login")
|
||||
async def login_for_access_token(
|
||||
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
||||
) -> Token:
|
||||
user = await authenticate_user_db(form_data.username, form_data.password)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="用户名或密码错误",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
access_token = await create_access_token(
|
||||
data={"sub": user.username}, expires_delta=access_token_expires
|
||||
)
|
||||
return Token(access_token=access_token, token_type="bearer")
|
||||
|
||||
|
||||
@router.get("/me", summary="获取当前登录用户信息")
|
||||
async def get_me(current_user: Annotated[User, Depends(get_current_active_user)]):
|
||||
"""
|
||||
返回当前登录用户的基本信息,并附带 gravatar 头像链接。
|
||||
"""
|
||||
email = (current_user.email or "").strip().lower()
|
||||
md5_hash = hashlib.md5(email.encode("utf-8")).hexdigest()
|
||||
gravatar_url = f"https://cn.cravatar.com/avatar/{md5_hash}?s=64&d=identicon"
|
||||
return success({
|
||||
"id": current_user.id,
|
||||
"username": current_user.username,
|
||||
"email": current_user.email,
|
||||
"full_name": current_user.full_name,
|
||||
"gravatar_url": gravatar_url,
|
||||
})
|
||||
|
||||
|
||||
class UpdateMeRequest(BaseModel):
|
||||
email: str | None = None
|
||||
full_name: str | None = None
|
||||
old_password: str | None = None
|
||||
new_password: str | None = None
|
||||
|
||||
|
||||
class PasswordResetRequest(BaseModel):
|
||||
email: str
|
||||
|
||||
|
||||
class PasswordResetConfirm(BaseModel):
|
||||
token: str
|
||||
password: str
|
||||
|
||||
|
||||
@router.put("/me", summary="更新当前登录用户信息")
|
||||
async def update_me(
|
||||
payload: UpdateMeRequest,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
db_user = await UserAccount.get_or_none(id=current_user.id)
|
||||
if not db_user:
|
||||
raise HTTPException(status_code=404, detail="用户不存在")
|
||||
|
||||
if payload.email is not None:
|
||||
exists = await UserAccount.filter(email=payload.email).exclude(id=db_user.id).exists()
|
||||
if exists:
|
||||
raise HTTPException(status_code=400, detail="邮箱已被占用")
|
||||
db_user.email = payload.email
|
||||
|
||||
if payload.full_name is not None:
|
||||
db_user.full_name = payload.full_name
|
||||
|
||||
if payload.new_password:
|
||||
if not payload.old_password:
|
||||
raise HTTPException(status_code=400, detail="请提供原密码")
|
||||
if not verify_password(payload.old_password, db_user.hashed_password):
|
||||
raise HTTPException(status_code=400, detail="原密码错误")
|
||||
db_user.hashed_password = get_password_hash(payload.new_password)
|
||||
|
||||
await db_user.save()
|
||||
|
||||
email = (db_user.email or "").strip().lower()
|
||||
md5_hash = hashlib.md5(email.encode("utf-8")).hexdigest()
|
||||
gravatar_url = f"https://cn.cravatar.com/avatar/{md5_hash}?s=64&d=identicon"
|
||||
return success({
|
||||
"id": db_user.id,
|
||||
"username": db_user.username,
|
||||
"email": db_user.email,
|
||||
"full_name": db_user.full_name,
|
||||
"gravatar_url": gravatar_url,
|
||||
})
|
||||
|
||||
|
||||
@router.post("/password-reset/request", summary="请求密码重置邮件")
|
||||
async def password_reset_request_endpoint(payload: PasswordResetRequest):
|
||||
await request_password_reset(payload.email)
|
||||
return success(msg="如果邮箱存在,将发送重置邮件")
|
||||
|
||||
|
||||
@router.get("/password-reset/verify", summary="校验密码重置令牌")
|
||||
async def password_reset_verify(token: str):
|
||||
user = await verify_password_reset_token(token)
|
||||
return success({
|
||||
"username": user.username,
|
||||
"email": user.email,
|
||||
})
|
||||
|
||||
|
||||
@router.post("/password-reset/confirm", summary="使用令牌重置密码")
|
||||
async def password_reset_confirm(payload: PasswordResetConfirm):
|
||||
await reset_password_with_token(payload.token, payload.password)
|
||||
return success(msg="密码已重置")
|
||||
@@ -1,50 +0,0 @@
|
||||
from fastapi import APIRouter, Depends, UploadFile, File, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from services.auth import get_current_active_user
|
||||
from services.backup import BackupService
|
||||
from models.database import UserAccount
|
||||
import json
|
||||
import datetime
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api/backup",
|
||||
tags=["Backup & Restore"],
|
||||
dependencies=[Depends(get_current_active_user)],
|
||||
)
|
||||
|
||||
@router.get("/export", summary="导出全站数据")
|
||||
async def export_backup():
|
||||
"""
|
||||
生成并下载一个包含所有关键数据的JSON文件。
|
||||
"""
|
||||
try:
|
||||
data = await BackupService.export_data()
|
||||
timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
headers = {
|
||||
"Content-Disposition": f"attachment; filename=foxel_backup_{timestamp}.json"
|
||||
}
|
||||
return JSONResponse(content=data, headers=headers)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.post("/import", summary="导入数据")
|
||||
async def import_backup(file: UploadFile = File(...)):
|
||||
"""
|
||||
从上传的JSON文件恢复数据。
|
||||
**警告**: 这将会覆盖所有现有数据!
|
||||
"""
|
||||
|
||||
if not file.filename.endswith(".json"):
|
||||
raise HTTPException(status_code=400, detail="无效的文件类型, 请上传 .json 文件")
|
||||
|
||||
try:
|
||||
contents = await file.read()
|
||||
data = json.loads(contents)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=400, detail="无法解析JSON文件")
|
||||
|
||||
try:
|
||||
await BackupService.import_data(data)
|
||||
return {"message": "数据导入成功。"}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"导入失败: {e}")
|
||||
@@ -1,83 +0,0 @@
|
||||
import httpx
|
||||
import time
|
||||
from fastapi import APIRouter, Depends, Form
|
||||
from typing import Annotated
|
||||
from services.config import ConfigCenter, VERSION
|
||||
from services.auth import get_current_active_user, User, has_users
|
||||
from api.response import success
|
||||
router = APIRouter(prefix="/api/config", tags=["config"])
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def get_config(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
key: str
|
||||
):
|
||||
value = await ConfigCenter.get(key)
|
||||
return success({"key": key, "value": value})
|
||||
|
||||
|
||||
@router.post("/")
|
||||
async def set_config(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
key: str = Form(...),
|
||||
value: str = Form(...)
|
||||
):
|
||||
await ConfigCenter.set(key, value)
|
||||
return success({"key": key, "value": value})
|
||||
|
||||
|
||||
@router.get("/all")
|
||||
async def get_all_config(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
configs = await ConfigCenter.get_all()
|
||||
return success(configs)
|
||||
|
||||
|
||||
@router.get("/status")
|
||||
async def get_system_status():
|
||||
logo = await ConfigCenter.get("APP_LOGO", "/logo.svg")
|
||||
favicon = await ConfigCenter.get("APP_FAVICON", logo)
|
||||
system_info = {
|
||||
"version": VERSION,
|
||||
"title": await ConfigCenter.get("APP_NAME", "Foxel"),
|
||||
"logo": logo,
|
||||
"favicon": favicon,
|
||||
"is_initialized": await has_users(),
|
||||
"app_domain": await ConfigCenter.get("APP_DOMAIN"),
|
||||
"file_domain": await ConfigCenter.get("FILE_DOMAIN"),
|
||||
}
|
||||
return success(system_info)
|
||||
|
||||
|
||||
latest_version_cache = {
|
||||
"timestamp": 0,
|
||||
"data": None
|
||||
}
|
||||
|
||||
|
||||
@router.get("/latest-version")
|
||||
async def get_latest_version():
|
||||
current_time = time.time()
|
||||
if current_time - latest_version_cache["timestamp"] < 3600 and latest_version_cache["data"]:
|
||||
return success(latest_version_cache["data"])
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.get(
|
||||
"https://api.github.com/repos/DrizzleTime/Foxel/releases/latest",
|
||||
follow_redirects=True,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
version_info = {
|
||||
"latest_version": data.get("tag_name"),
|
||||
"body": data.get("body")
|
||||
}
|
||||
latest_version_cache["timestamp"] = current_time
|
||||
latest_version_cache["data"] = version_info
|
||||
return success(version_info)
|
||||
except httpx.RequestError as e:
|
||||
if latest_version_cache["data"]:
|
||||
return success(latest_version_cache["data"])
|
||||
return success({"latest_version": None, "body": None})
|
||||
@@ -1,48 +0,0 @@
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, Query
|
||||
from models.database import Log
|
||||
from api.response import page, success
|
||||
from tortoise.expressions import Q
|
||||
from datetime import datetime
|
||||
|
||||
router = APIRouter(prefix="/api/logs", tags=["Logs"])
|
||||
|
||||
@router.get("")
|
||||
async def get_logs(
|
||||
page_num: int = Query(1, alias="page"),
|
||||
page_size: int = Query(20, alias="page_size"),
|
||||
level: Optional[str] = Query(None),
|
||||
source: Optional[str] = Query(None),
|
||||
start_time: Optional[datetime] = Query(None),
|
||||
end_time: Optional[datetime] = Query(None),
|
||||
):
|
||||
"""获取日志列表,支持分页和筛选"""
|
||||
query = Log.all()
|
||||
if level:
|
||||
query = query.filter(level=level)
|
||||
if source:
|
||||
query = query.filter(source__icontains=source)
|
||||
if start_time:
|
||||
query = query.filter(timestamp__gte=start_time)
|
||||
if end_time:
|
||||
query = query.filter(timestamp__lte=end_time)
|
||||
|
||||
total = await query.count()
|
||||
logs = await query.order_by("-timestamp").offset((page_num - 1) * page_size).limit(page_size)
|
||||
|
||||
return success(page([log for log in logs], total, page_num, page_size))
|
||||
|
||||
@router.delete("")
|
||||
async def clear_logs(
|
||||
start_time: Optional[datetime] = Query(None),
|
||||
end_time: Optional[datetime] = Query(None),
|
||||
):
|
||||
"""清理指定时间范围内的日志"""
|
||||
query = Log.all()
|
||||
if start_time:
|
||||
query = query.filter(timestamp__gte=start_time)
|
||||
if end_time:
|
||||
query = query.filter(timestamp__lte=end_time)
|
||||
|
||||
deleted_count = await query.delete()
|
||||
return success({"deleted_count": deleted_count})
|
||||
@@ -1,79 +0,0 @@
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from api.response import success
|
||||
from schemas.offline_downloads import OfflineDownloadCreate
|
||||
from services.auth import User, get_current_active_user
|
||||
from services.logging import LogService
|
||||
from services.task_queue import task_queue_service, TaskProgress
|
||||
from services.virtual_fs import path_is_directory
|
||||
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api/offline-downloads",
|
||||
tags=["OfflineDownloads"],
|
||||
)
|
||||
|
||||
|
||||
@router.post("/")
|
||||
async def create_offline_download(
|
||||
payload: OfflineDownloadCreate,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
dest_dir = payload.dest_dir
|
||||
try:
|
||||
is_dir = await path_is_directory(dest_dir)
|
||||
except HTTPException:
|
||||
is_dir = False
|
||||
if not is_dir:
|
||||
raise HTTPException(400, detail="Destination directory not found")
|
||||
|
||||
task = await task_queue_service.add_task(
|
||||
"offline_http_download",
|
||||
{
|
||||
"url": str(payload.url),
|
||||
"dest_dir": dest_dir,
|
||||
"filename": payload.filename,
|
||||
},
|
||||
)
|
||||
|
||||
await task_queue_service.update_progress(
|
||||
task.id,
|
||||
TaskProgress(
|
||||
stage="queued",
|
||||
percent=0.0,
|
||||
bytes_total=None,
|
||||
bytes_done=0,
|
||||
detail="Waiting to start",
|
||||
),
|
||||
)
|
||||
|
||||
await LogService.action(
|
||||
"route:offline_downloads",
|
||||
f"Offline download task created {task.id}",
|
||||
details={"url": str(payload.url), "dest_dir": dest_dir, "filename": payload.filename},
|
||||
user_id=current_user.id if hasattr(current_user, "id") else None,
|
||||
)
|
||||
|
||||
return success({"task_id": task.id})
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def list_offline_downloads(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
tasks = [t for t in task_queue_service.get_all_tasks() if t.name == "offline_http_download"]
|
||||
data = [t.dict() for t in tasks]
|
||||
return success(data)
|
||||
|
||||
|
||||
@router.get("/{task_id}")
|
||||
async def get_offline_download(
|
||||
task_id: str,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
task = task_queue_service.get_task(task_id)
|
||||
if not task or task.name != "offline_http_download":
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
return success(task.dict())
|
||||
@@ -1,73 +0,0 @@
|
||||
from typing import List, Any, Dict
|
||||
from fastapi import APIRouter, HTTPException, Body
|
||||
from models import database
|
||||
from schemas import PluginCreate, PluginOut
|
||||
|
||||
router = APIRouter(prefix="/api/plugins", tags=["plugins"])
|
||||
|
||||
|
||||
@router.post("", response_model=PluginOut)
|
||||
async def create_plugin(payload: PluginCreate):
|
||||
rec = await database.Plugin.create(
|
||||
url=payload.url,
|
||||
enabled=payload.enabled,
|
||||
)
|
||||
return PluginOut.model_validate(rec)
|
||||
|
||||
|
||||
@router.get("", response_model=List[PluginOut])
|
||||
async def list_plugins():
|
||||
rows = await database.Plugin.all().order_by("-id")
|
||||
return [PluginOut.model_validate(r) for r in rows]
|
||||
|
||||
|
||||
@router.delete("/{plugin_id}")
|
||||
async def delete_plugin(plugin_id: int):
|
||||
rec = await database.Plugin.get_or_none(id=plugin_id)
|
||||
if not rec:
|
||||
raise HTTPException(status_code=404, detail="Plugin not found")
|
||||
await rec.delete()
|
||||
return {"code": 0, "msg": "ok"}
|
||||
|
||||
|
||||
@router.put("/{plugin_id}", response_model=PluginOut)
|
||||
async def update_plugin(plugin_id: int, payload: PluginCreate):
|
||||
rec = await database.Plugin.get_or_none(id=plugin_id)
|
||||
if not rec:
|
||||
raise HTTPException(status_code=404, detail="Plugin not found")
|
||||
rec.url = payload.url
|
||||
rec.enabled = payload.enabled
|
||||
await rec.save()
|
||||
return PluginOut.model_validate(rec)
|
||||
|
||||
|
||||
@router.post("/{plugin_id}/metadata", response_model=PluginOut)
|
||||
async def update_manifest(plugin_id: int, manifest: Dict[str, Any] = Body(...)):
|
||||
rec = await database.Plugin.get_or_none(id=plugin_id)
|
||||
if not rec:
|
||||
raise HTTPException(status_code=404, detail="Plugin not found")
|
||||
key_map = {
|
||||
'key': 'key',
|
||||
'name': 'name',
|
||||
'version': 'version',
|
||||
'supported_exts': 'supported_exts',
|
||||
'supportedExts': 'supported_exts',
|
||||
'default_bounds': 'default_bounds',
|
||||
'defaultBounds': 'default_bounds',
|
||||
'default_maximized': 'default_maximized',
|
||||
'defaultMaximized': 'default_maximized',
|
||||
'icon': 'icon',
|
||||
'description': 'description',
|
||||
'author': 'author',
|
||||
'website': 'website',
|
||||
'github': 'github',
|
||||
}
|
||||
for k, v in list(manifest.items()):
|
||||
if v is None:
|
||||
continue
|
||||
attr = key_map.get(k)
|
||||
if not attr:
|
||||
continue
|
||||
setattr(rec, attr, v)
|
||||
await rec.save()
|
||||
return PluginOut.model_validate(rec)
|
||||
@@ -1,250 +0,0 @@
|
||||
from pathlib import Path
|
||||
from fastapi import APIRouter, Depends, Body, HTTPException
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from typing import Annotated
|
||||
from services.processors.registry import (
|
||||
get,
|
||||
get_config_schema,
|
||||
get_config_schemas,
|
||||
get_module_path,
|
||||
reload_processors,
|
||||
)
|
||||
from services.task_queue import task_queue_service
|
||||
from services.auth import get_current_active_user, User
|
||||
from api.response import success
|
||||
from pydantic import BaseModel
|
||||
from services.virtual_fs import path_is_directory, resolve_adapter_and_rel
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
router = APIRouter(prefix="/api/processors", tags=["processors"])
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_processors(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
schemas = get_config_schemas()
|
||||
out = []
|
||||
for t, meta in schemas.items():
|
||||
out.append({
|
||||
"type": meta["type"],
|
||||
"name": meta["name"],
|
||||
"supported_exts": meta.get("supported_exts", []),
|
||||
"config_schema": meta["config_schema"],
|
||||
"produces_file": meta.get("produces_file", False),
|
||||
"module_path": meta.get("module_path"),
|
||||
})
|
||||
return success(out)
|
||||
|
||||
|
||||
class ProcessRequest(BaseModel):
|
||||
path: str
|
||||
processor_type: str
|
||||
config: dict
|
||||
save_to: str | None = None
|
||||
overwrite: bool = False
|
||||
|
||||
|
||||
class ProcessDirectoryRequest(BaseModel):
|
||||
path: str
|
||||
processor_type: str
|
||||
config: dict
|
||||
overwrite: bool = True
|
||||
max_depth: Optional[int] = None
|
||||
suffix: Optional[str] = None
|
||||
|
||||
|
||||
class UpdateSourceRequest(BaseModel):
|
||||
source: str
|
||||
|
||||
|
||||
@router.post("/process")
|
||||
async def process_file_with_processor(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
req: ProcessRequest = Body(...)
|
||||
):
|
||||
is_dir = await path_is_directory(req.path)
|
||||
if is_dir and not req.overwrite:
|
||||
raise HTTPException(400, detail="Directory processing requires overwrite")
|
||||
|
||||
save_to = None if is_dir else (req.path if req.overwrite else req.save_to)
|
||||
task = await task_queue_service.add_task(
|
||||
"process_file",
|
||||
{
|
||||
"path": req.path,
|
||||
"processor_type": req.processor_type,
|
||||
"config": req.config,
|
||||
"save_to": save_to,
|
||||
"overwrite": req.overwrite,
|
||||
},
|
||||
)
|
||||
return success({"task_id": task.id})
|
||||
|
||||
|
||||
@router.post("/process-directory")
|
||||
async def process_directory_with_processor(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
req: ProcessDirectoryRequest = Body(...)
|
||||
):
|
||||
if req.max_depth is not None and req.max_depth < 0:
|
||||
raise HTTPException(400, detail="max_depth must be >= 0")
|
||||
|
||||
is_dir = await path_is_directory(req.path)
|
||||
if not is_dir:
|
||||
raise HTTPException(400, detail="Path must be a directory")
|
||||
|
||||
schema = get_config_schema(req.processor_type)
|
||||
_processor = get(req.processor_type)
|
||||
if not schema or not _processor:
|
||||
raise HTTPException(404, detail="Processor not found")
|
||||
|
||||
produces_file = bool(schema.get("produces_file"))
|
||||
raw_suffix = req.suffix if req.suffix is not None else None
|
||||
if raw_suffix is not None and raw_suffix.strip() == "":
|
||||
raw_suffix = None
|
||||
suffix = raw_suffix
|
||||
overwrite = req.overwrite
|
||||
|
||||
if produces_file:
|
||||
if not overwrite and not suffix:
|
||||
raise HTTPException(400, detail="Suffix is required when not overwriting files")
|
||||
else:
|
||||
overwrite = False
|
||||
suffix = None
|
||||
|
||||
supported_exts = schema.get("supported_exts") or []
|
||||
allowed_exts = {
|
||||
ext.lower().lstrip('.')
|
||||
for ext in supported_exts
|
||||
if isinstance(ext, str)
|
||||
}
|
||||
|
||||
def matches_extension(file_rel: str) -> bool:
|
||||
if not allowed_exts:
|
||||
return True
|
||||
if '.' not in file_rel:
|
||||
return '' in allowed_exts
|
||||
ext = file_rel.rsplit('.', 1)[-1].lower()
|
||||
return ext in allowed_exts or f'.{ext}' in allowed_exts
|
||||
|
||||
adapter_instance, adapter_model, root, rel = await resolve_adapter_and_rel(req.path)
|
||||
rel = rel.rstrip('/')
|
||||
|
||||
list_dir = getattr(adapter_instance, "list_dir", None)
|
||||
if not callable(list_dir):
|
||||
raise HTTPException(501, detail="Adapter does not implement list_dir")
|
||||
|
||||
def build_absolute_path(mount_path: str, rel_path: str) -> str:
|
||||
rel_norm = rel_path.lstrip('/')
|
||||
mount_norm = mount_path.rstrip('/')
|
||||
if not mount_norm:
|
||||
return '/' + rel_norm if rel_norm else '/'
|
||||
return f"{mount_norm}/{rel_norm}" if rel_norm else mount_norm
|
||||
|
||||
def apply_suffix(path_str: str, suffix_str: str) -> str:
|
||||
path_obj = Path(path_str)
|
||||
name = path_obj.name
|
||||
if not name:
|
||||
return path_str
|
||||
if '.' in name:
|
||||
base, ext = name.rsplit('.', 1)
|
||||
new_name = f"{base}{suffix_str}.{ext}"
|
||||
else:
|
||||
new_name = f"{name}{suffix_str}"
|
||||
return str(path_obj.with_name(new_name))
|
||||
|
||||
scheduled_tasks: List[str] = []
|
||||
stack: List[Tuple[str, int]] = [(rel, 0)]
|
||||
page_size = 200
|
||||
|
||||
while stack:
|
||||
current_rel, depth = stack.pop()
|
||||
page = 1
|
||||
while True:
|
||||
entries, total = await list_dir(root, current_rel, page, page_size, "name", "asc")
|
||||
entries = entries or []
|
||||
if not entries and (total or 0) == 0:
|
||||
break
|
||||
|
||||
for entry in entries:
|
||||
name = entry.get("name")
|
||||
if not name:
|
||||
continue
|
||||
child_rel = f"{current_rel}/{name}" if current_rel else name
|
||||
if entry.get("is_dir"):
|
||||
if req.max_depth is None or depth < req.max_depth:
|
||||
stack.append((child_rel.rstrip('/'), depth + 1))
|
||||
continue
|
||||
if not matches_extension(child_rel):
|
||||
continue
|
||||
absolute_path = build_absolute_path(adapter_model.path, child_rel)
|
||||
save_to = None
|
||||
if produces_file and not overwrite and suffix:
|
||||
save_to = apply_suffix(absolute_path, suffix)
|
||||
task = await task_queue_service.add_task(
|
||||
"process_file",
|
||||
{
|
||||
"path": absolute_path,
|
||||
"processor_type": req.processor_type,
|
||||
"config": req.config,
|
||||
"save_to": save_to,
|
||||
"overwrite": overwrite,
|
||||
},
|
||||
)
|
||||
scheduled_tasks.append(task.id)
|
||||
|
||||
if total is None or page * page_size >= total:
|
||||
break
|
||||
page += 1
|
||||
|
||||
return success({
|
||||
"task_ids": scheduled_tasks,
|
||||
"scheduled": len(scheduled_tasks),
|
||||
})
|
||||
|
||||
|
||||
@router.get("/source/{processor_type}")
|
||||
async def get_processor_source(
|
||||
processor_type: str,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
module_path = get_module_path(processor_type)
|
||||
if not module_path:
|
||||
raise HTTPException(404, detail="Processor not found")
|
||||
path_obj = Path(module_path)
|
||||
if not path_obj.exists():
|
||||
raise HTTPException(404, detail="Processor source not found")
|
||||
try:
|
||||
content = await run_in_threadpool(path_obj.read_text, encoding='utf-8')
|
||||
except Exception as exc:
|
||||
raise HTTPException(500, detail=f"Failed to read source: {exc}")
|
||||
return success({"source": content, "module_path": str(path_obj)})
|
||||
|
||||
|
||||
@router.put("/source/{processor_type}")
|
||||
async def update_processor_source(
|
||||
processor_type: str,
|
||||
req: UpdateSourceRequest,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
module_path = get_module_path(processor_type)
|
||||
if not module_path:
|
||||
raise HTTPException(404, detail="Processor not found")
|
||||
path_obj = Path(module_path)
|
||||
if not path_obj.exists():
|
||||
raise HTTPException(404, detail="Processor source not found")
|
||||
try:
|
||||
await run_in_threadpool(path_obj.write_text, req.source, encoding='utf-8')
|
||||
except Exception as exc:
|
||||
raise HTTPException(500, detail=f"Failed to write source: {exc}")
|
||||
return success(True)
|
||||
|
||||
|
||||
@router.post("/reload")
|
||||
async def reload_processor_modules(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
errors = reload_processors()
|
||||
if errors:
|
||||
raise HTTPException(500, detail="; ".join(errors))
|
||||
return success(True)
|
||||
@@ -1,217 +0,0 @@
|
||||
from typing import List, Optional
|
||||
from urllib.parse import quote
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from pydantic import BaseModel
|
||||
|
||||
from api.response import success
|
||||
from services.auth import User, get_current_active_user
|
||||
from services.share import share_service
|
||||
from services.virtual_fs import stream_file, stat_file
|
||||
from models.database import ShareLink, UserAccount
|
||||
|
||||
public_router = APIRouter(prefix="/api/s", tags=["Share - Public"])
|
||||
router = APIRouter(prefix="/api/shares", tags=["Share - Management"])
|
||||
|
||||
class ShareCreate(BaseModel):
|
||||
name: str
|
||||
paths: List[str]
|
||||
expires_in_days: Optional[int] = 7
|
||||
access_type: str = "public"
|
||||
password: Optional[str] = None
|
||||
|
||||
|
||||
class ShareInfo(BaseModel):
|
||||
id: int
|
||||
token: str
|
||||
name: str
|
||||
paths: List[str]
|
||||
created_at: str
|
||||
expires_at: Optional[str] = None
|
||||
access_type: str
|
||||
|
||||
@classmethod
|
||||
def from_orm(cls, obj: ShareLink):
|
||||
return cls(
|
||||
id=obj.id,
|
||||
token=obj.token,
|
||||
name=obj.name,
|
||||
paths=obj.paths,
|
||||
created_at=obj.created_at.isoformat(),
|
||||
expires_at=obj.expires_at.isoformat() if obj.expires_at else None,
|
||||
access_type=obj.access_type,
|
||||
)
|
||||
|
||||
|
||||
class ShareInfoWithPassword(ShareInfo):
|
||||
password: Optional[str] = None
|
||||
|
||||
|
||||
# --- Management Routes ---
|
||||
|
||||
@router.post("", response_model=ShareInfoWithPassword)
|
||||
async def create_share(
|
||||
payload: ShareCreate,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
):
|
||||
"""
|
||||
创建一个新的分享链接。
|
||||
"""
|
||||
user_account = await UserAccount.get(id=current_user.id)
|
||||
share = await share_service.create_share_link(
|
||||
user=user_account,
|
||||
name=payload.name,
|
||||
paths=payload.paths,
|
||||
expires_in_days=payload.expires_in_days,
|
||||
access_type=payload.access_type,
|
||||
password=payload.password,
|
||||
)
|
||||
share_info_base = ShareInfo.from_orm(share)
|
||||
response_data = share_info_base.model_dump()
|
||||
if payload.access_type == "password" and payload.password:
|
||||
response_data['password'] = payload.password
|
||||
|
||||
return response_data
|
||||
|
||||
|
||||
@router.get("", response_model=List[ShareInfo])
|
||||
async def get_my_shares(current_user: User = Depends(get_current_active_user)):
|
||||
"""
|
||||
获取当前用户的所有分享链接。
|
||||
"""
|
||||
user_account = await UserAccount.get(id=current_user.id)
|
||||
shares = await share_service.get_user_shares(user=user_account)
|
||||
return [ShareInfo.from_orm(s) for s in shares]
|
||||
|
||||
|
||||
@router.delete("/expired")
|
||||
async def delete_expired_shares(
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
):
|
||||
"""
|
||||
删除当前用户的所有已过期分享。
|
||||
"""
|
||||
user_account = await UserAccount.get(id=current_user.id)
|
||||
deleted_count = await share_service.delete_expired_shares(user=user_account)
|
||||
return success({"deleted_count": deleted_count})
|
||||
|
||||
|
||||
@router.delete("/{share_id}")
|
||||
async def delete_share(
|
||||
share_id: int,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
):
|
||||
"""
|
||||
删除一个分享链接。
|
||||
"""
|
||||
await share_service.delete_share_link(user=current_user, share_id=share_id)
|
||||
return success(msg="分享已取消")
|
||||
|
||||
|
||||
# --- Public Routes ---
|
||||
|
||||
class SharePassword(BaseModel):
|
||||
password: str
|
||||
|
||||
@public_router.post("/{token}/verify")
|
||||
async def verify_password(token: str, payload: SharePassword):
|
||||
"""
|
||||
验证分享链接的密码。
|
||||
"""
|
||||
share = await share_service.get_share_by_token(token)
|
||||
if share.access_type != "password":
|
||||
raise HTTPException(status_code=400, detail="此分享不需要密码")
|
||||
|
||||
if not share_service._verify_password(payload.password, share.hashed_password):
|
||||
raise HTTPException(status_code=403, detail="密码错误")
|
||||
|
||||
# 在这里可以考虑返回一个有时效性的token用于后续访问,但为了简单起见,
|
||||
# 我们让前端在每次请求时都带上密码或一个会话标识。
|
||||
# 简单起见,我们只返回成功状态。
|
||||
return success(msg="验证成功")
|
||||
|
||||
|
||||
@public_router.get("/{token}/ls")
|
||||
async def list_share_content(token: str, path: str = "/", password: Optional[str] = None):
|
||||
"""
|
||||
列出分享链接中的文件和目录。
|
||||
"""
|
||||
share = await share_service.get_share_by_token(token)
|
||||
|
||||
if share.access_type == "password":
|
||||
if not password:
|
||||
raise HTTPException(status_code=401, detail="需要密码")
|
||||
if not share_service._verify_password(password, share.hashed_password):
|
||||
raise HTTPException(status_code=403, detail="密码错误")
|
||||
|
||||
content = await share_service.get_shared_item_details(share, path)
|
||||
return success({
|
||||
"path": path,
|
||||
"entries": content.get("items", []),
|
||||
"pagination": {
|
||||
"total": content.get("total", 0),
|
||||
"page": content.get("page", 1),
|
||||
"page_size": content.get("page_size", 1),
|
||||
"pages": content.get("pages", 1),
|
||||
}
|
||||
})
|
||||
|
||||
@public_router.get("/{token}")
|
||||
async def get_share_info(token: str):
|
||||
"""
|
||||
获取分享链接的元数据信息。
|
||||
"""
|
||||
share = await share_service.get_share_by_token(token)
|
||||
return success(ShareInfo.from_orm(share))
|
||||
|
||||
|
||||
|
||||
@public_router.get("/{token}/download")
|
||||
async def download_shared_file(token: str, path: str, request: Request, password: Optional[str] = None):
|
||||
"""
|
||||
下载分享链接中的单个文件。
|
||||
"""
|
||||
if not path or path == "/" or ".." in path.split('/'):
|
||||
raise HTTPException(status_code=400, detail="无效的文件路径")
|
||||
|
||||
share = await share_service.get_share_by_token(token)
|
||||
if share.access_type == "password":
|
||||
if not password:
|
||||
raise HTTPException(status_code=401, detail="需要密码")
|
||||
if not share_service._verify_password(password, share.hashed_password):
|
||||
raise HTTPException(status_code=403, detail="密码错误")
|
||||
base_shared_path = share.paths[0]
|
||||
|
||||
# 判断分享的是文件还是目录
|
||||
is_dir = False
|
||||
try:
|
||||
stat = await stat_file(base_shared_path)
|
||||
if stat and stat.get("is_dir"):
|
||||
is_dir = True
|
||||
except HTTPException as e:
|
||||
if "Path is a directory" in str(e.detail) or "Not a file" in str(e.detail):
|
||||
is_dir = True
|
||||
else:
|
||||
# The shared path itself doesn't exist, which is an issue.
|
||||
raise HTTPException(status_code=404, detail="分享的源文件不存在")
|
||||
|
||||
if is_dir:
|
||||
# 目录分享:拼接路径
|
||||
full_virtual_path = f"{base_shared_path.rstrip('/')}/{path.lstrip('/')}"
|
||||
if not full_virtual_path.startswith(base_shared_path):
|
||||
raise HTTPException(status_code=403, detail="无权访问此路径")
|
||||
else:
|
||||
# 文件分享:路径应为分享的根路径
|
||||
shared_filename = base_shared_path.split('/')[-1]
|
||||
request_filename = path.lstrip('/')
|
||||
if shared_filename != request_filename:
|
||||
raise HTTPException(status_code=403, detail="无权访问此路径")
|
||||
full_virtual_path = base_shared_path
|
||||
|
||||
range_header = request.headers.get("Range")
|
||||
response = await stream_file(full_virtual_path, range_header)
|
||||
|
||||
# 设置 Content-Disposition 头来强制下载
|
||||
filename = full_virtual_path.split('/')[-1]
|
||||
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{quote(filename)}"
|
||||
|
||||
return response
|
||||
@@ -1,141 +0,0 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from typing import Annotated
|
||||
|
||||
from models.database import AutomationTask
|
||||
from schemas.tasks import (
|
||||
AutomationTaskCreate,
|
||||
AutomationTaskUpdate,
|
||||
TaskQueueSettings,
|
||||
TaskQueueSettingsResponse,
|
||||
)
|
||||
from api.response import success
|
||||
from services.auth import get_current_active_user, User
|
||||
from services.logging import LogService
|
||||
from services.task_queue import task_queue_service
|
||||
from services.config import ConfigCenter
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api/tasks",
|
||||
tags=["Tasks"],
|
||||
dependencies=[Depends(get_current_active_user)],
|
||||
responses={404: {"description": "Not found"}},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/queue")
|
||||
async def get_task_queue_status(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
tasks = task_queue_service.get_all_tasks()
|
||||
return success([task.dict() for task in tasks])
|
||||
|
||||
|
||||
@router.get("/queue/settings")
|
||||
async def get_task_queue_settings(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
payload = TaskQueueSettingsResponse(
|
||||
concurrency=task_queue_service.get_concurrency(),
|
||||
active_workers=task_queue_service.get_active_worker_count(),
|
||||
)
|
||||
return success(payload.model_dump())
|
||||
|
||||
|
||||
@router.post("/queue/settings")
|
||||
async def update_task_queue_settings(
|
||||
settings: TaskQueueSettings,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
await task_queue_service.set_concurrency(settings.concurrency)
|
||||
await ConfigCenter.set("TASK_QUEUE_CONCURRENCY", str(task_queue_service.get_concurrency()))
|
||||
await LogService.action(
|
||||
"route:tasks",
|
||||
"Updated task queue settings",
|
||||
details={"concurrency": settings.concurrency},
|
||||
user_id=getattr(current_user, "id", None),
|
||||
)
|
||||
payload = TaskQueueSettingsResponse(
|
||||
concurrency=task_queue_service.get_concurrency(),
|
||||
active_workers=task_queue_service.get_active_worker_count(),
|
||||
)
|
||||
return success(payload.model_dump())
|
||||
|
||||
|
||||
@router.get("/queue/{task_id}")
|
||||
async def get_task_status(
|
||||
task_id: str,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
task = task_queue_service.get_task(task_id)
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
return success(task.dict())
|
||||
|
||||
|
||||
@router.post("/")
|
||||
async def create_task(
|
||||
task_in: AutomationTaskCreate,
|
||||
user: User = Depends(get_current_active_user)
|
||||
):
|
||||
task = await AutomationTask.create(**task_in.model_dump())
|
||||
await LogService.action(
|
||||
"route:tasks",
|
||||
f"Created task {task.name}",
|
||||
details=task_in.model_dump(),
|
||||
user_id=user.id if hasattr(user, "id") else None,
|
||||
)
|
||||
return success(task)
|
||||
|
||||
|
||||
@router.get("/{task_id}")
|
||||
async def get_task(task_id: int):
|
||||
task = await AutomationTask.get_or_none(id=task_id)
|
||||
if not task:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Task {task_id} not found")
|
||||
return success(task)
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def list_tasks():
|
||||
tasks = await AutomationTask.all()
|
||||
return success(tasks)
|
||||
|
||||
|
||||
@router.put("/{task_id}")
|
||||
async def update_task(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
task_id: int, task_in: AutomationTaskUpdate):
|
||||
task = await AutomationTask.get_or_none(id=task_id)
|
||||
if not task:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Task {task_id} not found")
|
||||
update_data = task_in.model_dump(exclude_unset=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(task, key, value)
|
||||
await task.save()
|
||||
await LogService.action(
|
||||
"route:tasks",
|
||||
f"Updated task {task.name}",
|
||||
details=task_in.model_dump(),
|
||||
user_id=current_user.id,
|
||||
)
|
||||
return success(task)
|
||||
|
||||
|
||||
@router.delete("/{task_id}")
|
||||
async def delete_task(
|
||||
task_id: int,
|
||||
user: User = Depends(get_current_active_user)
|
||||
):
|
||||
deleted_count = await AutomationTask.filter(id=task_id).delete()
|
||||
if not deleted_count:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Task {task_id} not found")
|
||||
await LogService.action(
|
||||
"route:tasks",
|
||||
f"Deleted task {task_id}",
|
||||
details={"task_id": task_id},
|
||||
user_id=user.id if hasattr(user, "id") else None,
|
||||
)
|
||||
return success(msg="Task deleted")
|
||||
@@ -1,91 +0,0 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from services.auth import get_current_active_user
|
||||
from models.database import UserAccount
|
||||
from services.vector_db import (
|
||||
VectorDBService,
|
||||
VectorDBConfigManager,
|
||||
list_providers,
|
||||
get_provider_entry,
|
||||
)
|
||||
from services.vector_db.providers import get_provider_class
|
||||
from api.response import success
|
||||
|
||||
router = APIRouter(prefix="/api/vector-db", tags=["vector-db"])
|
||||
|
||||
|
||||
class VectorDBConfigPayload(BaseModel):
|
||||
type: str = Field(..., description="向量数据库提供者类型")
|
||||
config: Dict[str, Any] = Field(default_factory=dict, description="提供者配置参数")
|
||||
|
||||
|
||||
@router.post("/clear-all", summary="清空向量数据库")
|
||||
async def clear_vector_db(user: UserAccount = Depends(get_current_active_user)):
|
||||
try:
|
||||
service = VectorDBService()
|
||||
await service.clear_all_data()
|
||||
return success(msg="向量数据库已清空")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/stats", summary="获取向量数据库统计")
|
||||
async def get_vector_db_stats(user: UserAccount = Depends(get_current_active_user)):
|
||||
try:
|
||||
service = VectorDBService()
|
||||
data = await service.get_all_stats()
|
||||
return success(data=data)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/providers", summary="列出可用向量数据库提供者")
|
||||
async def list_vector_providers(user: UserAccount = Depends(get_current_active_user)):
|
||||
return success(list_providers())
|
||||
|
||||
|
||||
@router.get("/config", summary="获取当前向量数据库配置")
|
||||
async def get_vector_db_config(user: UserAccount = Depends(get_current_active_user)):
|
||||
service = VectorDBService()
|
||||
data = await service.current_provider()
|
||||
return success(data)
|
||||
|
||||
|
||||
@router.post("/config", summary="更新向量数据库配置")
|
||||
async def update_vector_db_config(payload: VectorDBConfigPayload, user: UserAccount = Depends(get_current_active_user)):
|
||||
entry = get_provider_entry(payload.type)
|
||||
if not entry:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"未知的向量数据库类型: {payload.type}")
|
||||
if not entry.get("enabled", True):
|
||||
raise HTTPException(status_code=400, detail="该向量数据库类型暂不可用")
|
||||
|
||||
provider_cls = get_provider_class(payload.type)
|
||||
if not provider_cls:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"未找到类型 {payload.type} 对应的实现")
|
||||
|
||||
# 先尝试建立连接,确保配置有效
|
||||
test_provider = provider_cls(payload.config)
|
||||
try:
|
||||
await test_provider.initialize()
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc))
|
||||
finally:
|
||||
client = getattr(test_provider, "client", None)
|
||||
close_fn = getattr(client, "close", None)
|
||||
if callable(close_fn):
|
||||
try:
|
||||
close_fn()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await VectorDBConfigManager.save_config(payload.type, payload.config)
|
||||
service = VectorDBService()
|
||||
await service.reload()
|
||||
config_data = await service.current_provider()
|
||||
stats = await service.get_all_stats()
|
||||
return success({"config": config_data, "stats": stats})
|
||||
@@ -1,376 +0,0 @@
|
||||
from fastapi import APIRouter, UploadFile, File, HTTPException, Response, Query, Request, Depends
|
||||
import mimetypes
|
||||
import re
|
||||
from typing import Annotated
|
||||
|
||||
from services.auth import get_current_active_user, User
|
||||
from services.virtual_fs import (
|
||||
list_virtual_dir,
|
||||
read_file,
|
||||
write_file,
|
||||
make_dir,
|
||||
delete_path,
|
||||
move_path,
|
||||
resolve_adapter_and_rel,
|
||||
stream_file,
|
||||
generate_temp_link_token,
|
||||
verify_temp_link_token,
|
||||
maybe_redirect_download,
|
||||
)
|
||||
from services.thumbnail import is_image_filename, get_or_create_thumb, is_raw_filename, is_video_filename
|
||||
from schemas import MkdirRequest, MoveRequest
|
||||
from api.response import success
|
||||
from services.config import ConfigCenter
|
||||
|
||||
router = APIRouter(prefix='/api/fs', tags=["virtual-fs"])
|
||||
|
||||
|
||||
@router.get("/file/{full_path:path}")
|
||||
async def get_file(
|
||||
full_path: str,
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
full_path = '/' + full_path if not full_path.startswith('/') else full_path
|
||||
|
||||
if is_raw_filename(full_path):
|
||||
import rawpy
|
||||
from PIL import Image
|
||||
import io
|
||||
try:
|
||||
raw_data = await read_file(full_path)
|
||||
with rawpy.imread(io.BytesIO(raw_data)) as raw:
|
||||
rgb = raw.postprocess(use_camera_wb=True, output_bps=8)
|
||||
im = Image.fromarray(rgb)
|
||||
buf = io.BytesIO()
|
||||
im.save(buf, 'JPEG', quality=90)
|
||||
content = buf.getvalue()
|
||||
return Response(content=content, media_type='image/jpeg')
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(404, detail="File not found")
|
||||
except Exception as e:
|
||||
raise HTTPException(500, detail=f"RAW file processing failed: {e}")
|
||||
|
||||
adapter_instance, adapter_model, root, rel = await resolve_adapter_and_rel(full_path)
|
||||
|
||||
redirect_response = await maybe_redirect_download(adapter_instance, adapter_model, root, rel)
|
||||
if redirect_response is not None:
|
||||
return redirect_response
|
||||
|
||||
try:
|
||||
content = await read_file(full_path)
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(404, detail="File not found")
|
||||
|
||||
if not isinstance(content, (bytes, bytearray)):
|
||||
return Response(content=content, media_type="application/octet-stream")
|
||||
|
||||
content_length = len(content)
|
||||
content_type = mimetypes.guess_type(
|
||||
full_path)[0] or "application/octet-stream"
|
||||
|
||||
range_header = request.headers.get('Range')
|
||||
if range_header:
|
||||
range_match = re.match(r'bytes=(\d+)-(\d*)', range_header)
|
||||
if range_match:
|
||||
start = int(range_match.group(1))
|
||||
end = int(range_match.group(2)) if range_match.group(
|
||||
2) else content_length - 1
|
||||
|
||||
start = max(0, min(start, content_length - 1))
|
||||
end = max(start, min(end, content_length - 1))
|
||||
|
||||
chunk = content[start:end + 1]
|
||||
chunk_size = len(chunk)
|
||||
|
||||
headers = {
|
||||
'Content-Range': f'bytes {start}-{end}/{content_length}',
|
||||
'Accept-Ranges': 'bytes',
|
||||
'Content-Length': str(chunk_size),
|
||||
'Content-Type': content_type,
|
||||
}
|
||||
|
||||
return Response(
|
||||
content=chunk,
|
||||
status_code=206,
|
||||
headers=headers
|
||||
)
|
||||
|
||||
headers = {
|
||||
'Accept-Ranges': 'bytes',
|
||||
'Content-Length': str(content_length),
|
||||
'Content-Type': content_type,
|
||||
}
|
||||
|
||||
if content_type.startswith('video/'):
|
||||
headers['Cache-Control'] = 'public, max-age=3600'
|
||||
|
||||
return Response(content=content, headers=headers)
|
||||
|
||||
|
||||
@router.get("/thumb/{full_path:path}")
|
||||
async def get_thumb(
|
||||
full_path: str,
|
||||
w: int = Query(256, ge=8, le=1024),
|
||||
h: int = Query(256, ge=8, le=1024),
|
||||
fit: str = Query("cover"),
|
||||
):
|
||||
full_path = '/' + full_path if not full_path.startswith('/') else full_path
|
||||
if fit not in ("cover", "contain"):
|
||||
raise HTTPException(400, detail="fit must be cover|contain")
|
||||
adapter, mount, root, rel = await resolve_adapter_and_rel(full_path)
|
||||
if not rel or rel.endswith('/'):
|
||||
raise HTTPException(400, detail="Not a file")
|
||||
if not (is_image_filename(rel) or is_video_filename(rel)):
|
||||
raise HTTPException(404, detail="Not an image or video")
|
||||
# type: ignore
|
||||
data, mime, key = await get_or_create_thumb(adapter, mount.id, root, rel, w, h, fit)
|
||||
headers = {
|
||||
'Cache-Control': 'public, max-age=3600',
|
||||
'ETag': key,
|
||||
}
|
||||
return Response(content=data, media_type=mime, headers=headers)
|
||||
|
||||
|
||||
@router.get("/stream/{full_path:path}")
|
||||
async def stream_endpoint(
|
||||
full_path: str,
|
||||
request: Request,
|
||||
):
|
||||
"""支持 Range 的视频/大文件流式读取,优先使用底层适配器 Range 能力。"""
|
||||
full_path = '/' + full_path if not full_path.startswith('/') else full_path
|
||||
range_header = request.headers.get('Range')
|
||||
try:
|
||||
return await stream_file(full_path, range_header)
|
||||
except HTTPException:
|
||||
raise
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(404, detail="File not found")
|
||||
except Exception as e:
|
||||
raise HTTPException(500, detail=f"Stream error: {e}")
|
||||
|
||||
|
||||
@router.get("/temp-link/{full_path:path}")
|
||||
async def get_temp_link(
|
||||
full_path: str,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
expires_in: int = Query(3600, description="有效时间(秒), 0或负数表示永久")
|
||||
):
|
||||
"""获取文件的临时公开访问令牌"""
|
||||
full_path = '/' + full_path if not full_path.startswith('/') else full_path
|
||||
token = await generate_temp_link_token(full_path, expires_in=expires_in)
|
||||
file_domain = await ConfigCenter.get("FILE_DOMAIN")
|
||||
if file_domain:
|
||||
file_domain = file_domain.rstrip('/')
|
||||
url = f"{file_domain}/api/fs/public/{token}"
|
||||
else:
|
||||
url = f"/api/fs/public/{token}"
|
||||
return success({"token": token, "path": full_path, "url": url})
|
||||
|
||||
|
||||
@router.get("/public/{token}")
|
||||
async def access_public_file(
|
||||
token: str,
|
||||
request: Request,
|
||||
):
|
||||
"""通过令牌公开访问文件,支持 Range 请求"""
|
||||
try:
|
||||
path = await verify_temp_link_token(token)
|
||||
except HTTPException as e:
|
||||
raise e
|
||||
|
||||
range_header = request.headers.get('Range')
|
||||
try:
|
||||
return await stream_file(path, range_header)
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(404, detail="File not found via token")
|
||||
except Exception as e:
|
||||
raise HTTPException(500, detail=f"File access error: {e}")
|
||||
|
||||
|
||||
@router.get("/stat/{full_path:path}")
|
||||
async def get_file_stat(
|
||||
full_path: str,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
full_path = '/' + full_path if not full_path.startswith('/') else full_path
|
||||
from services.virtual_fs import stat_file
|
||||
stat = await stat_file(full_path)
|
||||
return success(stat)
|
||||
|
||||
|
||||
@router.post("/file/{full_path:path}")
|
||||
async def put_file(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
full_path: str,
|
||||
file: UploadFile = File(...)
|
||||
):
|
||||
full_path = '/' + full_path if not full_path.startswith('/') else full_path
|
||||
data = await file.read()
|
||||
await write_file(full_path, data)
|
||||
return success({"written": True, "path": full_path, "size": len(data)})
|
||||
|
||||
|
||||
@router.post("/mkdir")
|
||||
async def api_mkdir(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
body: MkdirRequest
|
||||
):
|
||||
path = body.path if body.path.startswith('/') else '/' + body.path
|
||||
if not path or path == '/':
|
||||
raise HTTPException(400, detail="Invalid path")
|
||||
await make_dir(path)
|
||||
return success({"created": True, "path": path})
|
||||
|
||||
|
||||
@router.post("/move")
|
||||
async def api_move(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
body: MoveRequest,
|
||||
overwrite: bool = Query(False, description="是否允许覆盖已存在目标"),
|
||||
):
|
||||
src = body.src if body.src.startswith('/') else '/' + body.src
|
||||
dst = body.dst if body.dst.startswith('/') else '/' + body.dst
|
||||
debug_info = await move_path(src, dst, overwrite=overwrite, return_debug=True, allow_cross=True)
|
||||
queued = bool(debug_info.get("queued"))
|
||||
response = {
|
||||
"moved": not queued,
|
||||
"queued": queued,
|
||||
"src": src,
|
||||
"dst": dst,
|
||||
"overwrite": overwrite,
|
||||
}
|
||||
if queued:
|
||||
response["task_id"] = debug_info.get("task_id")
|
||||
response["task_name"] = debug_info.get("task_name")
|
||||
return success(response)
|
||||
|
||||
|
||||
@router.post("/rename")
|
||||
async def api_rename(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
body: MoveRequest,
|
||||
overwrite: bool = Query(False, description="是否允许覆盖已存在目标")
|
||||
):
|
||||
src = body.src if body.src.startswith('/') else '/' + body.src
|
||||
dst = body.dst if body.dst.startswith('/') else '/' + body.dst
|
||||
from services.virtual_fs import rename_path
|
||||
await rename_path(src, dst, overwrite=overwrite, return_debug=False)
|
||||
return success({
|
||||
"renamed": True,
|
||||
"src": src,
|
||||
"dst": dst,
|
||||
"overwrite": overwrite,
|
||||
})
|
||||
|
||||
|
||||
@router.post("/copy")
|
||||
async def api_copy(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
body: MoveRequest,
|
||||
overwrite: bool = Query(False, description="是否覆盖已存在目标"),
|
||||
):
|
||||
from services.virtual_fs import copy_path
|
||||
src = body.src if body.src.startswith('/') else '/' + body.src
|
||||
dst = body.dst if body.dst.startswith('/') else '/' + body.dst
|
||||
debug_info = await copy_path(src, dst, overwrite=overwrite, return_debug=True, allow_cross=True)
|
||||
queued = bool(debug_info.get("queued"))
|
||||
response = {
|
||||
"copied": not queued,
|
||||
"queued": queued,
|
||||
"src": src,
|
||||
"dst": dst,
|
||||
"overwrite": overwrite,
|
||||
}
|
||||
if queued:
|
||||
response["task_id"] = debug_info.get("task_id")
|
||||
response["task_name"] = debug_info.get("task_name")
|
||||
return success(response)
|
||||
|
||||
|
||||
@router.post("/upload/{full_path:path}")
|
||||
async def upload_stream(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
full_path: str,
|
||||
file: UploadFile = File(...),
|
||||
overwrite: bool = Query(True, description="是否覆盖已存在文件"),
|
||||
chunk_size: int = Query(1024 * 1024, ge=8 * 1024,
|
||||
le=8 * 1024 * 1024, description="单次读取块大小")
|
||||
):
|
||||
full_path = '/' + full_path if not full_path.startswith('/') else full_path
|
||||
if full_path.endswith('/'):
|
||||
raise HTTPException(400, detail="Path must be a file")
|
||||
from services.virtual_fs import write_file_stream, resolve_adapter_and_rel
|
||||
adapter, _m, root, rel = await resolve_adapter_and_rel(full_path)
|
||||
exists_func = getattr(adapter, "exists", None)
|
||||
if not overwrite and callable(exists_func):
|
||||
try:
|
||||
if await exists_func(root, rel):
|
||||
raise HTTPException(409, detail="Destination exists")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def gen():
|
||||
while True:
|
||||
chunk = await file.read(chunk_size)
|
||||
if not chunk:
|
||||
break
|
||||
yield chunk
|
||||
size = await write_file_stream(full_path, gen(), overwrite=overwrite)
|
||||
return success({"uploaded": True, "path": full_path, "size": size, "overwrite": overwrite})
|
||||
|
||||
|
||||
@router.get("/{full_path:path}")
|
||||
async def browse_fs(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
full_path: str,
|
||||
page_num: int = Query(1, alias="page", ge=1, description="页码"),
|
||||
page_size: int = Query(50, ge=1, le=500, description="每页条数"),
|
||||
sort_by: str = Query("name", description="按字段排序: name, size, mtime"),
|
||||
sort_order: str = Query("asc", description="排序顺序: asc, desc")
|
||||
):
|
||||
full_path = '/' + full_path if not full_path.startswith('/') else full_path
|
||||
result = await list_virtual_dir(full_path, page_num, page_size, sort_by, sort_order)
|
||||
return success({
|
||||
"path": full_path,
|
||||
"entries": result["items"],
|
||||
"pagination": {
|
||||
"total": result["total"],
|
||||
"page": result["page"],
|
||||
"page_size": result["page_size"],
|
||||
"pages": result["pages"]
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
@router.delete("/{full_path:path}")
|
||||
async def api_delete(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
full_path: str
|
||||
):
|
||||
full_path = '/' + full_path if not full_path.startswith('/') else full_path
|
||||
await delete_path(full_path)
|
||||
return success({"deleted": True, "path": full_path})
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def root_listing(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
page_num: int = Query(1, alias="page", ge=1, description="页码"),
|
||||
page_size: int = Query(50, ge=1, le=500, description="每页条数"),
|
||||
sort_by: str = Query("name", description="按字段排序: name, size, mtime"),
|
||||
sort_order: str = Query("asc", description="排序顺序: asc, desc")
|
||||
):
|
||||
result = await list_virtual_dir("/", page_num, page_size, sort_by, sort_order)
|
||||
return success({
|
||||
"path": "/",
|
||||
"entries": result["items"],
|
||||
"pagination": {
|
||||
"total": result["total"],
|
||||
"page": result["page"],
|
||||
"page_size": result["page_size"],
|
||||
"pages": result["pages"]
|
||||
}
|
||||
})
|
||||
@@ -1,6 +1,6 @@
|
||||
from tortoise import Tortoise
|
||||
|
||||
from services.adapters.registry import runtime_registry
|
||||
from domain.adapters.registry import runtime_registry
|
||||
|
||||
TORTOISE_ORM = {
|
||||
"connections": {"default": "sqlite://data/db/db.sqlite3"},
|
||||
|
||||
1
domain/adapters/__init__.py
Normal file
1
domain/adapters/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
85
domain/adapters/api.py
Normal file
85
domain/adapters/api.py
Normal file
@@ -0,0 +1,85 @@
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
|
||||
from api.response import success
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.adapters.service import AdapterService
|
||||
from domain.adapters.types import AdapterCreate
|
||||
from domain.auth.service import get_current_active_user
|
||||
from domain.auth.types import User
|
||||
|
||||
router = APIRouter(prefix="/api/adapters", tags=["adapters"])
|
||||
|
||||
|
||||
@router.post("")
|
||||
@audit(
|
||||
action=AuditAction.CREATE,
|
||||
description="创建存储适配器",
|
||||
body_fields=["name", "type", "path", "sub_path", "enabled"],
|
||||
)
|
||||
async def create_adapter(
|
||||
request: Request,
|
||||
data: AdapterCreate,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
adapter = await AdapterService.create_adapter(data, current_user)
|
||||
return success(adapter)
|
||||
|
||||
|
||||
@router.get("")
|
||||
@audit(action=AuditAction.READ, description="获取适配器列表")
|
||||
async def list_adapters(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
adapters = await AdapterService.list_adapters()
|
||||
return success(adapters)
|
||||
|
||||
|
||||
@router.get("/available")
|
||||
@audit(action=AuditAction.READ, description="获取可用适配器类型")
|
||||
async def available_adapter_types(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
data = await AdapterService.available_adapter_types()
|
||||
return success(data)
|
||||
|
||||
|
||||
@router.get("/{adapter_id}")
|
||||
@audit(action=AuditAction.READ, description="获取适配器详情")
|
||||
async def get_adapter(
|
||||
request: Request,
|
||||
adapter_id: int,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
adapter = await AdapterService.get_adapter(adapter_id)
|
||||
return success(adapter)
|
||||
|
||||
|
||||
@router.put("/{adapter_id}")
|
||||
@audit(
|
||||
action=AuditAction.UPDATE,
|
||||
description="更新存储适配器",
|
||||
body_fields=["name", "type", "path", "sub_path", "enabled"],
|
||||
)
|
||||
async def update_adapter(
|
||||
request: Request,
|
||||
adapter_id: int,
|
||||
data: AdapterCreate,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
adapter = await AdapterService.update_adapter(adapter_id, data, current_user)
|
||||
return success(adapter)
|
||||
|
||||
|
||||
@router.delete("/{adapter_id}")
|
||||
@audit(action=AuditAction.DELETE, description="删除存储适配器")
|
||||
async def delete_adapter(
|
||||
request: Request,
|
||||
adapter_id: int,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
result = await AdapterService.delete_adapter(adapter_id, current_user)
|
||||
return success(result)
|
||||
3
domain/adapters/providers/__init__.py
Normal file
3
domain/adapters/providers/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .base import BaseAdapter
|
||||
|
||||
__all__ = ["BaseAdapter"]
|
||||
@@ -10,7 +10,6 @@ from ftplib import FTP, error_perm
|
||||
import mimetypes
|
||||
|
||||
from models import StorageAdapter
|
||||
from services.logging import LogService
|
||||
|
||||
|
||||
def _join_remote(root: str, rel: str) -> str:
|
||||
@@ -240,11 +239,6 @@ class FTPAdapter:
|
||||
pass
|
||||
|
||||
await asyncio.to_thread(_do_write)
|
||||
await LogService.info(
|
||||
"adapter:ftp",
|
||||
f"Wrote file to {rel}",
|
||||
details={"adapter_id": self.record.id, "path": path, "size": len(data)},
|
||||
)
|
||||
|
||||
async def write_file_stream(self, root: str, rel: str, data_iter: AsyncIterator[bytes]):
|
||||
# KISS: 聚合后一次性写入
|
||||
@@ -276,7 +270,6 @@ class FTPAdapter:
|
||||
pass
|
||||
|
||||
await asyncio.to_thread(_do_mkdir)
|
||||
await LogService.info("adapter:ftp", f"Created directory {rel}", details={"adapter_id": self.record.id, "path": path})
|
||||
|
||||
async def delete(self, root: str, rel: str):
|
||||
path = _join_remote(root, rel)
|
||||
@@ -340,7 +333,6 @@ class FTPAdapter:
|
||||
pass
|
||||
|
||||
await asyncio.to_thread(_do_delete)
|
||||
await LogService.info("adapter:ftp", f"Deleted {rel}", details={"adapter_id": self.record.id, "path": path})
|
||||
|
||||
async def move(self, root: str, src_rel: str, dst_rel: str):
|
||||
src = _join_remote(root, src_rel)
|
||||
@@ -367,7 +359,6 @@ class FTPAdapter:
|
||||
pass
|
||||
|
||||
await asyncio.to_thread(_do_move)
|
||||
await LogService.info("adapter:ftp", f"Moved {src_rel} to {dst_rel}", details={"adapter_id": self.record.id, "src": src, "dst": dst})
|
||||
|
||||
async def rename(self, root: str, src_rel: str, dst_rel: str):
|
||||
await self.move(root, src_rel, dst_rel)
|
||||
@@ -402,10 +393,6 @@ class FTPAdapter:
|
||||
child_src = f"{src_rel.rstrip('/')}/{ent['name']}"
|
||||
child_dst = f"{dst_rel.rstrip('/')}/{ent['name']}"
|
||||
await self.copy(root, child_src, child_dst, overwrite)
|
||||
await LogService.info(
|
||||
"adapter:ftp", f"Copied directory {src_rel} to {dst_rel}",
|
||||
details={"adapter_id": self.record.id, "src": src, "dst": dst}
|
||||
)
|
||||
return
|
||||
|
||||
# file
|
||||
@@ -418,7 +405,6 @@ class FTPAdapter:
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
await self.write_file(root, dst_rel, data)
|
||||
await LogService.info("adapter:ftp", f"Copied {src_rel} to {dst_rel}", details={"adapter_id": self.record.id, "src": src, "dst": dst})
|
||||
|
||||
async def stat_file(self, root: str, rel: str):
|
||||
path = _join_remote(root, rel)
|
||||
@@ -10,7 +10,6 @@ import mimetypes
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import StreamingResponse, Response
|
||||
from models import StorageAdapter
|
||||
from services.logging import LogService
|
||||
|
||||
|
||||
def _safe_join(root: str, rel: str) -> Path:
|
||||
@@ -115,11 +114,6 @@ class LocalAdapter:
|
||||
await asyncio.to_thread(fp.write_bytes, data)
|
||||
if not pre_exists:
|
||||
await asyncio.to_thread(_apply_mode, fp, DEFAULT_FILE_MODE)
|
||||
await LogService.info(
|
||||
"adapter:local",
|
||||
f"Wrote file to {rel}",
|
||||
details={"adapter_id": self.record.id, "path": str(fp), "size": len(data)},
|
||||
)
|
||||
|
||||
async def write_file_stream(self, root: str, rel: str, data_iter: AsyncIterator[bytes]):
|
||||
fp = _safe_join(root, rel)
|
||||
@@ -140,21 +134,11 @@ class LocalAdapter:
|
||||
await asyncio.to_thread(f.close)
|
||||
if not pre_exists:
|
||||
await asyncio.to_thread(_apply_mode, fp, DEFAULT_FILE_MODE)
|
||||
await LogService.info(
|
||||
"adapter:local",
|
||||
f"Wrote file stream to {rel}",
|
||||
details={"adapter_id": self.record.id, "path": str(fp), "size": size},
|
||||
)
|
||||
return size
|
||||
|
||||
async def mkdir(self, root: str, rel: str):
|
||||
fp = _safe_join(root, rel)
|
||||
await asyncio.to_thread(os.makedirs, fp, mode=DEFAULT_DIR_MODE, exist_ok=True)
|
||||
await LogService.info(
|
||||
"adapter:local",
|
||||
f"Created directory {rel}",
|
||||
details={"adapter_id": self.record.id, "path": str(fp)},
|
||||
)
|
||||
|
||||
async def delete(self, root: str, rel: str):
|
||||
fp = _safe_join(root, rel)
|
||||
@@ -164,11 +148,6 @@ class LocalAdapter:
|
||||
await asyncio.to_thread(shutil.rmtree, fp)
|
||||
else:
|
||||
await asyncio.to_thread(fp.unlink)
|
||||
await LogService.info(
|
||||
"adapter:local",
|
||||
f"Deleted {rel}",
|
||||
details={"adapter_id": self.record.id, "path": str(fp)},
|
||||
)
|
||||
|
||||
async def stat_path(self, root: str, rel: str):
|
||||
"""新增: 返回路径状态调试信息"""
|
||||
@@ -203,15 +182,6 @@ class LocalAdapter:
|
||||
except OSError:
|
||||
shutil.move(str(src), str(dst))
|
||||
await asyncio.to_thread(_do_move)
|
||||
await LogService.info(
|
||||
"adapter:local",
|
||||
f"Moved {src_rel} to {dst_rel}",
|
||||
details={
|
||||
"adapter_id": self.record.id,
|
||||
"src": str(src),
|
||||
"dst": str(dst),
|
||||
},
|
||||
)
|
||||
|
||||
async def rename(self, root: str, src_rel: str, dst_rel: str):
|
||||
src = _safe_join(root, src_rel)
|
||||
@@ -227,15 +197,6 @@ class LocalAdapter:
|
||||
except OSError:
|
||||
os.replace(src, dst)
|
||||
await asyncio.to_thread(_do_rename)
|
||||
await LogService.info(
|
||||
"adapter:local",
|
||||
f"Renamed {src_rel} to {dst_rel}",
|
||||
details={
|
||||
"adapter_id": self.record.id,
|
||||
"src": str(src),
|
||||
"dst": str(dst),
|
||||
},
|
||||
)
|
||||
|
||||
async def copy(self, root: str, src_rel: str, dst_rel: str, overwrite: bool = False):
|
||||
src = _safe_join(root, src_rel)
|
||||
@@ -258,15 +219,6 @@ class LocalAdapter:
|
||||
else:
|
||||
shutil.copy2(src, dst)
|
||||
await asyncio.to_thread(_do)
|
||||
await LogService.info(
|
||||
"adapter:local",
|
||||
f"Copied {src_rel} to {dst_rel}",
|
||||
details={
|
||||
"adapter_id": self.record.id,
|
||||
"src": str(src),
|
||||
"dst": str(dst),
|
||||
},
|
||||
)
|
||||
|
||||
async def stream_file(self, root: str, rel: str, range_header: str | None):
|
||||
fp = _safe_join(root, rel)
|
||||
@@ -452,7 +452,7 @@ CONFIG_SCHEMA = [
|
||||
{"key": "client_secret", "label": "Client Secret",
|
||||
"type": "password", "required": True},
|
||||
{"key": "refresh_token", "label": "Refresh Token", "type": "password",
|
||||
"required": True, "help_text": "可以通过运行 'python -m services.adapters.onedrive' 获取"},
|
||||
"required": True, "help_text": "可以通过运行 'python -m domain.adapters.providers.onedrive' 获取"},
|
||||
{"key": "root", "label": "根目录 (Root Path)", "type": "string",
|
||||
"required": False, "placeholder": "默认为根目录 /"},
|
||||
{"key": "enable_direct_download_307", "label": "Enable 307 redirect download", "type": "boolean", "default": False},
|
||||
@@ -10,7 +10,6 @@ from botocore.exceptions import ClientError
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from models import StorageAdapter
|
||||
from services.logging import LogService
|
||||
|
||||
|
||||
class S3Adapter:
|
||||
@@ -127,11 +126,6 @@ class S3Adapter:
|
||||
key = self._get_s3_key(rel)
|
||||
async with self._get_client() as s3:
|
||||
await s3.put_object(Bucket=self.bucket_name, Key=key, Body=data)
|
||||
await LogService.info(
|
||||
"adapter:s3", f"Wrote file to {rel}",
|
||||
details={"adapter_id": self.record.id,
|
||||
"bucket": self.bucket_name, "key": key, "size": len(data)}
|
||||
)
|
||||
|
||||
async def write_file_stream(self, root: str, rel: str, data_iter: AsyncIterator[bytes]):
|
||||
key = self._get_s3_key(rel)
|
||||
@@ -193,10 +187,6 @@ class S3Adapter:
|
||||
)
|
||||
raise IOError(f"S3 stream upload failed: {e}") from e
|
||||
|
||||
await LogService.info(
|
||||
"adapter:s3", f"Wrote file stream to {rel}",
|
||||
details={"adapter_id": self.record.id, "bucket": self.bucket_name, "key": key, "size": total_size}
|
||||
)
|
||||
return total_size
|
||||
|
||||
async def mkdir(self, root: str, rel: str):
|
||||
@@ -205,11 +195,6 @@ class S3Adapter:
|
||||
key += "/"
|
||||
async with self._get_client() as s3:
|
||||
await s3.put_object(Bucket=self.bucket_name, Key=key, Body=b"")
|
||||
await LogService.info(
|
||||
"adapter:s3", f"Created directory {rel}",
|
||||
details={"adapter_id": self.record.id,
|
||||
"bucket": self.bucket_name, "key": key}
|
||||
)
|
||||
|
||||
async def delete(self, root: str, rel: str):
|
||||
key = self._get_s3_key(rel)
|
||||
@@ -237,20 +222,9 @@ class S3Adapter:
|
||||
else:
|
||||
await s3.delete_object(Bucket=self.bucket_name, Key=key)
|
||||
|
||||
await LogService.info(
|
||||
"adapter:s3", f"Deleted {rel}",
|
||||
details={"adapter_id": self.record.id,
|
||||
"bucket": self.bucket_name, "key": key}
|
||||
)
|
||||
|
||||
async def move(self, root: str, src_rel: str, dst_rel: str):
|
||||
await self.copy(root, src_rel, dst_rel, overwrite=True)
|
||||
await self.delete(root, src_rel)
|
||||
await LogService.info(
|
||||
"adapter:s3", f"Moved {src_rel} to {dst_rel}",
|
||||
details={"adapter_id": self.record.id, "bucket": self.bucket_name,
|
||||
"src_key": self._get_s3_key(src_rel), "dst_key": self._get_s3_key(dst_rel)}
|
||||
)
|
||||
|
||||
async def rename(self, root: str, src_rel: str, dst_rel: str):
|
||||
await self.move(root, src_rel, dst_rel)
|
||||
@@ -270,11 +244,6 @@ class S3Adapter:
|
||||
|
||||
copy_source = {"Bucket": self.bucket_name, "Key": src_key}
|
||||
await s3.copy_object(CopySource=copy_source, Bucket=self.bucket_name, Key=dst_key)
|
||||
await LogService.info(
|
||||
"adapter:s3", f"Copied {src_rel} to {dst_rel}",
|
||||
details={"adapter_id": self.record.id, "bucket": self.bucket_name,
|
||||
"src_key": src_key, "dst_key": dst_key}
|
||||
)
|
||||
|
||||
async def stat_file(self, root: str, rel: str):
|
||||
key = self._get_s3_key(rel)
|
||||
@@ -353,8 +322,7 @@ class S3Adapter:
|
||||
while chunk := await body.read(65536):
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
LogService.error(
|
||||
"adapter:s3", f"Error streaming file {key}: {e}")
|
||||
raise
|
||||
|
||||
return StreamingResponse(iterator(), status_code=status, headers=headers, media_type=content_type)
|
||||
|
||||
@@ -10,7 +10,6 @@ from fastapi.responses import StreamingResponse
|
||||
import paramiko
|
||||
|
||||
from models import StorageAdapter
|
||||
from services.logging import LogService
|
||||
|
||||
|
||||
def _join_remote(root: str, rel: str) -> str:
|
||||
@@ -159,7 +158,6 @@ class SFTPAdapter:
|
||||
pass
|
||||
|
||||
await asyncio.to_thread(_do_write)
|
||||
await LogService.info("adapter:sftp", f"Wrote file to {rel}", details={"adapter_id": self.record.id, "path": path, "size": len(data)})
|
||||
|
||||
async def write_file_stream(self, root: str, rel: str, data_iter: AsyncIterator[bytes]):
|
||||
buf = bytearray()
|
||||
@@ -190,7 +188,6 @@ class SFTPAdapter:
|
||||
pass
|
||||
|
||||
await asyncio.to_thread(_do_mkdir)
|
||||
await LogService.info("adapter:sftp", f"Created directory {rel}", details={"adapter_id": self.record.id, "path": path})
|
||||
|
||||
async def delete(self, root: str, rel: str):
|
||||
path = _join_remote(root, rel)
|
||||
@@ -228,7 +225,6 @@ class SFTPAdapter:
|
||||
pass
|
||||
|
||||
await asyncio.to_thread(_do_delete)
|
||||
await LogService.info("adapter:sftp", f"Deleted {rel}", details={"adapter_id": self.record.id, "path": path})
|
||||
|
||||
async def move(self, root: str, src_rel: str, dst_rel: str):
|
||||
src = _join_remote(root, src_rel)
|
||||
@@ -255,7 +251,6 @@ class SFTPAdapter:
|
||||
pass
|
||||
|
||||
await asyncio.to_thread(_do_move)
|
||||
await LogService.info("adapter:sftp", f"Moved {src_rel} to {dst_rel}", details={"adapter_id": self.record.id, "src": src, "dst": dst})
|
||||
|
||||
async def rename(self, root: str, src_rel: str, dst_rel: str):
|
||||
await self.move(root, src_rel, dst_rel)
|
||||
@@ -283,7 +278,6 @@ class SFTPAdapter:
|
||||
child_src = f"{src_rel.rstrip('/')}/{ent['name']}"
|
||||
child_dst = f"{dst_rel.rstrip('/')}/{ent['name']}"
|
||||
await self.copy(root, child_src, child_dst, overwrite)
|
||||
await LogService.info("adapter:sftp", f"Copied directory {src_rel} to {dst_rel}", details={"adapter_id": self.record.id, "src": src, "dst": dst})
|
||||
return
|
||||
|
||||
# file copy
|
||||
@@ -295,7 +289,6 @@ class SFTPAdapter:
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
await self.write_file(root, dst_rel, data)
|
||||
await LogService.info("adapter:sftp", f"Copied {src_rel} to {dst_rel}", details={"adapter_id": self.record.id, "src": src, "dst": dst})
|
||||
|
||||
async def stat_file(self, root: str, rel: str):
|
||||
path = _join_remote(root, rel)
|
||||
@@ -9,7 +9,6 @@ import mimetypes
|
||||
import logging
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import StreamingResponse, Response
|
||||
from services.logging import LogService
|
||||
|
||||
NS = {"d": "DAV:"}
|
||||
|
||||
@@ -148,15 +147,6 @@ class WebDAVAdapter:
|
||||
async with self._client() as client:
|
||||
resp = await client.put(url, content=data)
|
||||
resp.raise_for_status()
|
||||
await LogService.info(
|
||||
"adapter:webdav",
|
||||
f"Wrote file to {rel}",
|
||||
details={
|
||||
"adapter_id": self.record.id,
|
||||
"url": url,
|
||||
"size": len(data),
|
||||
},
|
||||
)
|
||||
|
||||
async def mkdir(self, root: str, rel: str):
|
||||
url = self._build_url(rel.rstrip('/') + '/')
|
||||
@@ -164,11 +154,6 @@ class WebDAVAdapter:
|
||||
resp = await client.request("MKCOL", url)
|
||||
if resp.status_code not in (201, 405):
|
||||
resp.raise_for_status()
|
||||
await LogService.info(
|
||||
"adapter:webdav",
|
||||
f"Created directory {rel}",
|
||||
details={"adapter_id": self.record.id, "url": url},
|
||||
)
|
||||
|
||||
async def delete(self, root: str, rel: str):
|
||||
url = self._build_url(rel)
|
||||
@@ -176,11 +161,6 @@ class WebDAVAdapter:
|
||||
resp = await client.delete(url)
|
||||
if resp.status_code not in (204, 200, 404):
|
||||
resp.raise_for_status()
|
||||
await LogService.info(
|
||||
"adapter:webdav",
|
||||
f"Deleted {rel}",
|
||||
details={"adapter_id": self.record.id, "url": url},
|
||||
)
|
||||
|
||||
async def move(self, root: str, src_rel: str, dst_rel: str):
|
||||
src_url = self._build_url(src_rel)
|
||||
@@ -188,15 +168,6 @@ class WebDAVAdapter:
|
||||
async with self._client() as client:
|
||||
resp = await client.request("MOVE", src_url, headers={"Destination": dst_url})
|
||||
resp.raise_for_status()
|
||||
await LogService.info(
|
||||
"adapter:webdav",
|
||||
f"Moved {src_rel} to {dst_rel}",
|
||||
details={
|
||||
"adapter_id": self.record.id,
|
||||
"src_url": src_url,
|
||||
"dst_url": dst_url,
|
||||
},
|
||||
)
|
||||
|
||||
async def rename(self, root: str, src_rel: str, dst_rel: str):
|
||||
src_url = self._build_url(src_rel)
|
||||
@@ -204,15 +175,6 @@ class WebDAVAdapter:
|
||||
async with self._client() as client:
|
||||
resp = await client.request("MOVE", src_url, headers={"Destination": dst_url})
|
||||
resp.raise_for_status()
|
||||
await LogService.info(
|
||||
"adapter:webdav",
|
||||
f"Renamed {src_rel} to {dst_rel}",
|
||||
details={
|
||||
"adapter_id": self.record.id,
|
||||
"src_url": src_url,
|
||||
"dst_url": dst_url,
|
||||
},
|
||||
)
|
||||
|
||||
async def get_file_size(self, root: str, rel: str) -> int:
|
||||
"""获取文件大小"""
|
||||
@@ -518,15 +480,6 @@ class WebDAVAdapter:
|
||||
if resp.status_code == 404:
|
||||
raise FileNotFoundError(src_rel)
|
||||
resp.raise_for_status()
|
||||
await LogService.info(
|
||||
"adapter:webdav",
|
||||
f"Copied {src_rel} to {dst_rel}",
|
||||
details={
|
||||
"adapter_id": self.record.id,
|
||||
"src_url": src_url,
|
||||
"dst_url": dst_url,
|
||||
},
|
||||
)
|
||||
|
||||
ADAPTER_TYPE = "webdav"
|
||||
CONFIG_SCHEMA = [
|
||||
@@ -1,12 +1,12 @@
|
||||
from typing import Dict, Callable
|
||||
import pkgutil
|
||||
import inspect
|
||||
import pkgutil
|
||||
from importlib import import_module
|
||||
from typing import Callable, Dict
|
||||
|
||||
from .base import BaseAdapter
|
||||
from models import StorageAdapter
|
||||
from domain.adapters.providers.base import BaseAdapter
|
||||
|
||||
AdapterFactory = Callable[[StorageAdapter], object]
|
||||
AdapterFactory = Callable[[StorageAdapter], BaseAdapter]
|
||||
|
||||
TYPE_MAP: Dict[str, AdapterFactory] = {}
|
||||
CONFIG_SCHEMAS: Dict[str, list] = {}
|
||||
@@ -20,8 +20,9 @@ def normalize_adapter_type(value: str | None) -> str | None:
|
||||
|
||||
|
||||
def discover_adapters():
|
||||
"""扫描 services.adapters 包, 自动注册适配器类型、工厂与配置 schema。"""
|
||||
from .. import adapters as adapters_pkg
|
||||
"""扫描 domain.adapters.providers 包, 自动注册适配器类型、工厂与配置 schema。"""
|
||||
from domain.adapters import providers as adapters_pkg
|
||||
|
||||
TYPE_MAP.clear()
|
||||
CONFIG_SCHEMAS.clear()
|
||||
for modinfo in pkgutil.iter_modules(adapters_pkg.__path__):
|
||||
@@ -64,7 +65,7 @@ def get_config_schema(adapter_type: str):
|
||||
|
||||
class RuntimeRegistry:
|
||||
def __init__(self):
|
||||
self._instances: Dict[int, object] = {}
|
||||
self._instances: Dict[int, BaseAdapter] = {}
|
||||
|
||||
async def refresh(self):
|
||||
discover_adapters()
|
||||
@@ -86,9 +87,9 @@ class RuntimeRegistry:
|
||||
try:
|
||||
self._instances[rec.id] = factory(rec)
|
||||
except Exception:
|
||||
continue
|
||||
continue
|
||||
|
||||
def get(self, adapter_id: int):
|
||||
def get(self, adapter_id: int) -> BaseAdapter | None:
|
||||
return self._instances.get(adapter_id)
|
||||
|
||||
def snapshot(self) -> Dict[int, BaseAdapter]:
|
||||
@@ -104,7 +105,7 @@ class RuntimeRegistry:
|
||||
if not rec.enabled:
|
||||
self.remove(rec.id)
|
||||
return
|
||||
|
||||
|
||||
normalized_type = normalize_adapter_type(rec.type)
|
||||
if not normalized_type:
|
||||
self.remove(rec.id)
|
||||
111
domain/adapters/service.py
Normal file
111
domain/adapters/service.py
Normal file
@@ -0,0 +1,111 @@
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from domain.adapters.registry import (
|
||||
get_config_schemas,
|
||||
normalize_adapter_type,
|
||||
runtime_registry,
|
||||
)
|
||||
from domain.adapters.types import AdapterCreate, AdapterOut
|
||||
from domain.auth.types import User
|
||||
from models import StorageAdapter
|
||||
|
||||
|
||||
class AdapterService:
|
||||
@classmethod
|
||||
def _validate_and_normalize_config(cls, adapter_type: str, cfg):
|
||||
schemas = get_config_schemas()
|
||||
adapter_type = normalize_adapter_type(adapter_type)
|
||||
if not adapter_type:
|
||||
raise HTTPException(400, detail="不支持的适配器类型")
|
||||
if not isinstance(cfg, dict):
|
||||
raise HTTPException(400, detail="config 必须是对象")
|
||||
schema = schemas.get(adapter_type)
|
||||
if not schema:
|
||||
raise HTTPException(400, detail=f"不支持的适配器类型: {adapter_type}")
|
||||
out = {}
|
||||
missing = []
|
||||
for f in schema:
|
||||
k = f["key"]
|
||||
if k in cfg and cfg[k] not in (None, ""):
|
||||
out[k] = cfg[k]
|
||||
elif "default" in f:
|
||||
out[k] = f["default"]
|
||||
elif f.get("required"):
|
||||
missing.append(k)
|
||||
if missing:
|
||||
raise HTTPException(400, detail="缺少必填配置字段: " + ", ".join(missing))
|
||||
return out
|
||||
|
||||
@classmethod
|
||||
async def create_adapter(cls, data: AdapterCreate, current_user: Optional[User]):
|
||||
norm_path = AdapterCreate.normalize_mount_path(data.path)
|
||||
exists = await StorageAdapter.get_or_none(path=norm_path)
|
||||
if exists:
|
||||
raise HTTPException(400, detail="Mount path already exists")
|
||||
|
||||
adapter_fields = {
|
||||
"name": data.name,
|
||||
"type": data.type,
|
||||
"config": cls._validate_and_normalize_config(data.type, data.config or {}),
|
||||
"enabled": data.enabled,
|
||||
"path": norm_path,
|
||||
"sub_path": data.sub_path,
|
||||
}
|
||||
|
||||
rec = await StorageAdapter.create(**adapter_fields)
|
||||
await runtime_registry.upsert(rec)
|
||||
return AdapterOut.model_validate(rec)
|
||||
|
||||
@classmethod
|
||||
async def list_adapters(cls):
|
||||
adapters = await StorageAdapter.all()
|
||||
return [AdapterOut.model_validate(a) for a in adapters]
|
||||
|
||||
@classmethod
|
||||
async def available_adapter_types(cls):
|
||||
data = []
|
||||
for adapter_type, fields in get_config_schemas().items():
|
||||
data.append({
|
||||
"type": adapter_type,
|
||||
"config_schema": fields,
|
||||
})
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
async def get_adapter(cls, adapter_id: int):
|
||||
rec = await StorageAdapter.get_or_none(id=adapter_id)
|
||||
if not rec:
|
||||
raise HTTPException(404, detail="Not found")
|
||||
return AdapterOut.model_validate(rec)
|
||||
|
||||
@classmethod
|
||||
async def update_adapter(cls, adapter_id: int, data: AdapterCreate, current_user: Optional[User]):
|
||||
rec = await StorageAdapter.get_or_none(id=adapter_id)
|
||||
if not rec:
|
||||
raise HTTPException(404, detail="Not found")
|
||||
|
||||
norm_path = AdapterCreate.normalize_mount_path(data.path)
|
||||
existing = await StorageAdapter.get_or_none(path=norm_path)
|
||||
if existing and existing.id != adapter_id:
|
||||
raise HTTPException(400, detail="Mount path already exists")
|
||||
|
||||
rec.name = data.name
|
||||
rec.type = data.type
|
||||
rec.config = cls._validate_and_normalize_config(data.type, data.config or {})
|
||||
rec.enabled = data.enabled
|
||||
rec.path = norm_path
|
||||
rec.sub_path = data.sub_path
|
||||
await rec.save()
|
||||
|
||||
await runtime_registry.upsert(rec)
|
||||
return AdapterOut.model_validate(rec)
|
||||
|
||||
@classmethod
|
||||
async def delete_adapter(cls, adapter_id: int, current_user: Optional[User]):
|
||||
deleted = await StorageAdapter.filter(id=adapter_id).delete()
|
||||
if not deleted:
|
||||
raise HTTPException(404, detail="Not found")
|
||||
runtime_registry.remove(adapter_id)
|
||||
return {"deleted": True}
|
||||
@@ -1,5 +1,6 @@
|
||||
import re
|
||||
from typing import Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
34
domain/ai/__init__.py
Normal file
34
domain/ai/__init__.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from .api import router_ai, router_vector_db
|
||||
from .service import (
|
||||
AIProviderService,
|
||||
VectorDBConfigManager,
|
||||
VectorDBService,
|
||||
DEFAULT_VECTOR_DIMENSION,
|
||||
ABILITIES,
|
||||
normalize_capabilities,
|
||||
)
|
||||
from .types import (
|
||||
AIDefaultsUpdate,
|
||||
AIModelCreate,
|
||||
AIModelUpdate,
|
||||
AIProviderCreate,
|
||||
AIProviderUpdate,
|
||||
VectorDBConfigPayload,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"router_ai",
|
||||
"router_vector_db",
|
||||
"AIProviderService",
|
||||
"VectorDBService",
|
||||
"VectorDBConfigManager",
|
||||
"DEFAULT_VECTOR_DIMENSION",
|
||||
"ABILITIES",
|
||||
"normalize_capabilities",
|
||||
"AIDefaultsUpdate",
|
||||
"AIModelCreate",
|
||||
"AIModelUpdate",
|
||||
"AIProviderCreate",
|
||||
"AIProviderUpdate",
|
||||
"VectorDBConfigPayload",
|
||||
]
|
||||
305
domain/ai/api.py
Normal file
305
domain/ai/api.py
Normal file
@@ -0,0 +1,305 @@
|
||||
from typing import Annotated, Dict, Optional
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Depends, HTTPException, Path, Request
|
||||
|
||||
from api.response import success
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.ai.service import AIProviderService, VectorDBConfigManager, VectorDBService
|
||||
from domain.ai.types import (
|
||||
AIDefaultsUpdate,
|
||||
AIModelCreate,
|
||||
AIModelUpdate,
|
||||
AIProviderCreate,
|
||||
AIProviderUpdate,
|
||||
VectorDBConfigPayload,
|
||||
)
|
||||
from domain.ai.vector_providers import get_provider_class, get_provider_entry, list_providers
|
||||
from domain.auth.service import get_current_active_user
|
||||
from domain.auth.types import User
|
||||
|
||||
router_ai = APIRouter(prefix="/api/ai", tags=["ai"])
|
||||
router_vector_db = APIRouter(prefix="/api/vector-db", tags=["vector-db"])
|
||||
|
||||
|
||||
@audit(action=AuditAction.READ, description="获取 AI 提供商列表")
|
||||
@router_ai.get("/providers")
|
||||
async def list_providers_endpoint(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
providers = await AIProviderService.list_providers()
|
||||
return success({"providers": providers})
|
||||
|
||||
|
||||
@audit(
|
||||
action=AuditAction.CREATE,
|
||||
description="创建 AI 提供商",
|
||||
body_fields=["name", "identifier", "provider_type", "api_format", "base_url", "logo_url"],
|
||||
redact_fields=["api_key"],
|
||||
)
|
||||
@router_ai.post("/providers")
|
||||
async def create_provider(
|
||||
request: Request,
|
||||
payload: AIProviderCreate,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
provider = await AIProviderService.create_provider(payload.dict())
|
||||
return success(provider)
|
||||
|
||||
|
||||
@audit(action=AuditAction.READ, description="获取 AI 提供商详情")
|
||||
@router_ai.get("/providers/{provider_id}")
|
||||
async def get_provider(
|
||||
request: Request,
|
||||
provider_id: Annotated[int, Path(..., gt=0)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
provider = await AIProviderService.get_provider(provider_id, with_models=True)
|
||||
return success(provider)
|
||||
|
||||
|
||||
@audit(
|
||||
action=AuditAction.UPDATE,
|
||||
description="更新 AI 提供商",
|
||||
body_fields=["name", "provider_type", "api_format", "base_url", "logo_url", "api_key"],
|
||||
redact_fields=["api_key"],
|
||||
)
|
||||
@router_ai.put("/providers/{provider_id}")
|
||||
async def update_provider(
|
||||
request: Request,
|
||||
provider_id: Annotated[int, Path(..., gt=0)],
|
||||
payload: AIProviderUpdate,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
data = {k: v for k, v in payload.dict().items() if v is not None}
|
||||
if not data:
|
||||
raise HTTPException(status_code=400, detail="No fields to update")
|
||||
provider = await AIProviderService.update_provider(provider_id, data)
|
||||
return success(provider)
|
||||
|
||||
|
||||
@audit(action=AuditAction.DELETE, description="删除 AI 提供商")
|
||||
@router_ai.delete("/providers/{provider_id}")
|
||||
async def delete_provider(
|
||||
request: Request,
|
||||
provider_id: Annotated[int, Path(..., gt=0)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
await AIProviderService.delete_provider(provider_id)
|
||||
return success({"id": provider_id})
|
||||
|
||||
|
||||
@audit(action=AuditAction.UPDATE, description="同步模型列表")
|
||||
@router_ai.post("/providers/{provider_id}/sync-models")
|
||||
async def sync_models(
|
||||
request: Request,
|
||||
provider_id: Annotated[int, Path(..., gt=0)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
try:
|
||||
result = await AIProviderService.sync_models(provider_id)
|
||||
except (httpx.RequestError, httpx.HTTPStatusError) as exc:
|
||||
raise HTTPException(status_code=502, detail=f"Failed to synchronize models: {exc}") from exc
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
return success(result)
|
||||
|
||||
|
||||
@audit(action=AuditAction.READ, description="获取远程模型列表")
|
||||
@router_ai.get("/providers/{provider_id}/remote-models")
|
||||
async def fetch_remote_models(
|
||||
request: Request,
|
||||
provider_id: Annotated[int, Path(..., gt=0)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
try:
|
||||
models = await AIProviderService.fetch_remote_models(provider_id)
|
||||
except (httpx.RequestError, httpx.HTTPStatusError) as exc:
|
||||
raise HTTPException(status_code=502, detail=f"Failed to pull models: {exc}") from exc
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
return success({"models": models})
|
||||
|
||||
|
||||
@audit(action=AuditAction.READ, description="获取模型列表")
|
||||
@router_ai.get("/providers/{provider_id}/models")
|
||||
async def list_models(
|
||||
request: Request,
|
||||
provider_id: Annotated[int, Path(..., gt=0)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
models = await AIProviderService.list_models(provider_id)
|
||||
return success({"models": models})
|
||||
|
||||
|
||||
@audit(
|
||||
action=AuditAction.CREATE,
|
||||
description="创建模型",
|
||||
body_fields=["name", "display_name", "capabilities", "context_window", "embedding_dimensions"],
|
||||
)
|
||||
@router_ai.post("/providers/{provider_id}/models")
|
||||
async def create_model(
|
||||
request: Request,
|
||||
provider_id: Annotated[int, Path(..., gt=0)],
|
||||
payload: AIModelCreate,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
model = await AIProviderService.create_model(provider_id, payload.dict())
|
||||
return success(model)
|
||||
|
||||
|
||||
@audit(
|
||||
action=AuditAction.UPDATE,
|
||||
description="更新模型",
|
||||
body_fields=["display_name", "description", "capabilities", "context_window", "embedding_dimensions"],
|
||||
)
|
||||
@router_ai.put("/models/{model_id}")
|
||||
async def update_model(
|
||||
request: Request,
|
||||
model_id: Annotated[int, Path(..., gt=0)],
|
||||
payload: AIModelUpdate,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
data = {k: v for k, v in payload.dict().items() if v is not None}
|
||||
if not data:
|
||||
raise HTTPException(status_code=400, detail="No fields to update")
|
||||
model = await AIProviderService.update_model(model_id, data)
|
||||
return success(model)
|
||||
|
||||
|
||||
@audit(action=AuditAction.DELETE, description="删除模型")
|
||||
@router_ai.delete("/models/{model_id}")
|
||||
async def delete_model(
|
||||
request: Request,
|
||||
model_id: Annotated[int, Path(..., gt=0)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
await AIProviderService.delete_model(model_id)
|
||||
return success({"id": model_id})
|
||||
|
||||
|
||||
def _get_embedding_dimension(entry: Optional[Dict]) -> Optional[int]:
|
||||
if not entry:
|
||||
return None
|
||||
value = entry.get("embedding_dimensions")
|
||||
return int(value) if value is not None else None
|
||||
|
||||
|
||||
@audit(action=AuditAction.READ, description="获取默认模型")
|
||||
@router_ai.get("/defaults")
|
||||
async def get_defaults(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
defaults = await AIProviderService.get_default_models()
|
||||
return success(defaults)
|
||||
|
||||
|
||||
@audit(
|
||||
action=AuditAction.UPDATE,
|
||||
description="更新默认模型",
|
||||
body_fields=["chat", "vision", "embedding", "rerank", "voice", "tools"],
|
||||
)
|
||||
@router_ai.put("/defaults")
|
||||
async def update_defaults(
|
||||
request: Request,
|
||||
payload: AIDefaultsUpdate,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
previous = await AIProviderService.get_default_models()
|
||||
try:
|
||||
updated = await AIProviderService.set_default_models(payload.as_mapping())
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
prev_dim = _get_embedding_dimension(previous.get("embedding"))
|
||||
next_dim = _get_embedding_dimension(updated.get("embedding"))
|
||||
|
||||
if prev_dim and next_dim and prev_dim != next_dim:
|
||||
try:
|
||||
await VectorDBService().clear_all_data()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
raise HTTPException(status_code=500, detail=f"Failed to clear vector database: {exc}") from exc
|
||||
|
||||
return success(updated)
|
||||
|
||||
|
||||
@audit(action=AuditAction.UPDATE, description="清空向量数据库")
|
||||
@router_vector_db.post("/clear-all", summary="清空向量数据库")
|
||||
async def clear_vector_db(request: Request, user: User = Depends(get_current_active_user)):
|
||||
try:
|
||||
service = VectorDBService()
|
||||
await service.clear_all_data()
|
||||
return success(msg="向量数据库已清空")
|
||||
except Exception as e: # noqa: BLE001
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@audit(action=AuditAction.READ, description="获取向量数据库统计")
|
||||
@router_vector_db.get("/stats", summary="获取向量数据库统计")
|
||||
async def get_vector_db_stats(request: Request, user: User = Depends(get_current_active_user)):
|
||||
try:
|
||||
service = VectorDBService()
|
||||
data = await service.get_all_stats()
|
||||
return success(data=data)
|
||||
except Exception as e: # noqa: BLE001
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@audit(action=AuditAction.READ, description="获取向量数据库提供者列表")
|
||||
@router_vector_db.get("/providers", summary="列出可用向量数据库提供者")
|
||||
async def list_vector_providers(request: Request, user: User = Depends(get_current_active_user)):
|
||||
return success(list_providers())
|
||||
|
||||
|
||||
@audit(action=AuditAction.READ, description="获取向量数据库配置")
|
||||
@router_vector_db.get("/config", summary="获取当前向量数据库配置")
|
||||
async def get_vector_db_config(request: Request, user: User = Depends(get_current_active_user)):
|
||||
service = VectorDBService()
|
||||
data = await service.current_provider()
|
||||
return success(data)
|
||||
|
||||
|
||||
@audit(action=AuditAction.UPDATE, description="更新向量数据库配置", body_fields=["type"])
|
||||
@router_vector_db.post("/config", summary="更新向量数据库配置")
|
||||
async def update_vector_db_config(
|
||||
request: Request, payload: VectorDBConfigPayload, user: User = Depends(get_current_active_user)
|
||||
):
|
||||
entry = get_provider_entry(payload.type)
|
||||
if not entry:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"未知的向量数据库类型: {payload.type}")
|
||||
if not entry.get("enabled", True):
|
||||
raise HTTPException(status_code=400, detail="该向量数据库类型暂不可用")
|
||||
|
||||
provider_cls = get_provider_class(payload.type)
|
||||
if not provider_cls:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"未找到类型 {payload.type} 对应的实现")
|
||||
|
||||
test_provider = provider_cls(payload.config)
|
||||
try:
|
||||
await test_provider.initialize()
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc))
|
||||
finally:
|
||||
client = getattr(test_provider, "client", None)
|
||||
close_fn = getattr(client, "close", None)
|
||||
if callable(close_fn):
|
||||
try:
|
||||
close_fn()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await VectorDBConfigManager.save_config(payload.type, payload.config)
|
||||
service = VectorDBService()
|
||||
await service.reload()
|
||||
config_data = await service.current_provider()
|
||||
stats = await service.get_all_stats()
|
||||
return success({"config": config_data, "stats": stats})
|
||||
|
||||
|
||||
__all__ = ["router_ai", "router_vector_db"]
|
||||
@@ -4,10 +4,10 @@ import httpx
|
||||
from typing import List, Sequence, Tuple
|
||||
|
||||
from models.database import AIModel, AIProvider
|
||||
from services.ai_providers import AIProviderService
|
||||
from domain.ai.service import AIProviderService
|
||||
|
||||
|
||||
provider_service = AIProviderService()
|
||||
provider_service = AIProviderService
|
||||
|
||||
|
||||
class MissingModelError(RuntimeError):
|
||||
@@ -1,5 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
@@ -7,10 +9,18 @@ import httpx
|
||||
from tortoise.exceptions import DoesNotExist
|
||||
from tortoise.transactions import in_transaction
|
||||
|
||||
from domain.config.service import ConfigService
|
||||
from models.database import AIDefaultModel, AIModel, AIProvider
|
||||
|
||||
from .types import ABILITIES, normalize_capabilities
|
||||
from .vector_providers import (
|
||||
BaseVectorProvider,
|
||||
get_provider_class,
|
||||
get_provider_entry,
|
||||
list_providers,
|
||||
)
|
||||
|
||||
ABILITIES = ["chat", "vision", "embedding", "rerank", "voice", "tools"]
|
||||
DEFAULT_VECTOR_DIMENSION = 4096
|
||||
|
||||
OPENAI_EMBEDDING_DIMS = {
|
||||
"text-embedding-3-large": 3072,
|
||||
@@ -19,6 +29,43 @@ OPENAI_EMBEDDING_DIMS = {
|
||||
}
|
||||
|
||||
|
||||
class VectorDBConfigManager:
|
||||
TYPE_KEY = "VECTOR_DB_TYPE"
|
||||
CONFIG_KEY = "VECTOR_DB_CONFIG"
|
||||
DEFAULT_TYPE = "milvus_lite"
|
||||
|
||||
@classmethod
|
||||
async def load_config(cls) -> Tuple[str, Dict[str, Any]]:
|
||||
raw_type = await ConfigService.get(cls.TYPE_KEY, cls.DEFAULT_TYPE)
|
||||
provider_type = str(raw_type or cls.DEFAULT_TYPE)
|
||||
|
||||
raw_config = await ConfigService.get(cls.CONFIG_KEY)
|
||||
config_dict: Dict[str, Any] = {}
|
||||
if isinstance(raw_config, str) and raw_config:
|
||||
try:
|
||||
config_dict = json.loads(raw_config)
|
||||
except json.JSONDecodeError:
|
||||
config_dict = {}
|
||||
elif isinstance(raw_config, dict):
|
||||
config_dict = raw_config
|
||||
return provider_type, config_dict
|
||||
|
||||
@classmethod
|
||||
async def save_config(cls, provider_type: str, config: Dict[str, Any]) -> None:
|
||||
await ConfigService.set(cls.TYPE_KEY, provider_type)
|
||||
await ConfigService.set(cls.CONFIG_KEY, json.dumps(config or {}))
|
||||
|
||||
@classmethod
|
||||
async def get_type(cls) -> str:
|
||||
provider_type, _ = await cls.load_config()
|
||||
return provider_type
|
||||
|
||||
@classmethod
|
||||
async def get_config(cls) -> Dict[str, Any]:
|
||||
_, config = await cls.load_config()
|
||||
return config
|
||||
|
||||
|
||||
def _normalize_embedding_dim(value: Any) -> Optional[int]:
|
||||
if value is None:
|
||||
return None
|
||||
@@ -47,17 +94,6 @@ def _apply_embedding_dim_to_metadata(
|
||||
return data
|
||||
|
||||
|
||||
def normalize_capabilities(items: Optional[Iterable[str]]) -> List[str]:
|
||||
if not items:
|
||||
return []
|
||||
normalized = []
|
||||
for cap in items:
|
||||
key = str(cap).strip().lower()
|
||||
if key in ABILITIES and key not in normalized:
|
||||
normalized.append(key)
|
||||
return normalized
|
||||
|
||||
|
||||
def infer_openai_capabilities(model_id: str) -> Tuple[List[str], Optional[int]]:
|
||||
lower = model_id.lower()
|
||||
caps = set()
|
||||
@@ -139,40 +175,46 @@ def provider_to_dict(provider: AIProvider, models: Optional[List[AIModel]] = Non
|
||||
|
||||
|
||||
class AIProviderService:
|
||||
async def list_providers(self) -> List[Dict[str, Any]]:
|
||||
@classmethod
|
||||
async def list_providers(cls) -> List[Dict[str, Any]]:
|
||||
providers = await AIProvider.all().order_by("id").prefetch_related("models")
|
||||
return [provider_to_dict(p, models=list(p.models)) for p in providers]
|
||||
|
||||
async def get_provider(self, provider_id: int, with_models: bool = False) -> Dict[str, Any]:
|
||||
@classmethod
|
||||
async def get_provider(cls, provider_id: int, with_models: bool = False) -> Dict[str, Any]:
|
||||
if with_models:
|
||||
provider = await AIProvider.get(id=provider_id)
|
||||
models = await provider.models.all()
|
||||
return provider_to_dict(provider, models=models)
|
||||
else:
|
||||
provider = await AIProvider.get(id=provider_id)
|
||||
return provider_to_dict(provider)
|
||||
provider = await AIProvider.get(id=provider_id)
|
||||
return provider_to_dict(provider)
|
||||
|
||||
async def create_provider(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
@classmethod
|
||||
async def create_provider(cls, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
data = payload.copy()
|
||||
data.setdefault("extra_config", {})
|
||||
provider = await AIProvider.create(**data)
|
||||
return provider_to_dict(provider)
|
||||
|
||||
async def update_provider(self, provider_id: int, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
@classmethod
|
||||
async def update_provider(cls, provider_id: int, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
provider = await AIProvider.get(id=provider_id)
|
||||
for field, value in payload.items():
|
||||
setattr(provider, field, value)
|
||||
await provider.save()
|
||||
return provider_to_dict(provider)
|
||||
|
||||
async def delete_provider(self, provider_id: int) -> None:
|
||||
@classmethod
|
||||
async def delete_provider(cls, provider_id: int) -> None:
|
||||
await AIProvider.filter(id=provider_id).delete()
|
||||
|
||||
async def list_models(self, provider_id: int) -> List[Dict[str, Any]]:
|
||||
@classmethod
|
||||
async def list_models(cls, provider_id: int) -> List[Dict[str, Any]]:
|
||||
models = await AIModel.filter(provider_id=provider_id).order_by("id").prefetch_related("provider")
|
||||
return [model_to_dict(m) for m in models]
|
||||
|
||||
async def create_model(self, provider_id: int, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
@classmethod
|
||||
async def create_model(cls, provider_id: int, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
data = payload.copy()
|
||||
data["provider_id"] = provider_id
|
||||
data["capabilities"] = normalize_capabilities(data.get("capabilities"))
|
||||
@@ -182,7 +224,8 @@ class AIProviderService:
|
||||
await model.fetch_related("provider")
|
||||
return model_to_dict(model)
|
||||
|
||||
async def update_model(self, model_id: int, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
@classmethod
|
||||
async def update_model(cls, model_id: int, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
model = await AIModel.get(id=model_id)
|
||||
data = payload.copy()
|
||||
if "capabilities" in data:
|
||||
@@ -199,14 +242,17 @@ class AIProviderService:
|
||||
await model.fetch_related("provider")
|
||||
return model_to_dict(model)
|
||||
|
||||
async def delete_model(self, model_id: int) -> None:
|
||||
@classmethod
|
||||
async def delete_model(cls, model_id: int) -> None:
|
||||
await AIModel.filter(id=model_id).delete()
|
||||
|
||||
async def fetch_remote_models(self, provider_id: int) -> List[Dict[str, Any]]:
|
||||
@classmethod
|
||||
async def fetch_remote_models(cls, provider_id: int) -> List[Dict[str, Any]]:
|
||||
provider = await AIProvider.get(id=provider_id)
|
||||
return await self._get_remote_models(provider)
|
||||
return await cls._get_remote_models(provider)
|
||||
|
||||
async def _get_remote_models(self, provider: AIProvider) -> List[Dict[str, Any]]:
|
||||
@classmethod
|
||||
async def _get_remote_models(cls, provider: AIProvider) -> List[Dict[str, Any]]:
|
||||
if not provider.base_url:
|
||||
raise ValueError("Provider base_url is required for syncing models")
|
||||
|
||||
@@ -215,12 +261,13 @@ class AIProviderService:
|
||||
raise ValueError(f"Unsupported api_format '{provider.api_format}' for syncing models")
|
||||
|
||||
if fmt == "openai":
|
||||
return await self._fetch_openai_models(provider)
|
||||
return await self._fetch_gemini_models(provider)
|
||||
return await cls._fetch_openai_models(provider)
|
||||
return await cls._fetch_gemini_models(provider)
|
||||
|
||||
async def sync_models(self, provider_id: int) -> Dict[str, int]:
|
||||
@classmethod
|
||||
async def sync_models(cls, provider_id: int) -> Dict[str, int]:
|
||||
provider = await AIProvider.get(id=provider_id)
|
||||
remote_models = await self._get_remote_models(provider)
|
||||
remote_models = await cls._get_remote_models(provider)
|
||||
|
||||
created = 0
|
||||
updated = 0
|
||||
@@ -247,14 +294,16 @@ class AIProviderService:
|
||||
|
||||
return {"created": created, "updated": updated}
|
||||
|
||||
async def get_default_models(self) -> Dict[str, Optional[Dict[str, Any]]]:
|
||||
@classmethod
|
||||
async def get_default_models(cls) -> Dict[str, Optional[Dict[str, Any]]]:
|
||||
defaults = await AIDefaultModel.all().prefetch_related("model__provider")
|
||||
result: Dict[str, Optional[Dict[str, Any]]] = {ability: None for ability in ABILITIES}
|
||||
for item in defaults:
|
||||
result[item.ability] = model_to_dict(item.model, provider=item.model.provider) # type: ignore[attr-defined]
|
||||
return result
|
||||
|
||||
async def set_default_models(self, mapping: Dict[str, Optional[int]]) -> Dict[str, Optional[Dict[str, Any]]]:
|
||||
@classmethod
|
||||
async def set_default_models(cls, mapping: Dict[str, Optional[int]]) -> Dict[str, Optional[Dict[str, Any]]]:
|
||||
normalized = {ability: mapping.get(ability) for ability in ABILITIES}
|
||||
async with in_transaction() as connection:
|
||||
for ability, model_id in normalized.items():
|
||||
@@ -271,9 +320,10 @@ class AIProviderService:
|
||||
await AIDefaultModel.create(ability=ability, model_id=model_id)
|
||||
elif record:
|
||||
await record.delete(using_db=connection)
|
||||
return await self.get_default_models()
|
||||
return await cls.get_default_models()
|
||||
|
||||
async def get_default_model(self, ability: str) -> Optional[AIModel]:
|
||||
@classmethod
|
||||
async def get_default_model(cls, ability: str) -> Optional[AIModel]:
|
||||
ability_key = ability.lower()
|
||||
if ability_key not in ABILITIES:
|
||||
return None
|
||||
@@ -285,7 +335,8 @@ class AIProviderService:
|
||||
await model.fetch_related("provider")
|
||||
return model
|
||||
|
||||
async def _fetch_openai_models(self, provider: AIProvider) -> List[Dict[str, Any]]:
|
||||
@classmethod
|
||||
async def _fetch_openai_models(cls, provider: AIProvider) -> List[Dict[str, Any]]:
|
||||
base_url = provider.base_url.rstrip("/")
|
||||
url = f"{base_url}/models"
|
||||
headers = {}
|
||||
@@ -315,7 +366,8 @@ class AIProviderService:
|
||||
})
|
||||
return entries
|
||||
|
||||
async def _fetch_gemini_models(self, provider: AIProvider) -> List[Dict[str, Any]]:
|
||||
@classmethod
|
||||
async def _fetch_gemini_models(cls, provider: AIProvider) -> List[Dict[str, Any]]:
|
||||
base_url = provider.base_url.rstrip("/")
|
||||
suffix = "/models"
|
||||
if provider.api_key:
|
||||
@@ -345,3 +397,105 @@ class AIProviderService:
|
||||
"metadata": item,
|
||||
})
|
||||
return entries
|
||||
|
||||
|
||||
class VectorDBService:
|
||||
_instance: "VectorDBService" | None = None
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(self, "_provider"):
|
||||
self._provider: Optional[BaseVectorProvider] = None
|
||||
self._provider_type: Optional[str] = None
|
||||
self._provider_config: Dict[str, Any] | None = None
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def _ensure_provider(self) -> BaseVectorProvider:
|
||||
if self._provider is None:
|
||||
await self.reload()
|
||||
assert self._provider is not None
|
||||
return self._provider
|
||||
|
||||
async def reload(self) -> BaseVectorProvider:
|
||||
async with self._lock:
|
||||
provider_type, provider_config = await VectorDBConfigManager.load_config()
|
||||
normalized_config = dict(provider_config or {})
|
||||
if (
|
||||
self._provider
|
||||
and self._provider_type == provider_type
|
||||
and self._provider_config == normalized_config
|
||||
):
|
||||
return self._provider
|
||||
|
||||
entry = get_provider_entry(provider_type)
|
||||
if not entry:
|
||||
raise RuntimeError(f"Unknown vector database provider: {provider_type}")
|
||||
if not entry.get("enabled", True):
|
||||
raise RuntimeError(f"Vector database provider '{provider_type}' is disabled")
|
||||
|
||||
provider_cls = get_provider_class(provider_type)
|
||||
if not provider_cls:
|
||||
raise RuntimeError(f"Provider class not found for '{provider_type}'")
|
||||
|
||||
provider = provider_cls(provider_config)
|
||||
await provider.initialize()
|
||||
|
||||
self._provider = provider
|
||||
self._provider_type = provider_type
|
||||
self._provider_config = normalized_config
|
||||
return provider
|
||||
|
||||
async def ensure_collection(self, collection_name: str, vector: bool = True, dim: int = DEFAULT_VECTOR_DIMENSION) -> None:
|
||||
provider = await self._ensure_provider()
|
||||
provider.ensure_collection(collection_name, vector, dim)
|
||||
|
||||
async def upsert_vector(self, collection_name: str, data: Dict[str, Any]) -> None:
|
||||
provider = await self._ensure_provider()
|
||||
provider.upsert_vector(collection_name, data)
|
||||
|
||||
async def delete_vector(self, collection_name: str, path: str) -> None:
|
||||
provider = await self._ensure_provider()
|
||||
provider.delete_vector(collection_name, path)
|
||||
|
||||
async def search_vectors(self, collection_name: str, query_embedding, top_k: int = 5):
|
||||
provider = await self._ensure_provider()
|
||||
return provider.search_vectors(collection_name, query_embedding, top_k)
|
||||
|
||||
async def search_by_path(self, collection_name: str, query_path: str, top_k: int = 20):
|
||||
provider = await self._ensure_provider()
|
||||
return provider.search_by_path(collection_name, query_path, top_k)
|
||||
|
||||
async def get_all_stats(self) -> Dict[str, Any]:
|
||||
provider = await self._ensure_provider()
|
||||
return provider.get_all_stats()
|
||||
|
||||
async def clear_all_data(self) -> None:
|
||||
provider = await self._ensure_provider()
|
||||
provider.clear_all_data()
|
||||
|
||||
async def current_provider(self) -> Dict[str, Any]:
|
||||
provider_type, provider_config = await VectorDBConfigManager.load_config()
|
||||
entry = get_provider_entry(provider_type) or {}
|
||||
return {
|
||||
"type": provider_type,
|
||||
"config": provider_config,
|
||||
"label": entry.get("label"),
|
||||
"enabled": entry.get("enabled", True),
|
||||
}
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AIProviderService",
|
||||
"VectorDBService",
|
||||
"VectorDBConfigManager",
|
||||
"DEFAULT_VECTOR_DIMENSION",
|
||||
"list_providers",
|
||||
"get_provider_entry",
|
||||
"get_provider_class",
|
||||
"normalize_capabilities",
|
||||
"ABILITIES",
|
||||
]
|
||||
@@ -1,8 +1,19 @@
|
||||
from typing import List, Optional
|
||||
from typing import Any, Dict, Iterable, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from services.ai_providers import ABILITIES, normalize_capabilities
|
||||
ABILITIES = ["chat", "vision", "embedding", "rerank", "voice", "tools"]
|
||||
|
||||
|
||||
def normalize_capabilities(items: Optional[Iterable[str]]) -> List[str]:
|
||||
if not items:
|
||||
return []
|
||||
normalized: List[str] = []
|
||||
for cap in items:
|
||||
key = str(cap).strip().lower()
|
||||
if key in ABILITIES and key not in normalized:
|
||||
normalized.append(key)
|
||||
return normalized
|
||||
|
||||
|
||||
class AIProviderBase(BaseModel):
|
||||
@@ -16,6 +27,7 @@ class AIProviderBase(BaseModel):
|
||||
extra_config: Optional[dict] = None
|
||||
|
||||
@field_validator("api_format")
|
||||
@classmethod
|
||||
def normalize_format(cls, value: str) -> str:
|
||||
fmt = value.lower()
|
||||
if fmt not in {"openai", "gemini"}:
|
||||
@@ -37,6 +49,7 @@ class AIProviderUpdate(BaseModel):
|
||||
extra_config: Optional[dict] = None
|
||||
|
||||
@field_validator("api_format")
|
||||
@classmethod
|
||||
def normalize_format(cls, value: Optional[str]) -> Optional[str]:
|
||||
if value is None:
|
||||
return value
|
||||
@@ -56,6 +69,7 @@ class AIModelBase(BaseModel):
|
||||
metadata: Optional[dict] = None
|
||||
|
||||
@field_validator("capabilities")
|
||||
@classmethod
|
||||
def validate_capabilities(cls, items: Optional[List[str]]) -> Optional[List[str]]:
|
||||
if items is None:
|
||||
return None
|
||||
@@ -79,6 +93,7 @@ class AIModelUpdate(BaseModel):
|
||||
metadata: Optional[dict] = None
|
||||
|
||||
@field_validator("capabilities")
|
||||
@classmethod
|
||||
def validate_capabilities(cls, items: Optional[List[str]]) -> Optional[List[str]]:
|
||||
if items is None:
|
||||
return None
|
||||
@@ -97,5 +112,10 @@ class AIDefaultsUpdate(BaseModel):
|
||||
voice: Optional[int] = None
|
||||
tools: Optional[int] = None
|
||||
|
||||
def as_mapping(self) -> dict:
|
||||
def as_mapping(self) -> Dict[str, Optional[int]]:
|
||||
return {ability: getattr(self, ability) for ability in ABILITIES}
|
||||
|
||||
|
||||
class VectorDBConfigPayload(BaseModel):
|
||||
type: str = Field(..., description="向量数据库提供者类型")
|
||||
config: Dict[str, Any] = Field(default_factory=dict, description="提供者配置参数")
|
||||
@@ -54,3 +54,14 @@ def get_provider_class(provider_type: str) -> Type[BaseVectorProvider] | None:
|
||||
if not entry:
|
||||
return None
|
||||
return entry.get("class") # type: ignore[return-value]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BaseVectorProvider",
|
||||
"MilvusLiteProvider",
|
||||
"MilvusServerProvider",
|
||||
"QdrantProvider",
|
||||
"list_providers",
|
||||
"get_provider_entry",
|
||||
"get_provider_class",
|
||||
]
|
||||
@@ -155,7 +155,7 @@ class MilvusLiteProvider(BaseVectorProvider):
|
||||
def search_by_path(self, collection_name: str, query_path: str, top_k: int):
|
||||
if query_path:
|
||||
escaped = query_path.replace('"', '\\"')
|
||||
filter_expr = f'source_path like "%{escaped}%"'
|
||||
filter_expr = f'source_path like \"%{escaped}%\"'
|
||||
else:
|
||||
filter_expr = "source_path like '%%'"
|
||||
results = self._get_client().query(
|
||||
@@ -232,7 +232,7 @@ class MilvusLiteProvider(BaseVectorProvider):
|
||||
|
||||
for index_name in index_names:
|
||||
try:
|
||||
detail = client.describe_index(name, index_name) or {}
|
||||
detail = client.describe_index(name) or {}
|
||||
except Exception:
|
||||
detail = {}
|
||||
indexes.append(
|
||||
@@ -162,7 +162,7 @@ class MilvusServerProvider(BaseVectorProvider):
|
||||
def search_by_path(self, collection_name: str, query_path: str, top_k: int):
|
||||
if query_path:
|
||||
escaped = query_path.replace('"', '\\"')
|
||||
filter_expr = f'source_path like "%{escaped}%"'
|
||||
filter_expr = f'source_path like \"%{escaped}%\"'
|
||||
else:
|
||||
filter_expr = "source_path like '%%'"
|
||||
results = self._get_client().query(
|
||||
@@ -239,7 +239,7 @@ class MilvusServerProvider(BaseVectorProvider):
|
||||
|
||||
for index_name in index_names:
|
||||
try:
|
||||
detail = client.describe_index(name, index_name) or {}
|
||||
detail = client.describe_index(name) or {}
|
||||
except Exception:
|
||||
detail = {}
|
||||
indexes.append(
|
||||
@@ -42,7 +42,6 @@ class QdrantProvider(BaseVectorProvider):
|
||||
api_key = (self.config.get("api_key") or None) or None
|
||||
try:
|
||||
client = QdrantClient(url=url, api_key=api_key)
|
||||
# 简单连通性校验
|
||||
client.get_collections()
|
||||
self.client = client
|
||||
except Exception as exc: # pragma: no cover - 依赖外部服务
|
||||
@@ -70,7 +69,6 @@ class QdrantProvider(BaseVectorProvider):
|
||||
message = str(exc).lower()
|
||||
if "already exists" in message or "index exists" in message:
|
||||
continue
|
||||
# 旧版本 qdrant 可能返回带状态码的异常,这里容忍重复创建
|
||||
raise
|
||||
|
||||
def ensure_collection(self, collection_name: str, vector: bool, dim: int) -> None:
|
||||
5
domain/audit/__init__.py
Normal file
5
domain/audit/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from domain.audit.decorator import audit
|
||||
from domain.audit.types import AuditAction
|
||||
from domain.audit.api import router
|
||||
|
||||
__all__ = ["audit", "AuditAction", "router"]
|
||||
68
domain/audit/api.py
Normal file
68
domain/audit/api.py
Normal file
@@ -0,0 +1,68 @@
|
||||
from datetime import datetime, timezone
|
||||
from typing import Annotated, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
|
||||
from api import response
|
||||
from domain.audit.service import AuditService
|
||||
from domain.audit.types import AuditAction
|
||||
from domain.auth.service import get_current_active_user
|
||||
from domain.auth.types import User
|
||||
|
||||
CurrentUser = Annotated[User, Depends(get_current_active_user)]
|
||||
|
||||
router = APIRouter(prefix="/api/audit", tags=["Audit"])
|
||||
|
||||
|
||||
def _parse_iso(value: Optional[str], field: str):
|
||||
if not value:
|
||||
return None
|
||||
try:
|
||||
normalized = value.replace("Z", "+00:00")
|
||||
dt = datetime.fromisoformat(normalized)
|
||||
if dt.tzinfo:
|
||||
dt = dt.astimezone(timezone.utc).replace(tzinfo=None)
|
||||
return dt
|
||||
except ValueError as exc: # noqa: BLE001
|
||||
raise HTTPException(status_code=400, detail=f"invalid {field}") from exc
|
||||
|
||||
|
||||
@router.get("/logs")
|
||||
async def list_audit_logs(
|
||||
current_user: CurrentUser,
|
||||
page_num: int = Query(1, ge=1, alias="page", description="页码"),
|
||||
page_size: int = Query(20, ge=1, le=200, description="每页条数"),
|
||||
action: AuditAction | None = Query(None, description="操作类型"),
|
||||
success: bool | None = Query(None, description="是否成功"),
|
||||
username: str | None = Query(None, description="用户名模糊匹配"),
|
||||
path: str | None = Query(None, description="路径模糊匹配"),
|
||||
start_time: str | None = Query(None, description="开始时间 (ISO 8601)"),
|
||||
end_time: str | None = Query(None, description="结束时间 (ISO 8601)"),
|
||||
):
|
||||
start_dt = _parse_iso(start_time, "start_time")
|
||||
end_dt = _parse_iso(end_time, "end_time")
|
||||
items, total = await AuditService.list_logs(
|
||||
page=page_num,
|
||||
page_size=page_size,
|
||||
action=str(action) if action else None,
|
||||
success=success,
|
||||
username=username,
|
||||
path=path,
|
||||
start_time=start_dt,
|
||||
end_time=end_dt,
|
||||
)
|
||||
return response.success(response.page(items, total, page_num, page_size))
|
||||
|
||||
|
||||
@router.delete("/logs")
|
||||
async def clear_audit_logs(
|
||||
current_user: CurrentUser,
|
||||
start_time: str | None = Query(None, description="开始时间 (ISO 8601)"),
|
||||
end_time: str | None = Query(None, description="结束时间 (ISO 8601)"),
|
||||
):
|
||||
start_dt = _parse_iso(start_time, "start_time")
|
||||
end_dt = _parse_iso(end_time, "end_time")
|
||||
if start_dt is None and end_dt is None:
|
||||
raise HTTPException(status_code=400, detail="start_time 或 end_time 至少提供一个")
|
||||
deleted_count = await AuditService.clear_logs(start_time=start_dt, end_time=end_dt)
|
||||
return response.success({"deleted_count": deleted_count})
|
||||
182
domain/audit/decorator.py
Normal file
182
domain/audit/decorator.py
Normal file
@@ -0,0 +1,182 @@
|
||||
import inspect
|
||||
import time
|
||||
from functools import wraps
|
||||
from typing import Any, Dict, Mapping, Optional
|
||||
|
||||
import jwt
|
||||
from fastapi import Request
|
||||
from jwt.exceptions import InvalidTokenError
|
||||
|
||||
from domain.audit.service import AuditService
|
||||
from domain.audit.types import AuditAction
|
||||
from domain.auth.service import ALGORITHM
|
||||
from domain.config.service import ConfigService
|
||||
from models.database import UserAccount
|
||||
|
||||
|
||||
def _extract_request(bound_args: Mapping[str, Any]) -> Request | None:
|
||||
for value in bound_args.values():
|
||||
if isinstance(value, Request):
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
async def _resolve_user(request: Request | None, user_obj: Any | None) -> tuple[Optional[int], Optional[str]]:
|
||||
user_id: int | None = None
|
||||
username: str | None = None
|
||||
|
||||
if request:
|
||||
auth_header = request.headers.get("authorization") or request.headers.get("Authorization")
|
||||
if auth_header and auth_header.lower().startswith("bearer "):
|
||||
token = auth_header.split(" ", 1)[1]
|
||||
try:
|
||||
payload = jwt.decode(token, await ConfigService.get_secret_key("SECRET_KEY"), algorithms=[ALGORITHM])
|
||||
username = payload.get("sub") or payload.get("username")
|
||||
if username:
|
||||
user = await UserAccount.get_or_none(username=username)
|
||||
user_id = user.id if user else None
|
||||
except (InvalidTokenError, Exception):
|
||||
pass
|
||||
|
||||
if user_id is None and username is None and user_obj is not None:
|
||||
user_id = getattr(user_obj, "id", None) or getattr(user_obj, "user_id", None)
|
||||
username = getattr(user_obj, "username", None) or getattr(user_obj, "name", None)
|
||||
if isinstance(user_obj, dict):
|
||||
user_id = user_obj.get("id", user_obj.get("user_id", user_id))
|
||||
username = user_obj.get("username", user_obj.get("name", username))
|
||||
|
||||
return user_id, username
|
||||
|
||||
|
||||
def _extract_body_fields(bound_args: Mapping[str, Any], body_fields: list[str] | None, redact_fields: list[str] | None):
|
||||
if not body_fields:
|
||||
return None
|
||||
body: Dict[str, Any] = {}
|
||||
redacts = set(redact_fields or [])
|
||||
for value in bound_args.values():
|
||||
data: Optional[Dict[str, Any]] = None
|
||||
if hasattr(value, "model_dump"):
|
||||
try:
|
||||
data = value.model_dump()
|
||||
except Exception:
|
||||
data = None
|
||||
elif hasattr(value, "dict"):
|
||||
try:
|
||||
data = value.dict()
|
||||
except Exception:
|
||||
data = None
|
||||
elif isinstance(value, dict):
|
||||
data = value
|
||||
elif hasattr(value, "__dict__"):
|
||||
data = dict(value.__dict__)
|
||||
if not isinstance(data, dict):
|
||||
continue
|
||||
for field in body_fields:
|
||||
if field in data and field not in body:
|
||||
body[field] = data[field]
|
||||
if not body:
|
||||
return None
|
||||
for field in redacts:
|
||||
if field in body:
|
||||
body[field] = "<redacted>"
|
||||
return body
|
||||
|
||||
|
||||
def _build_request_params(request: Request | None) -> Dict[str, Any] | None:
|
||||
if not request:
|
||||
return None
|
||||
params: Dict[str, Any] = {}
|
||||
query = dict(request.query_params)
|
||||
if query:
|
||||
params["query"] = query
|
||||
path_params = dict(request.path_params or {})
|
||||
if path_params:
|
||||
params["path"] = path_params
|
||||
return params or None
|
||||
|
||||
|
||||
def _status_code_from_response(response: Any) -> int:
|
||||
if hasattr(response, "status_code"):
|
||||
try:
|
||||
return int(getattr(response, "status_code"))
|
||||
except Exception:
|
||||
pass
|
||||
return 200
|
||||
|
||||
|
||||
def audit(
|
||||
*,
|
||||
action: AuditAction,
|
||||
description: str | None = None,
|
||||
body_fields: list[str] | None = None,
|
||||
redact_fields: list[str] | None = None,
|
||||
user_kw: str = "current_user",
|
||||
):
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
bound = inspect.signature(func).bind_partial(*args, **kwargs)
|
||||
bound.apply_defaults()
|
||||
request = _extract_request(bound.arguments)
|
||||
start = time.perf_counter()
|
||||
user_info = bound.arguments.get(user_kw)
|
||||
user_id, username = await _resolve_user(request, user_info)
|
||||
request_params = _build_request_params(request)
|
||||
request_body = _extract_body_fields(bound.arguments, body_fields, redact_fields)
|
||||
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
if inspect.isawaitable(result):
|
||||
result = await result
|
||||
status_code = _status_code_from_response(result)
|
||||
success = True
|
||||
error = None
|
||||
except Exception as exc: # noqa: BLE001
|
||||
status_code = getattr(exc, "status_code", 500)
|
||||
success = False
|
||||
error = str(exc)
|
||||
duration_ms = round((time.perf_counter() - start) * 1000, 2)
|
||||
try:
|
||||
await AuditService.log(
|
||||
action=action,
|
||||
description=description,
|
||||
user_id=user_id,
|
||||
username=username,
|
||||
client_ip=request.client.host if request and request.client else None,
|
||||
method=request.method if request else "",
|
||||
path=request.url.path if request else func.__name__,
|
||||
status_code=status_code,
|
||||
duration_ms=duration_ms,
|
||||
success=success,
|
||||
request_params=request_params,
|
||||
request_body=request_body,
|
||||
error=error,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
raise
|
||||
|
||||
duration_ms = round((time.perf_counter() - start) * 1000, 2)
|
||||
try:
|
||||
await AuditService.log(
|
||||
action=action,
|
||||
description=description,
|
||||
user_id=user_id,
|
||||
username=username,
|
||||
client_ip=request.client.host if request and request.client else None,
|
||||
method=request.method if request else "",
|
||||
path=request.url.path if request else func.__name__,
|
||||
status_code=status_code,
|
||||
duration_ms=duration_ms,
|
||||
success=success,
|
||||
request_params=request_params,
|
||||
request_body=request_body,
|
||||
error=error,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
124
domain/audit/service.py
Normal file
124
domain/audit/service.py
Normal file
@@ -0,0 +1,124 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from models.database import AuditLog
|
||||
|
||||
from domain.audit.types import AuditAction
|
||||
|
||||
|
||||
class AuditService:
|
||||
@classmethod
|
||||
async def log(
|
||||
cls,
|
||||
*,
|
||||
action: AuditAction | str,
|
||||
description: Optional[str],
|
||||
user_id: Optional[int],
|
||||
username: Optional[str],
|
||||
client_ip: Optional[str],
|
||||
method: str,
|
||||
path: str,
|
||||
status_code: int,
|
||||
duration_ms: Optional[float],
|
||||
success: bool,
|
||||
request_params: Optional[Dict[str, Any]] = None,
|
||||
request_body: Optional[Dict[str, Any]] = None,
|
||||
error: Optional[str] = None,
|
||||
) -> None:
|
||||
await AuditLog.create(
|
||||
action=str(action),
|
||||
description=description,
|
||||
user_id=user_id,
|
||||
username=username,
|
||||
client_ip=client_ip,
|
||||
method=method,
|
||||
path=path,
|
||||
status_code=status_code,
|
||||
duration_ms=duration_ms,
|
||||
success=success,
|
||||
request_params=request_params,
|
||||
request_body=request_body,
|
||||
error=error,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _serialize(cls, log: AuditLog) -> Dict[str, Any]:
|
||||
return {
|
||||
"id": log.id,
|
||||
"created_at": log.created_at.isoformat() if log.created_at else None,
|
||||
"action": log.action,
|
||||
"description": log.description,
|
||||
"user_id": log.user_id,
|
||||
"username": log.username,
|
||||
"client_ip": log.client_ip,
|
||||
"method": log.method,
|
||||
"path": log.path,
|
||||
"status_code": log.status_code,
|
||||
"duration_ms": log.duration_ms,
|
||||
"success": log.success,
|
||||
"request_params": log.request_params,
|
||||
"request_body": log.request_body,
|
||||
"error": log.error,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _apply_filters(
|
||||
cls,
|
||||
*,
|
||||
action: str | None = None,
|
||||
success: bool | None = None,
|
||||
username: str | None = None,
|
||||
path: str | None = None,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
):
|
||||
qs = AuditLog.all()
|
||||
if action:
|
||||
qs = qs.filter(action=action)
|
||||
if success is not None:
|
||||
qs = qs.filter(success=success)
|
||||
if username:
|
||||
qs = qs.filter(username__icontains=username)
|
||||
if path:
|
||||
qs = qs.filter(path__icontains=path)
|
||||
if start_time:
|
||||
qs = qs.filter(created_at__gte=start_time)
|
||||
if end_time:
|
||||
qs = qs.filter(created_at__lte=end_time)
|
||||
return qs
|
||||
|
||||
@classmethod
|
||||
async def list_logs(
|
||||
cls,
|
||||
*,
|
||||
page: int,
|
||||
page_size: int,
|
||||
action: str | None = None,
|
||||
success: bool | None = None,
|
||||
username: str | None = None,
|
||||
path: str | None = None,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
):
|
||||
qs = cls._apply_filters(
|
||||
action=action,
|
||||
success=success,
|
||||
username=username,
|
||||
path=path,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
total = await qs.count()
|
||||
offset = (page - 1) * page_size
|
||||
items = await qs.order_by("-created_at").offset(offset).limit(page_size)
|
||||
return [cls._serialize(log) for log in items], total
|
||||
|
||||
@classmethod
|
||||
async def clear_logs(
|
||||
cls,
|
||||
*,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
) -> int:
|
||||
qs = cls._apply_filters(start_time=start_time, end_time=end_time)
|
||||
deleted_count = await qs.delete()
|
||||
return deleted_count
|
||||
16
domain/audit/types.py
Normal file
16
domain/audit/types.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class AuditAction(StrEnum):
|
||||
LOGIN = "login"
|
||||
LOGOUT = "logout"
|
||||
REGISTER = "register"
|
||||
READ = "read"
|
||||
CREATE = "create"
|
||||
UPDATE = "update"
|
||||
DELETE = "delete"
|
||||
RESET_PASSWORD = "reset_password"
|
||||
SHARE = "share"
|
||||
DOWNLOAD = "download"
|
||||
UPLOAD = "upload"
|
||||
OTHER = "other"
|
||||
90
domain/auth/api.py
Normal file
90
domain/auth/api.py
Normal file
@@ -0,0 +1,90 @@
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
|
||||
from api.response import success
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.auth.service import AuthService, get_current_active_user
|
||||
from domain.auth.types import (
|
||||
PasswordResetConfirm,
|
||||
PasswordResetRequest,
|
||||
RegisterRequest,
|
||||
Token,
|
||||
UpdateMeRequest,
|
||||
User,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/api/auth", tags=["auth"])
|
||||
|
||||
|
||||
@router.post("/register", summary="注册第一个管理员用户")
|
||||
@audit(
|
||||
action=AuditAction.REGISTER,
|
||||
description="注册管理员",
|
||||
body_fields=["username", "email", "full_name"],
|
||||
redact_fields=["password"],
|
||||
)
|
||||
async def register(request: Request, data: RegisterRequest):
|
||||
user = await AuthService.register_user(data)
|
||||
return success({"username": user.username}, msg="初始用户注册成功")
|
||||
|
||||
|
||||
@router.post("/login")
|
||||
@audit(action=AuditAction.LOGIN, description="用户登录", body_fields=["username"], redact_fields=["password"])
|
||||
async def login_for_access_token(
|
||||
request: Request,
|
||||
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
||||
) -> Token:
|
||||
return await AuthService.login(form_data)
|
||||
|
||||
|
||||
@router.get("/me", summary="获取当前登录用户信息")
|
||||
@audit(action=AuditAction.READ, description="获取当前用户信息")
|
||||
async def get_me(
|
||||
request: Request, current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
profile = AuthService.get_profile(current_user)
|
||||
return success(profile)
|
||||
|
||||
|
||||
@router.put("/me", summary="更新当前登录用户信息")
|
||||
@audit(
|
||||
action=AuditAction.UPDATE,
|
||||
description="更新当前用户信息",
|
||||
body_fields=["email", "full_name"],
|
||||
redact_fields=["old_password", "new_password"],
|
||||
)
|
||||
async def update_me(
|
||||
request: Request,
|
||||
payload: UpdateMeRequest,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
profile = await AuthService.update_me(payload, current_user)
|
||||
return success(profile)
|
||||
|
||||
|
||||
@router.post("/password-reset/request", summary="请求密码重置邮件")
|
||||
@audit(action=AuditAction.RESET_PASSWORD, description="请求密码重置邮件", body_fields=["email"])
|
||||
async def password_reset_request_endpoint(request: Request, payload: PasswordResetRequest):
|
||||
await AuthService.request_password_reset(payload)
|
||||
return success(msg="如果邮箱存在,将发送重置邮件")
|
||||
|
||||
|
||||
@router.get("/password-reset/verify", summary="校验密码重置令牌")
|
||||
@audit(action=AuditAction.RESET_PASSWORD, description="校验密码重置令牌", redact_fields=["token"])
|
||||
async def password_reset_verify(request: Request, token: str):
|
||||
user = await AuthService.verify_password_reset_token(token)
|
||||
return success({"username": user.username, "email": user.email})
|
||||
|
||||
|
||||
@router.post("/password-reset/confirm", summary="使用令牌重置密码")
|
||||
@audit(
|
||||
action=AuditAction.RESET_PASSWORD,
|
||||
description="重置密码",
|
||||
body_fields=["token"],
|
||||
redact_fields=["token", "password"],
|
||||
)
|
||||
async def password_reset_confirm(request: Request, payload: PasswordResetConfirm):
|
||||
await AuthService.reset_password_with_token(payload)
|
||||
return success(msg="密码已重置")
|
||||
356
domain/auth/service.py
Normal file
356
domain/auth/service.py
Normal file
@@ -0,0 +1,356 @@
|
||||
import asyncio
|
||||
import hashlib
|
||||
import secrets
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Annotated
|
||||
|
||||
import jwt
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
||||
from jwt.exceptions import InvalidTokenError
|
||||
from passlib.context import CryptContext
|
||||
|
||||
from domain.auth.types import (
|
||||
PasswordResetConfirm,
|
||||
PasswordResetRequest,
|
||||
RegisterRequest,
|
||||
Token,
|
||||
TokenData,
|
||||
UpdateMeRequest,
|
||||
User,
|
||||
UserInDB,
|
||||
)
|
||||
from models.database import UserAccount
|
||||
from domain.config.service import ConfigService
|
||||
|
||||
ALGORITHM = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 * 365
|
||||
PASSWORD_RESET_TOKEN_EXPIRE_MINUTES = 10
|
||||
|
||||
|
||||
def _now() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PasswordResetEntry:
|
||||
user_id: int
|
||||
email: str
|
||||
username: str
|
||||
expires_at: datetime
|
||||
used: bool = False
|
||||
|
||||
|
||||
class PasswordResetStore:
|
||||
_tokens: dict[str, PasswordResetEntry] = {}
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
@classmethod
|
||||
def _cleanup(cls):
|
||||
now = _now()
|
||||
for token, record in list(cls._tokens.items()):
|
||||
if record.used or record.expires_at < now:
|
||||
cls._tokens.pop(token, None)
|
||||
|
||||
@classmethod
|
||||
async def create(cls, user: UserAccount) -> str:
|
||||
async with cls._lock:
|
||||
cls._cleanup()
|
||||
for key, record in list(cls._tokens.items()):
|
||||
if record.user_id == user.id:
|
||||
cls._tokens.pop(key, None)
|
||||
token = secrets.token_urlsafe(32)
|
||||
expires_at = _now() + timedelta(minutes=PASSWORD_RESET_TOKEN_EXPIRE_MINUTES)
|
||||
cls._tokens[token] = PasswordResetEntry(
|
||||
user_id=user.id,
|
||||
email=user.email or "",
|
||||
username=user.username,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
return token
|
||||
|
||||
@classmethod
|
||||
async def get(cls, token: str) -> PasswordResetEntry | None:
|
||||
async with cls._lock:
|
||||
cls._cleanup()
|
||||
record = cls._tokens.get(token)
|
||||
if not record or record.used:
|
||||
return None
|
||||
return record
|
||||
|
||||
@classmethod
|
||||
async def mark_used(cls, token: str) -> None:
|
||||
async with cls._lock:
|
||||
record = cls._tokens.get(token)
|
||||
if record:
|
||||
record.used = True
|
||||
cls._cleanup()
|
||||
|
||||
@classmethod
|
||||
async def invalidate_user(cls, user_id: int, except_token: str | None = None) -> None:
|
||||
async with cls._lock:
|
||||
for key, record in list(cls._tokens.items()):
|
||||
if record.user_id == user_id and key != except_token:
|
||||
cls._tokens.pop(key, None)
|
||||
cls._cleanup()
|
||||
|
||||
|
||||
class AuthService:
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login")
|
||||
algorithm = ALGORITHM
|
||||
access_token_expire_minutes = ACCESS_TOKEN_EXPIRE_MINUTES
|
||||
password_reset_token_expire_minutes = PASSWORD_RESET_TOKEN_EXPIRE_MINUTES
|
||||
|
||||
@classmethod
|
||||
async def get_secret_key(cls) -> str:
|
||||
return await ConfigService.get_secret_key("SECRET_KEY", None)
|
||||
|
||||
@classmethod
|
||||
def _normalize_email(cls, email: str | None) -> str:
|
||||
return (email or "").strip().lower()
|
||||
|
||||
@classmethod
|
||||
def verify_password(cls, plain_password: str, hashed_password: str) -> bool:
|
||||
return cls.pwd_context.verify(plain_password, hashed_password)
|
||||
|
||||
@classmethod
|
||||
def get_password_hash(cls, password: str) -> str:
|
||||
return cls.pwd_context.hash(password)
|
||||
|
||||
@classmethod
|
||||
async def get_user_db(cls, username_or_email: str) -> UserInDB | None:
|
||||
user = await UserAccount.get_or_none(username=username_or_email)
|
||||
if not user:
|
||||
user = await UserAccount.get_or_none(email=username_or_email)
|
||||
if user:
|
||||
return UserInDB(
|
||||
id=user.id,
|
||||
username=user.username,
|
||||
email=user.email,
|
||||
full_name=user.full_name,
|
||||
disabled=user.disabled,
|
||||
hashed_password=user.hashed_password,
|
||||
)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
async def authenticate_user_db(cls, username_or_email: str, password: str) -> UserInDB | None:
|
||||
user = await cls.get_user_db(username_or_email)
|
||||
if not user:
|
||||
return None
|
||||
if not cls.verify_password(password, user.hashed_password):
|
||||
return None
|
||||
return user
|
||||
|
||||
@classmethod
|
||||
async def has_users(cls) -> bool:
|
||||
user_count = await UserAccount.all().count()
|
||||
return user_count > 0
|
||||
|
||||
@classmethod
|
||||
async def register_user(cls, payload: RegisterRequest):
|
||||
if await cls.has_users():
|
||||
raise HTTPException(status_code=403, detail="系统已初始化,不允许注册新用户")
|
||||
exists = await UserAccount.get_or_none(username=payload.username)
|
||||
if exists:
|
||||
raise HTTPException(status_code=400, detail="用户名已存在")
|
||||
hashed = cls.get_password_hash(payload.password)
|
||||
user = await UserAccount.create(
|
||||
username=payload.username,
|
||||
email=payload.email,
|
||||
full_name=payload.full_name,
|
||||
hashed_password=hashed,
|
||||
disabled=False,
|
||||
)
|
||||
return user
|
||||
|
||||
@classmethod
|
||||
async def create_access_token(cls, data: dict, expires_delta: timedelta | None = None):
|
||||
to_encode = data.copy()
|
||||
if "sub" not in to_encode and "username" in to_encode:
|
||||
to_encode["sub"] = to_encode["username"]
|
||||
expire = _now() + (expires_delta or timedelta(minutes=15))
|
||||
to_encode.update({"exp": expire})
|
||||
secret_key = await cls.get_secret_key()
|
||||
encoded_jwt = jwt.encode(to_encode, secret_key, algorithm=cls.algorithm)
|
||||
return encoded_jwt
|
||||
|
||||
@classmethod
|
||||
async def login(cls, form: OAuth2PasswordRequestForm) -> Token:
|
||||
user = await cls.authenticate_user_db(form.username, form.password)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="用户名或密码错误",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
access_token_expires = timedelta(minutes=cls.access_token_expire_minutes)
|
||||
access_token = await cls.create_access_token(
|
||||
data={"sub": user.username}, expires_delta=access_token_expires
|
||||
)
|
||||
return Token(access_token=access_token, token_type="bearer")
|
||||
|
||||
@classmethod
|
||||
def _build_profile(cls, user: User | UserInDB | UserAccount) -> dict:
|
||||
email = cls._normalize_email(getattr(user, "email", None))
|
||||
md5_hash = hashlib.md5(email.encode("utf-8")).hexdigest()
|
||||
gravatar_url = f"https://cn.cravatar.com/avatar/{md5_hash}?s=64&d=identicon"
|
||||
return {
|
||||
"id": user.id,
|
||||
"username": user.username,
|
||||
"email": getattr(user, "email", None),
|
||||
"full_name": getattr(user, "full_name", None),
|
||||
"gravatar_url": gravatar_url,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_profile(cls, user: User | UserInDB | UserAccount) -> dict:
|
||||
return cls._build_profile(user)
|
||||
|
||||
@classmethod
|
||||
async def update_me(cls, payload: UpdateMeRequest, current_user: User) -> dict:
|
||||
db_user = await UserAccount.get_or_none(id=current_user.id)
|
||||
if not db_user:
|
||||
raise HTTPException(status_code=404, detail="用户不存在")
|
||||
|
||||
if payload.email is not None:
|
||||
exists = (
|
||||
await UserAccount.filter(email=payload.email)
|
||||
.exclude(id=db_user.id)
|
||||
.exists()
|
||||
)
|
||||
if exists:
|
||||
raise HTTPException(status_code=400, detail="邮箱已被占用")
|
||||
db_user.email = payload.email
|
||||
|
||||
if payload.full_name is not None:
|
||||
db_user.full_name = payload.full_name
|
||||
|
||||
if payload.new_password:
|
||||
if not payload.old_password:
|
||||
raise HTTPException(status_code=400, detail="请提供原密码")
|
||||
if not cls.verify_password(payload.old_password, db_user.hashed_password):
|
||||
raise HTTPException(status_code=400, detail="原密码错误")
|
||||
db_user.hashed_password = cls.get_password_hash(payload.new_password)
|
||||
|
||||
await db_user.save()
|
||||
return cls._build_profile(db_user)
|
||||
|
||||
@classmethod
|
||||
async def request_password_reset(cls, payload: PasswordResetRequest) -> bool:
|
||||
normalized = cls._normalize_email(payload.email)
|
||||
if not normalized:
|
||||
return False
|
||||
user = await UserAccount.get_or_none(email=normalized)
|
||||
if not user or not user.email:
|
||||
return False
|
||||
|
||||
token = await PasswordResetStore.create(user)
|
||||
try:
|
||||
await cls._send_password_reset_email(user, token)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
await PasswordResetStore.mark_used(token)
|
||||
await PasswordResetStore.invalidate_user(user.id)
|
||||
raise HTTPException(status_code=500, detail="邮件发送失败") from exc
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
async def verify_password_reset_token(cls, token: str) -> UserAccount:
|
||||
record = await PasswordResetStore.get(token)
|
||||
if not record:
|
||||
raise HTTPException(status_code=400, detail="重置链接无效")
|
||||
user = await UserAccount.get_or_none(id=record.user_id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=400, detail="重置链接无效")
|
||||
if record.expires_at < _now():
|
||||
await PasswordResetStore.mark_used(token)
|
||||
raise HTTPException(status_code=400, detail="重置链接已过期")
|
||||
return user
|
||||
|
||||
@classmethod
|
||||
async def reset_password_with_token(cls, payload: PasswordResetConfirm) -> None:
|
||||
record = await PasswordResetStore.get(payload.token)
|
||||
if not record:
|
||||
raise HTTPException(status_code=400, detail="重置链接无效")
|
||||
if record.expires_at < _now():
|
||||
await PasswordResetStore.mark_used(payload.token)
|
||||
raise HTTPException(status_code=400, detail="重置链接已过期")
|
||||
|
||||
user = await UserAccount.get_or_none(id=record.user_id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=400, detail="重置链接无效")
|
||||
user.hashed_password = cls.get_password_hash(payload.password)
|
||||
await user.save(update_fields=["hashed_password"])
|
||||
await PasswordResetStore.mark_used(payload.token)
|
||||
await PasswordResetStore.invalidate_user(user.id)
|
||||
|
||||
@classmethod
|
||||
async def get_current_user(cls, token: str):
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
try:
|
||||
secret_key = await cls.get_secret_key()
|
||||
payload = jwt.decode(token, secret_key, algorithms=[cls.algorithm])
|
||||
username = payload.get("sub")
|
||||
if username is None:
|
||||
raise credentials_exception
|
||||
token_data = TokenData(username=username)
|
||||
except InvalidTokenError:
|
||||
raise credentials_exception
|
||||
user = await cls.get_user_db(token_data.username)
|
||||
if user is None:
|
||||
raise credentials_exception
|
||||
return user
|
||||
|
||||
@classmethod
|
||||
async def get_current_active_user(cls, current_user: User):
|
||||
if current_user.disabled:
|
||||
raise HTTPException(status_code=400, detail="Inactive user")
|
||||
return current_user
|
||||
|
||||
@classmethod
|
||||
async def _send_password_reset_email(cls, user: UserAccount, token: str) -> None:
|
||||
from domain.email.service import EmailService
|
||||
|
||||
app_domain = await ConfigService.get("APP_DOMAIN", None)
|
||||
base_url = (app_domain or "http://localhost:5173").rstrip("/")
|
||||
reset_link = f"{base_url}/reset-password?token={token}"
|
||||
await EmailService.enqueue_email(
|
||||
recipients=[user.email],
|
||||
subject="Foxel 密码重置",
|
||||
template="password_reset",
|
||||
context={
|
||||
"username": user.username,
|
||||
"reset_link": reset_link,
|
||||
"expire_minutes": cls.password_reset_token_expire_minutes,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def _current_user_dep(token: Annotated[str, Depends(AuthService.oauth2_scheme)]):
|
||||
return await AuthService.get_current_user(token)
|
||||
|
||||
|
||||
async def _current_active_user_dep(
|
||||
current_user: Annotated[User, Depends(_current_user_dep)],
|
||||
):
|
||||
return await AuthService.get_current_active_user(current_user)
|
||||
|
||||
|
||||
# 方便依赖注入与外部使用
|
||||
get_current_user = _current_user_dep
|
||||
get_current_active_user = _current_active_user_dep
|
||||
authenticate_user_db = AuthService.authenticate_user_db
|
||||
create_access_token = AuthService.create_access_token
|
||||
register_user = AuthService.register_user
|
||||
request_password_reset = AuthService.request_password_reset
|
||||
verify_password_reset_token = AuthService.verify_password_reset_token
|
||||
reset_password_with_token = AuthService.reset_password_with_token
|
||||
has_users = AuthService.has_users
|
||||
verify_password = AuthService.verify_password
|
||||
get_password_hash = AuthService.get_password_hash
|
||||
45
domain/auth/types.py
Normal file
45
domain/auth/types.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Token(BaseModel):
|
||||
access_token: str
|
||||
token_type: str
|
||||
|
||||
|
||||
class TokenData(BaseModel):
|
||||
username: str | None = None
|
||||
|
||||
|
||||
class User(BaseModel):
|
||||
id: int
|
||||
username: str
|
||||
email: str | None = None
|
||||
full_name: str | None = None
|
||||
disabled: bool | None = None
|
||||
|
||||
|
||||
class UserInDB(User):
|
||||
hashed_password: str
|
||||
|
||||
|
||||
class RegisterRequest(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
email: str | None = None
|
||||
full_name: str | None = None
|
||||
|
||||
|
||||
class UpdateMeRequest(BaseModel):
|
||||
email: str | None = None
|
||||
full_name: str | None = None
|
||||
old_password: str | None = None
|
||||
new_password: str | None = None
|
||||
|
||||
|
||||
class PasswordResetRequest(BaseModel):
|
||||
email: str
|
||||
|
||||
|
||||
class PasswordResetConfirm(BaseModel):
|
||||
token: str
|
||||
password: str
|
||||
1
domain/backup/__init__.py
Normal file
1
domain/backup/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
30
domain/backup/api.py
Normal file
30
domain/backup/api.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, File, Request, UploadFile
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.auth.service import get_current_active_user
|
||||
from domain.backup.service import BackupService
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api/backup",
|
||||
tags=["Backup & Restore"],
|
||||
dependencies=[Depends(get_current_active_user)],
|
||||
)
|
||||
|
||||
|
||||
@router.get("/export", summary="导出全站数据")
|
||||
@audit(action=AuditAction.DOWNLOAD, description="导出备份")
|
||||
async def export_backup(request: Request):
|
||||
data = await BackupService.export_data()
|
||||
timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
headers = {"Content-Disposition": f"attachment; filename=foxel_backup_{timestamp}.json"}
|
||||
return JSONResponse(content=data.model_dump(), headers=headers)
|
||||
|
||||
|
||||
@router.post("/import", summary="导入数据")
|
||||
@audit(action=AuditAction.UPLOAD, description="导入备份")
|
||||
async def import_backup(request: Request, file: UploadFile = File(...)):
|
||||
await BackupService.import_from_bytes(file.filename, await file.read())
|
||||
return {"message": "数据导入成功。"}
|
||||
91
domain/backup/service.py
Normal file
91
domain/backup/service.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import json
|
||||
|
||||
from fastapi import HTTPException
|
||||
from tortoise.transactions import in_transaction
|
||||
|
||||
from domain.backup.types import BackupData
|
||||
from domain.config.service import VERSION
|
||||
from models.database import (
|
||||
AutomationTask,
|
||||
Configuration,
|
||||
ShareLink,
|
||||
StorageAdapter,
|
||||
UserAccount,
|
||||
)
|
||||
|
||||
|
||||
class BackupService:
|
||||
@classmethod
|
||||
async def export_data(cls) -> BackupData:
|
||||
async with in_transaction():
|
||||
adapters = await StorageAdapter.all().values()
|
||||
users = await UserAccount.all().values()
|
||||
tasks = await AutomationTask.all().values()
|
||||
shares = await ShareLink.all().values()
|
||||
configs = await Configuration.all().values()
|
||||
|
||||
for share in shares:
|
||||
share["created_at"] = (
|
||||
share["created_at"].isoformat() if share.get("created_at") else None
|
||||
)
|
||||
share["expires_at"] = (
|
||||
share["expires_at"].isoformat() if share.get("expires_at") else None
|
||||
)
|
||||
|
||||
return BackupData(
|
||||
version=VERSION,
|
||||
storage_adapters=list(adapters),
|
||||
user_accounts=list(users),
|
||||
automation_tasks=list(tasks),
|
||||
share_links=list(shares),
|
||||
configurations=list(configs),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def import_from_bytes(cls, filename: str, content: bytes) -> None:
|
||||
if not filename.endswith(".json"):
|
||||
raise HTTPException(status_code=400, detail="无效的文件类型, 请上传 .json 文件")
|
||||
try:
|
||||
raw_data = json.loads(content)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=400, detail="无法解析JSON文件")
|
||||
await cls.import_data(BackupData(**raw_data))
|
||||
|
||||
@classmethod
|
||||
async def import_data(cls, payload: BackupData) -> None:
|
||||
async with in_transaction() as conn:
|
||||
await ShareLink.all().using_db(conn).delete()
|
||||
await AutomationTask.all().using_db(conn).delete()
|
||||
await StorageAdapter.all().using_db(conn).delete()
|
||||
await UserAccount.all().using_db(conn).delete()
|
||||
await Configuration.all().using_db(conn).delete()
|
||||
|
||||
if payload.configurations:
|
||||
await Configuration.bulk_create(
|
||||
[Configuration(**config) for config in payload.configurations],
|
||||
using_db=conn,
|
||||
)
|
||||
|
||||
if payload.user_accounts:
|
||||
await UserAccount.bulk_create(
|
||||
[UserAccount(**user) for user in payload.user_accounts],
|
||||
using_db=conn,
|
||||
)
|
||||
|
||||
if payload.storage_adapters:
|
||||
await StorageAdapter.bulk_create(
|
||||
[StorageAdapter(**adapter) for adapter in payload.storage_adapters],
|
||||
using_db=conn,
|
||||
)
|
||||
|
||||
if payload.automation_tasks:
|
||||
await AutomationTask.bulk_create(
|
||||
[AutomationTask(**task) for task in payload.automation_tasks],
|
||||
using_db=conn,
|
||||
)
|
||||
|
||||
if payload.share_links:
|
||||
await ShareLink.bulk_create(
|
||||
[ShareLink(**share) for share in payload.share_links],
|
||||
using_db=conn,
|
||||
)
|
||||
12
domain/backup/types.py
Normal file
12
domain/backup/types.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class BackupData(BaseModel):
|
||||
version: str | None = None
|
||||
storage_adapters: list[dict[str, Any]] = Field(default_factory=list)
|
||||
user_accounts: list[dict[str, Any]] = Field(default_factory=list)
|
||||
automation_tasks: list[dict[str, Any]] = Field(default_factory=list)
|
||||
share_links: list[dict[str, Any]] = Field(default_factory=list)
|
||||
configurations: list[dict[str, Any]] = Field(default_factory=list)
|
||||
59
domain/config/api.py
Normal file
59
domain/config/api.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Form, Request
|
||||
|
||||
from api.response import success
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.auth.service import get_current_active_user
|
||||
from domain.auth.types import User
|
||||
from domain.config.service import ConfigService
|
||||
from domain.config.types import ConfigItem
|
||||
|
||||
router = APIRouter(prefix="/api/config", tags=["config"])
|
||||
|
||||
|
||||
@router.get("/")
|
||||
@audit(action=AuditAction.READ, description="获取配置")
|
||||
async def get_config(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
key: str,
|
||||
):
|
||||
value = await ConfigService.get(key)
|
||||
return success(ConfigItem(key=key, value=value).model_dump())
|
||||
|
||||
|
||||
@router.post("/")
|
||||
@audit(action=AuditAction.UPDATE, description="设置配置", body_fields=["key", "value"])
|
||||
async def set_config(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
key: str = Form(...),
|
||||
value: str = Form(...),
|
||||
):
|
||||
await ConfigService.set(key, value)
|
||||
return success(ConfigItem(key=key, value=value).model_dump())
|
||||
|
||||
|
||||
@router.get("/all")
|
||||
@audit(action=AuditAction.READ, description="获取全部配置")
|
||||
async def get_all_config(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
configs = await ConfigService.get_all()
|
||||
return success(configs)
|
||||
|
||||
|
||||
@router.get("/status")
|
||||
@audit(action=AuditAction.READ, description="获取系统状态")
|
||||
async def get_system_status(request: Request):
|
||||
status_data = await ConfigService.get_system_status()
|
||||
return success(status_data.model_dump())
|
||||
|
||||
|
||||
@router.get("/latest-version")
|
||||
@audit(action=AuditAction.READ, description="获取最新版本")
|
||||
async def get_latest_version(request: Request):
|
||||
info = await ConfigService.get_latest_version()
|
||||
return success(info.model_dump())
|
||||
111
domain/config/service.py
Normal file
111
domain/config/service.py
Normal file
@@ -0,0 +1,111 @@
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import httpx
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from domain.config.types import LatestVersionInfo, SystemStatus
|
||||
from models.database import Configuration, UserAccount
|
||||
|
||||
load_dotenv(dotenv_path=".env")
|
||||
|
||||
VERSION = "v1.3.8"
|
||||
|
||||
|
||||
class ConfigService:
|
||||
_cache: Dict[str, Any] = {}
|
||||
_latest_version_cache: Dict[str, Any] = {"timestamp": 0.0, "data": None}
|
||||
|
||||
@classmethod
|
||||
async def get(cls, key: str, default: Optional[Any] = None) -> Any:
|
||||
if key in cls._cache:
|
||||
return cls._cache[key]
|
||||
try:
|
||||
config = await Configuration.get_or_none(key=key)
|
||||
if config:
|
||||
cls._cache[key] = config.value
|
||||
return config.value
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
env_value = os.getenv(key)
|
||||
if env_value is not None:
|
||||
cls._cache[key] = env_value
|
||||
return env_value
|
||||
return default
|
||||
|
||||
@classmethod
|
||||
async def get_secret_key(cls, key: str, default: Optional[Any] = None) -> bytes:
|
||||
value = await cls.get(key, default)
|
||||
if isinstance(value, bytes):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
return value.encode("utf-8")
|
||||
if value is None:
|
||||
raise ValueError(f"Secret key '{key}' not found in config or environment.")
|
||||
return str(value).encode("utf-8")
|
||||
|
||||
@classmethod
|
||||
async def set(cls, key: str, value: Any):
|
||||
obj, _ = await Configuration.get_or_create(key=key, defaults={"value": value})
|
||||
obj.value = value
|
||||
await obj.save()
|
||||
cls._cache[key] = value
|
||||
|
||||
@classmethod
|
||||
async def get_all(cls) -> Dict[str, Any]:
|
||||
try:
|
||||
configs = await Configuration.all()
|
||||
result = {}
|
||||
for config in configs:
|
||||
result[config.key] = config.value
|
||||
cls._cache[config.key] = config.value
|
||||
return result
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def clear_cache(cls):
|
||||
cls._cache.clear()
|
||||
|
||||
@classmethod
|
||||
async def get_system_status(cls) -> SystemStatus:
|
||||
logo = await cls.get("APP_LOGO", "/logo.svg")
|
||||
favicon = await cls.get("APP_FAVICON", logo)
|
||||
user_count = await UserAccount.all().count()
|
||||
return SystemStatus(
|
||||
version=VERSION,
|
||||
title=await cls.get("APP_NAME", "Foxel"),
|
||||
logo=logo,
|
||||
favicon=favicon,
|
||||
is_initialized=user_count > 0,
|
||||
app_domain=await cls.get("APP_DOMAIN"),
|
||||
file_domain=await cls.get("FILE_DOMAIN"),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def get_latest_version(cls) -> LatestVersionInfo:
|
||||
current_time = time.time()
|
||||
cache = cls._latest_version_cache
|
||||
if current_time - cache["timestamp"] < 3600 and cache["data"]:
|
||||
return cache["data"]
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.get(
|
||||
"https://api.github.com/repos/DrizzleTime/Foxel/releases/latest",
|
||||
follow_redirects=True,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
version_info = LatestVersionInfo(
|
||||
latest_version=data.get("tag_name"),
|
||||
body=data.get("body"),
|
||||
)
|
||||
cache["timestamp"] = current_time
|
||||
cache["data"] = version_info
|
||||
return version_info
|
||||
except httpx.RequestError:
|
||||
if cache["data"]:
|
||||
return cache["data"]
|
||||
return LatestVersionInfo()
|
||||
23
domain/config/types.py
Normal file
23
domain/config/types.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ConfigItem(BaseModel):
|
||||
key: str
|
||||
value: Optional[Any] = None
|
||||
|
||||
|
||||
class SystemStatus(BaseModel):
|
||||
version: str
|
||||
title: str
|
||||
logo: str
|
||||
favicon: str
|
||||
is_initialized: bool
|
||||
app_domain: Optional[str] = None
|
||||
file_domain: Optional[str] = None
|
||||
|
||||
|
||||
class LatestVersionInfo(BaseModel):
|
||||
latest_version: Optional[str] = None
|
||||
body: Optional[str] = None
|
||||
@@ -1,20 +1,24 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
|
||||
from services.auth import User, get_current_active_user
|
||||
from services.email import EmailService, EmailTemplateRenderer
|
||||
from schemas.email import EmailTestRequest, EmailTemplateUpdate, EmailTemplatePreviewPayload
|
||||
from api.response import success
|
||||
from services.logging import LogService
|
||||
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api/email",
|
||||
tags=["email"],
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.auth.service import get_current_active_user
|
||||
from domain.auth.types import User
|
||||
from domain.email.service import EmailService, EmailTemplateRenderer
|
||||
from domain.email.types import (
|
||||
EmailTemplatePreviewPayload,
|
||||
EmailTemplateUpdate,
|
||||
EmailTestRequest,
|
||||
)
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/email", tags=["email"])
|
||||
|
||||
|
||||
@router.post("/test")
|
||||
@audit(action=AuditAction.CREATE, description="发送测试邮件")
|
||||
async def trigger_test_email(
|
||||
request: Request,
|
||||
payload: EmailTestRequest,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
):
|
||||
@@ -27,17 +31,13 @@ async def trigger_test_email(
|
||||
)
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc))
|
||||
await LogService.action(
|
||||
"route:email",
|
||||
"Triggered email test",
|
||||
details={"task_id": task.id, "template": payload.template, "to": str(payload.to)},
|
||||
user_id=getattr(current_user, "id", None),
|
||||
)
|
||||
return success({"task_id": task.id})
|
||||
|
||||
|
||||
@router.get("/templates")
|
||||
@audit(action=AuditAction.READ, description="获取邮件模板列表")
|
||||
async def list_email_templates(
|
||||
request: Request,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
):
|
||||
templates = await EmailTemplateRenderer.list_templates()
|
||||
@@ -45,7 +45,9 @@ async def list_email_templates(
|
||||
|
||||
|
||||
@router.get("/templates/{name}")
|
||||
@audit(action=AuditAction.READ, description="查看邮件模板")
|
||||
async def get_email_template(
|
||||
request: Request,
|
||||
name: str,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
):
|
||||
@@ -59,7 +61,9 @@ async def get_email_template(
|
||||
|
||||
|
||||
@router.post("/templates/{name}")
|
||||
@audit(action=AuditAction.UPDATE, description="更新邮件模板")
|
||||
async def update_email_template(
|
||||
request: Request,
|
||||
name: str,
|
||||
payload: EmailTemplateUpdate,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
@@ -68,17 +72,13 @@ async def update_email_template(
|
||||
await EmailTemplateRenderer.save(name, payload.content)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc))
|
||||
await LogService.action(
|
||||
"route:email",
|
||||
"Updated email template",
|
||||
details={"template": name},
|
||||
user_id=getattr(current_user, "id", None),
|
||||
)
|
||||
return success({"name": name})
|
||||
|
||||
|
||||
@router.post("/templates/{name}/preview")
|
||||
@audit(action=AuditAction.READ, description="预览邮件模板")
|
||||
async def preview_email_template(
|
||||
request: Request,
|
||||
name: str,
|
||||
payload: EmailTemplatePreviewPayload,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
@@ -1,42 +1,14 @@
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import smtplib
|
||||
from email.message import EmailMessage
|
||||
from email.utils import formataddr
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from string import Template
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, EmailStr, Field, ValidationError
|
||||
|
||||
from services.config import ConfigCenter
|
||||
from services.logging import LogService
|
||||
|
||||
|
||||
class EmailSecurity(str, Enum):
|
||||
NONE = "none"
|
||||
SSL = "ssl"
|
||||
STARTTLS = "starttls"
|
||||
|
||||
|
||||
class EmailConfig(BaseModel):
|
||||
host: str
|
||||
port: int = Field(..., gt=0)
|
||||
username: Optional[str] = None
|
||||
password: Optional[str] = None
|
||||
sender_email: EmailStr
|
||||
sender_name: Optional[str] = None
|
||||
security: EmailSecurity = EmailSecurity.NONE
|
||||
timeout: float = Field(default=30.0, gt=0.0)
|
||||
|
||||
|
||||
class EmailSendPayload(BaseModel):
|
||||
recipients: List[EmailStr] = Field(..., min_length=1)
|
||||
subject: str = Field(..., min_length=1)
|
||||
template: str = Field(..., min_length=1)
|
||||
context: Dict[str, Any] = Field(default_factory=dict)
|
||||
from domain.config.service import ConfigService
|
||||
from domain.email.types import EmailConfig, EmailSecurity, EmailSendPayload
|
||||
|
||||
|
||||
class EmailTemplateRenderer:
|
||||
@@ -52,9 +24,7 @@ class EmailTemplateRenderer:
|
||||
async def list_templates(cls) -> list[str]:
|
||||
cls.ROOT.mkdir(parents=True, exist_ok=True)
|
||||
return sorted(
|
||||
path.stem
|
||||
for path in cls.ROOT.glob("*.html")
|
||||
if path.is_file()
|
||||
path.stem for path in cls.ROOT.glob("*.html") if path.is_file()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -82,22 +52,8 @@ class EmailService:
|
||||
|
||||
@classmethod
|
||||
async def _load_config(cls) -> EmailConfig:
|
||||
raw_config = await ConfigCenter.get(cls.CONFIG_KEY)
|
||||
if raw_config is None:
|
||||
raise ValueError("Email configuration not found")
|
||||
|
||||
if isinstance(raw_config, str):
|
||||
raw_config = raw_config.strip()
|
||||
data: Any = json.loads(raw_config) if raw_config else {}
|
||||
elif isinstance(raw_config, dict):
|
||||
data = raw_config
|
||||
else:
|
||||
raise ValueError("Invalid email configuration format")
|
||||
|
||||
try:
|
||||
return EmailConfig(**data)
|
||||
except ValidationError as exc:
|
||||
raise ValueError(f"Invalid email configuration: {exc}") from exc
|
||||
raw_config = await ConfigService.get(cls.CONFIG_KEY)
|
||||
return EmailConfig.parse_config(raw_config)
|
||||
|
||||
@staticmethod
|
||||
def _html_to_text(html: str) -> str:
|
||||
@@ -108,7 +64,9 @@ class EmailService:
|
||||
async def _deliver(cls, config: EmailConfig, payload: EmailSendPayload, html_body: str):
|
||||
message = EmailMessage()
|
||||
message["Subject"] = payload.subject
|
||||
message["From"] = formataddr((config.sender_name or str(config.sender_email), str(config.sender_email)))
|
||||
message["From"] = formataddr(
|
||||
(config.sender_name or str(config.sender_email), str(config.sender_email))
|
||||
)
|
||||
message["To"] = ", ".join([str(addr) for addr in payload.recipients])
|
||||
|
||||
plain_body = cls._html_to_text(html_body)
|
||||
@@ -120,7 +78,9 @@ class EmailService:
|
||||
@staticmethod
|
||||
def _deliver_sync(config: EmailConfig, message: EmailMessage):
|
||||
if config.security == EmailSecurity.SSL:
|
||||
smtp: smtplib.SMTP = smtplib.SMTP_SSL(config.host, config.port, timeout=config.timeout)
|
||||
smtp: smtplib.SMTP = smtplib.SMTP_SSL(
|
||||
config.host, config.port, timeout=config.timeout
|
||||
)
|
||||
else:
|
||||
smtp = smtplib.SMTP(config.host, config.port, timeout=config.timeout)
|
||||
|
||||
@@ -144,7 +104,7 @@ class EmailService:
|
||||
template: str,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
from services.task_queue import TaskProgress, task_queue_service
|
||||
from domain.tasks.task_queue import TaskProgress, task_queue_service
|
||||
|
||||
payload = EmailSendPayload(
|
||||
recipients=recipients,
|
||||
@@ -162,16 +122,11 @@ class EmailService:
|
||||
task.id,
|
||||
TaskProgress(stage="queued", percent=0.0, detail="Waiting to send"),
|
||||
)
|
||||
await LogService.action(
|
||||
"email_service",
|
||||
"Email task enqueued",
|
||||
details={"task_id": task.id, "subject": subject, "template": template},
|
||||
)
|
||||
return task
|
||||
|
||||
@classmethod
|
||||
async def send_from_task(cls, task_id: str, data: Dict[str, Any]):
|
||||
from services.task_queue import TaskProgress, task_queue_service
|
||||
from domain.tasks.task_queue import TaskProgress, task_queue_service
|
||||
|
||||
payload = EmailSendPayload(**data)
|
||||
|
||||
@@ -194,8 +149,3 @@ class EmailService:
|
||||
task_id,
|
||||
TaskProgress(stage="completed", percent=100.0, detail="Email sent"),
|
||||
)
|
||||
await LogService.info(
|
||||
"email_service",
|
||||
"Email sent",
|
||||
details={"task_id": task_id, "subject": payload.subject},
|
||||
)
|
||||
63
domain/email/types.py
Normal file
63
domain/email/types.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import json
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, EmailStr, Field, ValidationError
|
||||
|
||||
|
||||
class EmailSecurity(str, Enum):
|
||||
NONE = "none"
|
||||
SSL = "ssl"
|
||||
STARTTLS = "starttls"
|
||||
|
||||
|
||||
class EmailConfig(BaseModel):
|
||||
host: str
|
||||
port: int = Field(..., gt=0)
|
||||
username: Optional[str] = None
|
||||
password: Optional[str] = None
|
||||
sender_email: EmailStr
|
||||
sender_name: Optional[str] = None
|
||||
security: EmailSecurity = EmailSecurity.NONE
|
||||
timeout: float = Field(default=30.0, gt=0.0)
|
||||
|
||||
@classmethod
|
||||
def parse_config(cls, raw_config: Any) -> "EmailConfig":
|
||||
"""接受字符串或 dict 配置并解析为 EmailConfig。"""
|
||||
if raw_config is None:
|
||||
raise ValueError("Email configuration not found")
|
||||
|
||||
if isinstance(raw_config, str):
|
||||
raw_config = raw_config.strip()
|
||||
data: Any = json.loads(raw_config) if raw_config else {}
|
||||
elif isinstance(raw_config, dict):
|
||||
data = raw_config
|
||||
else:
|
||||
raise ValueError("Invalid email configuration format")
|
||||
|
||||
try:
|
||||
return cls(**data)
|
||||
except ValidationError as exc:
|
||||
raise ValueError(f"Invalid email configuration: {exc}") from exc
|
||||
|
||||
|
||||
class EmailSendPayload(BaseModel):
|
||||
recipients: List[EmailStr] = Field(..., min_length=1)
|
||||
subject: str = Field(..., min_length=1)
|
||||
template: str = Field(..., min_length=1)
|
||||
context: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class EmailTestRequest(BaseModel):
|
||||
to: EmailStr
|
||||
subject: str = Field(..., min_length=1)
|
||||
template: str = Field(default="test", min_length=1)
|
||||
context: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class EmailTemplateUpdate(BaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
class EmailTemplatePreviewPayload(BaseModel):
|
||||
context: Dict[str, Any] = Field(default_factory=dict)
|
||||
42
domain/offline_downloads/api.py
Normal file
42
domain/offline_downloads/api.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
|
||||
from api.response import success
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.auth.service import get_current_active_user
|
||||
from domain.auth.types import User
|
||||
from domain.offline_downloads.service import OfflineDownloadService
|
||||
from domain.offline_downloads.types import OfflineDownloadCreate
|
||||
|
||||
CurrentUser = Annotated[User, Depends(get_current_active_user)]
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api/offline-downloads",
|
||||
tags=["OfflineDownloads"],
|
||||
)
|
||||
|
||||
|
||||
@router.post("/")
|
||||
@audit(
|
||||
action=AuditAction.CREATE,
|
||||
description="创建离线下载任务",
|
||||
body_fields=["url", "dest_dir", "filename"],
|
||||
)
|
||||
async def create_offline_download(request: Request, payload: OfflineDownloadCreate, current_user: CurrentUser):
|
||||
data = await OfflineDownloadService.create_download(payload, current_user)
|
||||
return success(data)
|
||||
|
||||
|
||||
@router.get("/")
|
||||
@audit(action=AuditAction.READ, description="获取离线下载列表")
|
||||
async def list_offline_downloads(request: Request, current_user: CurrentUser):
|
||||
data = OfflineDownloadService.list_downloads()
|
||||
return success(data)
|
||||
|
||||
|
||||
@router.get("/{task_id}")
|
||||
@audit(action=AuditAction.READ, description="获取离线下载详情")
|
||||
async def get_offline_download(task_id: str, request: Request, current_user: CurrentUser):
|
||||
data = OfflineDownloadService.get_download(task_id)
|
||||
return success(data)
|
||||
252
domain/offline_downloads/service.py
Normal file
252
domain/offline_downloads/service.py
Normal file
@@ -0,0 +1,252 @@
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Annotated, AsyncIterator
|
||||
|
||||
import aiofiles
|
||||
import aiohttp
|
||||
from fastapi import Depends, HTTPException
|
||||
|
||||
from domain.auth.service import get_current_active_user
|
||||
from domain.auth.types import User
|
||||
from domain.offline_downloads.types import OfflineDownloadCreate
|
||||
from domain.virtual_fs.service import VirtualFSService
|
||||
from domain.tasks.task_queue import Task, TaskProgress, task_queue_service
|
||||
|
||||
|
||||
class OfflineDownloadService:
|
||||
current_user_dep = Annotated[User, Depends(get_current_active_user)]
|
||||
temp_root = Path("data/tmp/offline_downloads")
|
||||
|
||||
@classmethod
|
||||
async def create_download(cls, payload: OfflineDownloadCreate, current_user: User) -> dict:
|
||||
await cls._ensure_destination(payload.dest_dir)
|
||||
task = await task_queue_service.add_task(
|
||||
"offline_http_download",
|
||||
{
|
||||
"url": str(payload.url),
|
||||
"dest_dir": payload.dest_dir,
|
||||
"filename": payload.filename,
|
||||
},
|
||||
)
|
||||
|
||||
await task_queue_service.update_progress(
|
||||
task.id,
|
||||
TaskProgress(
|
||||
stage="queued",
|
||||
percent=0.0,
|
||||
bytes_total=None,
|
||||
bytes_done=0,
|
||||
detail="Waiting to start",
|
||||
),
|
||||
)
|
||||
|
||||
return {"task_id": task.id}
|
||||
|
||||
@classmethod
|
||||
def list_downloads(cls) -> list[dict]:
|
||||
tasks = [t for t in task_queue_service.get_all_tasks() if t.name == "offline_http_download"]
|
||||
return [t.dict() for t in tasks]
|
||||
|
||||
@classmethod
|
||||
def get_download(cls, task_id: str) -> dict:
|
||||
task = task_queue_service.get_task(task_id)
|
||||
if not task or task.name != "offline_http_download":
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
return task.dict()
|
||||
|
||||
@classmethod
|
||||
async def run_http_download(cls, task: Task):
|
||||
params = task.task_info
|
||||
url = params.get("url")
|
||||
dest_dir = params.get("dest_dir")
|
||||
filename = params.get("filename")
|
||||
|
||||
if not url or not dest_dir or not filename:
|
||||
raise ValueError("Missing required parameters for offline download")
|
||||
|
||||
cls.temp_root.mkdir(parents=True, exist_ok=True)
|
||||
temp_dir = cls.temp_root / task.id
|
||||
temp_dir.mkdir(parents=True, exist_ok=True)
|
||||
temp_file = temp_dir / "payload"
|
||||
|
||||
bytes_total: int | None = None
|
||||
bytes_done = 0
|
||||
last_update = time.monotonic()
|
||||
|
||||
await task_queue_service.update_progress(
|
||||
task.id,
|
||||
TaskProgress(
|
||||
stage="downloading",
|
||||
percent=0.0,
|
||||
bytes_total=None,
|
||||
bytes_done=0,
|
||||
detail="HTTP downloading",
|
||||
),
|
||||
)
|
||||
|
||||
async def report_download(delta: int, total: int | None):
|
||||
nonlocal bytes_done, bytes_total, last_update
|
||||
if total is not None:
|
||||
bytes_total = total
|
||||
bytes_done += delta
|
||||
now = time.monotonic()
|
||||
if delta and now - last_update < 0.5:
|
||||
return
|
||||
last_update = now
|
||||
percent = None
|
||||
total_for_display = bytes_total if bytes_total is not None else None
|
||||
if bytes_total:
|
||||
percent = min(100.0, round(bytes_done / bytes_total * 100, 2))
|
||||
await task_queue_service.update_progress(
|
||||
task.id,
|
||||
TaskProgress(
|
||||
stage="downloading",
|
||||
percent=percent,
|
||||
bytes_total=total_for_display,
|
||||
bytes_done=bytes_done,
|
||||
detail="HTTP downloading",
|
||||
),
|
||||
)
|
||||
|
||||
timeout = aiohttp.ClientTimeout(total=None, connect=30)
|
||||
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
async with session.get(url) as resp:
|
||||
if resp.status != 200:
|
||||
raise ValueError(f"HTTP {resp.status} for {url}")
|
||||
content_length = resp.headers.get("Content-Length")
|
||||
total_size = int(content_length) if content_length else None
|
||||
bytes_done = 0
|
||||
async with aiofiles.open(temp_file, "wb") as f:
|
||||
async for chunk in resp.content.iter_chunked(512 * 1024):
|
||||
if not chunk:
|
||||
continue
|
||||
await f.write(chunk)
|
||||
await report_download(len(chunk), total_size)
|
||||
await report_download(0, total_size)
|
||||
|
||||
file_size = os.path.getsize(temp_file)
|
||||
bytes_done_transfer = 0
|
||||
|
||||
async def report_transfer(delta: int):
|
||||
nonlocal bytes_done_transfer
|
||||
bytes_done_transfer += delta
|
||||
percent = min(100.0, round(bytes_done_transfer / file_size * 100, 2)) if file_size else None
|
||||
await task_queue_service.update_progress(
|
||||
task.id,
|
||||
TaskProgress(
|
||||
stage="transferring",
|
||||
percent=percent,
|
||||
bytes_total=file_size or None,
|
||||
bytes_done=bytes_done_transfer,
|
||||
detail="Saving to storage",
|
||||
),
|
||||
)
|
||||
|
||||
async def chunk_iter() -> AsyncIterator[bytes]:
|
||||
async for chunk in cls._iter_file(temp_file, 512 * 1024, report_transfer):
|
||||
yield chunk
|
||||
|
||||
final_path, resolved_name = await cls._allocate_destination(dest_dir, filename)
|
||||
|
||||
await task_queue_service.update_progress(
|
||||
task.id,
|
||||
TaskProgress(
|
||||
stage="transferring",
|
||||
percent=0.0,
|
||||
bytes_total=file_size or None,
|
||||
bytes_done=0,
|
||||
detail="Saving to storage",
|
||||
),
|
||||
)
|
||||
|
||||
await VirtualFSService.write_file_stream(final_path, chunk_iter())
|
||||
|
||||
await task_queue_service.update_progress(
|
||||
task.id,
|
||||
TaskProgress(
|
||||
stage="completed",
|
||||
percent=100.0,
|
||||
bytes_total=file_size or None,
|
||||
bytes_done=file_size,
|
||||
detail="Completed",
|
||||
),
|
||||
)
|
||||
await task_queue_service.update_meta(task.id, {"final_path": final_path, "filename": resolved_name})
|
||||
|
||||
try:
|
||||
os.remove(temp_file)
|
||||
temp_dir.rmdir()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return final_path
|
||||
|
||||
@classmethod
|
||||
async def _ensure_destination(cls, dest_dir: str) -> None:
|
||||
try:
|
||||
is_dir = await VirtualFSService.path_is_directory(dest_dir)
|
||||
except HTTPException:
|
||||
is_dir = False
|
||||
if not is_dir:
|
||||
raise HTTPException(400, detail="Destination directory not found")
|
||||
|
||||
@staticmethod
|
||||
def _normalize_path(path: str) -> str:
|
||||
if not path:
|
||||
return "/"
|
||||
if not path.startswith("/"):
|
||||
path = "/" + path
|
||||
if len(path) > 1 and path.endswith("/"):
|
||||
path = path.rstrip("/")
|
||||
return path or "/"
|
||||
|
||||
@staticmethod
|
||||
async def _path_exists(full_path: str) -> bool:
|
||||
try:
|
||||
await VirtualFSService.stat_file(full_path)
|
||||
return True
|
||||
except FileNotFoundError:
|
||||
return False
|
||||
except HTTPException as exc: # noqa: PERF203
|
||||
if exc.status_code == 404:
|
||||
return False
|
||||
raise
|
||||
|
||||
@classmethod
|
||||
async def _allocate_destination(cls, dest_dir: str, filename: str) -> tuple[str, str]:
|
||||
dest_dir = cls._normalize_path(dest_dir)
|
||||
stem, suffix = cls._split_filename(filename)
|
||||
candidate = filename
|
||||
base = "" if dest_dir == "/" else dest_dir
|
||||
attempt = 0
|
||||
while await cls._path_exists(f"{base}/{candidate}" if base else f"/{candidate}"):
|
||||
attempt += 1
|
||||
if stem:
|
||||
candidate = f"{stem} ({attempt}){suffix}"
|
||||
else:
|
||||
candidate = f"file ({attempt}){suffix}" if suffix else f"file ({attempt})"
|
||||
full_path = f"{base}/{candidate}" if base else f"/{candidate}"
|
||||
return full_path, candidate
|
||||
|
||||
@staticmethod
|
||||
def _split_filename(filename: str) -> tuple[str, str]:
|
||||
if not filename:
|
||||
return "", ""
|
||||
if filename.startswith(".") and filename.count(".") == 1:
|
||||
return filename, ""
|
||||
if "." not in filename:
|
||||
return filename, ""
|
||||
stem, ext = filename.rsplit(".", 1)
|
||||
return stem, f".{ext}"
|
||||
|
||||
@staticmethod
|
||||
async def _iter_file(path: Path, chunk_size: int, report_cb) -> AsyncIterator[bytes]:
|
||||
async with aiofiles.open(path, "rb") as f:
|
||||
while True:
|
||||
chunk = await f.read(chunk_size)
|
||||
if not chunk:
|
||||
break
|
||||
await report_cb(len(chunk))
|
||||
yield chunk
|
||||
1
domain/plugins/__init__.py
Normal file
1
domain/plugins/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
66
domain/plugins/api.py
Normal file
66
domain/plugins/api.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Body, Request
|
||||
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.plugins.service import PluginService
|
||||
from domain.plugins.types import PluginCreate, PluginManifestUpdate, PluginOut
|
||||
|
||||
router = APIRouter(prefix="/api/plugins", tags=["plugins"])
|
||||
|
||||
|
||||
@router.post("", response_model=PluginOut)
|
||||
@audit(
|
||||
action=AuditAction.CREATE,
|
||||
description="创建插件",
|
||||
body_fields=["url", "enabled"],
|
||||
)
|
||||
async def create_plugin(request: Request, payload: PluginCreate):
|
||||
return await PluginService.create(payload)
|
||||
|
||||
|
||||
@router.get("", response_model=List[PluginOut])
|
||||
@audit(action=AuditAction.READ, description="获取插件列表")
|
||||
async def list_plugins(request: Request):
|
||||
return await PluginService.list_plugins()
|
||||
|
||||
|
||||
@router.delete("/{plugin_id}")
|
||||
@audit(action=AuditAction.DELETE, description="删除插件")
|
||||
async def delete_plugin(request: Request, plugin_id: int):
|
||||
await PluginService.delete(plugin_id)
|
||||
return {"code": 0, "msg": "ok"}
|
||||
|
||||
|
||||
@router.put("/{plugin_id}", response_model=PluginOut)
|
||||
@audit(
|
||||
action=AuditAction.UPDATE,
|
||||
description="更新插件",
|
||||
body_fields=["url", "enabled"],
|
||||
)
|
||||
async def update_plugin(request: Request, plugin_id: int, payload: PluginCreate):
|
||||
return await PluginService.update(plugin_id, payload)
|
||||
|
||||
|
||||
@router.post("/{plugin_id}/metadata", response_model=PluginOut)
|
||||
@audit(
|
||||
action=AuditAction.UPDATE,
|
||||
description="更新插件 manifest",
|
||||
body_fields=[
|
||||
"key",
|
||||
"name",
|
||||
"version",
|
||||
"supported_exts",
|
||||
"default_bounds",
|
||||
"default_maximized",
|
||||
"icon",
|
||||
"description",
|
||||
"author",
|
||||
"website",
|
||||
"github",
|
||||
],
|
||||
)
|
||||
async def update_manifest(
|
||||
request: Request, plugin_id: int, manifest: PluginManifestUpdate = Body(...)
|
||||
):
|
||||
return await PluginService.update_manifest(plugin_id, manifest)
|
||||
48
domain/plugins/service.py
Normal file
48
domain/plugins/service.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from fastapi import HTTPException
|
||||
|
||||
from domain.plugins.types import PluginCreate, PluginManifestUpdate, PluginOut
|
||||
from models.database import Plugin
|
||||
|
||||
|
||||
class PluginService:
|
||||
@classmethod
|
||||
async def create(cls, payload: PluginCreate) -> PluginOut:
|
||||
rec = await Plugin.create(**payload.model_dump())
|
||||
return PluginOut.model_validate(rec)
|
||||
|
||||
@classmethod
|
||||
async def list_plugins(cls) -> list[PluginOut]:
|
||||
rows = await Plugin.all().order_by("-id")
|
||||
return [PluginOut.model_validate(r) for r in rows]
|
||||
|
||||
@classmethod
|
||||
async def _get_or_404(cls, plugin_id: int) -> Plugin:
|
||||
rec = await Plugin.get_or_none(id=plugin_id)
|
||||
if not rec:
|
||||
raise HTTPException(status_code=404, detail="Plugin not found")
|
||||
return rec
|
||||
|
||||
@classmethod
|
||||
async def delete(cls, plugin_id: int) -> None:
|
||||
rec = await cls._get_or_404(plugin_id)
|
||||
await rec.delete()
|
||||
|
||||
@classmethod
|
||||
async def update(cls, plugin_id: int, payload: PluginCreate) -> PluginOut:
|
||||
rec = await cls._get_or_404(plugin_id)
|
||||
rec.url = payload.url
|
||||
rec.enabled = payload.enabled
|
||||
await rec.save()
|
||||
return PluginOut.model_validate(rec)
|
||||
|
||||
@classmethod
|
||||
async def update_manifest(
|
||||
cls, plugin_id: int, manifest: PluginManifestUpdate
|
||||
) -> PluginOut:
|
||||
rec = await cls._get_or_404(plugin_id)
|
||||
updates = manifest.model_dump(exclude_none=True)
|
||||
if updates:
|
||||
for key, value in updates.items():
|
||||
setattr(rec, key, value)
|
||||
await rec.save()
|
||||
return PluginOut.model_validate(rec)
|
||||
52
domain/plugins/types.py
Normal file
52
domain/plugins/types.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import AliasChoices, BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class PluginCreate(BaseModel):
|
||||
url: str = Field(min_length=1)
|
||||
enabled: bool = True
|
||||
|
||||
|
||||
class PluginManifestUpdate(BaseModel):
|
||||
model_config = ConfigDict(populate_by_name=True, extra="ignore")
|
||||
|
||||
key: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
version: Optional[str] = None
|
||||
supported_exts: Optional[List[str]] = Field(
|
||||
default=None,
|
||||
validation_alias=AliasChoices("supported_exts", "supportedExts"),
|
||||
)
|
||||
default_bounds: Optional[Dict[str, Any]] = Field(
|
||||
default=None,
|
||||
validation_alias=AliasChoices("default_bounds", "defaultBounds"),
|
||||
)
|
||||
default_maximized: Optional[bool] = Field(
|
||||
default=None,
|
||||
validation_alias=AliasChoices("default_maximized", "defaultMaximized"),
|
||||
)
|
||||
icon: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
author: Optional[str] = None
|
||||
website: Optional[str] = None
|
||||
github: Optional[str] = None
|
||||
|
||||
|
||||
class PluginOut(BaseModel):
|
||||
id: int
|
||||
url: str
|
||||
enabled: bool
|
||||
key: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
version: Optional[str] = None
|
||||
supported_exts: Optional[List[str]] = None
|
||||
default_bounds: Optional[Dict[str, Any]] = None
|
||||
default_maximized: Optional[bool] = None
|
||||
icon: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
author: Optional[str] = None
|
||||
website: Optional[str] = None
|
||||
github: Optional[str] = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
89
domain/processors/api.py
Normal file
89
domain/processors/api.py
Normal file
@@ -0,0 +1,89 @@
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, Request
|
||||
|
||||
from api.response import success
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.auth.service import get_current_active_user
|
||||
from domain.auth.types import User
|
||||
from domain.processors.service import ProcessorService
|
||||
from domain.processors.types import (
|
||||
ProcessDirectoryRequest,
|
||||
ProcessRequest,
|
||||
UpdateSourceRequest,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/api/processors", tags=["processors"])
|
||||
|
||||
|
||||
@router.get("")
|
||||
@audit(action=AuditAction.READ, description="获取处理器列表")
|
||||
async def list_processors(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
data = ProcessorService.list_processors()
|
||||
return success(data)
|
||||
|
||||
|
||||
@router.post("/process")
|
||||
@audit(
|
||||
action=AuditAction.CREATE,
|
||||
description="处理单个文件",
|
||||
body_fields=["path", "processor_type", "save_to", "overwrite"],
|
||||
)
|
||||
async def process_file_with_processor(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
req: ProcessRequest = Body(...),
|
||||
):
|
||||
data = await ProcessorService.process_file(req)
|
||||
return success(data)
|
||||
|
||||
|
||||
@router.post("/process-directory")
|
||||
@audit(
|
||||
action=AuditAction.CREATE,
|
||||
description="批量处理目录",
|
||||
body_fields=["path", "processor_type", "overwrite", "max_depth", "suffix"],
|
||||
)
|
||||
async def process_directory_with_processor(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
req: ProcessDirectoryRequest = Body(...),
|
||||
):
|
||||
data = await ProcessorService.process_directory(req)
|
||||
return success(data)
|
||||
|
||||
|
||||
@router.get("/source/{processor_type}")
|
||||
@audit(action=AuditAction.READ, description="获取处理器源码")
|
||||
async def get_processor_source(
|
||||
request: Request,
|
||||
processor_type: str,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
data = await ProcessorService.get_source(processor_type)
|
||||
return success(data)
|
||||
|
||||
|
||||
@router.put("/source/{processor_type}")
|
||||
@audit(action=AuditAction.UPDATE, description="更新处理器源码")
|
||||
async def update_processor_source(
|
||||
request: Request,
|
||||
processor_type: str,
|
||||
req: UpdateSourceRequest,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
data = await ProcessorService.update_source(processor_type, req)
|
||||
return success(data)
|
||||
|
||||
|
||||
@router.post("/reload")
|
||||
@audit(action=AuditAction.UPDATE, description="重载处理器模块")
|
||||
async def reload_processor_modules(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
data = ProcessorService.reload()
|
||||
return success(data)
|
||||
1
domain/processors/builtin/__init__.py
Normal file
1
domain/processors/builtin/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# 内置处理器包
|
||||
@@ -1,9 +1,11 @@
|
||||
from .base import BaseProcessor
|
||||
from typing import Dict, Any
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from io import BytesIO
|
||||
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from fastapi.responses import Response
|
||||
from services.logging import LogService
|
||||
|
||||
from ..base import BaseProcessor
|
||||
|
||||
|
||||
class ImageWatermarkProcessor:
|
||||
name = "图片水印"
|
||||
@@ -26,10 +28,11 @@ class ImageWatermarkProcessor:
|
||||
]
|
||||
produces_file = True
|
||||
|
||||
async def process(self, input_bytes: bytes,path: str, config: Dict[str, Any]) -> Response:
|
||||
async def process(self, input_bytes: bytes, path: str, config: Dict[str, Any]) -> Response:
|
||||
text = config.get("text", "")
|
||||
position = config.get("position", "bottom-right")
|
||||
font_size = int(config.get("font_size", 24))
|
||||
|
||||
img = Image.open(BytesIO(input_bytes)).convert("RGBA")
|
||||
watermark = Image.new("RGBA", img.size)
|
||||
draw = ImageDraw.Draw(watermark)
|
||||
@@ -37,29 +40,29 @@ class ImageWatermarkProcessor:
|
||||
font = ImageFont.truetype("arial.ttf", font_size)
|
||||
except Exception:
|
||||
font = ImageFont.load_default()
|
||||
|
||||
w, h = img.size
|
||||
try:
|
||||
text_w, text_h = font.getsize(text)
|
||||
except AttributeError:
|
||||
bbox = draw.textbbox((0, 0), text, font=font)
|
||||
text_w, text_h = bbox[2] - bbox[0], bbox[3] - bbox[1]
|
||||
|
||||
if position == "bottom-right":
|
||||
xy = (w - text_w - 10, h - text_h - 10)
|
||||
elif position == "top-left":
|
||||
xy = (10, 10)
|
||||
else:
|
||||
xy = (w // 2 - text_w // 2, h // 2 - text_h // 2)
|
||||
|
||||
draw.text(xy, text, font=font, fill=(255, 255, 255, 128))
|
||||
out = Image.alpha_composite(img, watermark)
|
||||
buf = BytesIO()
|
||||
out.convert("RGB").save(buf, format="JPEG")
|
||||
await LogService.info(
|
||||
"processor:image_watermark",
|
||||
f"Watermarked image {path}",
|
||||
details={"path": path, "config": config},
|
||||
)
|
||||
|
||||
return Response(content=buf.getvalue(), media_type="image/jpeg")
|
||||
|
||||
|
||||
PROCESSOR_TYPE = "image_watermark"
|
||||
PROCESSOR_NAME = ImageWatermarkProcessor.name
|
||||
SUPPORTED_EXTS = ImageWatermarkProcessor.supported_exts
|
||||
@@ -1,15 +1,15 @@
|
||||
from typing import Dict, Any, List, Tuple
|
||||
from fastapi.responses import Response
|
||||
import base64
|
||||
import mimetypes
|
||||
import os
|
||||
from io import BytesIO
|
||||
from typing import Dict, Any, List, Tuple
|
||||
|
||||
from services.ai import describe_image_base64, get_text_embedding, provider_service
|
||||
from services.vector_db import VectorDBService, DEFAULT_VECTOR_DIMENSION
|
||||
from services.logging import LogService
|
||||
from fastapi.responses import Response
|
||||
from PIL import Image
|
||||
|
||||
from ..base import BaseProcessor
|
||||
from domain.ai.inference import describe_image_base64, get_text_embedding, provider_service
|
||||
from domain.ai.service import VectorDBService, DEFAULT_VECTOR_DIMENSION
|
||||
|
||||
|
||||
CHUNK_SIZE = 800
|
||||
@@ -116,11 +116,6 @@ class VectorIndexProcessor:
|
||||
|
||||
if action == "destroy":
|
||||
await vector_db.delete_vector(collection_name, path)
|
||||
await LogService.info(
|
||||
"processor:vector_index",
|
||||
f"Destroyed {index_type} index for {path}",
|
||||
details={"path": path, "action": "destroy", "index_type": index_type},
|
||||
)
|
||||
return Response(content=f"文件 {path} 的 {index_type} 索引已销毁", media_type="text/plain")
|
||||
|
||||
mime_type = _guess_mime(path)
|
||||
@@ -136,11 +131,6 @@ class VectorIndexProcessor:
|
||||
"type": "filename",
|
||||
"name": os.path.basename(path),
|
||||
})
|
||||
await LogService.info(
|
||||
"processor:vector_index",
|
||||
f"Created simple index for {path}",
|
||||
details={"path": path, "action": "create", "index_type": "simple"},
|
||||
)
|
||||
return Response(content=f"文件 {path} 的普通索引已创建", media_type="text/plain")
|
||||
|
||||
file_ext = path.split('.')[-1].lower()
|
||||
@@ -177,11 +167,6 @@ class VectorIndexProcessor:
|
||||
details["description"] = description
|
||||
if compression:
|
||||
details["image_compression"] = compression
|
||||
await LogService.info(
|
||||
"processor:vector_index",
|
||||
f"Indexed image {path}",
|
||||
details=details,
|
||||
)
|
||||
return Response(content=f"图片已索引,描述:{description}", media_type="text/plain")
|
||||
|
||||
if file_ext in ["txt", "md"]:
|
||||
@@ -204,11 +189,6 @@ class VectorIndexProcessor:
|
||||
"end_offset": len(text),
|
||||
})
|
||||
details["chunks"] = 1
|
||||
await LogService.info(
|
||||
"processor:vector_index",
|
||||
f"Indexed text file {path}",
|
||||
details=details,
|
||||
)
|
||||
return Response(content="文本文件已索引", media_type="text/plain")
|
||||
|
||||
chunk_count = 0
|
||||
@@ -230,11 +210,6 @@ class VectorIndexProcessor:
|
||||
details["chunks"] = chunk_count
|
||||
sample = chunks[0][1]
|
||||
details["sample"] = sample[:120]
|
||||
await LogService.info(
|
||||
"processor:vector_index",
|
||||
f"Indexed text file {path}",
|
||||
details=details,
|
||||
)
|
||||
return Response(content="文本文件已索引", media_type="text/plain")
|
||||
|
||||
# 其他类型暂未支持向量索引,回退为文件名索引
|
||||
@@ -248,11 +223,6 @@ class VectorIndexProcessor:
|
||||
"name": os.path.basename(path),
|
||||
"embedding": [0.0] * vector_dim,
|
||||
})
|
||||
await LogService.info(
|
||||
"processor:vector_index",
|
||||
f"File type fallback to simple index for {path}",
|
||||
details={"path": path, "action": "create", "index_type": "simple", "original_type": file_ext},
|
||||
)
|
||||
return Response(content="暂不支持该类型的向量索引,已创建文件名索引", media_type="text/plain")
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
from .base import BaseProcessor
|
||||
from domain.processors.base import BaseProcessor
|
||||
|
||||
ProcessorFactory = Callable[[], BaseProcessor]
|
||||
TYPE_MAP: Dict[str, ProcessorFactory] = {}
|
||||
@@ -15,10 +15,9 @@ LAST_DISCOVERY_ERRORS: list[str] = []
|
||||
|
||||
|
||||
def discover_processors(force_reload: bool = False) -> list[str]:
|
||||
"""Discover available processor modules and cache their metadata."""
|
||||
import services.processors # 延迟导入以避免循环
|
||||
"""扫描并缓存可用的处理器模块。"""
|
||||
from domain.processors import builtin as processors_pkg
|
||||
|
||||
processors_pkg = services.processors
|
||||
TYPE_MAP.clear()
|
||||
CONFIG_SCHEMAS.clear()
|
||||
MODULE_MAP.clear()
|
||||
@@ -51,8 +50,10 @@ def discover_processors(force_reload: bool = False) -> list[str]:
|
||||
if factory is None:
|
||||
for attr in module.__dict__.values():
|
||||
if inspect.isclass(attr) and attr.__name__.endswith("Processor"):
|
||||
|
||||
def _mk(cls=attr):
|
||||
return lambda: cls()
|
||||
|
||||
factory = _mk()
|
||||
break
|
||||
|
||||
@@ -114,7 +115,7 @@ def get_config_schema(processor_type: str):
|
||||
return CONFIG_SCHEMAS.get(processor_type)
|
||||
|
||||
|
||||
def get(processor_type: str) -> BaseProcessor:
|
||||
def get(processor_type: str) -> BaseProcessor | None:
|
||||
factory = TYPE_MAP.get(processor_type)
|
||||
if factory:
|
||||
return factory()
|
||||
217
domain/processors/service.py
Normal file
217
domain/processors/service.py
Normal file
@@ -0,0 +1,217 @@
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
from fastapi import HTTPException
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from domain.processors.registry import (
|
||||
get,
|
||||
get_config_schema,
|
||||
get_config_schemas,
|
||||
get_module_path,
|
||||
reload_processors,
|
||||
)
|
||||
from domain.processors.types import (
|
||||
ProcessDirectoryRequest,
|
||||
ProcessRequest,
|
||||
UpdateSourceRequest,
|
||||
)
|
||||
from domain.virtual_fs.service import VirtualFSService
|
||||
from domain.tasks.task_queue import task_queue_service
|
||||
|
||||
|
||||
class ProcessorService:
|
||||
@classmethod
|
||||
def get_processor(cls, processor_type: str):
|
||||
return get(processor_type)
|
||||
|
||||
@classmethod
|
||||
def list_processors(cls):
|
||||
schemas = get_config_schemas()
|
||||
out = []
|
||||
for t, meta in schemas.items():
|
||||
out.append({
|
||||
"type": meta["type"],
|
||||
"name": meta["name"],
|
||||
"supported_exts": meta.get("supported_exts", []),
|
||||
"config_schema": meta["config_schema"],
|
||||
"produces_file": meta.get("produces_file", False),
|
||||
"module_path": meta.get("module_path"),
|
||||
})
|
||||
return out
|
||||
|
||||
@classmethod
|
||||
async def process_file(cls, req: ProcessRequest):
|
||||
is_dir = await VirtualFSService.path_is_directory(req.path)
|
||||
if is_dir and not req.overwrite:
|
||||
raise HTTPException(400, detail="Directory processing requires overwrite")
|
||||
|
||||
save_to = None if is_dir else (req.path if req.overwrite else req.save_to)
|
||||
task = await task_queue_service.add_task(
|
||||
"process_file",
|
||||
{
|
||||
"path": req.path,
|
||||
"processor_type": req.processor_type,
|
||||
"config": req.config,
|
||||
"save_to": save_to,
|
||||
"overwrite": req.overwrite,
|
||||
},
|
||||
)
|
||||
return {"task_id": task.id}
|
||||
|
||||
@classmethod
|
||||
async def process_directory(cls, req: ProcessDirectoryRequest):
|
||||
if req.max_depth is not None and req.max_depth < 0:
|
||||
raise HTTPException(400, detail="max_depth must be >= 0")
|
||||
|
||||
is_dir = await VirtualFSService.path_is_directory(req.path)
|
||||
if not is_dir:
|
||||
raise HTTPException(400, detail="Path must be a directory")
|
||||
|
||||
schema = get_config_schema(req.processor_type)
|
||||
_processor = get(req.processor_type)
|
||||
if not schema or not _processor:
|
||||
raise HTTPException(404, detail="Processor not found")
|
||||
|
||||
produces_file = bool(schema.get("produces_file"))
|
||||
raw_suffix = req.suffix if req.suffix is not None else None
|
||||
if raw_suffix is not None and raw_suffix.strip() == "":
|
||||
raw_suffix = None
|
||||
suffix = raw_suffix
|
||||
overwrite = req.overwrite
|
||||
|
||||
if produces_file:
|
||||
if not overwrite and not suffix:
|
||||
raise HTTPException(400, detail="Suffix is required when not overwriting files")
|
||||
else:
|
||||
overwrite = False
|
||||
suffix = None
|
||||
|
||||
supported_exts = schema.get("supported_exts") or []
|
||||
allowed_exts = {
|
||||
ext.lower().lstrip('.')
|
||||
for ext in supported_exts
|
||||
if isinstance(ext, str)
|
||||
}
|
||||
|
||||
def matches_extension(file_rel: str) -> bool:
|
||||
if not allowed_exts:
|
||||
return True
|
||||
if '.' not in file_rel:
|
||||
return '' in allowed_exts
|
||||
ext = file_rel.rsplit('.', 1)[-1].lower()
|
||||
return ext in allowed_exts or f'.{ext}' in allowed_exts
|
||||
|
||||
adapter_instance, adapter_model, root, rel = await VirtualFSService.resolve_adapter_and_rel(req.path)
|
||||
rel = rel.rstrip('/')
|
||||
|
||||
list_dir = getattr(adapter_instance, "list_dir", None)
|
||||
if not callable(list_dir):
|
||||
raise HTTPException(501, detail="Adapter does not implement list_dir")
|
||||
|
||||
def build_absolute_path(mount_path: str, rel_path: str) -> str:
|
||||
rel_norm = rel_path.lstrip('/')
|
||||
mount_norm = mount_path.rstrip('/')
|
||||
if not mount_norm:
|
||||
return '/' + rel_norm if rel_norm else '/'
|
||||
return f"{mount_norm}/{rel_norm}" if rel_norm else mount_norm
|
||||
|
||||
def apply_suffix(path_str: str, suffix_str: str) -> str:
|
||||
path_obj = Path(path_str)
|
||||
name = path_obj.name
|
||||
if not name:
|
||||
return path_str
|
||||
if '.' in name:
|
||||
base, ext = name.rsplit('.', 1)
|
||||
new_name = f"{base}{suffix_str}.{ext}"
|
||||
else:
|
||||
new_name = f"{name}{suffix_str}"
|
||||
return str(path_obj.with_name(new_name))
|
||||
|
||||
scheduled_tasks: List[str] = []
|
||||
stack: List[Tuple[str, int]] = [(rel, 0)]
|
||||
page_size = 200
|
||||
|
||||
while stack:
|
||||
current_rel, depth = stack.pop()
|
||||
page = 1
|
||||
while True:
|
||||
entries, total = await list_dir(root, current_rel, page, page_size, "name", "asc")
|
||||
entries = entries or []
|
||||
if not entries and (total or 0) == 0:
|
||||
break
|
||||
|
||||
for entry in entries:
|
||||
name = entry.get("name")
|
||||
if not name:
|
||||
continue
|
||||
child_rel = f"{current_rel}/{name}" if current_rel else name
|
||||
if entry.get("is_dir"):
|
||||
if req.max_depth is None or depth < req.max_depth:
|
||||
stack.append((child_rel.rstrip('/'), depth + 1))
|
||||
continue
|
||||
if not matches_extension(child_rel):
|
||||
continue
|
||||
absolute_path = build_absolute_path(adapter_model.path, child_rel)
|
||||
save_to = None
|
||||
if produces_file and not overwrite and suffix:
|
||||
save_to = apply_suffix(absolute_path, suffix)
|
||||
task = await task_queue_service.add_task(
|
||||
"process_file",
|
||||
{
|
||||
"path": absolute_path,
|
||||
"processor_type": req.processor_type,
|
||||
"config": req.config,
|
||||
"save_to": save_to,
|
||||
"overwrite": overwrite,
|
||||
},
|
||||
)
|
||||
scheduled_tasks.append(task.id)
|
||||
|
||||
if total is None or page * page_size >= total:
|
||||
break
|
||||
page += 1
|
||||
|
||||
return {
|
||||
"task_ids": scheduled_tasks,
|
||||
"scheduled": len(scheduled_tasks),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
async def get_source(cls, processor_type: str):
|
||||
module_path = get_module_path(processor_type)
|
||||
if not module_path:
|
||||
raise HTTPException(404, detail="Processor not found")
|
||||
path_obj = Path(module_path)
|
||||
if not path_obj.exists():
|
||||
raise HTTPException(404, detail="Processor source not found")
|
||||
try:
|
||||
content = await run_in_threadpool(path_obj.read_text, encoding='utf-8')
|
||||
except Exception as exc:
|
||||
raise HTTPException(500, detail=f"Failed to read source: {exc}")
|
||||
return {"source": content, "module_path": str(path_obj)}
|
||||
|
||||
@classmethod
|
||||
async def update_source(cls, processor_type: str, req: UpdateSourceRequest):
|
||||
module_path = get_module_path(processor_type)
|
||||
if not module_path:
|
||||
raise HTTPException(404, detail="Processor not found")
|
||||
path_obj = Path(module_path)
|
||||
if not path_obj.exists():
|
||||
raise HTTPException(404, detail="Processor source not found")
|
||||
try:
|
||||
await run_in_threadpool(path_obj.write_text, req.source, encoding='utf-8')
|
||||
except Exception as exc:
|
||||
raise HTTPException(500, detail=f"Failed to write source: {exc}")
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def reload(cls):
|
||||
errors = reload_processors()
|
||||
if errors:
|
||||
raise HTTPException(500, detail="; ".join(errors))
|
||||
return True
|
||||
|
||||
|
||||
get_processor = ProcessorService.get_processor
|
||||
list_processors = ProcessorService.list_processors
|
||||
reload_processor_modules = ProcessorService.reload
|
||||
24
domain/processors/types.py
Normal file
24
domain/processors/types.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ProcessRequest(BaseModel):
|
||||
path: str
|
||||
processor_type: str
|
||||
config: Dict[str, Any]
|
||||
save_to: Optional[str] = None
|
||||
overwrite: bool = False
|
||||
|
||||
|
||||
class ProcessDirectoryRequest(BaseModel):
|
||||
path: str
|
||||
processor_type: str
|
||||
config: Dict[str, Any]
|
||||
overwrite: bool = True
|
||||
max_depth: Optional[int] = None
|
||||
suffix: Optional[str] = None
|
||||
|
||||
|
||||
class UpdateSourceRequest(BaseModel):
|
||||
source: str
|
||||
129
domain/share/api.py
Normal file
129
domain/share/api.py
Normal file
@@ -0,0 +1,129 @@
|
||||
from typing import Annotated, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
|
||||
from api.response import success
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.auth.service import get_current_active_user
|
||||
from domain.auth.types import User
|
||||
from domain.share.service import ShareService
|
||||
from domain.share.types import (
|
||||
ShareCreate,
|
||||
ShareInfo,
|
||||
ShareInfoWithPassword,
|
||||
SharePassword,
|
||||
)
|
||||
from models.database import UserAccount
|
||||
|
||||
public_router = APIRouter(prefix="/api/s", tags=["Share - Public"])
|
||||
router = APIRouter(prefix="/api/shares", tags=["Share - Management"])
|
||||
|
||||
|
||||
@router.post("", response_model=ShareInfoWithPassword)
|
||||
@audit(
|
||||
action=AuditAction.SHARE,
|
||||
description="创建分享链接",
|
||||
body_fields=["name", "paths", "expires_in_days", "access_type"],
|
||||
)
|
||||
async def create_share(
|
||||
request: Request,
|
||||
payload: ShareCreate,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
user_account = await UserAccount.get(id=current_user.id)
|
||||
share = await ShareService.create_share_link(
|
||||
user=user_account,
|
||||
name=payload.name,
|
||||
paths=payload.paths,
|
||||
expires_in_days=payload.expires_in_days,
|
||||
access_type=payload.access_type,
|
||||
password=payload.password,
|
||||
)
|
||||
share_info = ShareInfo.from_orm(share).model_dump()
|
||||
if payload.access_type == "password" and payload.password:
|
||||
share_info["password"] = payload.password
|
||||
return share_info
|
||||
|
||||
|
||||
@router.get("", response_model=List[ShareInfo])
|
||||
@audit(action=AuditAction.READ, description="获取我的分享列表")
|
||||
async def get_my_shares(
|
||||
request: Request, current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
user_account = await UserAccount.get(id=current_user.id)
|
||||
shares = await ShareService.get_user_shares(user=user_account)
|
||||
return [ShareInfo.from_orm(s) for s in shares]
|
||||
|
||||
|
||||
@router.delete("/expired")
|
||||
@audit(action=AuditAction.DELETE, description="删除已过期分享")
|
||||
async def delete_expired_shares(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
user_account = await UserAccount.get(id=current_user.id)
|
||||
deleted_count = await ShareService.delete_expired_shares(user=user_account)
|
||||
return success({"deleted_count": deleted_count})
|
||||
|
||||
|
||||
@router.delete("/{share_id}")
|
||||
@audit(action=AuditAction.DELETE, description="删除分享链接")
|
||||
async def delete_share(
|
||||
share_id: int,
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
user_account = await UserAccount.get(id=current_user.id)
|
||||
await ShareService.delete_share_link(user=user_account, share_id=share_id)
|
||||
return success(msg="分享已取消")
|
||||
|
||||
|
||||
@public_router.post("/{token}/verify")
|
||||
@audit(
|
||||
action=AuditAction.SHARE,
|
||||
description="校验分享密码",
|
||||
body_fields=["password"],
|
||||
redact_fields=["password"],
|
||||
)
|
||||
async def verify_password(request: Request, token: str, payload: SharePassword):
|
||||
await ShareService.verify_share_password(token, payload.password)
|
||||
return success(msg="验证成功")
|
||||
|
||||
|
||||
@public_router.get("/{token}/ls")
|
||||
@audit(action=AuditAction.SHARE, description="浏览分享内容")
|
||||
async def list_share_content(
|
||||
request: Request, token: str, path: str = "/", password: Optional[str] = None
|
||||
):
|
||||
share = await ShareService.ensure_share_access(token, password)
|
||||
content = await ShareService.get_shared_item_details(share, path)
|
||||
return success(
|
||||
{
|
||||
"path": path,
|
||||
"entries": content.get("items", []),
|
||||
"pagination": {
|
||||
"total": content.get("total", 0),
|
||||
"page": content.get("page", 1),
|
||||
"page_size": content.get("page_size", 1),
|
||||
"pages": content.get("pages", 1),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@public_router.get("/{token}")
|
||||
@audit(action=AuditAction.SHARE, description="获取分享信息")
|
||||
async def get_share_info(request: Request, token: str):
|
||||
share = await ShareService.get_share_by_token(token)
|
||||
return success(ShareInfo.from_orm(share))
|
||||
|
||||
|
||||
@public_router.get("/{token}/download")
|
||||
@audit(action=AuditAction.DOWNLOAD, description="下载分享文件")
|
||||
async def download_shared_file(
|
||||
token: str,
|
||||
path: str,
|
||||
request: Request,
|
||||
password: Optional[str] = None,
|
||||
):
|
||||
return await ShareService.stream_shared_file(token, path, request.headers.get("Range"), password)
|
||||
187
domain/share/service.py
Normal file
187
domain/share/service.py
Normal file
@@ -0,0 +1,187 @@
|
||||
import secrets
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import List, Optional
|
||||
from urllib.parse import quote
|
||||
|
||||
import bcrypt
|
||||
from fastapi import HTTPException, status
|
||||
from fastapi.responses import Response
|
||||
|
||||
from domain.virtual_fs.service import VirtualFSService
|
||||
from models.database import ShareLink, UserAccount
|
||||
|
||||
|
||||
class ShareService:
|
||||
@classmethod
|
||||
def _hash_password(cls, password: str) -> str:
|
||||
return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
|
||||
|
||||
@classmethod
|
||||
def _verify_password(cls, plain_password: str, hashed_password: str) -> bool:
|
||||
return bcrypt.checkpw(plain_password.encode("utf-8"), hashed_password.encode("utf-8"))
|
||||
|
||||
@classmethod
|
||||
def _calc_expires_at(cls, expires_in_days: Optional[int]) -> Optional[datetime]:
|
||||
if expires_in_days is None or expires_in_days <= 0:
|
||||
return None
|
||||
return datetime.now(timezone.utc) + timedelta(days=expires_in_days)
|
||||
|
||||
@classmethod
|
||||
def _ensure_password_if_needed(cls, share: ShareLink, password: Optional[str]) -> None:
|
||||
if share.access_type != "password":
|
||||
return
|
||||
if not password:
|
||||
raise HTTPException(status_code=401, detail="需要密码")
|
||||
if not share.hashed_password:
|
||||
raise HTTPException(status_code=403, detail="密码错误")
|
||||
if not cls._verify_password(password, share.hashed_password):
|
||||
raise HTTPException(status_code=403, detail="密码错误")
|
||||
|
||||
@classmethod
|
||||
async def create_share_link(
|
||||
cls,
|
||||
user: UserAccount,
|
||||
name: str,
|
||||
paths: List[str],
|
||||
expires_in_days: Optional[int] = 7,
|
||||
access_type: str = "public",
|
||||
password: Optional[str] = None,
|
||||
) -> ShareLink:
|
||||
if not paths:
|
||||
raise HTTPException(status_code=400, detail="分享路径不能为空")
|
||||
|
||||
if access_type == "password" and not password:
|
||||
raise HTTPException(status_code=400, detail="密码不能为空")
|
||||
|
||||
token = secrets.token_urlsafe(16)
|
||||
expires_at = cls._calc_expires_at(expires_in_days)
|
||||
|
||||
hashed_password = None
|
||||
if access_type == "password" and password:
|
||||
hashed_password = cls._hash_password(password)
|
||||
|
||||
share = await ShareLink.create(
|
||||
token=token,
|
||||
name=name,
|
||||
paths=paths,
|
||||
user=user,
|
||||
expires_at=expires_at,
|
||||
access_type=access_type,
|
||||
hashed_password=hashed_password,
|
||||
)
|
||||
return share
|
||||
|
||||
@classmethod
|
||||
async def get_share_by_token(cls, token: str) -> ShareLink:
|
||||
share = await ShareLink.get_or_none(token=token).prefetch_related("user")
|
||||
if not share:
|
||||
raise HTTPException(status_code=404, detail="分享链接不存在")
|
||||
|
||||
if share.expires_at and share.expires_at < datetime.now(timezone.utc):
|
||||
raise HTTPException(status_code=410, detail="分享链接已过期")
|
||||
|
||||
return share
|
||||
|
||||
@classmethod
|
||||
async def verify_share_password(cls, token: str, password: str) -> ShareLink:
|
||||
share = await cls.get_share_by_token(token)
|
||||
if share.access_type != "password":
|
||||
raise HTTPException(status_code=400, detail="此分享不需要密码")
|
||||
cls._ensure_password_if_needed(share, password)
|
||||
return share
|
||||
|
||||
@classmethod
|
||||
async def ensure_share_access(cls, token: str, password: Optional[str]) -> ShareLink:
|
||||
share = await cls.get_share_by_token(token)
|
||||
cls._ensure_password_if_needed(share, password)
|
||||
return share
|
||||
|
||||
@classmethod
|
||||
async def get_user_shares(cls, user: UserAccount) -> List[ShareLink]:
|
||||
return await ShareLink.filter(user=user).order_by("-created_at")
|
||||
|
||||
@classmethod
|
||||
async def delete_share_link(cls, user: UserAccount, share_id: int) -> None:
|
||||
share = await ShareLink.get_or_none(id=share_id, user_id=user.id)
|
||||
if not share:
|
||||
raise HTTPException(status_code=404, detail="分享链接不存在")
|
||||
await share.delete()
|
||||
|
||||
@classmethod
|
||||
async def delete_expired_shares(cls, user: UserAccount) -> int:
|
||||
now = datetime.now(timezone.utc)
|
||||
deleted_count = await ShareLink.filter(user=user, expires_at__lte=now).delete()
|
||||
return deleted_count
|
||||
|
||||
@classmethod
|
||||
async def get_shared_item_details(cls, share: ShareLink, sub_path: str = ""):
|
||||
if not share.paths:
|
||||
raise HTTPException(status_code=404, detail="分享内容为空")
|
||||
|
||||
base_shared_path = share.paths[0]
|
||||
|
||||
if sub_path and sub_path != "/":
|
||||
full_path = f"{base_shared_path.rstrip('/')}/{sub_path.lstrip('/')}".rstrip("/")
|
||||
if not full_path.startswith(base_shared_path):
|
||||
raise HTTPException(status_code=403, detail="无权访问此路径")
|
||||
try:
|
||||
return await VirtualFSService.list_virtual_dir(full_path)
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="目录未找到")
|
||||
|
||||
try:
|
||||
stat = await VirtualFSService.stat_file(base_shared_path)
|
||||
if stat.get("is_dir"):
|
||||
return await VirtualFSService.list_virtual_dir(base_shared_path)
|
||||
|
||||
stat["name"] = base_shared_path.split("/")[-1]
|
||||
return {"items": [stat], "total": 1, "page": 1, "page_size": 1, "pages": 1}
|
||||
except HTTPException as e:
|
||||
if "Path is a directory" in str(e.detail) or "Not a file" in str(e.detail):
|
||||
return await VirtualFSService.list_virtual_dir(base_shared_path)
|
||||
raise e
|
||||
|
||||
@classmethod
|
||||
async def stream_shared_file(
|
||||
cls,
|
||||
token: str,
|
||||
path: str,
|
||||
range_header: str | None,
|
||||
password: Optional[str] = None,
|
||||
) -> Response:
|
||||
if not path or path == "/" or ".." in path.split("/"):
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="无效的文件路径")
|
||||
|
||||
share = await cls.ensure_share_access(token, password)
|
||||
if not share.paths:
|
||||
raise HTTPException(status_code=404, detail="分享的源文件不存在")
|
||||
base_shared_path = share.paths[0]
|
||||
|
||||
is_dir = False
|
||||
try:
|
||||
stat = await VirtualFSService.stat_file(base_shared_path)
|
||||
if stat and stat.get("is_dir"):
|
||||
is_dir = True
|
||||
except HTTPException as e:
|
||||
if "Path is a directory" in str(e.detail) or "Not a file" in str(e.detail):
|
||||
is_dir = True
|
||||
elif e.status_code == 404:
|
||||
raise HTTPException(status_code=404, detail="分享的源文件不存在")
|
||||
else:
|
||||
raise
|
||||
|
||||
if is_dir:
|
||||
full_virtual_path = f"{base_shared_path.rstrip('/')}/{path.lstrip('/')}"
|
||||
if not full_virtual_path.startswith(base_shared_path):
|
||||
raise HTTPException(status_code=403, detail="无权访问此路径")
|
||||
else:
|
||||
shared_filename = base_shared_path.split("/")[-1]
|
||||
request_filename = path.lstrip("/")
|
||||
if shared_filename != request_filename:
|
||||
raise HTTPException(status_code=403, detail="无权访问此路径")
|
||||
full_virtual_path = base_shared_path
|
||||
|
||||
response = await VirtualFSService.stream_file(full_virtual_path, range_header)
|
||||
filename = full_virtual_path.split("/")[-1]
|
||||
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{quote(filename)}"
|
||||
return response
|
||||
43
domain/share/types.py
Normal file
43
domain/share/types.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from models.database import ShareLink
|
||||
|
||||
|
||||
class ShareCreate(BaseModel):
|
||||
name: str
|
||||
paths: List[str]
|
||||
expires_in_days: Optional[int] = 7
|
||||
access_type: str = "public"
|
||||
password: Optional[str] = None
|
||||
|
||||
|
||||
class SharePassword(BaseModel):
|
||||
password: str
|
||||
|
||||
|
||||
class ShareInfo(BaseModel):
|
||||
id: int
|
||||
token: str
|
||||
name: str
|
||||
paths: List[str]
|
||||
created_at: str
|
||||
expires_at: Optional[str] = None
|
||||
access_type: str
|
||||
|
||||
@classmethod
|
||||
def from_orm(cls, obj: ShareLink):
|
||||
return cls(
|
||||
id=obj.id,
|
||||
token=obj.token,
|
||||
name=obj.name,
|
||||
paths=obj.paths,
|
||||
created_at=obj.created_at.isoformat(),
|
||||
expires_at=obj.expires_at.isoformat() if obj.expires_at else None,
|
||||
access_type=obj.access_type,
|
||||
)
|
||||
|
||||
|
||||
class ShareInfoWithPassword(ShareInfo):
|
||||
password: Optional[str] = None
|
||||
112
domain/tasks/api.py
Normal file
112
domain/tasks/api.py
Normal file
@@ -0,0 +1,112 @@
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
|
||||
from api.response import success
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.auth.service import get_current_active_user
|
||||
from domain.tasks.service import TaskService
|
||||
from domain.tasks.types import (
|
||||
AutomationTaskCreate,
|
||||
AutomationTaskUpdate,
|
||||
TaskQueueSettings,
|
||||
)
|
||||
|
||||
CurrentUser = TaskService.current_user_dep
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api/tasks",
|
||||
tags=["Tasks"],
|
||||
dependencies=[Depends(get_current_active_user)],
|
||||
responses={404: {"description": "Not found"}},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/queue")
|
||||
@audit(action=AuditAction.READ, description="获取任务队列状态")
|
||||
async def get_task_queue_status(request: Request, current_user: CurrentUser):
|
||||
payload = TaskService.get_queue_tasks()
|
||||
return success(payload)
|
||||
|
||||
|
||||
@router.get("/queue/settings")
|
||||
@audit(action=AuditAction.READ, description="获取任务队列设置")
|
||||
async def get_task_queue_settings(request: Request, current_user: CurrentUser):
|
||||
payload = TaskService.get_queue_settings()
|
||||
return success(payload.model_dump())
|
||||
|
||||
|
||||
@router.post("/queue/settings")
|
||||
@audit(
|
||||
action=AuditAction.UPDATE,
|
||||
description="更新任务队列设置",
|
||||
body_fields=["concurrency"],
|
||||
)
|
||||
async def update_task_queue_settings(request: Request, settings: TaskQueueSettings, current_user: CurrentUser):
|
||||
payload = await TaskService.update_queue_settings(settings, getattr(current_user, "id", None))
|
||||
return success(payload.model_dump())
|
||||
|
||||
|
||||
@router.get("/queue/{task_id}")
|
||||
@audit(action=AuditAction.READ, description="获取队列任务状态")
|
||||
async def get_task_status(task_id: str, request: Request, current_user: CurrentUser):
|
||||
payload = TaskService.get_queue_task(task_id)
|
||||
return success(payload)
|
||||
|
||||
|
||||
@router.post("/")
|
||||
@audit(
|
||||
action=AuditAction.CREATE,
|
||||
description="创建自动化任务",
|
||||
body_fields=[
|
||||
"name",
|
||||
"event",
|
||||
"path_pattern",
|
||||
"filename_regex",
|
||||
"processor_type",
|
||||
"processor_config",
|
||||
"enabled",
|
||||
],
|
||||
user_kw="user",
|
||||
)
|
||||
async def create_task(request: Request, task_in: AutomationTaskCreate, user: CurrentUser):
|
||||
task = await TaskService.create_task(task_in, user)
|
||||
return success(task)
|
||||
|
||||
|
||||
@router.get("/{task_id}")
|
||||
@audit(action=AuditAction.READ, description="获取自动化任务详情")
|
||||
async def get_task(task_id: int, request: Request, current_user: CurrentUser):
|
||||
task = await TaskService.get_task(task_id)
|
||||
return success(task)
|
||||
|
||||
|
||||
@router.get("/")
|
||||
@audit(action=AuditAction.READ, description="获取自动化任务列表")
|
||||
async def list_tasks(request: Request, current_user: CurrentUser):
|
||||
tasks = await TaskService.list_tasks()
|
||||
return success(tasks)
|
||||
|
||||
|
||||
@router.put("/{task_id}")
|
||||
@audit(
|
||||
action=AuditAction.UPDATE,
|
||||
description="更新自动化任务",
|
||||
body_fields=[
|
||||
"name",
|
||||
"event",
|
||||
"path_pattern",
|
||||
"filename_regex",
|
||||
"processor_type",
|
||||
"processor_config",
|
||||
"enabled",
|
||||
],
|
||||
)
|
||||
async def update_task(request: Request, current_user: CurrentUser, task_id: int, task_in: AutomationTaskUpdate):
|
||||
task = await TaskService.update_task(task_id, task_in, current_user)
|
||||
return success(task)
|
||||
|
||||
|
||||
@router.delete("/{task_id}")
|
||||
@audit(action=AuditAction.DELETE, description="删除自动化任务", user_kw="user")
|
||||
async def delete_task(task_id: int, request: Request, user: CurrentUser):
|
||||
await TaskService.delete_task(task_id, user)
|
||||
return success(msg="Task deleted")
|
||||
109
domain/tasks/service.py
Normal file
109
domain/tasks/service.py
Normal file
@@ -0,0 +1,109 @@
|
||||
import re
|
||||
from typing import Annotated, Any, Dict, Optional
|
||||
|
||||
from fastapi import Depends, HTTPException
|
||||
|
||||
from domain.auth.service import get_current_active_user
|
||||
from domain.auth.types import User
|
||||
from domain.config.service import ConfigService
|
||||
from domain.tasks.types import (
|
||||
AutomationTaskCreate,
|
||||
AutomationTaskUpdate,
|
||||
TaskQueueSettings,
|
||||
TaskQueueSettingsResponse,
|
||||
)
|
||||
from models.database import AutomationTask
|
||||
from domain.tasks.task_queue import task_queue_service
|
||||
|
||||
|
||||
class TaskService:
|
||||
current_user_dep = Annotated[User, Depends(get_current_active_user)]
|
||||
|
||||
@classmethod
|
||||
def get_queue_tasks(cls) -> list[dict[str, Any]]:
|
||||
tasks = task_queue_service.get_all_tasks()
|
||||
return [task.dict() for task in tasks]
|
||||
|
||||
@classmethod
|
||||
def get_queue_settings(cls) -> TaskQueueSettingsResponse:
|
||||
return TaskQueueSettingsResponse(
|
||||
concurrency=task_queue_service.get_concurrency(),
|
||||
active_workers=task_queue_service.get_active_worker_count(),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def update_queue_settings(cls, settings: TaskQueueSettings, user_id: Optional[int]) -> TaskQueueSettingsResponse:
|
||||
await task_queue_service.set_concurrency(settings.concurrency)
|
||||
await ConfigService.set("TASK_QUEUE_CONCURRENCY", str(task_queue_service.get_concurrency()))
|
||||
return cls.get_queue_settings()
|
||||
|
||||
@classmethod
|
||||
def get_queue_task(cls, task_id: str) -> dict[str, Any]:
|
||||
task = task_queue_service.get_task(task_id)
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
return task.dict()
|
||||
|
||||
@classmethod
|
||||
async def create_task(cls, payload: AutomationTaskCreate, user: Optional[User]) -> AutomationTask:
|
||||
task = await AutomationTask.create(**payload.model_dump())
|
||||
return task
|
||||
|
||||
@classmethod
|
||||
async def get_task(cls, task_id: int) -> AutomationTask:
|
||||
task = await AutomationTask.get_or_none(id=task_id)
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail=f"Task {task_id} not found")
|
||||
return task
|
||||
|
||||
@classmethod
|
||||
async def list_tasks(cls) -> list[AutomationTask]:
|
||||
tasks = await AutomationTask.all()
|
||||
return tasks
|
||||
|
||||
@classmethod
|
||||
async def update_task(cls, task_id: int, payload: AutomationTaskUpdate, current_user: User) -> AutomationTask:
|
||||
task = await AutomationTask.get_or_none(id=task_id)
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail=f"Task {task_id} not found")
|
||||
update_data = payload.model_dump(exclude_unset=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(task, key, value)
|
||||
await task.save()
|
||||
return task
|
||||
|
||||
@classmethod
|
||||
async def delete_task(cls, task_id: int, user: Optional[User]) -> None:
|
||||
deleted_count = await AutomationTask.filter(id=task_id).delete()
|
||||
if not deleted_count:
|
||||
raise HTTPException(status_code=404, detail=f"Task {task_id} not found")
|
||||
|
||||
@classmethod
|
||||
async def trigger_tasks(cls, event: str, path: str):
|
||||
tasks = await AutomationTask.filter(event=event, enabled=True)
|
||||
for task in tasks:
|
||||
if cls.match(task, path):
|
||||
await cls.execute(task, path)
|
||||
|
||||
@classmethod
|
||||
def match(cls, task: AutomationTask, path: str) -> bool:
|
||||
if task.path_pattern and not path.startswith(task.path_pattern):
|
||||
return False
|
||||
if task.filename_regex:
|
||||
filename = path.split("/")[-1]
|
||||
if not re.match(task.filename_regex, filename):
|
||||
return False
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
async def execute(cls, task: AutomationTask, path: str):
|
||||
await task_queue_service.add_task(
|
||||
task.processor_type,
|
||||
{
|
||||
"task_id": task.id,
|
||||
"path": path,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
task_service = TaskService
|
||||
@@ -2,7 +2,6 @@ import asyncio
|
||||
from typing import Dict, Any
|
||||
from pydantic import BaseModel, Field
|
||||
import uuid
|
||||
from services.logging import LogService
|
||||
from enum import Enum
|
||||
|
||||
|
||||
@@ -47,7 +46,6 @@ class TaskQueueService:
|
||||
task = Task(name=name, task_info=task_info)
|
||||
self._tasks[task.id] = task
|
||||
await self._queue.put(task)
|
||||
await LogService.info("task_queue", f"Task {name} ({task.id}) enqueued", {"task_id": task.id, "name": name})
|
||||
return task
|
||||
|
||||
def get_task(self, task_id: str) -> Task | None:
|
||||
@@ -72,15 +70,15 @@ class TaskQueueService:
|
||||
task.meta = (task.meta or {}) | meta
|
||||
|
||||
async def _execute_task(self, task: Task):
|
||||
from services.virtual_fs import process_file
|
||||
|
||||
task.status = TaskStatus.RUNNING
|
||||
await LogService.info("task_queue", f"Task {task.name} ({task.id}) started", {"task_id": task.id, "name": task.name})
|
||||
|
||||
try:
|
||||
# Local import to avoid circular dependency during module load.
|
||||
from domain.virtual_fs.service import VirtualFSService
|
||||
|
||||
if task.name == "process_file":
|
||||
params = task.task_info
|
||||
result = await process_file(
|
||||
result = await VirtualFSService.process_file(
|
||||
path=params["path"],
|
||||
processor_type=params["processor_type"],
|
||||
config=params["config"],
|
||||
@@ -90,8 +88,7 @@ class TaskQueueService:
|
||||
task.result = result
|
||||
elif task.name == "automation_task" or self._is_processor_task(task.name):
|
||||
from models.database import AutomationTask
|
||||
from services.processors.registry import get as get_processor
|
||||
from services.virtual_fs import read_file, write_file
|
||||
from domain.processors.service import get_processor
|
||||
|
||||
params = task.task_info
|
||||
auto_task = await AutomationTask.get(id=params["task_id"])
|
||||
@@ -103,54 +100,45 @@ class TaskQueueService:
|
||||
raise ValueError(f"Processor {processor_type} not found for task {auto_task.id}")
|
||||
|
||||
if processor_type != auto_task.processor_type:
|
||||
await LogService.warning(
|
||||
"task_queue",
|
||||
"Processor type mismatch; falling back to stored type",
|
||||
{"task_id": auto_task.id, "expected": auto_task.processor_type, "got": processor_type},
|
||||
)
|
||||
processor_type = auto_task.processor_type
|
||||
processor = get_processor(processor_type)
|
||||
if not processor:
|
||||
raise ValueError(f"Processor {processor_type} not found for task {auto_task.id}")
|
||||
|
||||
file_content = await read_file(path)
|
||||
file_content = await VirtualFSService.read_file(path)
|
||||
result = await processor.process(file_content, path, auto_task.processor_config)
|
||||
|
||||
save_to = auto_task.processor_config.get("save_to")
|
||||
if save_to and getattr(processor, "produces_file", False):
|
||||
await write_file(save_to, result)
|
||||
await VirtualFSService.write_file(save_to, result)
|
||||
task.result = "Automation task completed"
|
||||
elif task.name == "offline_http_download":
|
||||
from services.offline_download import run_http_download
|
||||
from domain.offline_downloads.service import OfflineDownloadService
|
||||
|
||||
result_path = await run_http_download(task)
|
||||
result_path = await OfflineDownloadService.run_http_download(task)
|
||||
task.result = {"path": result_path}
|
||||
elif task.name == "cross_mount_transfer":
|
||||
from services.virtual_fs import run_cross_mount_transfer_task
|
||||
|
||||
result = await run_cross_mount_transfer_task(task)
|
||||
result = await VirtualFSService.run_cross_mount_transfer_task(task)
|
||||
task.result = result
|
||||
elif task.name == "send_email":
|
||||
from services.email import EmailService
|
||||
from domain.email.service import EmailService
|
||||
await EmailService.send_from_task(task.id, task.task_info)
|
||||
task.result = "Email sent"
|
||||
else:
|
||||
raise ValueError(f"Unknown task name: {task.name}")
|
||||
|
||||
task.status = TaskStatus.SUCCESS
|
||||
await LogService.info("task_queue", f"Task {task.name} ({task.id}) succeeded", {"task_id": task.id, "name": task.name})
|
||||
|
||||
except Exception as e:
|
||||
task.status = TaskStatus.FAILED
|
||||
task.error = str(e)
|
||||
await LogService.error("task_queue", f"Task {task.name} ({task.id}) failed: {e}", {"task_id": task.id, "name": task.name})
|
||||
|
||||
def _cleanup_workers(self):
|
||||
self._worker_tasks = [task for task in self._worker_tasks if not task.done()]
|
||||
|
||||
def _is_processor_task(self, task_name: str) -> bool:
|
||||
try:
|
||||
from services.processors.registry import get as get_processor
|
||||
from domain.processors.service import get_processor
|
||||
|
||||
return get_processor(task_name) is not None
|
||||
except Exception:
|
||||
@@ -165,15 +153,12 @@ class TaskQueueService:
|
||||
worker_id = self._worker_seq
|
||||
worker_task = asyncio.create_task(self._worker_loop(worker_id))
|
||||
self._worker_tasks.append(worker_task)
|
||||
await LogService.info("task_queue", "Task workers adjusted", {"active_workers": len(self._worker_tasks), "target": self._concurrency})
|
||||
elif current > self._concurrency:
|
||||
for _ in range(current - self._concurrency):
|
||||
await self._queue.put(_SENTINEL)
|
||||
await LogService.info("task_queue", "Task workers scaling down", {"active_workers": len(self._worker_tasks), "target": self._concurrency})
|
||||
|
||||
async def _worker_loop(self, worker_id: int):
|
||||
current_task = asyncio.current_task()
|
||||
await LogService.info("task_queue", f"Worker {worker_id} started")
|
||||
try:
|
||||
while True:
|
||||
job = await self._queue.get()
|
||||
@@ -183,23 +168,18 @@ class TaskQueueService:
|
||||
try:
|
||||
await self._execute_task(job)
|
||||
except Exception as e:
|
||||
await LogService.error(
|
||||
"task_queue",
|
||||
f"Error executing task {job.id}: {e}",
|
||||
{"task_id": job.id, "name": job.name},
|
||||
)
|
||||
pass
|
||||
finally:
|
||||
self._queue.task_done()
|
||||
finally:
|
||||
if current_task in self._worker_tasks:
|
||||
self._worker_tasks.remove(current_task) # type: ignore[arg-type]
|
||||
await LogService.info("task_queue", f"Worker {worker_id} stopped")
|
||||
|
||||
async def start_worker(self, concurrency: int | None = None):
|
||||
if concurrency is None:
|
||||
from services.config import ConfigCenter
|
||||
from domain.config.service import ConfigService
|
||||
|
||||
stored_value = await ConfigCenter.get("TASK_QUEUE_CONCURRENCY", self._concurrency)
|
||||
stored_value = await ConfigService.get("TASK_QUEUE_CONCURRENCY", self._concurrency)
|
||||
try:
|
||||
concurrency = int(stored_value)
|
||||
except (TypeError, ValueError):
|
||||
@@ -219,7 +199,6 @@ class TaskQueueService:
|
||||
if self._worker_tasks:
|
||||
await asyncio.gather(*self._worker_tasks, return_exceptions=True)
|
||||
self._worker_tasks.clear()
|
||||
await LogService.info("task_queue", "Task workers have been stopped.")
|
||||
|
||||
def get_concurrency(self) -> int:
|
||||
return self._concurrency
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
|
||||
class AutomationTaskBase(BaseModel):
|
||||
189
domain/virtual_fs/api.py
Normal file
189
domain/virtual_fs/api.py
Normal file
@@ -0,0 +1,189 @@
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, File, Query, Request, UploadFile
|
||||
|
||||
from api.response import success
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.auth.service import get_current_active_user
|
||||
from domain.auth.types import User
|
||||
from domain.virtual_fs.service import VirtualFSService
|
||||
from domain.virtual_fs.types import MkdirRequest, MoveRequest
|
||||
|
||||
router = APIRouter(prefix="/api/fs", tags=["virtual-fs"])
|
||||
|
||||
|
||||
@router.get("/file/{full_path:path}")
|
||||
@audit(action=AuditAction.DOWNLOAD, description="获取文件")
|
||||
async def get_file(
|
||||
full_path: str,
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
return await VirtualFSService.serve_file(full_path, request.headers.get("Range"))
|
||||
|
||||
|
||||
@router.get("/thumb/{full_path:path}")
|
||||
@audit(action=AuditAction.READ, description="获取缩略图")
|
||||
async def get_thumb(
|
||||
full_path: str,
|
||||
request: Request,
|
||||
w: int = Query(256, ge=8, le=1024),
|
||||
h: int = Query(256, ge=8, le=1024),
|
||||
fit: str = Query("cover"),
|
||||
):
|
||||
return await VirtualFSService.get_thumbnail(full_path, w, h, fit)
|
||||
|
||||
|
||||
@router.get("/stream/{full_path:path}")
|
||||
@audit(action=AuditAction.DOWNLOAD, description="流式读取文件")
|
||||
async def stream_endpoint(
|
||||
full_path: str,
|
||||
request: Request,
|
||||
):
|
||||
return await VirtualFSService.stream_response(full_path, request.headers.get("Range"))
|
||||
|
||||
|
||||
@router.get("/temp-link/{full_path:path}")
|
||||
@audit(action=AuditAction.SHARE, description="创建临时链接")
|
||||
async def get_temp_link(
|
||||
full_path: str,
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
expires_in: int = Query(3600, description="有效时间(秒), 0或负数表示永久"),
|
||||
):
|
||||
data = await VirtualFSService.create_temp_link(full_path, expires_in)
|
||||
return success(data)
|
||||
|
||||
|
||||
@router.get("/public/{token}")
|
||||
@audit(action=AuditAction.DOWNLOAD, description="访问临时链接文件")
|
||||
async def access_public_file(
|
||||
token: str,
|
||||
request: Request,
|
||||
):
|
||||
return await VirtualFSService.access_public_file(token, request.headers.get("Range"))
|
||||
|
||||
|
||||
@router.get("/stat/{full_path:path}")
|
||||
@audit(action=AuditAction.READ, description="查看文件信息")
|
||||
async def get_file_stat(
|
||||
full_path: str,
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
stat = await VirtualFSService.stat(full_path)
|
||||
return success(stat)
|
||||
|
||||
|
||||
@router.post("/file/{full_path:path}")
|
||||
@audit(action=AuditAction.UPLOAD, description="上传文件")
|
||||
async def put_file(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
full_path: str,
|
||||
file: UploadFile = File(...),
|
||||
):
|
||||
data = await file.read()
|
||||
result = await VirtualFSService.write_uploaded_file(full_path, data)
|
||||
return success(result)
|
||||
|
||||
|
||||
@router.post("/mkdir")
|
||||
@audit(action=AuditAction.CREATE, description="创建目录", body_fields=["path"])
|
||||
async def api_mkdir(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
body: MkdirRequest,
|
||||
):
|
||||
result = await VirtualFSService.mkdir(body.path)
|
||||
return success(result)
|
||||
|
||||
|
||||
@router.post("/move")
|
||||
@audit(action=AuditAction.UPDATE, description="移动路径", body_fields=["src", "dst"])
|
||||
async def api_move(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
body: MoveRequest,
|
||||
overwrite: bool = Query(False, description="是否允许覆盖已存在目标"),
|
||||
):
|
||||
result = await VirtualFSService.move(body.src, body.dst, overwrite)
|
||||
return success(result)
|
||||
|
||||
|
||||
@router.post("/rename")
|
||||
@audit(action=AuditAction.UPDATE, description="重命名路径", body_fields=["src", "dst"])
|
||||
async def api_rename(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
body: MoveRequest,
|
||||
overwrite: bool = Query(False, description="是否允许覆盖已存在目标"),
|
||||
):
|
||||
result = await VirtualFSService.rename(body.src, body.dst, overwrite)
|
||||
return success(result)
|
||||
|
||||
|
||||
@router.post("/copy")
|
||||
@audit(action=AuditAction.CREATE, description="复制路径", body_fields=["src", "dst"])
|
||||
async def api_copy(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
body: MoveRequest,
|
||||
overwrite: bool = Query(False, description="是否覆盖已存在目标"),
|
||||
):
|
||||
result = await VirtualFSService.copy(body.src, body.dst, overwrite)
|
||||
return success(result)
|
||||
|
||||
|
||||
@router.post("/upload/{full_path:path}")
|
||||
@audit(action=AuditAction.UPLOAD, description="流式上传文件")
|
||||
async def upload_stream(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
full_path: str,
|
||||
file: UploadFile = File(...),
|
||||
overwrite: bool = Query(True, description="是否覆盖已存在文件"),
|
||||
chunk_size: int = Query(1024 * 1024, ge=8 * 1024, le=8 * 1024 * 1024, description="单次读取块大小"),
|
||||
):
|
||||
result = await VirtualFSService.upload_stream_from_upload_file(full_path, file, chunk_size, overwrite)
|
||||
return success(result)
|
||||
|
||||
|
||||
@router.get("/{full_path:path}")
|
||||
@audit(action=AuditAction.READ, description="浏览目录")
|
||||
async def browse_fs(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
full_path: str,
|
||||
page_num: int = Query(1, alias="page", ge=1, description="页码"),
|
||||
page_size: int = Query(50, ge=1, le=500, description="每页条数"),
|
||||
sort_by: str = Query("name", description="按字段排序: name, size, mtime"),
|
||||
sort_order: str = Query("asc", description="排序顺序: asc, desc"),
|
||||
):
|
||||
data = await VirtualFSService.list_directory(full_path, page_num, page_size, sort_by, sort_order)
|
||||
return success(data)
|
||||
|
||||
|
||||
@router.delete("/{full_path:path}")
|
||||
@audit(action=AuditAction.DELETE, description="删除路径")
|
||||
async def api_delete(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
full_path: str,
|
||||
):
|
||||
result = await VirtualFSService.delete(full_path)
|
||||
return success(result)
|
||||
|
||||
|
||||
@router.get("/")
|
||||
@audit(action=AuditAction.READ, description="浏览根目录")
|
||||
async def root_listing(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
page_num: int = Query(1, alias="page", ge=1, description="页码"),
|
||||
page_size: int = Query(50, ge=1, le=500, description="每页条数"),
|
||||
sort_by: str = Query("name", description="按字段排序: name, size, mtime"),
|
||||
sort_order: str = Query("asc", description="排序顺序: asc, desc"),
|
||||
):
|
||||
data = await VirtualFSService.list_directory("/", page_num, page_size, sort_by, sort_order)
|
||||
return success(data)
|
||||
@@ -10,14 +10,8 @@ from typing import Dict, Iterable, List, Optional, Tuple
|
||||
from fastapi import APIRouter, Request, Response
|
||||
from fastapi import HTTPException
|
||||
|
||||
from services.config import ConfigCenter
|
||||
from services.virtual_fs import (
|
||||
delete_path,
|
||||
list_virtual_dir,
|
||||
stat_file,
|
||||
stream_file,
|
||||
write_file_stream,
|
||||
)
|
||||
from domain.config.service import ConfigService
|
||||
from domain.virtual_fs.service import VirtualFSService
|
||||
|
||||
|
||||
router = APIRouter(prefix="/s3", tags=["s3"])
|
||||
@@ -69,18 +63,18 @@ def _s3_error(code: str, message: str, resource: str = "", status: int = 400) ->
|
||||
|
||||
|
||||
async def _ensure_enabled() -> Optional[Response]:
|
||||
flag = await ConfigCenter.get("S3_MAPPING_ENABLED", "1")
|
||||
flag = await ConfigService.get("S3_MAPPING_ENABLED", "1")
|
||||
if str(flag).strip().lower() in FALSEY:
|
||||
return _s3_error("ServiceUnavailable", "S3 mapping disabled", status=503)
|
||||
return None
|
||||
|
||||
|
||||
async def _get_settings() -> Tuple[Optional[S3Settings], Optional[Response]]:
|
||||
bucket = (await ConfigCenter.get("S3_MAPPING_BUCKET", "foxel")) or "foxel"
|
||||
region = (await ConfigCenter.get("S3_MAPPING_REGION", "us-east-1")) or "us-east-1"
|
||||
base_path = (await ConfigCenter.get("S3_MAPPING_BASE_PATH", "/")) or "/"
|
||||
access_key = (await ConfigCenter.get("S3_MAPPING_ACCESS_KEY")) or ""
|
||||
secret_key = (await ConfigCenter.get("S3_MAPPING_SECRET_KEY")) or ""
|
||||
bucket = (await ConfigService.get("S3_MAPPING_BUCKET", "foxel")) or "foxel"
|
||||
region = (await ConfigService.get("S3_MAPPING_REGION", "us-east-1")) or "us-east-1"
|
||||
base_path = (await ConfigService.get("S3_MAPPING_BASE_PATH", "/")) or "/"
|
||||
access_key = (await ConfigService.get("S3_MAPPING_ACCESS_KEY")) or ""
|
||||
secret_key = (await ConfigService.get("S3_MAPPING_SECRET_KEY")) or ""
|
||||
if not access_key or not secret_key:
|
||||
return None, _s3_error(
|
||||
"InvalidAccessKeyId",
|
||||
@@ -221,7 +215,7 @@ async def _list_dir_all(path: str) -> List[Dict]:
|
||||
page_size = 1000
|
||||
while True:
|
||||
try:
|
||||
res = await list_virtual_dir(path, page_num=page_num, page_size=page_size)
|
||||
res = await VirtualFSService.list_virtual_dir(path, page_num=page_num, page_size=page_size)
|
||||
except HTTPException as exc: # directory missing
|
||||
if exc.status_code in (400, 404):
|
||||
return []
|
||||
@@ -376,7 +370,7 @@ async def list_objects(request: Request, bucket: str):
|
||||
prefixes: List[str] = []
|
||||
if prefix and not prefix.endswith("/"):
|
||||
try:
|
||||
info = await stat_file(_virtual_path(settings, prefix))
|
||||
info = await VirtualFSService.stat_file(_virtual_path(settings, prefix))
|
||||
if not info.get("is_dir"):
|
||||
files = [(prefix, info)]
|
||||
except HTTPException as exc:
|
||||
@@ -473,7 +467,7 @@ def _object_headers(meta: Dict, key: str) -> Dict[str, str]:
|
||||
|
||||
async def _stat_object(settings: S3Settings, key: str) -> Tuple[Optional[Dict], Optional[Response]]:
|
||||
try:
|
||||
info = await stat_file(_virtual_path(settings, key))
|
||||
info = await VirtualFSService.stat_file(_virtual_path(settings, key))
|
||||
if info.get("is_dir"):
|
||||
return None, _s3_error("NoSuchKey", "The specified key does not exist.", _resource_path(settings["bucket"], key), status=404)
|
||||
return info, None
|
||||
@@ -498,7 +492,7 @@ async def object_get_head(request: Request, bucket: str, object_path: str):
|
||||
base_headers.update(_object_headers(meta, key))
|
||||
if request.method == "HEAD":
|
||||
return Response(status_code=200, headers=base_headers)
|
||||
resp = await stream_file(_virtual_path(settings, key), request.headers.get("range"))
|
||||
resp = await VirtualFSService.stream_file(_virtual_path(settings, key), request.headers.get("range"))
|
||||
safe_merge_keys = {"ETag", "Last-Modified", "x-amz-version-id", "Accept-Ranges"}
|
||||
for hk, hv in base_headers.items():
|
||||
if hk in safe_merge_keys:
|
||||
@@ -514,7 +508,7 @@ async def put_object(request: Request, bucket: str, object_path: str):
|
||||
return error
|
||||
assert settings
|
||||
key = object_path.lstrip("/")
|
||||
await write_file_stream(_virtual_path(settings, key), request.stream(), overwrite=True)
|
||||
await VirtualFSService.write_file_stream(_virtual_path(settings, key), request.stream(), overwrite=True)
|
||||
meta, err = await _stat_object(settings, key)
|
||||
if err:
|
||||
return err
|
||||
@@ -535,7 +529,7 @@ async def delete_object(request: Request, bucket: str, object_path: str):
|
||||
assert settings
|
||||
key = object_path.lstrip("/")
|
||||
try:
|
||||
await delete_path(_virtual_path(settings, key))
|
||||
await VirtualFSService.delete_path(_virtual_path(settings, key))
|
||||
except HTTPException as exc:
|
||||
if exc.status_code not in (400, 404):
|
||||
raise
|
||||
26
domain/virtual_fs/search_api.py
Normal file
26
domain/virtual_fs/search_api.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
|
||||
from domain.auth.service import get_current_active_user
|
||||
from domain.auth.types import User
|
||||
from domain.virtual_fs.search_service import VirtualFSSearchService
|
||||
|
||||
router = APIRouter(prefix="/api/fs/search", tags=["search"])
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def search_files(
|
||||
q: str = Query(..., description="搜索查询"),
|
||||
top_k: int = Query(10, description="返回结果数量"),
|
||||
mode: str = Query("vector", description="搜索模式: 'vector' 或 'filename'"),
|
||||
page: int = Query(1, description="分页页码,仅在文件名搜索模式下生效"),
|
||||
page_size: int = Query(10, description="分页大小,仅在文件名搜索模式下生效"),
|
||||
user: User = Depends(get_current_active_user),
|
||||
):
|
||||
if not q.strip():
|
||||
return {"items": [], "query": q}
|
||||
|
||||
top_k = max(top_k, 1)
|
||||
page = max(page, 1)
|
||||
page_size = max(min(page_size, 100), 1)
|
||||
|
||||
return await VirtualFSSearchService.search(q, top_k, mode, page, page_size)
|
||||
@@ -1,13 +1,8 @@
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
|
||||
from schemas.fs import SearchResultItem
|
||||
from services.auth import get_current_active_user, User
|
||||
from services.ai import get_text_embedding
|
||||
from services.vector_db import VectorDBService
|
||||
|
||||
router = APIRouter(prefix="/api/search", tags=["search"])
|
||||
from domain.virtual_fs.types import SearchResultItem
|
||||
from domain.ai.inference import get_text_embedding
|
||||
from domain.ai.service import VectorDBService
|
||||
|
||||
|
||||
def _normalize_result(raw: Dict[str, Any], source: str, fallback_score: float = 0.0) -> SearchResultItem:
|
||||
@@ -101,37 +96,23 @@ async def _filename_search(query: str, page: int, page_size: int) -> Tuple[List[
|
||||
return page_items, has_more
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def search_files(
|
||||
q: str = Query(..., description="搜索查询"),
|
||||
top_k: int = Query(10, description="返回结果数量"),
|
||||
mode: str = Query("vector", description="搜索模式: 'vector' 或 'filename'"),
|
||||
page: int = Query(1, description="分页页码,仅在文件名搜索模式下生效"),
|
||||
page_size: int = Query(10, description="分页大小,仅在文件名搜索模式下生效"),
|
||||
user: User = Depends(get_current_active_user),
|
||||
):
|
||||
if not q.strip():
|
||||
return {"items": [], "query": q}
|
||||
|
||||
top_k = max(top_k, 1)
|
||||
page = max(page, 1)
|
||||
page_size = max(min(page_size, 100), 1)
|
||||
|
||||
if mode == "vector":
|
||||
items = (await _vector_search(q, top_k))[:top_k]
|
||||
elif mode == "filename":
|
||||
items, has_more = await _filename_search(q, page, page_size)
|
||||
return {
|
||||
"items": items,
|
||||
"query": q,
|
||||
"mode": mode,
|
||||
"pagination": {
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"has_more": has_more,
|
||||
},
|
||||
}
|
||||
else:
|
||||
items = (await _vector_search(q, top_k))[:top_k]
|
||||
|
||||
return {"items": items, "query": q, "mode": mode}
|
||||
class VirtualFSSearchService:
|
||||
@staticmethod
|
||||
async def search(query: str, top_k: int, mode: str, page: int, page_size: int):
|
||||
if mode == "vector":
|
||||
items = (await _vector_search(query, top_k))[:top_k]
|
||||
return {"items": items, "query": query, "mode": mode}
|
||||
if mode == "filename":
|
||||
items, has_more = await _filename_search(query, page, page_size)
|
||||
return {
|
||||
"items": items,
|
||||
"query": query,
|
||||
"mode": mode,
|
||||
"pagination": {
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"has_more": has_more,
|
||||
},
|
||||
}
|
||||
items = (await _vector_search(query, top_k))[:top_k]
|
||||
return {"items": items, "query": query, "mode": mode}
|
||||
1360
domain/virtual_fs/service.py
Normal file
1360
domain/virtual_fs/service.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,7 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class VfsEntry(BaseModel):
|
||||
name: str
|
||||
@@ -9,25 +9,17 @@ from typing import Optional
|
||||
from fastapi import APIRouter, Request, Response, HTTPException, Depends
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
from services.auth import authenticate_user_db, User, UserInDB
|
||||
from services.virtual_fs import (
|
||||
list_virtual_dir,
|
||||
stat_file,
|
||||
write_file_stream,
|
||||
make_dir,
|
||||
delete_path,
|
||||
move_path,
|
||||
copy_path,
|
||||
stream_file,
|
||||
)
|
||||
from services.config import ConfigCenter
|
||||
from domain.auth.service import AuthService
|
||||
from domain.auth.types import User, UserInDB
|
||||
from domain.virtual_fs.service import VirtualFSService
|
||||
from domain.config.service import ConfigService
|
||||
|
||||
|
||||
_WEBDAV_ENABLED_KEY = "WEBDAV_MAPPING_ENABLED"
|
||||
|
||||
|
||||
async def _ensure_webdav_enabled() -> None:
|
||||
enabled = await ConfigCenter.get(_WEBDAV_ENABLED_KEY, "1")
|
||||
enabled = await ConfigService.get(_WEBDAV_ENABLED_KEY, "1")
|
||||
if str(enabled).strip().lower() in ("0", "false", "off", "no"):
|
||||
raise HTTPException(503, detail="WebDAV mapping disabled")
|
||||
|
||||
@@ -70,7 +62,7 @@ async def _get_basic_user(request: Request) -> User:
|
||||
username, _, password = decoded.partition(":")
|
||||
except Exception:
|
||||
raise HTTPException(401, detail="Invalid Basic auth", headers={"WWW-Authenticate": "Basic realm=webdav"})
|
||||
user_or_false: Optional[UserInDB] = await authenticate_user_db(username, password)
|
||||
user_or_false: Optional[UserInDB] = await AuthService.authenticate_user_db(username, password)
|
||||
if not user_or_false:
|
||||
raise HTTPException(401, detail="Invalid credentials", headers={"WWW-Authenticate": "Basic realm=webdav"})
|
||||
u: UserInDB = user_or_false
|
||||
@@ -170,7 +162,7 @@ async def propfind(
|
||||
|
||||
# 先获取当前路径信息
|
||||
try:
|
||||
st = await stat_file(full_path)
|
||||
st = await VirtualFSService.stat_file(full_path)
|
||||
is_dir = bool(st.get("is_dir"))
|
||||
name = st.get("name") or full_path.rsplit("/", 1)[-1] or "/"
|
||||
size = None if is_dir else int(st.get("size", 0))
|
||||
@@ -182,7 +174,7 @@ async def propfind(
|
||||
|
||||
if depth in ("1", "infinity"):
|
||||
try:
|
||||
listing = await list_virtual_dir(full_path, page_num=1, page_size=1000)
|
||||
listing = await VirtualFSService.list_virtual_dir(full_path, page_num=1, page_size=1000)
|
||||
for ent in listing["items"]:
|
||||
is_dir = bool(ent.get("is_dir"))
|
||||
name = ent.get("name")
|
||||
@@ -210,7 +202,7 @@ async def dav_get(
|
||||
):
|
||||
full_path = _normalize_fs_path(path)
|
||||
range_header = request.headers.get("Range")
|
||||
return await stream_file(full_path, range_header)
|
||||
return await VirtualFSService.stream_file(full_path, range_header)
|
||||
|
||||
|
||||
@router.head("/{path:path}")
|
||||
@@ -221,7 +213,7 @@ async def dav_head(
|
||||
):
|
||||
full_path = _normalize_fs_path(path)
|
||||
try:
|
||||
st = await stat_file(full_path)
|
||||
st = await VirtualFSService.stat_file(full_path)
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(404, detail="Not found")
|
||||
is_dir = bool(st.get("is_dir"))
|
||||
@@ -251,7 +243,7 @@ async def dav_put(
|
||||
async for chunk in request.stream():
|
||||
if chunk:
|
||||
yield chunk
|
||||
size = await write_file_stream(full_path, body_iter(), overwrite=True)
|
||||
size = await VirtualFSService.write_file_stream(full_path, body_iter(), overwrite=True)
|
||||
return Response(status_code=201, headers=_dav_headers({"Content-Length": "0"}))
|
||||
|
||||
|
||||
@@ -262,7 +254,7 @@ async def dav_delete(
|
||||
user: User = Depends(_get_basic_user),
|
||||
):
|
||||
full_path = _normalize_fs_path(path)
|
||||
await delete_path(full_path)
|
||||
await VirtualFSService.delete_path(full_path)
|
||||
return Response(status_code=204, headers=_dav_headers())
|
||||
|
||||
|
||||
@@ -273,7 +265,7 @@ async def dav_mkcol(
|
||||
user: User = Depends(_get_basic_user),
|
||||
):
|
||||
full_path = _normalize_fs_path(path)
|
||||
await make_dir(full_path)
|
||||
await VirtualFSService.make_dir(full_path)
|
||||
return Response(status_code=201, headers=_dav_headers())
|
||||
|
||||
|
||||
@@ -295,7 +287,7 @@ async def dav_move(path: str, request: Request, user: User = Depends(_get_basic_
|
||||
dest_header = request.headers.get("Destination")
|
||||
dst = _parse_destination(dest_header or "")
|
||||
overwrite = request.headers.get("Overwrite", "T").upper() != "F"
|
||||
await move_path(full_src, dst, overwrite=overwrite)
|
||||
await VirtualFSService.move_path(full_src, dst, overwrite=overwrite)
|
||||
return Response(status_code=204, headers=_dav_headers())
|
||||
|
||||
|
||||
@@ -305,5 +297,5 @@ async def dav_copy(path: str, request: Request, user: User = Depends(_get_basic_
|
||||
dest_header = request.headers.get("Destination")
|
||||
dst = _parse_destination(dest_header or "")
|
||||
overwrite = request.headers.get("Overwrite", "T").upper() != "F"
|
||||
await copy_path(full_src, dst, overwrite=overwrite)
|
||||
await VirtualFSService.copy_path(full_src, dst, overwrite=overwrite)
|
||||
return Response(status_code=201 if not overwrite else 204, headers=_dav_headers())
|
||||
12
main.py
12
main.py
@@ -1,15 +1,14 @@
|
||||
import os
|
||||
from services.config import VERSION, ConfigCenter
|
||||
from services.adapters.registry import runtime_registry
|
||||
from domain.config.service import ConfigService, VERSION
|
||||
from domain.adapters.registry import runtime_registry
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from contextlib import asynccontextmanager
|
||||
from db.session import close_db, init_db
|
||||
from api.routers import include_routers
|
||||
from fastapi import FastAPI
|
||||
from services.middleware.logging_middleware import LoggingMiddleware
|
||||
from services.middleware.exception_handler import global_exception_handler
|
||||
from middleware.exception_handler import global_exception_handler
|
||||
from dotenv import load_dotenv
|
||||
from services.task_queue import task_queue_service
|
||||
from domain.tasks.task_queue import task_queue_service
|
||||
|
||||
load_dotenv()
|
||||
|
||||
@@ -19,7 +18,7 @@ async def lifespan(app: FastAPI):
|
||||
os.makedirs("data/db", exist_ok=True)
|
||||
await init_db()
|
||||
await runtime_registry.refresh()
|
||||
await ConfigCenter.set("APP_VERSION", VERSION)
|
||||
await ConfigService.set("APP_VERSION", VERSION)
|
||||
await task_queue_service.start_worker()
|
||||
try:
|
||||
yield
|
||||
@@ -35,7 +34,6 @@ def create_app() -> FastAPI:
|
||||
lifespan=lifespan,
|
||||
)
|
||||
include_routers(app)
|
||||
app.add_middleware(LoggingMiddleware)
|
||||
app.add_exception_handler(Exception, global_exception_handler)
|
||||
return app
|
||||
|
||||
|
||||
1
middleware/__init__.py
Normal file
1
middleware/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Middleware package for FastAPI app."""
|
||||
11
middleware/exception_handler.py
Normal file
11
middleware/exception_handler.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from fastapi import Request, status
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
async def global_exception_handler(request: Request, exc: Exception):
|
||||
"""
|
||||
全局异常处理
|
||||
"""
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
content={"error": "Internal Server Error", "detail": str(exc)},
|
||||
)
|
||||
@@ -128,17 +128,25 @@ class AutomationTask(Model):
|
||||
table = "automation_tasks"
|
||||
|
||||
|
||||
class Log(Model):
|
||||
class AuditLog(Model):
|
||||
id = fields.IntField(pk=True)
|
||||
timestamp = fields.DatetimeField(auto_now_add=True)
|
||||
level = fields.CharField(max_length=50)
|
||||
source = fields.CharField(max_length=100)
|
||||
message = fields.TextField()
|
||||
details = fields.JSONField(null=True)
|
||||
created_at = fields.DatetimeField(auto_now_add=True)
|
||||
action = fields.CharField(max_length=50)
|
||||
description = fields.TextField(null=True)
|
||||
user_id = fields.IntField(null=True)
|
||||
username = fields.CharField(max_length=100, null=True)
|
||||
client_ip = fields.CharField(max_length=64, null=True)
|
||||
method = fields.CharField(max_length=10)
|
||||
path = fields.CharField(max_length=1024)
|
||||
status_code = fields.IntField()
|
||||
duration_ms = fields.FloatField(null=True)
|
||||
success = fields.BooleanField(default=True)
|
||||
request_params = fields.JSONField(null=True)
|
||||
request_body = fields.JSONField(null=True)
|
||||
error = fields.TextField(null=True)
|
||||
|
||||
class Meta:
|
||||
table = "logs"
|
||||
table = "audit_logs"
|
||||
|
||||
|
||||
class ShareLink(Model):
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
from schemas.plugins import PluginCreate,PluginOut
|
||||
from .adapters import AdapterCreate, AdapterOut
|
||||
from .fs import MkdirRequest, MoveRequest
|
||||
|
||||
__all__ = [
|
||||
"PluginOut"
|
||||
"PluginCreate"
|
||||
"AdapterCreate",
|
||||
"AdapterOut",
|
||||
"MkdirRequest",
|
||||
"MoveRequest",
|
||||
]
|
||||
@@ -1,18 +0,0 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel, EmailStr, Field
|
||||
|
||||
|
||||
class EmailTestRequest(BaseModel):
|
||||
to: EmailStr
|
||||
subject: str = Field(..., min_length=1)
|
||||
template: str = Field(default="test", min_length=1)
|
||||
context: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class EmailTemplateUpdate(BaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
class EmailTemplatePreviewPayload(BaseModel):
|
||||
context: Dict[str, Any] = Field(default_factory=dict)
|
||||
@@ -1,27 +0,0 @@
|
||||
from typing import List, Optional, Dict, Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class PluginCreate(BaseModel):
|
||||
url: str = Field(min_length=1)
|
||||
enabled: bool = True
|
||||
|
||||
|
||||
class PluginOut(BaseModel):
|
||||
id: int
|
||||
url: str
|
||||
enabled: bool
|
||||
key: Optional[str]
|
||||
name: Optional[str]
|
||||
version: Optional[str]
|
||||
supported_exts: Optional[List[str]]
|
||||
default_bounds: Optional[Dict[str, Any]]
|
||||
default_maximized: Optional[bool]
|
||||
icon: Optional[str]
|
||||
description: Optional[str]
|
||||
author: Optional[str]
|
||||
website: Optional[str]
|
||||
github: Optional[str]
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
321
services/auth.py
321
services/auth.py
@@ -1,321 +0,0 @@
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Annotated
|
||||
import secrets
|
||||
|
||||
import jwt
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from jwt.exceptions import InvalidTokenError
|
||||
from passlib.context import CryptContext
|
||||
from pydantic import BaseModel
|
||||
|
||||
from models.database import UserAccount
|
||||
from services.config import ConfigCenter
|
||||
from services.logging import LogService
|
||||
|
||||
ALGORITHM = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 * 365
|
||||
PASSWORD_RESET_TOKEN_EXPIRE_MINUTES = 10
|
||||
|
||||
|
||||
def _now() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PasswordResetEntry:
|
||||
user_id: int
|
||||
email: str
|
||||
username: str
|
||||
expires_at: datetime
|
||||
used: bool = False
|
||||
|
||||
|
||||
class PasswordResetStore:
|
||||
_tokens: dict[str, PasswordResetEntry] = {}
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
@classmethod
|
||||
def _cleanup(cls):
|
||||
now = _now()
|
||||
for token, record in list(cls._tokens.items()):
|
||||
if record.used or record.expires_at < now:
|
||||
cls._tokens.pop(token, None)
|
||||
|
||||
@classmethod
|
||||
async def create(cls, user: UserAccount) -> str:
|
||||
async with cls._lock:
|
||||
cls._cleanup()
|
||||
for key, record in list(cls._tokens.items()):
|
||||
if record.user_id == user.id:
|
||||
cls._tokens.pop(key, None)
|
||||
token = secrets.token_urlsafe(32)
|
||||
expires_at = _now() + timedelta(minutes=PASSWORD_RESET_TOKEN_EXPIRE_MINUTES)
|
||||
cls._tokens[token] = PasswordResetEntry(
|
||||
user_id=user.id,
|
||||
email=user.email or "",
|
||||
username=user.username,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
return token
|
||||
|
||||
@classmethod
|
||||
async def get(cls, token: str) -> PasswordResetEntry | None:
|
||||
async with cls._lock:
|
||||
cls._cleanup()
|
||||
record = cls._tokens.get(token)
|
||||
if not record or record.used:
|
||||
return None
|
||||
return record
|
||||
|
||||
@classmethod
|
||||
async def mark_used(cls, token: str) -> None:
|
||||
async with cls._lock:
|
||||
record = cls._tokens.get(token)
|
||||
if record:
|
||||
record.used = True
|
||||
cls._cleanup()
|
||||
|
||||
@classmethod
|
||||
async def invalidate_user(cls, user_id: int, except_token: str | None = None) -> None:
|
||||
async with cls._lock:
|
||||
for key, record in list(cls._tokens.items()):
|
||||
if record.user_id == user_id and key != except_token:
|
||||
cls._tokens.pop(key, None)
|
||||
cls._cleanup()
|
||||
|
||||
|
||||
async def get_secret_key():
|
||||
return await ConfigCenter.get_secret_key(
|
||||
"SECRET_KEY", None
|
||||
)
|
||||
|
||||
|
||||
class Token(BaseModel):
|
||||
access_token: str
|
||||
token_type: str
|
||||
|
||||
|
||||
class TokenData(BaseModel):
|
||||
username: str | None = None
|
||||
|
||||
|
||||
class User(BaseModel):
|
||||
id:int
|
||||
username: str
|
||||
email: str | None = None
|
||||
full_name: str | None = None
|
||||
disabled: bool | None = None
|
||||
|
||||
|
||||
class UserInDB(User):
|
||||
hashed_password: str
|
||||
|
||||
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login")
|
||||
|
||||
|
||||
def verify_password(plain_password, hashed_password):
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
|
||||
|
||||
def get_password_hash(password):
|
||||
return pwd_context.hash(password)
|
||||
|
||||
|
||||
def get_user(db, username: str):
|
||||
if username in db:
|
||||
user_dict = db[username]
|
||||
return UserInDB(**user_dict)
|
||||
|
||||
|
||||
async def get_user_db(username_or_email: str):
|
||||
user = await UserAccount.get_or_none(username=username_or_email)
|
||||
if not user:
|
||||
user = await UserAccount.get_or_none(email=username_or_email)
|
||||
if user:
|
||||
return UserInDB(
|
||||
id= user.id,
|
||||
username=user.username,
|
||||
email=user.email,
|
||||
full_name=user.full_name,
|
||||
disabled=user.disabled,
|
||||
hashed_password=user.hashed_password,
|
||||
)
|
||||
|
||||
|
||||
def authenticate_user(fake_db, username: str, password: str):
|
||||
user = get_user(fake_db, username)
|
||||
if not user:
|
||||
return False
|
||||
if not verify_password(password, user.hashed_password):
|
||||
return False
|
||||
return user
|
||||
|
||||
|
||||
async def authenticate_user_db(username_or_email: str, password: str):
|
||||
user = await get_user_db(username_or_email)
|
||||
if not user:
|
||||
return False
|
||||
if not verify_password(password, user.hashed_password):
|
||||
return False
|
||||
return user
|
||||
|
||||
|
||||
async def register_user(username: str, password: str, email: str = None, full_name: str = None):
|
||||
if await has_users():
|
||||
raise HTTPException(status_code=403, detail="系统已初始化,不允许注册新用户")
|
||||
exists = await UserAccount.get_or_none(username=username)
|
||||
if exists:
|
||||
raise HTTPException(status_code=400, detail="用户名已存在")
|
||||
hashed = get_password_hash(password)
|
||||
user = await UserAccount.create(
|
||||
username=username,
|
||||
email=email,
|
||||
full_name=full_name,
|
||||
hashed_password=hashed,
|
||||
disabled=False,
|
||||
)
|
||||
return user
|
||||
|
||||
|
||||
async def has_users() -> bool:
|
||||
"""
|
||||
检查数据库中是否存在任何用户
|
||||
"""
|
||||
user_count = await UserAccount.all().count()
|
||||
return user_count > 0
|
||||
|
||||
|
||||
async def create_access_token(data: dict, expires_delta: timedelta | None = None):
|
||||
to_encode = data.copy()
|
||||
if "sub" not in to_encode and "username" in to_encode:
|
||||
to_encode["sub"] = to_encode["username"]
|
||||
if expires_delta:
|
||||
expire = datetime.now(timezone.utc) + expires_delta
|
||||
else:
|
||||
expire = datetime.now(timezone.utc) + timedelta(minutes=15)
|
||||
to_encode.update({"exp": expire})
|
||||
secret_key = await get_secret_key()
|
||||
encoded_jwt = jwt.encode(to_encode, secret_key, algorithm=ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
def _normalize_email(email: str | None) -> str:
|
||||
return (email or "").strip().lower()
|
||||
|
||||
|
||||
async def _send_password_reset_email(user: UserAccount, token: str) -> None:
|
||||
from services.email import EmailService
|
||||
|
||||
app_domain = await ConfigCenter.get("APP_DOMAIN", None)
|
||||
base_url = (app_domain or "http://localhost:5173").rstrip("/")
|
||||
reset_link = f"{base_url}/reset-password?token={token}"
|
||||
await EmailService.enqueue_email(
|
||||
recipients=[user.email],
|
||||
subject="Foxel 密码重置",
|
||||
template="password_reset",
|
||||
context={
|
||||
"username": user.username,
|
||||
"reset_link": reset_link,
|
||||
"expire_minutes": PASSWORD_RESET_TOKEN_EXPIRE_MINUTES,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def request_password_reset(email: str) -> bool:
|
||||
normalized = _normalize_email(email)
|
||||
if not normalized:
|
||||
return False
|
||||
user = await UserAccount.get_or_none(email=normalized)
|
||||
if not user or not user.email:
|
||||
return False
|
||||
|
||||
token = await PasswordResetStore.create(user)
|
||||
try:
|
||||
await _send_password_reset_email(user, token)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
await PasswordResetStore.mark_used(token)
|
||||
await PasswordResetStore.invalidate_user(user.id)
|
||||
await LogService.error(
|
||||
"auth",
|
||||
f"Failed to enqueue password reset email: {exc}",
|
||||
details={"user_id": user.id},
|
||||
user_id=user.id,
|
||||
)
|
||||
raise HTTPException(status_code=500, detail="邮件发送失败") from exc
|
||||
await LogService.action(
|
||||
"auth",
|
||||
"Password reset requested",
|
||||
details={"user_id": user.id},
|
||||
user_id=user.id,
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
async def verify_password_reset_token(token: str) -> UserAccount:
|
||||
record = await PasswordResetStore.get(token)
|
||||
if not record:
|
||||
raise HTTPException(status_code=400, detail="重置链接无效")
|
||||
user = await UserAccount.get_or_none(id=record.user_id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=400, detail="重置链接无效")
|
||||
if record.expires_at < _now():
|
||||
await PasswordResetStore.mark_used(token)
|
||||
raise HTTPException(status_code=400, detail="重置链接已过期")
|
||||
return user
|
||||
|
||||
|
||||
async def reset_password_with_token(token: str, new_password: str) -> None:
|
||||
record = await PasswordResetStore.get(token)
|
||||
if not record:
|
||||
raise HTTPException(status_code=400, detail="重置链接无效")
|
||||
if record.expires_at < _now():
|
||||
await PasswordResetStore.mark_used(token)
|
||||
raise HTTPException(status_code=400, detail="重置链接已过期")
|
||||
|
||||
user = await UserAccount.get_or_none(id=record.user_id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=400, detail="重置链接无效")
|
||||
user.hashed_password = get_password_hash(new_password)
|
||||
await user.save(update_fields=["hashed_password"])
|
||||
await PasswordResetStore.mark_used(token)
|
||||
await PasswordResetStore.invalidate_user(user.id)
|
||||
await LogService.action(
|
||||
"auth",
|
||||
"Password reset via email",
|
||||
details={"user_id": user.id},
|
||||
user_id=user.id,
|
||||
)
|
||||
|
||||
|
||||
async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]):
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
try:
|
||||
secret_key = await get_secret_key()
|
||||
payload = jwt.decode(token, secret_key, algorithms=[ALGORITHM])
|
||||
username = payload.get("sub")
|
||||
if username is None:
|
||||
raise credentials_exception
|
||||
token_data = TokenData(username=username)
|
||||
except InvalidTokenError:
|
||||
raise credentials_exception
|
||||
user = await get_user_db(token_data.username)
|
||||
if user is None:
|
||||
raise credentials_exception
|
||||
return user
|
||||
|
||||
|
||||
async def get_current_active_user(
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
):
|
||||
if current_user.disabled:
|
||||
raise HTTPException(status_code=400, detail="Inactive user")
|
||||
return current_user
|
||||
@@ -1,80 +0,0 @@
|
||||
from tortoise.transactions import in_transaction
|
||||
from models.database import (
|
||||
StorageAdapter,
|
||||
UserAccount,
|
||||
AutomationTask,
|
||||
ShareLink,
|
||||
Configuration,
|
||||
)
|
||||
from services.config import VERSION
|
||||
|
||||
|
||||
class BackupService:
|
||||
@staticmethod
|
||||
async def export_data():
|
||||
"""
|
||||
导出所有相关数据到JSON格式。
|
||||
"""
|
||||
async with in_transaction() as conn:
|
||||
adapters = await StorageAdapter.all().values()
|
||||
users = await UserAccount.all().values()
|
||||
tasks = await AutomationTask.all().values()
|
||||
shares = await ShareLink.all().values()
|
||||
configs = await Configuration.all().values()
|
||||
|
||||
for share in shares:
|
||||
share["created_at"] = share["created_at"].isoformat(
|
||||
) if share.get("created_at") else None
|
||||
share["expires_at"] = share["expires_at"].isoformat(
|
||||
) if share.get("expires_at") else None
|
||||
|
||||
return {
|
||||
"version": VERSION,
|
||||
"storage_adapters": list(adapters),
|
||||
"user_accounts": list(users),
|
||||
"automation_tasks": list(tasks),
|
||||
"share_links": list(shares),
|
||||
"configurations": list(configs),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def import_data(data: dict):
|
||||
"""
|
||||
从JSON数据导入到数据库。
|
||||
"""
|
||||
async with in_transaction() as conn:
|
||||
await ShareLink.all().using_db(conn).delete()
|
||||
await AutomationTask.all().using_db(conn).delete()
|
||||
await StorageAdapter.all().using_db(conn).delete()
|
||||
await UserAccount.all().using_db(conn).delete()
|
||||
await Configuration.all().using_db(conn).delete()
|
||||
|
||||
if data.get("configurations"):
|
||||
await Configuration.bulk_create(
|
||||
[Configuration(**c) for c in data["configurations"]],
|
||||
using_db=conn
|
||||
)
|
||||
|
||||
if data.get("user_accounts"):
|
||||
await UserAccount.bulk_create(
|
||||
[UserAccount(**u) for u in data["user_accounts"]],
|
||||
using_db=conn
|
||||
)
|
||||
|
||||
if data.get("storage_adapters"):
|
||||
await StorageAdapter.bulk_create(
|
||||
[StorageAdapter(**a) for a in data["storage_adapters"]],
|
||||
using_db=conn
|
||||
)
|
||||
|
||||
if data.get("automation_tasks"):
|
||||
await AutomationTask.bulk_create(
|
||||
[AutomationTask(**t) for t in data["automation_tasks"]],
|
||||
using_db=conn
|
||||
)
|
||||
|
||||
if data.get("share_links"):
|
||||
await ShareLink.bulk_create(
|
||||
[ShareLink(**s) for s in data["share_links"]],
|
||||
using_db=conn
|
||||
)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user