refactor: optimize backend module

This commit is contained in:
shiyu
2025-12-08 17:46:45 +08:00
parent cf8d10f71c
commit 8f515aaaf4
124 changed files with 6884 additions and 6390 deletions

View File

@@ -137,8 +137,8 @@ Install the following tooling first:
Storage adapters integrate new storage providers (for example S3, FTP, or Alist).
1. Create a new module under [`services/adapters/`](services/adapters/) (for example `my_new_adapter.py`).
2. Implement a class that inherits from [`services.adapters.base.BaseAdapter`](services/adapters/base.py) and provide concrete implementations for the abstract methods such as `list_dir`, `get_meta`, `upload`, and `download`.
1. Create a new module under [`domain/adapters/providers/`](domain/adapters/providers/) (for example `my_new_adapter.py`).
2. Implement a class that inherits from [`domain.adapters.providers.base.BaseAdapter`](domain/adapters/providers/base.py) and provide concrete implementations for the abstract methods such as `list_dir`, `get_meta`, `upload`, and `download`.
### Frontend Apps

View File

@@ -143,9 +143,9 @@
存储适配器是 Foxel 的核心扩展点,用于接入不同的存储后端 (如 S3, FTP, Alist 等)。
1. **创建适配器文件**: 在 [`services/adapters/`](services/adapters/) 目录下,创建一个新文件,例如 `my_new_adapter.py`。
1. **创建适配器文件**: 在 [`domain/adapters/providers/`](domain/adapters/providers/) 目录下,创建一个新文件,例如 `my_new_adapter.py`。
2. **实现适配器类**:
- 创建一个类,继承自 [`services.adapters.base.BaseAdapter`](services/adapters/base.py)。
- 创建一个类,继承自 [`domain.adapters.providers.base.BaseAdapter`](domain/adapters/providers/base.py)。
- 实现 `BaseAdapter` 中定义的所有抽象方法,如 `list_dir`, `get_meta`, `upload`, `download` 等。请仔细阅读基类中的文档注释以理解每个方法的作用和参数。
### 贡献前端应用 (App)

View File

@@ -1,26 +1,37 @@
from fastapi import FastAPI
from .routes import adapters, virtual_fs, auth, config, processors, tasks, logs, share, backup, search, vector_db, offline_downloads, ai_providers, email
from .routes import webdav, s3
from .routes import plugins
from domain.adapters import api as adapters
from domain.auth import api as auth
from domain.backup import api as backup
from domain.config import api as config
from domain.email import api as email
from domain.offline_downloads import api as offline_downloads
from domain.plugins import api as plugins
from domain.processors import api as processors
from domain.share import api as share
from domain.tasks import api as tasks
from domain.ai import api as ai
from domain.virtual_fs import api as virtual_fs
from domain.virtual_fs import s3_api, search_api, webdav_api
from domain.audit import router as audit
def include_routers(app: FastAPI):
app.include_router(adapters.router)
app.include_router(virtual_fs.router)
app.include_router(search.router)
app.include_router(search_api.router)
app.include_router(auth.router)
app.include_router(config.router)
app.include_router(processors.router)
app.include_router(tasks.router)
app.include_router(logs.router)
app.include_router(share.router)
app.include_router(share.public_router)
app.include_router(backup.router)
app.include_router(vector_db.router)
app.include_router(ai_providers.router)
app.include_router(ai.router_vector_db)
app.include_router(ai.router_ai)
app.include_router(plugins.router)
app.include_router(webdav.router)
app.include_router(s3.router)
app.include_router(webdav_api.router)
app.include_router(s3_api.router)
app.include_router(offline_downloads.router)
app.include_router(email.router)
app.include_router(audit)

View File

@@ -1,151 +0,0 @@
from fastapi import APIRouter, HTTPException, Depends
from tortoise.transactions import in_transaction
from typing import Annotated
from models import StorageAdapter
from schemas import AdapterCreate, AdapterOut
from services.auth import get_current_active_user, User
from services.adapters.registry import runtime_registry, get_config_schemas, normalize_adapter_type
from api.response import success
from services.logging import LogService
router = APIRouter(prefix="/api/adapters", tags=["adapters"])
def validate_and_normalize_config(adapter_type: str, cfg):
schemas = get_config_schemas()
adapter_type = normalize_adapter_type(adapter_type)
if not adapter_type:
raise HTTPException(400, detail="不支持的适配器类型")
if not isinstance(cfg, dict):
raise HTTPException(400, detail="config 必须是对象")
schema = schemas.get(adapter_type)
if not schema:
raise HTTPException(400, detail=f"不支持的适配器类型: {adapter_type}")
out = {}
missing = []
for f in schema:
k = f["key"]
if k in cfg and cfg[k] not in (None, ""):
out[k] = cfg[k]
elif "default" in f:
out[k] = f["default"]
elif f.get("required"):
missing.append(k)
if missing:
raise HTTPException(400, detail="缺少必填配置字段: " + ", ".join(missing))
return out
@router.post("")
async def create_adapter(
data: AdapterCreate,
current_user: Annotated[User, Depends(get_current_active_user)]
):
norm_path = AdapterCreate.normalize_mount_path(data.path)
exists = await StorageAdapter.get_or_none(path=norm_path)
if exists:
raise HTTPException(400, detail="Mount path already exists")
adapter_fields = {
"name": data.name,
"type": data.type,
"config": validate_and_normalize_config(data.type, data.config or {}),
"enabled": data.enabled,
"path": norm_path,
"sub_path": data.sub_path,
}
rec = await StorageAdapter.create(**adapter_fields)
await runtime_registry.upsert(rec)
await LogService.action(
"route:adapters",
f"Created adapter {rec.name}",
details=adapter_fields,
user_id=current_user.id if hasattr(current_user, "id") else None,
)
return success(rec)
@router.get("")
async def list_adapters(
current_user: Annotated[User, Depends(get_current_active_user)]
):
adapters = await StorageAdapter.all()
out = [AdapterOut.model_validate(a) for a in adapters]
return success(out)
@router.get("/available")
async def available_adapter_types(
current_user: Annotated[User, Depends(get_current_active_user)]
):
data = []
for t, fields in get_config_schemas().items():
data.append({
"type": t,
"config_schema": fields,
})
return success(data)
@router.get("/{adapter_id}")
async def get_adapter(
adapter_id: int,
current_user: Annotated[User, Depends(get_current_active_user)]
):
rec = await StorageAdapter.get_or_none(id=adapter_id)
if not rec:
raise HTTPException(404, detail="Not found")
return success(AdapterOut.model_validate(rec))
@router.put("/{adapter_id}")
async def update_adapter(
adapter_id: int,
data: AdapterCreate,
current_user: Annotated[User, Depends(get_current_active_user)]
):
rec = await StorageAdapter.get_or_none(id=adapter_id)
if not rec:
raise HTTPException(404, detail="Not found")
norm_path = AdapterCreate.normalize_mount_path(data.path)
existing = await StorageAdapter.get_or_none(path=norm_path)
if existing and existing.id != adapter_id:
raise HTTPException(400, detail="Mount path already exists")
rec.name = data.name
rec.type = data.type
rec.config = validate_and_normalize_config(data.type, data.config or {})
rec.enabled = data.enabled
rec.path = norm_path
rec.sub_path = data.sub_path
await rec.save()
await runtime_registry.upsert(rec)
await LogService.action(
"route:adapters",
f"Updated adapter {rec.name}",
details=data.model_dump(),
user_id=current_user.id if hasattr(current_user, "id") else None,
)
return success(rec)
@router.delete("/{adapter_id}")
async def delete_adapter(
adapter_id: int,
current_user: Annotated[User, Depends(get_current_active_user)]
):
deleted = await StorageAdapter.filter(id=adapter_id).delete()
if not deleted:
raise HTTPException(404, detail="Not found")
runtime_registry.remove(adapter_id)
await LogService.action(
"route:adapters",
f"Deleted adapter {adapter_id}",
details={"adapter_id": adapter_id},
user_id=current_user.id if hasattr(current_user, "id") else None,
)
return success({"deleted": True})

View File

@@ -1,177 +0,0 @@
from typing import Annotated, Dict, Optional
import httpx
from fastapi import APIRouter, Depends, HTTPException, Path
from api.response import success
from schemas.ai import (
AIDefaultsUpdate,
AIModelCreate,
AIModelUpdate,
AIProviderCreate,
AIProviderUpdate,
)
from services.ai_providers import AIProviderService
from services.auth import User, get_current_active_user
from services.vector_db import VectorDBService
router = APIRouter(prefix="/api/ai", tags=["ai"])
service = AIProviderService()
@router.get("/providers")
async def list_providers(
current_user: Annotated[User, Depends(get_current_active_user)]
):
providers = await service.list_providers()
return success({"providers": providers})
@router.post("/providers")
async def create_provider(
payload: AIProviderCreate,
current_user: Annotated[User, Depends(get_current_active_user)]
):
provider = await service.create_provider(payload.dict())
return success(provider)
@router.get("/providers/{provider_id}")
async def get_provider(
provider_id: Annotated[int, Path(..., gt=0)],
current_user: Annotated[User, Depends(get_current_active_user)],
):
provider = await service.get_provider(provider_id, with_models=True)
return success(provider)
@router.put("/providers/{provider_id}")
async def update_provider(
provider_id: Annotated[int, Path(..., gt=0)],
payload: AIProviderUpdate,
current_user: Annotated[User, Depends(get_current_active_user)],
):
data = {k: v for k, v in payload.dict().items() if v is not None}
if not data:
raise HTTPException(status_code=400, detail="No fields to update")
provider = await service.update_provider(provider_id, data)
return success(provider)
@router.delete("/providers/{provider_id}")
async def delete_provider(
provider_id: Annotated[int, Path(..., gt=0)],
current_user: Annotated[User, Depends(get_current_active_user)],
):
await service.delete_provider(provider_id)
return success({"id": provider_id})
@router.post("/providers/{provider_id}/sync-models")
async def sync_models(
provider_id: Annotated[int, Path(..., gt=0)],
current_user: Annotated[User, Depends(get_current_active_user)],
):
try:
result = await service.sync_models(provider_id)
except (httpx.RequestError, httpx.HTTPStatusError) as exc:
raise HTTPException(status_code=502, detail=f"Failed to synchronize models: {exc}") from exc
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
return success(result)
@router.get("/providers/{provider_id}/remote-models")
async def fetch_remote_models(
provider_id: Annotated[int, Path(..., gt=0)],
current_user: Annotated[User, Depends(get_current_active_user)],
):
try:
models = await service.fetch_remote_models(provider_id)
except (httpx.RequestError, httpx.HTTPStatusError) as exc:
raise HTTPException(status_code=502, detail=f"Failed to pull models: {exc}") from exc
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
return success({"models": models})
@router.get("/providers/{provider_id}/models")
async def list_models(
provider_id: Annotated[int, Path(..., gt=0)],
current_user: Annotated[User, Depends(get_current_active_user)],
):
models = await service.list_models(provider_id)
return success({"models": models})
@router.post("/providers/{provider_id}/models")
async def create_model(
provider_id: Annotated[int, Path(..., gt=0)],
payload: AIModelCreate,
current_user: Annotated[User, Depends(get_current_active_user)],
):
model = await service.create_model(provider_id, payload.dict())
return success(model)
@router.put("/models/{model_id}")
async def update_model(
model_id: Annotated[int, Path(..., gt=0)],
payload: AIModelUpdate,
current_user: Annotated[User, Depends(get_current_active_user)],
):
data = {k: v for k, v in payload.dict().items() if v is not None}
if not data:
raise HTTPException(status_code=400, detail="No fields to update")
model = await service.update_model(model_id, data)
return success(model)
@router.delete("/models/{model_id}")
async def delete_model(
model_id: Annotated[int, Path(..., gt=0)],
current_user: Annotated[User, Depends(get_current_active_user)],
):
await service.delete_model(model_id)
return success({"id": model_id})
def _get_embedding_dimension(entry: Optional[Dict]) -> Optional[int]:
if not entry:
return None
value = entry.get("embedding_dimensions")
return int(value) if value is not None else None
@router.get("/defaults")
async def get_defaults(
current_user: Annotated[User, Depends(get_current_active_user)],
):
defaults = await service.get_default_models()
return success(defaults)
@router.put("/defaults")
async def update_defaults(
payload: AIDefaultsUpdate,
current_user: Annotated[User, Depends(get_current_active_user)],
):
previous = await service.get_default_models()
try:
updated = await service.set_default_models(payload.as_mapping())
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
prev_dim = _get_embedding_dimension(previous.get("embedding"))
next_dim = _get_embedding_dimension(updated.get("embedding"))
if prev_dim and next_dim and prev_dim != next_dim:
try:
await VectorDBService().clear_all_data()
except Exception as exc: # noqa: BLE001
raise HTTPException(status_code=500, detail=f"Failed to clear vector database: {exc}") from exc
return success(updated)

View File

@@ -1,155 +0,0 @@
from typing import Annotated
from fastapi import APIRouter, HTTPException, Depends, Form
import hashlib
from fastapi.security import OAuth2PasswordRequestForm
from services.auth import (
authenticate_user_db,
create_access_token,
ACCESS_TOKEN_EXPIRE_MINUTES,
register_user,
Token,
get_current_active_user,
User,
request_password_reset,
verify_password_reset_token,
reset_password_with_token,
)
from pydantic import BaseModel
from datetime import timedelta
from api.response import success
from models.database import UserAccount
from services.auth import verify_password, get_password_hash
router = APIRouter(prefix="/api/auth", tags=["auth"])
class RegisterRequest(BaseModel):
username: str
password: str
email: str | None = None
full_name: str | None = None
@router.post("/register", summary="注册第一个管理员用户")
async def register(data: RegisterRequest):
"""
仅当系统中没有用户时,才允许注册。
"""
user = await register_user(
username=data.username,
password=data.password,
email=data.email,
full_name=data.full_name,
)
return success({"username": user.username}, msg="初始用户注册成功")
@router.post("/login")
async def login_for_access_token(
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
) -> Token:
user = await authenticate_user_db(form_data.username, form_data.password)
if not user:
raise HTTPException(
status_code=401,
detail="用户名或密码错误",
headers={"WWW-Authenticate": "Bearer"},
)
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = await create_access_token(
data={"sub": user.username}, expires_delta=access_token_expires
)
return Token(access_token=access_token, token_type="bearer")
@router.get("/me", summary="获取当前登录用户信息")
async def get_me(current_user: Annotated[User, Depends(get_current_active_user)]):
"""
返回当前登录用户的基本信息,并附带 gravatar 头像链接。
"""
email = (current_user.email or "").strip().lower()
md5_hash = hashlib.md5(email.encode("utf-8")).hexdigest()
gravatar_url = f"https://cn.cravatar.com/avatar/{md5_hash}?s=64&d=identicon"
return success({
"id": current_user.id,
"username": current_user.username,
"email": current_user.email,
"full_name": current_user.full_name,
"gravatar_url": gravatar_url,
})
class UpdateMeRequest(BaseModel):
email: str | None = None
full_name: str | None = None
old_password: str | None = None
new_password: str | None = None
class PasswordResetRequest(BaseModel):
email: str
class PasswordResetConfirm(BaseModel):
token: str
password: str
@router.put("/me", summary="更新当前登录用户信息")
async def update_me(
payload: UpdateMeRequest,
current_user: Annotated[User, Depends(get_current_active_user)],
):
db_user = await UserAccount.get_or_none(id=current_user.id)
if not db_user:
raise HTTPException(status_code=404, detail="用户不存在")
if payload.email is not None:
exists = await UserAccount.filter(email=payload.email).exclude(id=db_user.id).exists()
if exists:
raise HTTPException(status_code=400, detail="邮箱已被占用")
db_user.email = payload.email
if payload.full_name is not None:
db_user.full_name = payload.full_name
if payload.new_password:
if not payload.old_password:
raise HTTPException(status_code=400, detail="请提供原密码")
if not verify_password(payload.old_password, db_user.hashed_password):
raise HTTPException(status_code=400, detail="原密码错误")
db_user.hashed_password = get_password_hash(payload.new_password)
await db_user.save()
email = (db_user.email or "").strip().lower()
md5_hash = hashlib.md5(email.encode("utf-8")).hexdigest()
gravatar_url = f"https://cn.cravatar.com/avatar/{md5_hash}?s=64&d=identicon"
return success({
"id": db_user.id,
"username": db_user.username,
"email": db_user.email,
"full_name": db_user.full_name,
"gravatar_url": gravatar_url,
})
@router.post("/password-reset/request", summary="请求密码重置邮件")
async def password_reset_request_endpoint(payload: PasswordResetRequest):
await request_password_reset(payload.email)
return success(msg="如果邮箱存在,将发送重置邮件")
@router.get("/password-reset/verify", summary="校验密码重置令牌")
async def password_reset_verify(token: str):
user = await verify_password_reset_token(token)
return success({
"username": user.username,
"email": user.email,
})
@router.post("/password-reset/confirm", summary="使用令牌重置密码")
async def password_reset_confirm(payload: PasswordResetConfirm):
await reset_password_with_token(payload.token, payload.password)
return success(msg="密码已重置")

View File

@@ -1,50 +0,0 @@
from fastapi import APIRouter, Depends, UploadFile, File, HTTPException
from fastapi.responses import JSONResponse
from services.auth import get_current_active_user
from services.backup import BackupService
from models.database import UserAccount
import json
import datetime
router = APIRouter(
prefix="/api/backup",
tags=["Backup & Restore"],
dependencies=[Depends(get_current_active_user)],
)
@router.get("/export", summary="导出全站数据")
async def export_backup():
"""
生成并下载一个包含所有关键数据的JSON文件。
"""
try:
data = await BackupService.export_data()
timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
headers = {
"Content-Disposition": f"attachment; filename=foxel_backup_{timestamp}.json"
}
return JSONResponse(content=data, headers=headers)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/import", summary="导入数据")
async def import_backup(file: UploadFile = File(...)):
"""
从上传的JSON文件恢复数据。
**警告**: 这将会覆盖所有现有数据!
"""
if not file.filename.endswith(".json"):
raise HTTPException(status_code=400, detail="无效的文件类型, 请上传 .json 文件")
try:
contents = await file.read()
data = json.loads(contents)
except Exception:
raise HTTPException(status_code=400, detail="无法解析JSON文件")
try:
await BackupService.import_data(data)
return {"message": "数据导入成功。"}
except Exception as e:
raise HTTPException(status_code=500, detail=f"导入失败: {e}")

View File

@@ -1,83 +0,0 @@
import httpx
import time
from fastapi import APIRouter, Depends, Form
from typing import Annotated
from services.config import ConfigCenter, VERSION
from services.auth import get_current_active_user, User, has_users
from api.response import success
router = APIRouter(prefix="/api/config", tags=["config"])
@router.get("/")
async def get_config(
current_user: Annotated[User, Depends(get_current_active_user)],
key: str
):
value = await ConfigCenter.get(key)
return success({"key": key, "value": value})
@router.post("/")
async def set_config(
current_user: Annotated[User, Depends(get_current_active_user)],
key: str = Form(...),
value: str = Form(...)
):
await ConfigCenter.set(key, value)
return success({"key": key, "value": value})
@router.get("/all")
async def get_all_config(
current_user: Annotated[User, Depends(get_current_active_user)]
):
configs = await ConfigCenter.get_all()
return success(configs)
@router.get("/status")
async def get_system_status():
logo = await ConfigCenter.get("APP_LOGO", "/logo.svg")
favicon = await ConfigCenter.get("APP_FAVICON", logo)
system_info = {
"version": VERSION,
"title": await ConfigCenter.get("APP_NAME", "Foxel"),
"logo": logo,
"favicon": favicon,
"is_initialized": await has_users(),
"app_domain": await ConfigCenter.get("APP_DOMAIN"),
"file_domain": await ConfigCenter.get("FILE_DOMAIN"),
}
return success(system_info)
latest_version_cache = {
"timestamp": 0,
"data": None
}
@router.get("/latest-version")
async def get_latest_version():
current_time = time.time()
if current_time - latest_version_cache["timestamp"] < 3600 and latest_version_cache["data"]:
return success(latest_version_cache["data"])
try:
async with httpx.AsyncClient(timeout=10.0) as client:
resp = await client.get(
"https://api.github.com/repos/DrizzleTime/Foxel/releases/latest",
follow_redirects=True,
)
resp.raise_for_status()
data = resp.json()
version_info = {
"latest_version": data.get("tag_name"),
"body": data.get("body")
}
latest_version_cache["timestamp"] = current_time
latest_version_cache["data"] = version_info
return success(version_info)
except httpx.RequestError as e:
if latest_version_cache["data"]:
return success(latest_version_cache["data"])
return success({"latest_version": None, "body": None})

View File

@@ -1,48 +0,0 @@
from typing import Optional
from fastapi import APIRouter, Query
from models.database import Log
from api.response import page, success
from tortoise.expressions import Q
from datetime import datetime
router = APIRouter(prefix="/api/logs", tags=["Logs"])
@router.get("")
async def get_logs(
page_num: int = Query(1, alias="page"),
page_size: int = Query(20, alias="page_size"),
level: Optional[str] = Query(None),
source: Optional[str] = Query(None),
start_time: Optional[datetime] = Query(None),
end_time: Optional[datetime] = Query(None),
):
"""获取日志列表,支持分页和筛选"""
query = Log.all()
if level:
query = query.filter(level=level)
if source:
query = query.filter(source__icontains=source)
if start_time:
query = query.filter(timestamp__gte=start_time)
if end_time:
query = query.filter(timestamp__lte=end_time)
total = await query.count()
logs = await query.order_by("-timestamp").offset((page_num - 1) * page_size).limit(page_size)
return success(page([log for log in logs], total, page_num, page_size))
@router.delete("")
async def clear_logs(
start_time: Optional[datetime] = Query(None),
end_time: Optional[datetime] = Query(None),
):
"""清理指定时间范围内的日志"""
query = Log.all()
if start_time:
query = query.filter(timestamp__gte=start_time)
if end_time:
query = query.filter(timestamp__lte=end_time)
deleted_count = await query.delete()
return success({"deleted_count": deleted_count})

View File

@@ -1,79 +0,0 @@
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException
from api.response import success
from schemas.offline_downloads import OfflineDownloadCreate
from services.auth import User, get_current_active_user
from services.logging import LogService
from services.task_queue import task_queue_service, TaskProgress
from services.virtual_fs import path_is_directory
router = APIRouter(
prefix="/api/offline-downloads",
tags=["OfflineDownloads"],
)
@router.post("/")
async def create_offline_download(
payload: OfflineDownloadCreate,
current_user: Annotated[User, Depends(get_current_active_user)],
):
dest_dir = payload.dest_dir
try:
is_dir = await path_is_directory(dest_dir)
except HTTPException:
is_dir = False
if not is_dir:
raise HTTPException(400, detail="Destination directory not found")
task = await task_queue_service.add_task(
"offline_http_download",
{
"url": str(payload.url),
"dest_dir": dest_dir,
"filename": payload.filename,
},
)
await task_queue_service.update_progress(
task.id,
TaskProgress(
stage="queued",
percent=0.0,
bytes_total=None,
bytes_done=0,
detail="Waiting to start",
),
)
await LogService.action(
"route:offline_downloads",
f"Offline download task created {task.id}",
details={"url": str(payload.url), "dest_dir": dest_dir, "filename": payload.filename},
user_id=current_user.id if hasattr(current_user, "id") else None,
)
return success({"task_id": task.id})
@router.get("/")
async def list_offline_downloads(
current_user: Annotated[User, Depends(get_current_active_user)],
):
tasks = [t for t in task_queue_service.get_all_tasks() if t.name == "offline_http_download"]
data = [t.dict() for t in tasks]
return success(data)
@router.get("/{task_id}")
async def get_offline_download(
task_id: str,
current_user: Annotated[User, Depends(get_current_active_user)],
):
task = task_queue_service.get_task(task_id)
if not task or task.name != "offline_http_download":
raise HTTPException(status_code=404, detail="Task not found")
return success(task.dict())

View File

@@ -1,73 +0,0 @@
from typing import List, Any, Dict
from fastapi import APIRouter, HTTPException, Body
from models import database
from schemas import PluginCreate, PluginOut
router = APIRouter(prefix="/api/plugins", tags=["plugins"])
@router.post("", response_model=PluginOut)
async def create_plugin(payload: PluginCreate):
rec = await database.Plugin.create(
url=payload.url,
enabled=payload.enabled,
)
return PluginOut.model_validate(rec)
@router.get("", response_model=List[PluginOut])
async def list_plugins():
rows = await database.Plugin.all().order_by("-id")
return [PluginOut.model_validate(r) for r in rows]
@router.delete("/{plugin_id}")
async def delete_plugin(plugin_id: int):
rec = await database.Plugin.get_or_none(id=plugin_id)
if not rec:
raise HTTPException(status_code=404, detail="Plugin not found")
await rec.delete()
return {"code": 0, "msg": "ok"}
@router.put("/{plugin_id}", response_model=PluginOut)
async def update_plugin(plugin_id: int, payload: PluginCreate):
rec = await database.Plugin.get_or_none(id=plugin_id)
if not rec:
raise HTTPException(status_code=404, detail="Plugin not found")
rec.url = payload.url
rec.enabled = payload.enabled
await rec.save()
return PluginOut.model_validate(rec)
@router.post("/{plugin_id}/metadata", response_model=PluginOut)
async def update_manifest(plugin_id: int, manifest: Dict[str, Any] = Body(...)):
rec = await database.Plugin.get_or_none(id=plugin_id)
if not rec:
raise HTTPException(status_code=404, detail="Plugin not found")
key_map = {
'key': 'key',
'name': 'name',
'version': 'version',
'supported_exts': 'supported_exts',
'supportedExts': 'supported_exts',
'default_bounds': 'default_bounds',
'defaultBounds': 'default_bounds',
'default_maximized': 'default_maximized',
'defaultMaximized': 'default_maximized',
'icon': 'icon',
'description': 'description',
'author': 'author',
'website': 'website',
'github': 'github',
}
for k, v in list(manifest.items()):
if v is None:
continue
attr = key_map.get(k)
if not attr:
continue
setattr(rec, attr, v)
await rec.save()
return PluginOut.model_validate(rec)

View File

@@ -1,250 +0,0 @@
from pathlib import Path
from fastapi import APIRouter, Depends, Body, HTTPException
from fastapi.concurrency import run_in_threadpool
from typing import Annotated
from services.processors.registry import (
get,
get_config_schema,
get_config_schemas,
get_module_path,
reload_processors,
)
from services.task_queue import task_queue_service
from services.auth import get_current_active_user, User
from api.response import success
from pydantic import BaseModel
from services.virtual_fs import path_is_directory, resolve_adapter_and_rel
from typing import List, Optional, Tuple
router = APIRouter(prefix="/api/processors", tags=["processors"])
@router.get("")
async def list_processors(
current_user: Annotated[User, Depends(get_current_active_user)]
):
schemas = get_config_schemas()
out = []
for t, meta in schemas.items():
out.append({
"type": meta["type"],
"name": meta["name"],
"supported_exts": meta.get("supported_exts", []),
"config_schema": meta["config_schema"],
"produces_file": meta.get("produces_file", False),
"module_path": meta.get("module_path"),
})
return success(out)
class ProcessRequest(BaseModel):
path: str
processor_type: str
config: dict
save_to: str | None = None
overwrite: bool = False
class ProcessDirectoryRequest(BaseModel):
path: str
processor_type: str
config: dict
overwrite: bool = True
max_depth: Optional[int] = None
suffix: Optional[str] = None
class UpdateSourceRequest(BaseModel):
source: str
@router.post("/process")
async def process_file_with_processor(
current_user: Annotated[User, Depends(get_current_active_user)],
req: ProcessRequest = Body(...)
):
is_dir = await path_is_directory(req.path)
if is_dir and not req.overwrite:
raise HTTPException(400, detail="Directory processing requires overwrite")
save_to = None if is_dir else (req.path if req.overwrite else req.save_to)
task = await task_queue_service.add_task(
"process_file",
{
"path": req.path,
"processor_type": req.processor_type,
"config": req.config,
"save_to": save_to,
"overwrite": req.overwrite,
},
)
return success({"task_id": task.id})
@router.post("/process-directory")
async def process_directory_with_processor(
current_user: Annotated[User, Depends(get_current_active_user)],
req: ProcessDirectoryRequest = Body(...)
):
if req.max_depth is not None and req.max_depth < 0:
raise HTTPException(400, detail="max_depth must be >= 0")
is_dir = await path_is_directory(req.path)
if not is_dir:
raise HTTPException(400, detail="Path must be a directory")
schema = get_config_schema(req.processor_type)
_processor = get(req.processor_type)
if not schema or not _processor:
raise HTTPException(404, detail="Processor not found")
produces_file = bool(schema.get("produces_file"))
raw_suffix = req.suffix if req.suffix is not None else None
if raw_suffix is not None and raw_suffix.strip() == "":
raw_suffix = None
suffix = raw_suffix
overwrite = req.overwrite
if produces_file:
if not overwrite and not suffix:
raise HTTPException(400, detail="Suffix is required when not overwriting files")
else:
overwrite = False
suffix = None
supported_exts = schema.get("supported_exts") or []
allowed_exts = {
ext.lower().lstrip('.')
for ext in supported_exts
if isinstance(ext, str)
}
def matches_extension(file_rel: str) -> bool:
if not allowed_exts:
return True
if '.' not in file_rel:
return '' in allowed_exts
ext = file_rel.rsplit('.', 1)[-1].lower()
return ext in allowed_exts or f'.{ext}' in allowed_exts
adapter_instance, adapter_model, root, rel = await resolve_adapter_and_rel(req.path)
rel = rel.rstrip('/')
list_dir = getattr(adapter_instance, "list_dir", None)
if not callable(list_dir):
raise HTTPException(501, detail="Adapter does not implement list_dir")
def build_absolute_path(mount_path: str, rel_path: str) -> str:
rel_norm = rel_path.lstrip('/')
mount_norm = mount_path.rstrip('/')
if not mount_norm:
return '/' + rel_norm if rel_norm else '/'
return f"{mount_norm}/{rel_norm}" if rel_norm else mount_norm
def apply_suffix(path_str: str, suffix_str: str) -> str:
path_obj = Path(path_str)
name = path_obj.name
if not name:
return path_str
if '.' in name:
base, ext = name.rsplit('.', 1)
new_name = f"{base}{suffix_str}.{ext}"
else:
new_name = f"{name}{suffix_str}"
return str(path_obj.with_name(new_name))
scheduled_tasks: List[str] = []
stack: List[Tuple[str, int]] = [(rel, 0)]
page_size = 200
while stack:
current_rel, depth = stack.pop()
page = 1
while True:
entries, total = await list_dir(root, current_rel, page, page_size, "name", "asc")
entries = entries or []
if not entries and (total or 0) == 0:
break
for entry in entries:
name = entry.get("name")
if not name:
continue
child_rel = f"{current_rel}/{name}" if current_rel else name
if entry.get("is_dir"):
if req.max_depth is None or depth < req.max_depth:
stack.append((child_rel.rstrip('/'), depth + 1))
continue
if not matches_extension(child_rel):
continue
absolute_path = build_absolute_path(adapter_model.path, child_rel)
save_to = None
if produces_file and not overwrite and suffix:
save_to = apply_suffix(absolute_path, suffix)
task = await task_queue_service.add_task(
"process_file",
{
"path": absolute_path,
"processor_type": req.processor_type,
"config": req.config,
"save_to": save_to,
"overwrite": overwrite,
},
)
scheduled_tasks.append(task.id)
if total is None or page * page_size >= total:
break
page += 1
return success({
"task_ids": scheduled_tasks,
"scheduled": len(scheduled_tasks),
})
@router.get("/source/{processor_type}")
async def get_processor_source(
processor_type: str,
current_user: Annotated[User, Depends(get_current_active_user)],
):
module_path = get_module_path(processor_type)
if not module_path:
raise HTTPException(404, detail="Processor not found")
path_obj = Path(module_path)
if not path_obj.exists():
raise HTTPException(404, detail="Processor source not found")
try:
content = await run_in_threadpool(path_obj.read_text, encoding='utf-8')
except Exception as exc:
raise HTTPException(500, detail=f"Failed to read source: {exc}")
return success({"source": content, "module_path": str(path_obj)})
@router.put("/source/{processor_type}")
async def update_processor_source(
processor_type: str,
req: UpdateSourceRequest,
current_user: Annotated[User, Depends(get_current_active_user)],
):
module_path = get_module_path(processor_type)
if not module_path:
raise HTTPException(404, detail="Processor not found")
path_obj = Path(module_path)
if not path_obj.exists():
raise HTTPException(404, detail="Processor source not found")
try:
await run_in_threadpool(path_obj.write_text, req.source, encoding='utf-8')
except Exception as exc:
raise HTTPException(500, detail=f"Failed to write source: {exc}")
return success(True)
@router.post("/reload")
async def reload_processor_modules(
current_user: Annotated[User, Depends(get_current_active_user)],
):
errors = reload_processors()
if errors:
raise HTTPException(500, detail="; ".join(errors))
return success(True)

View File

@@ -1,217 +0,0 @@
from typing import List, Optional
from urllib.parse import quote
from fastapi import APIRouter, Depends, HTTPException, Request
from pydantic import BaseModel
from api.response import success
from services.auth import User, get_current_active_user
from services.share import share_service
from services.virtual_fs import stream_file, stat_file
from models.database import ShareLink, UserAccount
public_router = APIRouter(prefix="/api/s", tags=["Share - Public"])
router = APIRouter(prefix="/api/shares", tags=["Share - Management"])
class ShareCreate(BaseModel):
name: str
paths: List[str]
expires_in_days: Optional[int] = 7
access_type: str = "public"
password: Optional[str] = None
class ShareInfo(BaseModel):
id: int
token: str
name: str
paths: List[str]
created_at: str
expires_at: Optional[str] = None
access_type: str
@classmethod
def from_orm(cls, obj: ShareLink):
return cls(
id=obj.id,
token=obj.token,
name=obj.name,
paths=obj.paths,
created_at=obj.created_at.isoformat(),
expires_at=obj.expires_at.isoformat() if obj.expires_at else None,
access_type=obj.access_type,
)
class ShareInfoWithPassword(ShareInfo):
password: Optional[str] = None
# --- Management Routes ---
@router.post("", response_model=ShareInfoWithPassword)
async def create_share(
payload: ShareCreate,
current_user: User = Depends(get_current_active_user),
):
"""
创建一个新的分享链接。
"""
user_account = await UserAccount.get(id=current_user.id)
share = await share_service.create_share_link(
user=user_account,
name=payload.name,
paths=payload.paths,
expires_in_days=payload.expires_in_days,
access_type=payload.access_type,
password=payload.password,
)
share_info_base = ShareInfo.from_orm(share)
response_data = share_info_base.model_dump()
if payload.access_type == "password" and payload.password:
response_data['password'] = payload.password
return response_data
@router.get("", response_model=List[ShareInfo])
async def get_my_shares(current_user: User = Depends(get_current_active_user)):
"""
获取当前用户的所有分享链接。
"""
user_account = await UserAccount.get(id=current_user.id)
shares = await share_service.get_user_shares(user=user_account)
return [ShareInfo.from_orm(s) for s in shares]
@router.delete("/expired")
async def delete_expired_shares(
current_user: User = Depends(get_current_active_user),
):
"""
删除当前用户的所有已过期分享。
"""
user_account = await UserAccount.get(id=current_user.id)
deleted_count = await share_service.delete_expired_shares(user=user_account)
return success({"deleted_count": deleted_count})
@router.delete("/{share_id}")
async def delete_share(
share_id: int,
current_user: User = Depends(get_current_active_user),
):
"""
删除一个分享链接。
"""
await share_service.delete_share_link(user=current_user, share_id=share_id)
return success(msg="分享已取消")
# --- Public Routes ---
class SharePassword(BaseModel):
password: str
@public_router.post("/{token}/verify")
async def verify_password(token: str, payload: SharePassword):
"""
验证分享链接的密码。
"""
share = await share_service.get_share_by_token(token)
if share.access_type != "password":
raise HTTPException(status_code=400, detail="此分享不需要密码")
if not share_service._verify_password(payload.password, share.hashed_password):
raise HTTPException(status_code=403, detail="密码错误")
# 在这里可以考虑返回一个有时效性的token用于后续访问但为了简单起见
# 我们让前端在每次请求时都带上密码或一个会话标识。
# 简单起见,我们只返回成功状态。
return success(msg="验证成功")
@public_router.get("/{token}/ls")
async def list_share_content(token: str, path: str = "/", password: Optional[str] = None):
"""
列出分享链接中的文件和目录。
"""
share = await share_service.get_share_by_token(token)
if share.access_type == "password":
if not password:
raise HTTPException(status_code=401, detail="需要密码")
if not share_service._verify_password(password, share.hashed_password):
raise HTTPException(status_code=403, detail="密码错误")
content = await share_service.get_shared_item_details(share, path)
return success({
"path": path,
"entries": content.get("items", []),
"pagination": {
"total": content.get("total", 0),
"page": content.get("page", 1),
"page_size": content.get("page_size", 1),
"pages": content.get("pages", 1),
}
})
@public_router.get("/{token}")
async def get_share_info(token: str):
"""
获取分享链接的元数据信息。
"""
share = await share_service.get_share_by_token(token)
return success(ShareInfo.from_orm(share))
@public_router.get("/{token}/download")
async def download_shared_file(token: str, path: str, request: Request, password: Optional[str] = None):
"""
下载分享链接中的单个文件。
"""
if not path or path == "/" or ".." in path.split('/'):
raise HTTPException(status_code=400, detail="无效的文件路径")
share = await share_service.get_share_by_token(token)
if share.access_type == "password":
if not password:
raise HTTPException(status_code=401, detail="需要密码")
if not share_service._verify_password(password, share.hashed_password):
raise HTTPException(status_code=403, detail="密码错误")
base_shared_path = share.paths[0]
# 判断分享的是文件还是目录
is_dir = False
try:
stat = await stat_file(base_shared_path)
if stat and stat.get("is_dir"):
is_dir = True
except HTTPException as e:
if "Path is a directory" in str(e.detail) or "Not a file" in str(e.detail):
is_dir = True
else:
# The shared path itself doesn't exist, which is an issue.
raise HTTPException(status_code=404, detail="分享的源文件不存在")
if is_dir:
# 目录分享:拼接路径
full_virtual_path = f"{base_shared_path.rstrip('/')}/{path.lstrip('/')}"
if not full_virtual_path.startswith(base_shared_path):
raise HTTPException(status_code=403, detail="无权访问此路径")
else:
# 文件分享:路径应为分享的根路径
shared_filename = base_shared_path.split('/')[-1]
request_filename = path.lstrip('/')
if shared_filename != request_filename:
raise HTTPException(status_code=403, detail="无权访问此路径")
full_virtual_path = base_shared_path
range_header = request.headers.get("Range")
response = await stream_file(full_virtual_path, range_header)
# 设置 Content-Disposition 头来强制下载
filename = full_virtual_path.split('/')[-1]
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{quote(filename)}"
return response

View File

@@ -1,141 +0,0 @@
from fastapi import APIRouter, Depends, HTTPException
from typing import Annotated
from models.database import AutomationTask
from schemas.tasks import (
AutomationTaskCreate,
AutomationTaskUpdate,
TaskQueueSettings,
TaskQueueSettingsResponse,
)
from api.response import success
from services.auth import get_current_active_user, User
from services.logging import LogService
from services.task_queue import task_queue_service
from services.config import ConfigCenter
router = APIRouter(
prefix="/api/tasks",
tags=["Tasks"],
dependencies=[Depends(get_current_active_user)],
responses={404: {"description": "Not found"}},
)
@router.get("/queue")
async def get_task_queue_status(
current_user: Annotated[User, Depends(get_current_active_user)],
):
tasks = task_queue_service.get_all_tasks()
return success([task.dict() for task in tasks])
@router.get("/queue/settings")
async def get_task_queue_settings(
current_user: Annotated[User, Depends(get_current_active_user)],
):
payload = TaskQueueSettingsResponse(
concurrency=task_queue_service.get_concurrency(),
active_workers=task_queue_service.get_active_worker_count(),
)
return success(payload.model_dump())
@router.post("/queue/settings")
async def update_task_queue_settings(
settings: TaskQueueSettings,
current_user: Annotated[User, Depends(get_current_active_user)],
):
await task_queue_service.set_concurrency(settings.concurrency)
await ConfigCenter.set("TASK_QUEUE_CONCURRENCY", str(task_queue_service.get_concurrency()))
await LogService.action(
"route:tasks",
"Updated task queue settings",
details={"concurrency": settings.concurrency},
user_id=getattr(current_user, "id", None),
)
payload = TaskQueueSettingsResponse(
concurrency=task_queue_service.get_concurrency(),
active_workers=task_queue_service.get_active_worker_count(),
)
return success(payload.model_dump())
@router.get("/queue/{task_id}")
async def get_task_status(
task_id: str,
current_user: Annotated[User, Depends(get_current_active_user)],
):
task = task_queue_service.get_task(task_id)
if not task:
raise HTTPException(status_code=404, detail="Task not found")
return success(task.dict())
@router.post("/")
async def create_task(
task_in: AutomationTaskCreate,
user: User = Depends(get_current_active_user)
):
task = await AutomationTask.create(**task_in.model_dump())
await LogService.action(
"route:tasks",
f"Created task {task.name}",
details=task_in.model_dump(),
user_id=user.id if hasattr(user, "id") else None,
)
return success(task)
@router.get("/{task_id}")
async def get_task(task_id: int):
task = await AutomationTask.get_or_none(id=task_id)
if not task:
raise HTTPException(
status_code=404, detail=f"Task {task_id} not found")
return success(task)
@router.get("/")
async def list_tasks():
tasks = await AutomationTask.all()
return success(tasks)
@router.put("/{task_id}")
async def update_task(
current_user: Annotated[User, Depends(get_current_active_user)],
task_id: int, task_in: AutomationTaskUpdate):
task = await AutomationTask.get_or_none(id=task_id)
if not task:
raise HTTPException(
status_code=404, detail=f"Task {task_id} not found")
update_data = task_in.model_dump(exclude_unset=True)
for key, value in update_data.items():
setattr(task, key, value)
await task.save()
await LogService.action(
"route:tasks",
f"Updated task {task.name}",
details=task_in.model_dump(),
user_id=current_user.id,
)
return success(task)
@router.delete("/{task_id}")
async def delete_task(
task_id: int,
user: User = Depends(get_current_active_user)
):
deleted_count = await AutomationTask.filter(id=task_id).delete()
if not deleted_count:
raise HTTPException(
status_code=404, detail=f"Task {task_id} not found")
await LogService.action(
"route:tasks",
f"Deleted task {task_id}",
details={"task_id": task_id},
user_id=user.id if hasattr(user, "id") else None,
)
return success(msg="Task deleted")

View File

@@ -1,91 +0,0 @@
from typing import Any, Dict
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field
from services.auth import get_current_active_user
from models.database import UserAccount
from services.vector_db import (
VectorDBService,
VectorDBConfigManager,
list_providers,
get_provider_entry,
)
from services.vector_db.providers import get_provider_class
from api.response import success
router = APIRouter(prefix="/api/vector-db", tags=["vector-db"])
class VectorDBConfigPayload(BaseModel):
type: str = Field(..., description="向量数据库提供者类型")
config: Dict[str, Any] = Field(default_factory=dict, description="提供者配置参数")
@router.post("/clear-all", summary="清空向量数据库")
async def clear_vector_db(user: UserAccount = Depends(get_current_active_user)):
try:
service = VectorDBService()
await service.clear_all_data()
return success(msg="向量数据库已清空")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/stats", summary="获取向量数据库统计")
async def get_vector_db_stats(user: UserAccount = Depends(get_current_active_user)):
try:
service = VectorDBService()
data = await service.get_all_stats()
return success(data=data)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/providers", summary="列出可用向量数据库提供者")
async def list_vector_providers(user: UserAccount = Depends(get_current_active_user)):
return success(list_providers())
@router.get("/config", summary="获取当前向量数据库配置")
async def get_vector_db_config(user: UserAccount = Depends(get_current_active_user)):
service = VectorDBService()
data = await service.current_provider()
return success(data)
@router.post("/config", summary="更新向量数据库配置")
async def update_vector_db_config(payload: VectorDBConfigPayload, user: UserAccount = Depends(get_current_active_user)):
entry = get_provider_entry(payload.type)
if not entry:
raise HTTPException(
status_code=400, detail=f"未知的向量数据库类型: {payload.type}")
if not entry.get("enabled", True):
raise HTTPException(status_code=400, detail="该向量数据库类型暂不可用")
provider_cls = get_provider_class(payload.type)
if not provider_cls:
raise HTTPException(
status_code=400, detail=f"未找到类型 {payload.type} 对应的实现")
# 先尝试建立连接,确保配置有效
test_provider = provider_cls(payload.config)
try:
await test_provider.initialize()
except Exception as exc:
raise HTTPException(status_code=400, detail=str(exc))
finally:
client = getattr(test_provider, "client", None)
close_fn = getattr(client, "close", None)
if callable(close_fn):
try:
close_fn()
except Exception:
pass
await VectorDBConfigManager.save_config(payload.type, payload.config)
service = VectorDBService()
await service.reload()
config_data = await service.current_provider()
stats = await service.get_all_stats()
return success({"config": config_data, "stats": stats})

View File

@@ -1,376 +0,0 @@
from fastapi import APIRouter, UploadFile, File, HTTPException, Response, Query, Request, Depends
import mimetypes
import re
from typing import Annotated
from services.auth import get_current_active_user, User
from services.virtual_fs import (
list_virtual_dir,
read_file,
write_file,
make_dir,
delete_path,
move_path,
resolve_adapter_and_rel,
stream_file,
generate_temp_link_token,
verify_temp_link_token,
maybe_redirect_download,
)
from services.thumbnail import is_image_filename, get_or_create_thumb, is_raw_filename, is_video_filename
from schemas import MkdirRequest, MoveRequest
from api.response import success
from services.config import ConfigCenter
router = APIRouter(prefix='/api/fs', tags=["virtual-fs"])
@router.get("/file/{full_path:path}")
async def get_file(
full_path: str,
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)]
):
full_path = '/' + full_path if not full_path.startswith('/') else full_path
if is_raw_filename(full_path):
import rawpy
from PIL import Image
import io
try:
raw_data = await read_file(full_path)
with rawpy.imread(io.BytesIO(raw_data)) as raw:
rgb = raw.postprocess(use_camera_wb=True, output_bps=8)
im = Image.fromarray(rgb)
buf = io.BytesIO()
im.save(buf, 'JPEG', quality=90)
content = buf.getvalue()
return Response(content=content, media_type='image/jpeg')
except FileNotFoundError:
raise HTTPException(404, detail="File not found")
except Exception as e:
raise HTTPException(500, detail=f"RAW file processing failed: {e}")
adapter_instance, adapter_model, root, rel = await resolve_adapter_and_rel(full_path)
redirect_response = await maybe_redirect_download(adapter_instance, adapter_model, root, rel)
if redirect_response is not None:
return redirect_response
try:
content = await read_file(full_path)
except FileNotFoundError:
raise HTTPException(404, detail="File not found")
if not isinstance(content, (bytes, bytearray)):
return Response(content=content, media_type="application/octet-stream")
content_length = len(content)
content_type = mimetypes.guess_type(
full_path)[0] or "application/octet-stream"
range_header = request.headers.get('Range')
if range_header:
range_match = re.match(r'bytes=(\d+)-(\d*)', range_header)
if range_match:
start = int(range_match.group(1))
end = int(range_match.group(2)) if range_match.group(
2) else content_length - 1
start = max(0, min(start, content_length - 1))
end = max(start, min(end, content_length - 1))
chunk = content[start:end + 1]
chunk_size = len(chunk)
headers = {
'Content-Range': f'bytes {start}-{end}/{content_length}',
'Accept-Ranges': 'bytes',
'Content-Length': str(chunk_size),
'Content-Type': content_type,
}
return Response(
content=chunk,
status_code=206,
headers=headers
)
headers = {
'Accept-Ranges': 'bytes',
'Content-Length': str(content_length),
'Content-Type': content_type,
}
if content_type.startswith('video/'):
headers['Cache-Control'] = 'public, max-age=3600'
return Response(content=content, headers=headers)
@router.get("/thumb/{full_path:path}")
async def get_thumb(
full_path: str,
w: int = Query(256, ge=8, le=1024),
h: int = Query(256, ge=8, le=1024),
fit: str = Query("cover"),
):
full_path = '/' + full_path if not full_path.startswith('/') else full_path
if fit not in ("cover", "contain"):
raise HTTPException(400, detail="fit must be cover|contain")
adapter, mount, root, rel = await resolve_adapter_and_rel(full_path)
if not rel or rel.endswith('/'):
raise HTTPException(400, detail="Not a file")
if not (is_image_filename(rel) or is_video_filename(rel)):
raise HTTPException(404, detail="Not an image or video")
# type: ignore
data, mime, key = await get_or_create_thumb(adapter, mount.id, root, rel, w, h, fit)
headers = {
'Cache-Control': 'public, max-age=3600',
'ETag': key,
}
return Response(content=data, media_type=mime, headers=headers)
@router.get("/stream/{full_path:path}")
async def stream_endpoint(
full_path: str,
request: Request,
):
"""支持 Range 的视频/大文件流式读取,优先使用底层适配器 Range 能力。"""
full_path = '/' + full_path if not full_path.startswith('/') else full_path
range_header = request.headers.get('Range')
try:
return await stream_file(full_path, range_header)
except HTTPException:
raise
except FileNotFoundError:
raise HTTPException(404, detail="File not found")
except Exception as e:
raise HTTPException(500, detail=f"Stream error: {e}")
@router.get("/temp-link/{full_path:path}")
async def get_temp_link(
full_path: str,
current_user: Annotated[User, Depends(get_current_active_user)],
expires_in: int = Query(3600, description="有效时间(秒), 0或负数表示永久")
):
"""获取文件的临时公开访问令牌"""
full_path = '/' + full_path if not full_path.startswith('/') else full_path
token = await generate_temp_link_token(full_path, expires_in=expires_in)
file_domain = await ConfigCenter.get("FILE_DOMAIN")
if file_domain:
file_domain = file_domain.rstrip('/')
url = f"{file_domain}/api/fs/public/{token}"
else:
url = f"/api/fs/public/{token}"
return success({"token": token, "path": full_path, "url": url})
@router.get("/public/{token}")
async def access_public_file(
token: str,
request: Request,
):
"""通过令牌公开访问文件,支持 Range 请求"""
try:
path = await verify_temp_link_token(token)
except HTTPException as e:
raise e
range_header = request.headers.get('Range')
try:
return await stream_file(path, range_header)
except FileNotFoundError:
raise HTTPException(404, detail="File not found via token")
except Exception as e:
raise HTTPException(500, detail=f"File access error: {e}")
@router.get("/stat/{full_path:path}")
async def get_file_stat(
full_path: str,
current_user: Annotated[User, Depends(get_current_active_user)]
):
full_path = '/' + full_path if not full_path.startswith('/') else full_path
from services.virtual_fs import stat_file
stat = await stat_file(full_path)
return success(stat)
@router.post("/file/{full_path:path}")
async def put_file(
current_user: Annotated[User, Depends(get_current_active_user)],
full_path: str,
file: UploadFile = File(...)
):
full_path = '/' + full_path if not full_path.startswith('/') else full_path
data = await file.read()
await write_file(full_path, data)
return success({"written": True, "path": full_path, "size": len(data)})
@router.post("/mkdir")
async def api_mkdir(
current_user: Annotated[User, Depends(get_current_active_user)],
body: MkdirRequest
):
path = body.path if body.path.startswith('/') else '/' + body.path
if not path or path == '/':
raise HTTPException(400, detail="Invalid path")
await make_dir(path)
return success({"created": True, "path": path})
@router.post("/move")
async def api_move(
current_user: Annotated[User, Depends(get_current_active_user)],
body: MoveRequest,
overwrite: bool = Query(False, description="是否允许覆盖已存在目标"),
):
src = body.src if body.src.startswith('/') else '/' + body.src
dst = body.dst if body.dst.startswith('/') else '/' + body.dst
debug_info = await move_path(src, dst, overwrite=overwrite, return_debug=True, allow_cross=True)
queued = bool(debug_info.get("queued"))
response = {
"moved": not queued,
"queued": queued,
"src": src,
"dst": dst,
"overwrite": overwrite,
}
if queued:
response["task_id"] = debug_info.get("task_id")
response["task_name"] = debug_info.get("task_name")
return success(response)
@router.post("/rename")
async def api_rename(
current_user: Annotated[User, Depends(get_current_active_user)],
body: MoveRequest,
overwrite: bool = Query(False, description="是否允许覆盖已存在目标")
):
src = body.src if body.src.startswith('/') else '/' + body.src
dst = body.dst if body.dst.startswith('/') else '/' + body.dst
from services.virtual_fs import rename_path
await rename_path(src, dst, overwrite=overwrite, return_debug=False)
return success({
"renamed": True,
"src": src,
"dst": dst,
"overwrite": overwrite,
})
@router.post("/copy")
async def api_copy(
current_user: Annotated[User, Depends(get_current_active_user)],
body: MoveRequest,
overwrite: bool = Query(False, description="是否覆盖已存在目标"),
):
from services.virtual_fs import copy_path
src = body.src if body.src.startswith('/') else '/' + body.src
dst = body.dst if body.dst.startswith('/') else '/' + body.dst
debug_info = await copy_path(src, dst, overwrite=overwrite, return_debug=True, allow_cross=True)
queued = bool(debug_info.get("queued"))
response = {
"copied": not queued,
"queued": queued,
"src": src,
"dst": dst,
"overwrite": overwrite,
}
if queued:
response["task_id"] = debug_info.get("task_id")
response["task_name"] = debug_info.get("task_name")
return success(response)
@router.post("/upload/{full_path:path}")
async def upload_stream(
current_user: Annotated[User, Depends(get_current_active_user)],
full_path: str,
file: UploadFile = File(...),
overwrite: bool = Query(True, description="是否覆盖已存在文件"),
chunk_size: int = Query(1024 * 1024, ge=8 * 1024,
le=8 * 1024 * 1024, description="单次读取块大小")
):
full_path = '/' + full_path if not full_path.startswith('/') else full_path
if full_path.endswith('/'):
raise HTTPException(400, detail="Path must be a file")
from services.virtual_fs import write_file_stream, resolve_adapter_and_rel
adapter, _m, root, rel = await resolve_adapter_and_rel(full_path)
exists_func = getattr(adapter, "exists", None)
if not overwrite and callable(exists_func):
try:
if await exists_func(root, rel):
raise HTTPException(409, detail="Destination exists")
except HTTPException:
raise
except Exception:
pass
async def gen():
while True:
chunk = await file.read(chunk_size)
if not chunk:
break
yield chunk
size = await write_file_stream(full_path, gen(), overwrite=overwrite)
return success({"uploaded": True, "path": full_path, "size": size, "overwrite": overwrite})
@router.get("/{full_path:path}")
async def browse_fs(
current_user: Annotated[User, Depends(get_current_active_user)],
full_path: str,
page_num: int = Query(1, alias="page", ge=1, description="页码"),
page_size: int = Query(50, ge=1, le=500, description="每页条数"),
sort_by: str = Query("name", description="按字段排序: name, size, mtime"),
sort_order: str = Query("asc", description="排序顺序: asc, desc")
):
full_path = '/' + full_path if not full_path.startswith('/') else full_path
result = await list_virtual_dir(full_path, page_num, page_size, sort_by, sort_order)
return success({
"path": full_path,
"entries": result["items"],
"pagination": {
"total": result["total"],
"page": result["page"],
"page_size": result["page_size"],
"pages": result["pages"]
}
})
@router.delete("/{full_path:path}")
async def api_delete(
current_user: Annotated[User, Depends(get_current_active_user)],
full_path: str
):
full_path = '/' + full_path if not full_path.startswith('/') else full_path
await delete_path(full_path)
return success({"deleted": True, "path": full_path})
@router.get("/")
async def root_listing(
current_user: Annotated[User, Depends(get_current_active_user)],
page_num: int = Query(1, alias="page", ge=1, description="页码"),
page_size: int = Query(50, ge=1, le=500, description="每页条数"),
sort_by: str = Query("name", description="按字段排序: name, size, mtime"),
sort_order: str = Query("asc", description="排序顺序: asc, desc")
):
result = await list_virtual_dir("/", page_num, page_size, sort_by, sort_order)
return success({
"path": "/",
"entries": result["items"],
"pagination": {
"total": result["total"],
"page": result["page"],
"page_size": result["page_size"],
"pages": result["pages"]
}
})

View File

@@ -1,6 +1,6 @@
from tortoise import Tortoise
from services.adapters.registry import runtime_registry
from domain.adapters.registry import runtime_registry
TORTOISE_ORM = {
"connections": {"default": "sqlite://data/db/db.sqlite3"},

View File

@@ -0,0 +1 @@

85
domain/adapters/api.py Normal file
View File

@@ -0,0 +1,85 @@
from typing import Annotated
from fastapi import APIRouter, Depends, Request
from api.response import success
from domain.audit import AuditAction, audit
from domain.adapters.service import AdapterService
from domain.adapters.types import AdapterCreate
from domain.auth.service import get_current_active_user
from domain.auth.types import User
router = APIRouter(prefix="/api/adapters", tags=["adapters"])
@router.post("")
@audit(
action=AuditAction.CREATE,
description="创建存储适配器",
body_fields=["name", "type", "path", "sub_path", "enabled"],
)
async def create_adapter(
request: Request,
data: AdapterCreate,
current_user: Annotated[User, Depends(get_current_active_user)]
):
adapter = await AdapterService.create_adapter(data, current_user)
return success(adapter)
@router.get("")
@audit(action=AuditAction.READ, description="获取适配器列表")
async def list_adapters(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)]
):
adapters = await AdapterService.list_adapters()
return success(adapters)
@router.get("/available")
@audit(action=AuditAction.READ, description="获取可用适配器类型")
async def available_adapter_types(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)]
):
data = await AdapterService.available_adapter_types()
return success(data)
@router.get("/{adapter_id}")
@audit(action=AuditAction.READ, description="获取适配器详情")
async def get_adapter(
request: Request,
adapter_id: int,
current_user: Annotated[User, Depends(get_current_active_user)]
):
adapter = await AdapterService.get_adapter(adapter_id)
return success(adapter)
@router.put("/{adapter_id}")
@audit(
action=AuditAction.UPDATE,
description="更新存储适配器",
body_fields=["name", "type", "path", "sub_path", "enabled"],
)
async def update_adapter(
request: Request,
adapter_id: int,
data: AdapterCreate,
current_user: Annotated[User, Depends(get_current_active_user)]
):
adapter = await AdapterService.update_adapter(adapter_id, data, current_user)
return success(adapter)
@router.delete("/{adapter_id}")
@audit(action=AuditAction.DELETE, description="删除存储适配器")
async def delete_adapter(
request: Request,
adapter_id: int,
current_user: Annotated[User, Depends(get_current_active_user)]
):
result = await AdapterService.delete_adapter(adapter_id, current_user)
return success(result)

View File

@@ -0,0 +1,3 @@
from .base import BaseAdapter
__all__ = ["BaseAdapter"]

View File

@@ -10,7 +10,6 @@ from ftplib import FTP, error_perm
import mimetypes
from models import StorageAdapter
from services.logging import LogService
def _join_remote(root: str, rel: str) -> str:
@@ -240,11 +239,6 @@ class FTPAdapter:
pass
await asyncio.to_thread(_do_write)
await LogService.info(
"adapter:ftp",
f"Wrote file to {rel}",
details={"adapter_id": self.record.id, "path": path, "size": len(data)},
)
async def write_file_stream(self, root: str, rel: str, data_iter: AsyncIterator[bytes]):
# KISS: 聚合后一次性写入
@@ -276,7 +270,6 @@ class FTPAdapter:
pass
await asyncio.to_thread(_do_mkdir)
await LogService.info("adapter:ftp", f"Created directory {rel}", details={"adapter_id": self.record.id, "path": path})
async def delete(self, root: str, rel: str):
path = _join_remote(root, rel)
@@ -340,7 +333,6 @@ class FTPAdapter:
pass
await asyncio.to_thread(_do_delete)
await LogService.info("adapter:ftp", f"Deleted {rel}", details={"adapter_id": self.record.id, "path": path})
async def move(self, root: str, src_rel: str, dst_rel: str):
src = _join_remote(root, src_rel)
@@ -367,7 +359,6 @@ class FTPAdapter:
pass
await asyncio.to_thread(_do_move)
await LogService.info("adapter:ftp", f"Moved {src_rel} to {dst_rel}", details={"adapter_id": self.record.id, "src": src, "dst": dst})
async def rename(self, root: str, src_rel: str, dst_rel: str):
await self.move(root, src_rel, dst_rel)
@@ -402,10 +393,6 @@ class FTPAdapter:
child_src = f"{src_rel.rstrip('/')}/{ent['name']}"
child_dst = f"{dst_rel.rstrip('/')}/{ent['name']}"
await self.copy(root, child_src, child_dst, overwrite)
await LogService.info(
"adapter:ftp", f"Copied directory {src_rel} to {dst_rel}",
details={"adapter_id": self.record.id, "src": src, "dst": dst}
)
return
# file
@@ -418,7 +405,6 @@ class FTPAdapter:
except FileNotFoundError:
pass
await self.write_file(root, dst_rel, data)
await LogService.info("adapter:ftp", f"Copied {src_rel} to {dst_rel}", details={"adapter_id": self.record.id, "src": src, "dst": dst})
async def stat_file(self, root: str, rel: str):
path = _join_remote(root, rel)

View File

@@ -10,7 +10,6 @@ import mimetypes
from fastapi import HTTPException
from fastapi.responses import StreamingResponse, Response
from models import StorageAdapter
from services.logging import LogService
def _safe_join(root: str, rel: str) -> Path:
@@ -115,11 +114,6 @@ class LocalAdapter:
await asyncio.to_thread(fp.write_bytes, data)
if not pre_exists:
await asyncio.to_thread(_apply_mode, fp, DEFAULT_FILE_MODE)
await LogService.info(
"adapter:local",
f"Wrote file to {rel}",
details={"adapter_id": self.record.id, "path": str(fp), "size": len(data)},
)
async def write_file_stream(self, root: str, rel: str, data_iter: AsyncIterator[bytes]):
fp = _safe_join(root, rel)
@@ -140,21 +134,11 @@ class LocalAdapter:
await asyncio.to_thread(f.close)
if not pre_exists:
await asyncio.to_thread(_apply_mode, fp, DEFAULT_FILE_MODE)
await LogService.info(
"adapter:local",
f"Wrote file stream to {rel}",
details={"adapter_id": self.record.id, "path": str(fp), "size": size},
)
return size
async def mkdir(self, root: str, rel: str):
fp = _safe_join(root, rel)
await asyncio.to_thread(os.makedirs, fp, mode=DEFAULT_DIR_MODE, exist_ok=True)
await LogService.info(
"adapter:local",
f"Created directory {rel}",
details={"adapter_id": self.record.id, "path": str(fp)},
)
async def delete(self, root: str, rel: str):
fp = _safe_join(root, rel)
@@ -164,11 +148,6 @@ class LocalAdapter:
await asyncio.to_thread(shutil.rmtree, fp)
else:
await asyncio.to_thread(fp.unlink)
await LogService.info(
"adapter:local",
f"Deleted {rel}",
details={"adapter_id": self.record.id, "path": str(fp)},
)
async def stat_path(self, root: str, rel: str):
"""新增: 返回路径状态调试信息"""
@@ -203,15 +182,6 @@ class LocalAdapter:
except OSError:
shutil.move(str(src), str(dst))
await asyncio.to_thread(_do_move)
await LogService.info(
"adapter:local",
f"Moved {src_rel} to {dst_rel}",
details={
"adapter_id": self.record.id,
"src": str(src),
"dst": str(dst),
},
)
async def rename(self, root: str, src_rel: str, dst_rel: str):
src = _safe_join(root, src_rel)
@@ -227,15 +197,6 @@ class LocalAdapter:
except OSError:
os.replace(src, dst)
await asyncio.to_thread(_do_rename)
await LogService.info(
"adapter:local",
f"Renamed {src_rel} to {dst_rel}",
details={
"adapter_id": self.record.id,
"src": str(src),
"dst": str(dst),
},
)
async def copy(self, root: str, src_rel: str, dst_rel: str, overwrite: bool = False):
src = _safe_join(root, src_rel)
@@ -258,15 +219,6 @@ class LocalAdapter:
else:
shutil.copy2(src, dst)
await asyncio.to_thread(_do)
await LogService.info(
"adapter:local",
f"Copied {src_rel} to {dst_rel}",
details={
"adapter_id": self.record.id,
"src": str(src),
"dst": str(dst),
},
)
async def stream_file(self, root: str, rel: str, range_header: str | None):
fp = _safe_join(root, rel)

View File

@@ -452,7 +452,7 @@ CONFIG_SCHEMA = [
{"key": "client_secret", "label": "Client Secret",
"type": "password", "required": True},
{"key": "refresh_token", "label": "Refresh Token", "type": "password",
"required": True, "help_text": "可以通过运行 'python -m services.adapters.onedrive' 获取"},
"required": True, "help_text": "可以通过运行 'python -m domain.adapters.providers.onedrive' 获取"},
{"key": "root", "label": "根目录 (Root Path)", "type": "string",
"required": False, "placeholder": "默认为根目录 /"},
{"key": "enable_direct_download_307", "label": "Enable 307 redirect download", "type": "boolean", "default": False},

View File

@@ -10,7 +10,6 @@ from botocore.exceptions import ClientError
from fastapi import HTTPException
from fastapi.responses import StreamingResponse
from models import StorageAdapter
from services.logging import LogService
class S3Adapter:
@@ -127,11 +126,6 @@ class S3Adapter:
key = self._get_s3_key(rel)
async with self._get_client() as s3:
await s3.put_object(Bucket=self.bucket_name, Key=key, Body=data)
await LogService.info(
"adapter:s3", f"Wrote file to {rel}",
details={"adapter_id": self.record.id,
"bucket": self.bucket_name, "key": key, "size": len(data)}
)
async def write_file_stream(self, root: str, rel: str, data_iter: AsyncIterator[bytes]):
key = self._get_s3_key(rel)
@@ -193,10 +187,6 @@ class S3Adapter:
)
raise IOError(f"S3 stream upload failed: {e}") from e
await LogService.info(
"adapter:s3", f"Wrote file stream to {rel}",
details={"adapter_id": self.record.id, "bucket": self.bucket_name, "key": key, "size": total_size}
)
return total_size
async def mkdir(self, root: str, rel: str):
@@ -205,11 +195,6 @@ class S3Adapter:
key += "/"
async with self._get_client() as s3:
await s3.put_object(Bucket=self.bucket_name, Key=key, Body=b"")
await LogService.info(
"adapter:s3", f"Created directory {rel}",
details={"adapter_id": self.record.id,
"bucket": self.bucket_name, "key": key}
)
async def delete(self, root: str, rel: str):
key = self._get_s3_key(rel)
@@ -237,20 +222,9 @@ class S3Adapter:
else:
await s3.delete_object(Bucket=self.bucket_name, Key=key)
await LogService.info(
"adapter:s3", f"Deleted {rel}",
details={"adapter_id": self.record.id,
"bucket": self.bucket_name, "key": key}
)
async def move(self, root: str, src_rel: str, dst_rel: str):
await self.copy(root, src_rel, dst_rel, overwrite=True)
await self.delete(root, src_rel)
await LogService.info(
"adapter:s3", f"Moved {src_rel} to {dst_rel}",
details={"adapter_id": self.record.id, "bucket": self.bucket_name,
"src_key": self._get_s3_key(src_rel), "dst_key": self._get_s3_key(dst_rel)}
)
async def rename(self, root: str, src_rel: str, dst_rel: str):
await self.move(root, src_rel, dst_rel)
@@ -270,11 +244,6 @@ class S3Adapter:
copy_source = {"Bucket": self.bucket_name, "Key": src_key}
await s3.copy_object(CopySource=copy_source, Bucket=self.bucket_name, Key=dst_key)
await LogService.info(
"adapter:s3", f"Copied {src_rel} to {dst_rel}",
details={"adapter_id": self.record.id, "bucket": self.bucket_name,
"src_key": src_key, "dst_key": dst_key}
)
async def stat_file(self, root: str, rel: str):
key = self._get_s3_key(rel)
@@ -353,8 +322,7 @@ class S3Adapter:
while chunk := await body.read(65536):
yield chunk
except Exception as e:
LogService.error(
"adapter:s3", f"Error streaming file {key}: {e}")
raise
return StreamingResponse(iterator(), status_code=status, headers=headers, media_type=content_type)

View File

@@ -10,7 +10,6 @@ from fastapi.responses import StreamingResponse
import paramiko
from models import StorageAdapter
from services.logging import LogService
def _join_remote(root: str, rel: str) -> str:
@@ -159,7 +158,6 @@ class SFTPAdapter:
pass
await asyncio.to_thread(_do_write)
await LogService.info("adapter:sftp", f"Wrote file to {rel}", details={"adapter_id": self.record.id, "path": path, "size": len(data)})
async def write_file_stream(self, root: str, rel: str, data_iter: AsyncIterator[bytes]):
buf = bytearray()
@@ -190,7 +188,6 @@ class SFTPAdapter:
pass
await asyncio.to_thread(_do_mkdir)
await LogService.info("adapter:sftp", f"Created directory {rel}", details={"adapter_id": self.record.id, "path": path})
async def delete(self, root: str, rel: str):
path = _join_remote(root, rel)
@@ -228,7 +225,6 @@ class SFTPAdapter:
pass
await asyncio.to_thread(_do_delete)
await LogService.info("adapter:sftp", f"Deleted {rel}", details={"adapter_id": self.record.id, "path": path})
async def move(self, root: str, src_rel: str, dst_rel: str):
src = _join_remote(root, src_rel)
@@ -255,7 +251,6 @@ class SFTPAdapter:
pass
await asyncio.to_thread(_do_move)
await LogService.info("adapter:sftp", f"Moved {src_rel} to {dst_rel}", details={"adapter_id": self.record.id, "src": src, "dst": dst})
async def rename(self, root: str, src_rel: str, dst_rel: str):
await self.move(root, src_rel, dst_rel)
@@ -283,7 +278,6 @@ class SFTPAdapter:
child_src = f"{src_rel.rstrip('/')}/{ent['name']}"
child_dst = f"{dst_rel.rstrip('/')}/{ent['name']}"
await self.copy(root, child_src, child_dst, overwrite)
await LogService.info("adapter:sftp", f"Copied directory {src_rel} to {dst_rel}", details={"adapter_id": self.record.id, "src": src, "dst": dst})
return
# file copy
@@ -295,7 +289,6 @@ class SFTPAdapter:
except FileNotFoundError:
pass
await self.write_file(root, dst_rel, data)
await LogService.info("adapter:sftp", f"Copied {src_rel} to {dst_rel}", details={"adapter_id": self.record.id, "src": src, "dst": dst})
async def stat_file(self, root: str, rel: str):
path = _join_remote(root, rel)

View File

@@ -9,7 +9,6 @@ import mimetypes
import logging
from fastapi import HTTPException
from fastapi.responses import StreamingResponse, Response
from services.logging import LogService
NS = {"d": "DAV:"}
@@ -148,15 +147,6 @@ class WebDAVAdapter:
async with self._client() as client:
resp = await client.put(url, content=data)
resp.raise_for_status()
await LogService.info(
"adapter:webdav",
f"Wrote file to {rel}",
details={
"adapter_id": self.record.id,
"url": url,
"size": len(data),
},
)
async def mkdir(self, root: str, rel: str):
url = self._build_url(rel.rstrip('/') + '/')
@@ -164,11 +154,6 @@ class WebDAVAdapter:
resp = await client.request("MKCOL", url)
if resp.status_code not in (201, 405):
resp.raise_for_status()
await LogService.info(
"adapter:webdav",
f"Created directory {rel}",
details={"adapter_id": self.record.id, "url": url},
)
async def delete(self, root: str, rel: str):
url = self._build_url(rel)
@@ -176,11 +161,6 @@ class WebDAVAdapter:
resp = await client.delete(url)
if resp.status_code not in (204, 200, 404):
resp.raise_for_status()
await LogService.info(
"adapter:webdav",
f"Deleted {rel}",
details={"adapter_id": self.record.id, "url": url},
)
async def move(self, root: str, src_rel: str, dst_rel: str):
src_url = self._build_url(src_rel)
@@ -188,15 +168,6 @@ class WebDAVAdapter:
async with self._client() as client:
resp = await client.request("MOVE", src_url, headers={"Destination": dst_url})
resp.raise_for_status()
await LogService.info(
"adapter:webdav",
f"Moved {src_rel} to {dst_rel}",
details={
"adapter_id": self.record.id,
"src_url": src_url,
"dst_url": dst_url,
},
)
async def rename(self, root: str, src_rel: str, dst_rel: str):
src_url = self._build_url(src_rel)
@@ -204,15 +175,6 @@ class WebDAVAdapter:
async with self._client() as client:
resp = await client.request("MOVE", src_url, headers={"Destination": dst_url})
resp.raise_for_status()
await LogService.info(
"adapter:webdav",
f"Renamed {src_rel} to {dst_rel}",
details={
"adapter_id": self.record.id,
"src_url": src_url,
"dst_url": dst_url,
},
)
async def get_file_size(self, root: str, rel: str) -> int:
"""获取文件大小"""
@@ -518,15 +480,6 @@ class WebDAVAdapter:
if resp.status_code == 404:
raise FileNotFoundError(src_rel)
resp.raise_for_status()
await LogService.info(
"adapter:webdav",
f"Copied {src_rel} to {dst_rel}",
details={
"adapter_id": self.record.id,
"src_url": src_url,
"dst_url": dst_url,
},
)
ADAPTER_TYPE = "webdav"
CONFIG_SCHEMA = [

View File

@@ -1,12 +1,12 @@
from typing import Dict, Callable
import pkgutil
import inspect
import pkgutil
from importlib import import_module
from typing import Callable, Dict
from .base import BaseAdapter
from models import StorageAdapter
from domain.adapters.providers.base import BaseAdapter
AdapterFactory = Callable[[StorageAdapter], object]
AdapterFactory = Callable[[StorageAdapter], BaseAdapter]
TYPE_MAP: Dict[str, AdapterFactory] = {}
CONFIG_SCHEMAS: Dict[str, list] = {}
@@ -20,8 +20,9 @@ def normalize_adapter_type(value: str | None) -> str | None:
def discover_adapters():
"""扫描 services.adapters 包, 自动注册适配器类型、工厂与配置 schema。"""
from .. import adapters as adapters_pkg
"""扫描 domain.adapters.providers 包, 自动注册适配器类型、工厂与配置 schema。"""
from domain.adapters import providers as adapters_pkg
TYPE_MAP.clear()
CONFIG_SCHEMAS.clear()
for modinfo in pkgutil.iter_modules(adapters_pkg.__path__):
@@ -64,7 +65,7 @@ def get_config_schema(adapter_type: str):
class RuntimeRegistry:
def __init__(self):
self._instances: Dict[int, object] = {}
self._instances: Dict[int, BaseAdapter] = {}
async def refresh(self):
discover_adapters()
@@ -86,9 +87,9 @@ class RuntimeRegistry:
try:
self._instances[rec.id] = factory(rec)
except Exception:
continue
continue
def get(self, adapter_id: int):
def get(self, adapter_id: int) -> BaseAdapter | None:
return self._instances.get(adapter_id)
def snapshot(self) -> Dict[int, BaseAdapter]:
@@ -104,7 +105,7 @@ class RuntimeRegistry:
if not rec.enabled:
self.remove(rec.id)
return
normalized_type = normalize_adapter_type(rec.type)
if not normalized_type:
self.remove(rec.id)

111
domain/adapters/service.py Normal file
View File

@@ -0,0 +1,111 @@
from typing import Optional
from fastapi import HTTPException
from domain.adapters.registry import (
get_config_schemas,
normalize_adapter_type,
runtime_registry,
)
from domain.adapters.types import AdapterCreate, AdapterOut
from domain.auth.types import User
from models import StorageAdapter
class AdapterService:
@classmethod
def _validate_and_normalize_config(cls, adapter_type: str, cfg):
schemas = get_config_schemas()
adapter_type = normalize_adapter_type(adapter_type)
if not adapter_type:
raise HTTPException(400, detail="不支持的适配器类型")
if not isinstance(cfg, dict):
raise HTTPException(400, detail="config 必须是对象")
schema = schemas.get(adapter_type)
if not schema:
raise HTTPException(400, detail=f"不支持的适配器类型: {adapter_type}")
out = {}
missing = []
for f in schema:
k = f["key"]
if k in cfg and cfg[k] not in (None, ""):
out[k] = cfg[k]
elif "default" in f:
out[k] = f["default"]
elif f.get("required"):
missing.append(k)
if missing:
raise HTTPException(400, detail="缺少必填配置字段: " + ", ".join(missing))
return out
@classmethod
async def create_adapter(cls, data: AdapterCreate, current_user: Optional[User]):
norm_path = AdapterCreate.normalize_mount_path(data.path)
exists = await StorageAdapter.get_or_none(path=norm_path)
if exists:
raise HTTPException(400, detail="Mount path already exists")
adapter_fields = {
"name": data.name,
"type": data.type,
"config": cls._validate_and_normalize_config(data.type, data.config or {}),
"enabled": data.enabled,
"path": norm_path,
"sub_path": data.sub_path,
}
rec = await StorageAdapter.create(**adapter_fields)
await runtime_registry.upsert(rec)
return AdapterOut.model_validate(rec)
@classmethod
async def list_adapters(cls):
adapters = await StorageAdapter.all()
return [AdapterOut.model_validate(a) for a in adapters]
@classmethod
async def available_adapter_types(cls):
data = []
for adapter_type, fields in get_config_schemas().items():
data.append({
"type": adapter_type,
"config_schema": fields,
})
return data
@classmethod
async def get_adapter(cls, adapter_id: int):
rec = await StorageAdapter.get_or_none(id=adapter_id)
if not rec:
raise HTTPException(404, detail="Not found")
return AdapterOut.model_validate(rec)
@classmethod
async def update_adapter(cls, adapter_id: int, data: AdapterCreate, current_user: Optional[User]):
rec = await StorageAdapter.get_or_none(id=adapter_id)
if not rec:
raise HTTPException(404, detail="Not found")
norm_path = AdapterCreate.normalize_mount_path(data.path)
existing = await StorageAdapter.get_or_none(path=norm_path)
if existing and existing.id != adapter_id:
raise HTTPException(400, detail="Mount path already exists")
rec.name = data.name
rec.type = data.type
rec.config = cls._validate_and_normalize_config(data.type, data.config or {})
rec.enabled = data.enabled
rec.path = norm_path
rec.sub_path = data.sub_path
await rec.save()
await runtime_registry.upsert(rec)
return AdapterOut.model_validate(rec)
@classmethod
async def delete_adapter(cls, adapter_id: int, current_user: Optional[User]):
deleted = await StorageAdapter.filter(id=adapter_id).delete()
if not deleted:
raise HTTPException(404, detail="Not found")
runtime_registry.remove(adapter_id)
return {"deleted": True}

View File

@@ -1,5 +1,6 @@
import re
from typing import Dict, Optional
from pydantic import BaseModel, Field, field_validator

34
domain/ai/__init__.py Normal file
View File

@@ -0,0 +1,34 @@
from .api import router_ai, router_vector_db
from .service import (
AIProviderService,
VectorDBConfigManager,
VectorDBService,
DEFAULT_VECTOR_DIMENSION,
ABILITIES,
normalize_capabilities,
)
from .types import (
AIDefaultsUpdate,
AIModelCreate,
AIModelUpdate,
AIProviderCreate,
AIProviderUpdate,
VectorDBConfigPayload,
)
__all__ = [
"router_ai",
"router_vector_db",
"AIProviderService",
"VectorDBService",
"VectorDBConfigManager",
"DEFAULT_VECTOR_DIMENSION",
"ABILITIES",
"normalize_capabilities",
"AIDefaultsUpdate",
"AIModelCreate",
"AIModelUpdate",
"AIProviderCreate",
"AIProviderUpdate",
"VectorDBConfigPayload",
]

305
domain/ai/api.py Normal file
View File

@@ -0,0 +1,305 @@
from typing import Annotated, Dict, Optional
import httpx
from fastapi import APIRouter, Depends, HTTPException, Path, Request
from api.response import success
from domain.audit import AuditAction, audit
from domain.ai.service import AIProviderService, VectorDBConfigManager, VectorDBService
from domain.ai.types import (
AIDefaultsUpdate,
AIModelCreate,
AIModelUpdate,
AIProviderCreate,
AIProviderUpdate,
VectorDBConfigPayload,
)
from domain.ai.vector_providers import get_provider_class, get_provider_entry, list_providers
from domain.auth.service import get_current_active_user
from domain.auth.types import User
router_ai = APIRouter(prefix="/api/ai", tags=["ai"])
router_vector_db = APIRouter(prefix="/api/vector-db", tags=["vector-db"])
@audit(action=AuditAction.READ, description="获取 AI 提供商列表")
@router_ai.get("/providers")
async def list_providers_endpoint(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)]
):
providers = await AIProviderService.list_providers()
return success({"providers": providers})
@audit(
action=AuditAction.CREATE,
description="创建 AI 提供商",
body_fields=["name", "identifier", "provider_type", "api_format", "base_url", "logo_url"],
redact_fields=["api_key"],
)
@router_ai.post("/providers")
async def create_provider(
request: Request,
payload: AIProviderCreate,
current_user: Annotated[User, Depends(get_current_active_user)]
):
provider = await AIProviderService.create_provider(payload.dict())
return success(provider)
@audit(action=AuditAction.READ, description="获取 AI 提供商详情")
@router_ai.get("/providers/{provider_id}")
async def get_provider(
request: Request,
provider_id: Annotated[int, Path(..., gt=0)],
current_user: Annotated[User, Depends(get_current_active_user)],
):
provider = await AIProviderService.get_provider(provider_id, with_models=True)
return success(provider)
@audit(
action=AuditAction.UPDATE,
description="更新 AI 提供商",
body_fields=["name", "provider_type", "api_format", "base_url", "logo_url", "api_key"],
redact_fields=["api_key"],
)
@router_ai.put("/providers/{provider_id}")
async def update_provider(
request: Request,
provider_id: Annotated[int, Path(..., gt=0)],
payload: AIProviderUpdate,
current_user: Annotated[User, Depends(get_current_active_user)],
):
data = {k: v for k, v in payload.dict().items() if v is not None}
if not data:
raise HTTPException(status_code=400, detail="No fields to update")
provider = await AIProviderService.update_provider(provider_id, data)
return success(provider)
@audit(action=AuditAction.DELETE, description="删除 AI 提供商")
@router_ai.delete("/providers/{provider_id}")
async def delete_provider(
request: Request,
provider_id: Annotated[int, Path(..., gt=0)],
current_user: Annotated[User, Depends(get_current_active_user)],
):
await AIProviderService.delete_provider(provider_id)
return success({"id": provider_id})
@audit(action=AuditAction.UPDATE, description="同步模型列表")
@router_ai.post("/providers/{provider_id}/sync-models")
async def sync_models(
request: Request,
provider_id: Annotated[int, Path(..., gt=0)],
current_user: Annotated[User, Depends(get_current_active_user)],
):
try:
result = await AIProviderService.sync_models(provider_id)
except (httpx.RequestError, httpx.HTTPStatusError) as exc:
raise HTTPException(status_code=502, detail=f"Failed to synchronize models: {exc}") from exc
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
return success(result)
@audit(action=AuditAction.READ, description="获取远程模型列表")
@router_ai.get("/providers/{provider_id}/remote-models")
async def fetch_remote_models(
request: Request,
provider_id: Annotated[int, Path(..., gt=0)],
current_user: Annotated[User, Depends(get_current_active_user)],
):
try:
models = await AIProviderService.fetch_remote_models(provider_id)
except (httpx.RequestError, httpx.HTTPStatusError) as exc:
raise HTTPException(status_code=502, detail=f"Failed to pull models: {exc}") from exc
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
return success({"models": models})
@audit(action=AuditAction.READ, description="获取模型列表")
@router_ai.get("/providers/{provider_id}/models")
async def list_models(
request: Request,
provider_id: Annotated[int, Path(..., gt=0)],
current_user: Annotated[User, Depends(get_current_active_user)],
):
models = await AIProviderService.list_models(provider_id)
return success({"models": models})
@audit(
action=AuditAction.CREATE,
description="创建模型",
body_fields=["name", "display_name", "capabilities", "context_window", "embedding_dimensions"],
)
@router_ai.post("/providers/{provider_id}/models")
async def create_model(
request: Request,
provider_id: Annotated[int, Path(..., gt=0)],
payload: AIModelCreate,
current_user: Annotated[User, Depends(get_current_active_user)],
):
model = await AIProviderService.create_model(provider_id, payload.dict())
return success(model)
@audit(
action=AuditAction.UPDATE,
description="更新模型",
body_fields=["display_name", "description", "capabilities", "context_window", "embedding_dimensions"],
)
@router_ai.put("/models/{model_id}")
async def update_model(
request: Request,
model_id: Annotated[int, Path(..., gt=0)],
payload: AIModelUpdate,
current_user: Annotated[User, Depends(get_current_active_user)],
):
data = {k: v for k, v in payload.dict().items() if v is not None}
if not data:
raise HTTPException(status_code=400, detail="No fields to update")
model = await AIProviderService.update_model(model_id, data)
return success(model)
@audit(action=AuditAction.DELETE, description="删除模型")
@router_ai.delete("/models/{model_id}")
async def delete_model(
request: Request,
model_id: Annotated[int, Path(..., gt=0)],
current_user: Annotated[User, Depends(get_current_active_user)],
):
await AIProviderService.delete_model(model_id)
return success({"id": model_id})
def _get_embedding_dimension(entry: Optional[Dict]) -> Optional[int]:
if not entry:
return None
value = entry.get("embedding_dimensions")
return int(value) if value is not None else None
@audit(action=AuditAction.READ, description="获取默认模型")
@router_ai.get("/defaults")
async def get_defaults(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
):
defaults = await AIProviderService.get_default_models()
return success(defaults)
@audit(
action=AuditAction.UPDATE,
description="更新默认模型",
body_fields=["chat", "vision", "embedding", "rerank", "voice", "tools"],
)
@router_ai.put("/defaults")
async def update_defaults(
request: Request,
payload: AIDefaultsUpdate,
current_user: Annotated[User, Depends(get_current_active_user)],
):
previous = await AIProviderService.get_default_models()
try:
updated = await AIProviderService.set_default_models(payload.as_mapping())
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
prev_dim = _get_embedding_dimension(previous.get("embedding"))
next_dim = _get_embedding_dimension(updated.get("embedding"))
if prev_dim and next_dim and prev_dim != next_dim:
try:
await VectorDBService().clear_all_data()
except Exception as exc: # noqa: BLE001
raise HTTPException(status_code=500, detail=f"Failed to clear vector database: {exc}") from exc
return success(updated)
@audit(action=AuditAction.UPDATE, description="清空向量数据库")
@router_vector_db.post("/clear-all", summary="清空向量数据库")
async def clear_vector_db(request: Request, user: User = Depends(get_current_active_user)):
try:
service = VectorDBService()
await service.clear_all_data()
return success(msg="向量数据库已清空")
except Exception as e: # noqa: BLE001
raise HTTPException(status_code=500, detail=str(e))
@audit(action=AuditAction.READ, description="获取向量数据库统计")
@router_vector_db.get("/stats", summary="获取向量数据库统计")
async def get_vector_db_stats(request: Request, user: User = Depends(get_current_active_user)):
try:
service = VectorDBService()
data = await service.get_all_stats()
return success(data=data)
except Exception as e: # noqa: BLE001
raise HTTPException(status_code=500, detail=str(e))
@audit(action=AuditAction.READ, description="获取向量数据库提供者列表")
@router_vector_db.get("/providers", summary="列出可用向量数据库提供者")
async def list_vector_providers(request: Request, user: User = Depends(get_current_active_user)):
return success(list_providers())
@audit(action=AuditAction.READ, description="获取向量数据库配置")
@router_vector_db.get("/config", summary="获取当前向量数据库配置")
async def get_vector_db_config(request: Request, user: User = Depends(get_current_active_user)):
service = VectorDBService()
data = await service.current_provider()
return success(data)
@audit(action=AuditAction.UPDATE, description="更新向量数据库配置", body_fields=["type"])
@router_vector_db.post("/config", summary="更新向量数据库配置")
async def update_vector_db_config(
request: Request, payload: VectorDBConfigPayload, user: User = Depends(get_current_active_user)
):
entry = get_provider_entry(payload.type)
if not entry:
raise HTTPException(
status_code=400, detail=f"未知的向量数据库类型: {payload.type}")
if not entry.get("enabled", True):
raise HTTPException(status_code=400, detail="该向量数据库类型暂不可用")
provider_cls = get_provider_class(payload.type)
if not provider_cls:
raise HTTPException(
status_code=400, detail=f"未找到类型 {payload.type} 对应的实现")
test_provider = provider_cls(payload.config)
try:
await test_provider.initialize()
except Exception as exc:
raise HTTPException(status_code=400, detail=str(exc))
finally:
client = getattr(test_provider, "client", None)
close_fn = getattr(client, "close", None)
if callable(close_fn):
try:
close_fn()
except Exception:
pass
await VectorDBConfigManager.save_config(payload.type, payload.config)
service = VectorDBService()
await service.reload()
config_data = await service.current_provider()
stats = await service.get_all_stats()
return success({"config": config_data, "stats": stats})
__all__ = ["router_ai", "router_vector_db"]

View File

@@ -4,10 +4,10 @@ import httpx
from typing import List, Sequence, Tuple
from models.database import AIModel, AIProvider
from services.ai_providers import AIProviderService
from domain.ai.service import AIProviderService
provider_service = AIProviderService()
provider_service = AIProviderService
class MissingModelError(RuntimeError):

View File

@@ -1,5 +1,7 @@
from __future__ import annotations
import asyncio
import json
from collections.abc import Iterable
from typing import Any, Dict, List, Optional, Tuple
@@ -7,10 +9,18 @@ import httpx
from tortoise.exceptions import DoesNotExist
from tortoise.transactions import in_transaction
from domain.config.service import ConfigService
from models.database import AIDefaultModel, AIModel, AIProvider
from .types import ABILITIES, normalize_capabilities
from .vector_providers import (
BaseVectorProvider,
get_provider_class,
get_provider_entry,
list_providers,
)
ABILITIES = ["chat", "vision", "embedding", "rerank", "voice", "tools"]
DEFAULT_VECTOR_DIMENSION = 4096
OPENAI_EMBEDDING_DIMS = {
"text-embedding-3-large": 3072,
@@ -19,6 +29,43 @@ OPENAI_EMBEDDING_DIMS = {
}
class VectorDBConfigManager:
TYPE_KEY = "VECTOR_DB_TYPE"
CONFIG_KEY = "VECTOR_DB_CONFIG"
DEFAULT_TYPE = "milvus_lite"
@classmethod
async def load_config(cls) -> Tuple[str, Dict[str, Any]]:
raw_type = await ConfigService.get(cls.TYPE_KEY, cls.DEFAULT_TYPE)
provider_type = str(raw_type or cls.DEFAULT_TYPE)
raw_config = await ConfigService.get(cls.CONFIG_KEY)
config_dict: Dict[str, Any] = {}
if isinstance(raw_config, str) and raw_config:
try:
config_dict = json.loads(raw_config)
except json.JSONDecodeError:
config_dict = {}
elif isinstance(raw_config, dict):
config_dict = raw_config
return provider_type, config_dict
@classmethod
async def save_config(cls, provider_type: str, config: Dict[str, Any]) -> None:
await ConfigService.set(cls.TYPE_KEY, provider_type)
await ConfigService.set(cls.CONFIG_KEY, json.dumps(config or {}))
@classmethod
async def get_type(cls) -> str:
provider_type, _ = await cls.load_config()
return provider_type
@classmethod
async def get_config(cls) -> Dict[str, Any]:
_, config = await cls.load_config()
return config
def _normalize_embedding_dim(value: Any) -> Optional[int]:
if value is None:
return None
@@ -47,17 +94,6 @@ def _apply_embedding_dim_to_metadata(
return data
def normalize_capabilities(items: Optional[Iterable[str]]) -> List[str]:
if not items:
return []
normalized = []
for cap in items:
key = str(cap).strip().lower()
if key in ABILITIES and key not in normalized:
normalized.append(key)
return normalized
def infer_openai_capabilities(model_id: str) -> Tuple[List[str], Optional[int]]:
lower = model_id.lower()
caps = set()
@@ -139,40 +175,46 @@ def provider_to_dict(provider: AIProvider, models: Optional[List[AIModel]] = Non
class AIProviderService:
async def list_providers(self) -> List[Dict[str, Any]]:
@classmethod
async def list_providers(cls) -> List[Dict[str, Any]]:
providers = await AIProvider.all().order_by("id").prefetch_related("models")
return [provider_to_dict(p, models=list(p.models)) for p in providers]
async def get_provider(self, provider_id: int, with_models: bool = False) -> Dict[str, Any]:
@classmethod
async def get_provider(cls, provider_id: int, with_models: bool = False) -> Dict[str, Any]:
if with_models:
provider = await AIProvider.get(id=provider_id)
models = await provider.models.all()
return provider_to_dict(provider, models=models)
else:
provider = await AIProvider.get(id=provider_id)
return provider_to_dict(provider)
provider = await AIProvider.get(id=provider_id)
return provider_to_dict(provider)
async def create_provider(self, payload: Dict[str, Any]) -> Dict[str, Any]:
@classmethod
async def create_provider(cls, payload: Dict[str, Any]) -> Dict[str, Any]:
data = payload.copy()
data.setdefault("extra_config", {})
provider = await AIProvider.create(**data)
return provider_to_dict(provider)
async def update_provider(self, provider_id: int, payload: Dict[str, Any]) -> Dict[str, Any]:
@classmethod
async def update_provider(cls, provider_id: int, payload: Dict[str, Any]) -> Dict[str, Any]:
provider = await AIProvider.get(id=provider_id)
for field, value in payload.items():
setattr(provider, field, value)
await provider.save()
return provider_to_dict(provider)
async def delete_provider(self, provider_id: int) -> None:
@classmethod
async def delete_provider(cls, provider_id: int) -> None:
await AIProvider.filter(id=provider_id).delete()
async def list_models(self, provider_id: int) -> List[Dict[str, Any]]:
@classmethod
async def list_models(cls, provider_id: int) -> List[Dict[str, Any]]:
models = await AIModel.filter(provider_id=provider_id).order_by("id").prefetch_related("provider")
return [model_to_dict(m) for m in models]
async def create_model(self, provider_id: int, payload: Dict[str, Any]) -> Dict[str, Any]:
@classmethod
async def create_model(cls, provider_id: int, payload: Dict[str, Any]) -> Dict[str, Any]:
data = payload.copy()
data["provider_id"] = provider_id
data["capabilities"] = normalize_capabilities(data.get("capabilities"))
@@ -182,7 +224,8 @@ class AIProviderService:
await model.fetch_related("provider")
return model_to_dict(model)
async def update_model(self, model_id: int, payload: Dict[str, Any]) -> Dict[str, Any]:
@classmethod
async def update_model(cls, model_id: int, payload: Dict[str, Any]) -> Dict[str, Any]:
model = await AIModel.get(id=model_id)
data = payload.copy()
if "capabilities" in data:
@@ -199,14 +242,17 @@ class AIProviderService:
await model.fetch_related("provider")
return model_to_dict(model)
async def delete_model(self, model_id: int) -> None:
@classmethod
async def delete_model(cls, model_id: int) -> None:
await AIModel.filter(id=model_id).delete()
async def fetch_remote_models(self, provider_id: int) -> List[Dict[str, Any]]:
@classmethod
async def fetch_remote_models(cls, provider_id: int) -> List[Dict[str, Any]]:
provider = await AIProvider.get(id=provider_id)
return await self._get_remote_models(provider)
return await cls._get_remote_models(provider)
async def _get_remote_models(self, provider: AIProvider) -> List[Dict[str, Any]]:
@classmethod
async def _get_remote_models(cls, provider: AIProvider) -> List[Dict[str, Any]]:
if not provider.base_url:
raise ValueError("Provider base_url is required for syncing models")
@@ -215,12 +261,13 @@ class AIProviderService:
raise ValueError(f"Unsupported api_format '{provider.api_format}' for syncing models")
if fmt == "openai":
return await self._fetch_openai_models(provider)
return await self._fetch_gemini_models(provider)
return await cls._fetch_openai_models(provider)
return await cls._fetch_gemini_models(provider)
async def sync_models(self, provider_id: int) -> Dict[str, int]:
@classmethod
async def sync_models(cls, provider_id: int) -> Dict[str, int]:
provider = await AIProvider.get(id=provider_id)
remote_models = await self._get_remote_models(provider)
remote_models = await cls._get_remote_models(provider)
created = 0
updated = 0
@@ -247,14 +294,16 @@ class AIProviderService:
return {"created": created, "updated": updated}
async def get_default_models(self) -> Dict[str, Optional[Dict[str, Any]]]:
@classmethod
async def get_default_models(cls) -> Dict[str, Optional[Dict[str, Any]]]:
defaults = await AIDefaultModel.all().prefetch_related("model__provider")
result: Dict[str, Optional[Dict[str, Any]]] = {ability: None for ability in ABILITIES}
for item in defaults:
result[item.ability] = model_to_dict(item.model, provider=item.model.provider) # type: ignore[attr-defined]
return result
async def set_default_models(self, mapping: Dict[str, Optional[int]]) -> Dict[str, Optional[Dict[str, Any]]]:
@classmethod
async def set_default_models(cls, mapping: Dict[str, Optional[int]]) -> Dict[str, Optional[Dict[str, Any]]]:
normalized = {ability: mapping.get(ability) for ability in ABILITIES}
async with in_transaction() as connection:
for ability, model_id in normalized.items():
@@ -271,9 +320,10 @@ class AIProviderService:
await AIDefaultModel.create(ability=ability, model_id=model_id)
elif record:
await record.delete(using_db=connection)
return await self.get_default_models()
return await cls.get_default_models()
async def get_default_model(self, ability: str) -> Optional[AIModel]:
@classmethod
async def get_default_model(cls, ability: str) -> Optional[AIModel]:
ability_key = ability.lower()
if ability_key not in ABILITIES:
return None
@@ -285,7 +335,8 @@ class AIProviderService:
await model.fetch_related("provider")
return model
async def _fetch_openai_models(self, provider: AIProvider) -> List[Dict[str, Any]]:
@classmethod
async def _fetch_openai_models(cls, provider: AIProvider) -> List[Dict[str, Any]]:
base_url = provider.base_url.rstrip("/")
url = f"{base_url}/models"
headers = {}
@@ -315,7 +366,8 @@ class AIProviderService:
})
return entries
async def _fetch_gemini_models(self, provider: AIProvider) -> List[Dict[str, Any]]:
@classmethod
async def _fetch_gemini_models(cls, provider: AIProvider) -> List[Dict[str, Any]]:
base_url = provider.base_url.rstrip("/")
suffix = "/models"
if provider.api_key:
@@ -345,3 +397,105 @@ class AIProviderService:
"metadata": item,
})
return entries
class VectorDBService:
_instance: "VectorDBService" | None = None
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if not hasattr(self, "_provider"):
self._provider: Optional[BaseVectorProvider] = None
self._provider_type: Optional[str] = None
self._provider_config: Dict[str, Any] | None = None
self._lock = asyncio.Lock()
async def _ensure_provider(self) -> BaseVectorProvider:
if self._provider is None:
await self.reload()
assert self._provider is not None
return self._provider
async def reload(self) -> BaseVectorProvider:
async with self._lock:
provider_type, provider_config = await VectorDBConfigManager.load_config()
normalized_config = dict(provider_config or {})
if (
self._provider
and self._provider_type == provider_type
and self._provider_config == normalized_config
):
return self._provider
entry = get_provider_entry(provider_type)
if not entry:
raise RuntimeError(f"Unknown vector database provider: {provider_type}")
if not entry.get("enabled", True):
raise RuntimeError(f"Vector database provider '{provider_type}' is disabled")
provider_cls = get_provider_class(provider_type)
if not provider_cls:
raise RuntimeError(f"Provider class not found for '{provider_type}'")
provider = provider_cls(provider_config)
await provider.initialize()
self._provider = provider
self._provider_type = provider_type
self._provider_config = normalized_config
return provider
async def ensure_collection(self, collection_name: str, vector: bool = True, dim: int = DEFAULT_VECTOR_DIMENSION) -> None:
provider = await self._ensure_provider()
provider.ensure_collection(collection_name, vector, dim)
async def upsert_vector(self, collection_name: str, data: Dict[str, Any]) -> None:
provider = await self._ensure_provider()
provider.upsert_vector(collection_name, data)
async def delete_vector(self, collection_name: str, path: str) -> None:
provider = await self._ensure_provider()
provider.delete_vector(collection_name, path)
async def search_vectors(self, collection_name: str, query_embedding, top_k: int = 5):
provider = await self._ensure_provider()
return provider.search_vectors(collection_name, query_embedding, top_k)
async def search_by_path(self, collection_name: str, query_path: str, top_k: int = 20):
provider = await self._ensure_provider()
return provider.search_by_path(collection_name, query_path, top_k)
async def get_all_stats(self) -> Dict[str, Any]:
provider = await self._ensure_provider()
return provider.get_all_stats()
async def clear_all_data(self) -> None:
provider = await self._ensure_provider()
provider.clear_all_data()
async def current_provider(self) -> Dict[str, Any]:
provider_type, provider_config = await VectorDBConfigManager.load_config()
entry = get_provider_entry(provider_type) or {}
return {
"type": provider_type,
"config": provider_config,
"label": entry.get("label"),
"enabled": entry.get("enabled", True),
}
__all__ = [
"AIProviderService",
"VectorDBService",
"VectorDBConfigManager",
"DEFAULT_VECTOR_DIMENSION",
"list_providers",
"get_provider_entry",
"get_provider_class",
"normalize_capabilities",
"ABILITIES",
]

View File

@@ -1,8 +1,19 @@
from typing import List, Optional
from typing import Any, Dict, Iterable, List, Optional
from pydantic import BaseModel, Field, field_validator
from services.ai_providers import ABILITIES, normalize_capabilities
ABILITIES = ["chat", "vision", "embedding", "rerank", "voice", "tools"]
def normalize_capabilities(items: Optional[Iterable[str]]) -> List[str]:
if not items:
return []
normalized: List[str] = []
for cap in items:
key = str(cap).strip().lower()
if key in ABILITIES and key not in normalized:
normalized.append(key)
return normalized
class AIProviderBase(BaseModel):
@@ -16,6 +27,7 @@ class AIProviderBase(BaseModel):
extra_config: Optional[dict] = None
@field_validator("api_format")
@classmethod
def normalize_format(cls, value: str) -> str:
fmt = value.lower()
if fmt not in {"openai", "gemini"}:
@@ -37,6 +49,7 @@ class AIProviderUpdate(BaseModel):
extra_config: Optional[dict] = None
@field_validator("api_format")
@classmethod
def normalize_format(cls, value: Optional[str]) -> Optional[str]:
if value is None:
return value
@@ -56,6 +69,7 @@ class AIModelBase(BaseModel):
metadata: Optional[dict] = None
@field_validator("capabilities")
@classmethod
def validate_capabilities(cls, items: Optional[List[str]]) -> Optional[List[str]]:
if items is None:
return None
@@ -79,6 +93,7 @@ class AIModelUpdate(BaseModel):
metadata: Optional[dict] = None
@field_validator("capabilities")
@classmethod
def validate_capabilities(cls, items: Optional[List[str]]) -> Optional[List[str]]:
if items is None:
return None
@@ -97,5 +112,10 @@ class AIDefaultsUpdate(BaseModel):
voice: Optional[int] = None
tools: Optional[int] = None
def as_mapping(self) -> dict:
def as_mapping(self) -> Dict[str, Optional[int]]:
return {ability: getattr(self, ability) for ability in ABILITIES}
class VectorDBConfigPayload(BaseModel):
type: str = Field(..., description="向量数据库提供者类型")
config: Dict[str, Any] = Field(default_factory=dict, description="提供者配置参数")

View File

@@ -54,3 +54,14 @@ def get_provider_class(provider_type: str) -> Type[BaseVectorProvider] | None:
if not entry:
return None
return entry.get("class") # type: ignore[return-value]
__all__ = [
"BaseVectorProvider",
"MilvusLiteProvider",
"MilvusServerProvider",
"QdrantProvider",
"list_providers",
"get_provider_entry",
"get_provider_class",
]

View File

@@ -155,7 +155,7 @@ class MilvusLiteProvider(BaseVectorProvider):
def search_by_path(self, collection_name: str, query_path: str, top_k: int):
if query_path:
escaped = query_path.replace('"', '\\"')
filter_expr = f'source_path like "%{escaped}%"'
filter_expr = f'source_path like \"%{escaped}%\"'
else:
filter_expr = "source_path like '%%'"
results = self._get_client().query(
@@ -232,7 +232,7 @@ class MilvusLiteProvider(BaseVectorProvider):
for index_name in index_names:
try:
detail = client.describe_index(name, index_name) or {}
detail = client.describe_index(name) or {}
except Exception:
detail = {}
indexes.append(

View File

@@ -162,7 +162,7 @@ class MilvusServerProvider(BaseVectorProvider):
def search_by_path(self, collection_name: str, query_path: str, top_k: int):
if query_path:
escaped = query_path.replace('"', '\\"')
filter_expr = f'source_path like "%{escaped}%"'
filter_expr = f'source_path like \"%{escaped}%\"'
else:
filter_expr = "source_path like '%%'"
results = self._get_client().query(
@@ -239,7 +239,7 @@ class MilvusServerProvider(BaseVectorProvider):
for index_name in index_names:
try:
detail = client.describe_index(name, index_name) or {}
detail = client.describe_index(name) or {}
except Exception:
detail = {}
indexes.append(

View File

@@ -42,7 +42,6 @@ class QdrantProvider(BaseVectorProvider):
api_key = (self.config.get("api_key") or None) or None
try:
client = QdrantClient(url=url, api_key=api_key)
# 简单连通性校验
client.get_collections()
self.client = client
except Exception as exc: # pragma: no cover - 依赖外部服务
@@ -70,7 +69,6 @@ class QdrantProvider(BaseVectorProvider):
message = str(exc).lower()
if "already exists" in message or "index exists" in message:
continue
# 旧版本 qdrant 可能返回带状态码的异常,这里容忍重复创建
raise
def ensure_collection(self, collection_name: str, vector: bool, dim: int) -> None:

5
domain/audit/__init__.py Normal file
View File

@@ -0,0 +1,5 @@
from domain.audit.decorator import audit
from domain.audit.types import AuditAction
from domain.audit.api import router
__all__ = ["audit", "AuditAction", "router"]

68
domain/audit/api.py Normal file
View File

@@ -0,0 +1,68 @@
from datetime import datetime, timezone
from typing import Annotated, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from api import response
from domain.audit.service import AuditService
from domain.audit.types import AuditAction
from domain.auth.service import get_current_active_user
from domain.auth.types import User
CurrentUser = Annotated[User, Depends(get_current_active_user)]
router = APIRouter(prefix="/api/audit", tags=["Audit"])
def _parse_iso(value: Optional[str], field: str):
if not value:
return None
try:
normalized = value.replace("Z", "+00:00")
dt = datetime.fromisoformat(normalized)
if dt.tzinfo:
dt = dt.astimezone(timezone.utc).replace(tzinfo=None)
return dt
except ValueError as exc: # noqa: BLE001
raise HTTPException(status_code=400, detail=f"invalid {field}") from exc
@router.get("/logs")
async def list_audit_logs(
current_user: CurrentUser,
page_num: int = Query(1, ge=1, alias="page", description="页码"),
page_size: int = Query(20, ge=1, le=200, description="每页条数"),
action: AuditAction | None = Query(None, description="操作类型"),
success: bool | None = Query(None, description="是否成功"),
username: str | None = Query(None, description="用户名模糊匹配"),
path: str | None = Query(None, description="路径模糊匹配"),
start_time: str | None = Query(None, description="开始时间 (ISO 8601)"),
end_time: str | None = Query(None, description="结束时间 (ISO 8601)"),
):
start_dt = _parse_iso(start_time, "start_time")
end_dt = _parse_iso(end_time, "end_time")
items, total = await AuditService.list_logs(
page=page_num,
page_size=page_size,
action=str(action) if action else None,
success=success,
username=username,
path=path,
start_time=start_dt,
end_time=end_dt,
)
return response.success(response.page(items, total, page_num, page_size))
@router.delete("/logs")
async def clear_audit_logs(
current_user: CurrentUser,
start_time: str | None = Query(None, description="开始时间 (ISO 8601)"),
end_time: str | None = Query(None, description="结束时间 (ISO 8601)"),
):
start_dt = _parse_iso(start_time, "start_time")
end_dt = _parse_iso(end_time, "end_time")
if start_dt is None and end_dt is None:
raise HTTPException(status_code=400, detail="start_time 或 end_time 至少提供一个")
deleted_count = await AuditService.clear_logs(start_time=start_dt, end_time=end_dt)
return response.success({"deleted_count": deleted_count})

182
domain/audit/decorator.py Normal file
View File

@@ -0,0 +1,182 @@
import inspect
import time
from functools import wraps
from typing import Any, Dict, Mapping, Optional
import jwt
from fastapi import Request
from jwt.exceptions import InvalidTokenError
from domain.audit.service import AuditService
from domain.audit.types import AuditAction
from domain.auth.service import ALGORITHM
from domain.config.service import ConfigService
from models.database import UserAccount
def _extract_request(bound_args: Mapping[str, Any]) -> Request | None:
for value in bound_args.values():
if isinstance(value, Request):
return value
return None
async def _resolve_user(request: Request | None, user_obj: Any | None) -> tuple[Optional[int], Optional[str]]:
user_id: int | None = None
username: str | None = None
if request:
auth_header = request.headers.get("authorization") or request.headers.get("Authorization")
if auth_header and auth_header.lower().startswith("bearer "):
token = auth_header.split(" ", 1)[1]
try:
payload = jwt.decode(token, await ConfigService.get_secret_key("SECRET_KEY"), algorithms=[ALGORITHM])
username = payload.get("sub") or payload.get("username")
if username:
user = await UserAccount.get_or_none(username=username)
user_id = user.id if user else None
except (InvalidTokenError, Exception):
pass
if user_id is None and username is None and user_obj is not None:
user_id = getattr(user_obj, "id", None) or getattr(user_obj, "user_id", None)
username = getattr(user_obj, "username", None) or getattr(user_obj, "name", None)
if isinstance(user_obj, dict):
user_id = user_obj.get("id", user_obj.get("user_id", user_id))
username = user_obj.get("username", user_obj.get("name", username))
return user_id, username
def _extract_body_fields(bound_args: Mapping[str, Any], body_fields: list[str] | None, redact_fields: list[str] | None):
if not body_fields:
return None
body: Dict[str, Any] = {}
redacts = set(redact_fields or [])
for value in bound_args.values():
data: Optional[Dict[str, Any]] = None
if hasattr(value, "model_dump"):
try:
data = value.model_dump()
except Exception:
data = None
elif hasattr(value, "dict"):
try:
data = value.dict()
except Exception:
data = None
elif isinstance(value, dict):
data = value
elif hasattr(value, "__dict__"):
data = dict(value.__dict__)
if not isinstance(data, dict):
continue
for field in body_fields:
if field in data and field not in body:
body[field] = data[field]
if not body:
return None
for field in redacts:
if field in body:
body[field] = "<redacted>"
return body
def _build_request_params(request: Request | None) -> Dict[str, Any] | None:
if not request:
return None
params: Dict[str, Any] = {}
query = dict(request.query_params)
if query:
params["query"] = query
path_params = dict(request.path_params or {})
if path_params:
params["path"] = path_params
return params or None
def _status_code_from_response(response: Any) -> int:
if hasattr(response, "status_code"):
try:
return int(getattr(response, "status_code"))
except Exception:
pass
return 200
def audit(
*,
action: AuditAction,
description: str | None = None,
body_fields: list[str] | None = None,
redact_fields: list[str] | None = None,
user_kw: str = "current_user",
):
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
bound = inspect.signature(func).bind_partial(*args, **kwargs)
bound.apply_defaults()
request = _extract_request(bound.arguments)
start = time.perf_counter()
user_info = bound.arguments.get(user_kw)
user_id, username = await _resolve_user(request, user_info)
request_params = _build_request_params(request)
request_body = _extract_body_fields(bound.arguments, body_fields, redact_fields)
try:
result = func(*args, **kwargs)
if inspect.isawaitable(result):
result = await result
status_code = _status_code_from_response(result)
success = True
error = None
except Exception as exc: # noqa: BLE001
status_code = getattr(exc, "status_code", 500)
success = False
error = str(exc)
duration_ms = round((time.perf_counter() - start) * 1000, 2)
try:
await AuditService.log(
action=action,
description=description,
user_id=user_id,
username=username,
client_ip=request.client.host if request and request.client else None,
method=request.method if request else "",
path=request.url.path if request else func.__name__,
status_code=status_code,
duration_ms=duration_ms,
success=success,
request_params=request_params,
request_body=request_body,
error=error,
)
except Exception:
pass
raise
duration_ms = round((time.perf_counter() - start) * 1000, 2)
try:
await AuditService.log(
action=action,
description=description,
user_id=user_id,
username=username,
client_ip=request.client.host if request and request.client else None,
method=request.method if request else "",
path=request.url.path if request else func.__name__,
status_code=status_code,
duration_ms=duration_ms,
success=success,
request_params=request_params,
request_body=request_body,
error=error,
)
except Exception:
pass
return result
return wrapper
return decorator

124
domain/audit/service.py Normal file
View File

@@ -0,0 +1,124 @@
from typing import Any, Dict, Optional
from models.database import AuditLog
from domain.audit.types import AuditAction
class AuditService:
@classmethod
async def log(
cls,
*,
action: AuditAction | str,
description: Optional[str],
user_id: Optional[int],
username: Optional[str],
client_ip: Optional[str],
method: str,
path: str,
status_code: int,
duration_ms: Optional[float],
success: bool,
request_params: Optional[Dict[str, Any]] = None,
request_body: Optional[Dict[str, Any]] = None,
error: Optional[str] = None,
) -> None:
await AuditLog.create(
action=str(action),
description=description,
user_id=user_id,
username=username,
client_ip=client_ip,
method=method,
path=path,
status_code=status_code,
duration_ms=duration_ms,
success=success,
request_params=request_params,
request_body=request_body,
error=error,
)
@classmethod
def _serialize(cls, log: AuditLog) -> Dict[str, Any]:
return {
"id": log.id,
"created_at": log.created_at.isoformat() if log.created_at else None,
"action": log.action,
"description": log.description,
"user_id": log.user_id,
"username": log.username,
"client_ip": log.client_ip,
"method": log.method,
"path": log.path,
"status_code": log.status_code,
"duration_ms": log.duration_ms,
"success": log.success,
"request_params": log.request_params,
"request_body": log.request_body,
"error": log.error,
}
@classmethod
def _apply_filters(
cls,
*,
action: str | None = None,
success: bool | None = None,
username: str | None = None,
path: str | None = None,
start_time=None,
end_time=None,
):
qs = AuditLog.all()
if action:
qs = qs.filter(action=action)
if success is not None:
qs = qs.filter(success=success)
if username:
qs = qs.filter(username__icontains=username)
if path:
qs = qs.filter(path__icontains=path)
if start_time:
qs = qs.filter(created_at__gte=start_time)
if end_time:
qs = qs.filter(created_at__lte=end_time)
return qs
@classmethod
async def list_logs(
cls,
*,
page: int,
page_size: int,
action: str | None = None,
success: bool | None = None,
username: str | None = None,
path: str | None = None,
start_time=None,
end_time=None,
):
qs = cls._apply_filters(
action=action,
success=success,
username=username,
path=path,
start_time=start_time,
end_time=end_time,
)
total = await qs.count()
offset = (page - 1) * page_size
items = await qs.order_by("-created_at").offset(offset).limit(page_size)
return [cls._serialize(log) for log in items], total
@classmethod
async def clear_logs(
cls,
*,
start_time=None,
end_time=None,
) -> int:
qs = cls._apply_filters(start_time=start_time, end_time=end_time)
deleted_count = await qs.delete()
return deleted_count

16
domain/audit/types.py Normal file
View File

@@ -0,0 +1,16 @@
from enum import StrEnum
class AuditAction(StrEnum):
LOGIN = "login"
LOGOUT = "logout"
REGISTER = "register"
READ = "read"
CREATE = "create"
UPDATE = "update"
DELETE = "delete"
RESET_PASSWORD = "reset_password"
SHARE = "share"
DOWNLOAD = "download"
UPLOAD = "upload"
OTHER = "other"

90
domain/auth/api.py Normal file
View File

@@ -0,0 +1,90 @@
from typing import Annotated
from fastapi import APIRouter, Depends, Request
from fastapi.security import OAuth2PasswordRequestForm
from api.response import success
from domain.audit import AuditAction, audit
from domain.auth.service import AuthService, get_current_active_user
from domain.auth.types import (
PasswordResetConfirm,
PasswordResetRequest,
RegisterRequest,
Token,
UpdateMeRequest,
User,
)
router = APIRouter(prefix="/api/auth", tags=["auth"])
@router.post("/register", summary="注册第一个管理员用户")
@audit(
action=AuditAction.REGISTER,
description="注册管理员",
body_fields=["username", "email", "full_name"],
redact_fields=["password"],
)
async def register(request: Request, data: RegisterRequest):
user = await AuthService.register_user(data)
return success({"username": user.username}, msg="初始用户注册成功")
@router.post("/login")
@audit(action=AuditAction.LOGIN, description="用户登录", body_fields=["username"], redact_fields=["password"])
async def login_for_access_token(
request: Request,
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
) -> Token:
return await AuthService.login(form_data)
@router.get("/me", summary="获取当前登录用户信息")
@audit(action=AuditAction.READ, description="获取当前用户信息")
async def get_me(
request: Request, current_user: Annotated[User, Depends(get_current_active_user)]
):
profile = AuthService.get_profile(current_user)
return success(profile)
@router.put("/me", summary="更新当前登录用户信息")
@audit(
action=AuditAction.UPDATE,
description="更新当前用户信息",
body_fields=["email", "full_name"],
redact_fields=["old_password", "new_password"],
)
async def update_me(
request: Request,
payload: UpdateMeRequest,
current_user: Annotated[User, Depends(get_current_active_user)],
):
profile = await AuthService.update_me(payload, current_user)
return success(profile)
@router.post("/password-reset/request", summary="请求密码重置邮件")
@audit(action=AuditAction.RESET_PASSWORD, description="请求密码重置邮件", body_fields=["email"])
async def password_reset_request_endpoint(request: Request, payload: PasswordResetRequest):
await AuthService.request_password_reset(payload)
return success(msg="如果邮箱存在,将发送重置邮件")
@router.get("/password-reset/verify", summary="校验密码重置令牌")
@audit(action=AuditAction.RESET_PASSWORD, description="校验密码重置令牌", redact_fields=["token"])
async def password_reset_verify(request: Request, token: str):
user = await AuthService.verify_password_reset_token(token)
return success({"username": user.username, "email": user.email})
@router.post("/password-reset/confirm", summary="使用令牌重置密码")
@audit(
action=AuditAction.RESET_PASSWORD,
description="重置密码",
body_fields=["token"],
redact_fields=["token", "password"],
)
async def password_reset_confirm(request: Request, payload: PasswordResetConfirm):
await AuthService.reset_password_with_token(payload)
return success(msg="密码已重置")

356
domain/auth/service.py Normal file
View File

@@ -0,0 +1,356 @@
import asyncio
import hashlib
import secrets
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from typing import Annotated
import jwt
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from jwt.exceptions import InvalidTokenError
from passlib.context import CryptContext
from domain.auth.types import (
PasswordResetConfirm,
PasswordResetRequest,
RegisterRequest,
Token,
TokenData,
UpdateMeRequest,
User,
UserInDB,
)
from models.database import UserAccount
from domain.config.service import ConfigService
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 * 365
PASSWORD_RESET_TOKEN_EXPIRE_MINUTES = 10
def _now() -> datetime:
return datetime.now(timezone.utc)
@dataclass
class PasswordResetEntry:
user_id: int
email: str
username: str
expires_at: datetime
used: bool = False
class PasswordResetStore:
_tokens: dict[str, PasswordResetEntry] = {}
_lock = asyncio.Lock()
@classmethod
def _cleanup(cls):
now = _now()
for token, record in list(cls._tokens.items()):
if record.used or record.expires_at < now:
cls._tokens.pop(token, None)
@classmethod
async def create(cls, user: UserAccount) -> str:
async with cls._lock:
cls._cleanup()
for key, record in list(cls._tokens.items()):
if record.user_id == user.id:
cls._tokens.pop(key, None)
token = secrets.token_urlsafe(32)
expires_at = _now() + timedelta(minutes=PASSWORD_RESET_TOKEN_EXPIRE_MINUTES)
cls._tokens[token] = PasswordResetEntry(
user_id=user.id,
email=user.email or "",
username=user.username,
expires_at=expires_at,
)
return token
@classmethod
async def get(cls, token: str) -> PasswordResetEntry | None:
async with cls._lock:
cls._cleanup()
record = cls._tokens.get(token)
if not record or record.used:
return None
return record
@classmethod
async def mark_used(cls, token: str) -> None:
async with cls._lock:
record = cls._tokens.get(token)
if record:
record.used = True
cls._cleanup()
@classmethod
async def invalidate_user(cls, user_id: int, except_token: str | None = None) -> None:
async with cls._lock:
for key, record in list(cls._tokens.items()):
if record.user_id == user_id and key != except_token:
cls._tokens.pop(key, None)
cls._cleanup()
class AuthService:
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login")
algorithm = ALGORITHM
access_token_expire_minutes = ACCESS_TOKEN_EXPIRE_MINUTES
password_reset_token_expire_minutes = PASSWORD_RESET_TOKEN_EXPIRE_MINUTES
@classmethod
async def get_secret_key(cls) -> str:
return await ConfigService.get_secret_key("SECRET_KEY", None)
@classmethod
def _normalize_email(cls, email: str | None) -> str:
return (email or "").strip().lower()
@classmethod
def verify_password(cls, plain_password: str, hashed_password: str) -> bool:
return cls.pwd_context.verify(plain_password, hashed_password)
@classmethod
def get_password_hash(cls, password: str) -> str:
return cls.pwd_context.hash(password)
@classmethod
async def get_user_db(cls, username_or_email: str) -> UserInDB | None:
user = await UserAccount.get_or_none(username=username_or_email)
if not user:
user = await UserAccount.get_or_none(email=username_or_email)
if user:
return UserInDB(
id=user.id,
username=user.username,
email=user.email,
full_name=user.full_name,
disabled=user.disabled,
hashed_password=user.hashed_password,
)
return None
@classmethod
async def authenticate_user_db(cls, username_or_email: str, password: str) -> UserInDB | None:
user = await cls.get_user_db(username_or_email)
if not user:
return None
if not cls.verify_password(password, user.hashed_password):
return None
return user
@classmethod
async def has_users(cls) -> bool:
user_count = await UserAccount.all().count()
return user_count > 0
@classmethod
async def register_user(cls, payload: RegisterRequest):
if await cls.has_users():
raise HTTPException(status_code=403, detail="系统已初始化,不允许注册新用户")
exists = await UserAccount.get_or_none(username=payload.username)
if exists:
raise HTTPException(status_code=400, detail="用户名已存在")
hashed = cls.get_password_hash(payload.password)
user = await UserAccount.create(
username=payload.username,
email=payload.email,
full_name=payload.full_name,
hashed_password=hashed,
disabled=False,
)
return user
@classmethod
async def create_access_token(cls, data: dict, expires_delta: timedelta | None = None):
to_encode = data.copy()
if "sub" not in to_encode and "username" in to_encode:
to_encode["sub"] = to_encode["username"]
expire = _now() + (expires_delta or timedelta(minutes=15))
to_encode.update({"exp": expire})
secret_key = await cls.get_secret_key()
encoded_jwt = jwt.encode(to_encode, secret_key, algorithm=cls.algorithm)
return encoded_jwt
@classmethod
async def login(cls, form: OAuth2PasswordRequestForm) -> Token:
user = await cls.authenticate_user_db(form.username, form.password)
if not user:
raise HTTPException(
status_code=401,
detail="用户名或密码错误",
headers={"WWW-Authenticate": "Bearer"},
)
access_token_expires = timedelta(minutes=cls.access_token_expire_minutes)
access_token = await cls.create_access_token(
data={"sub": user.username}, expires_delta=access_token_expires
)
return Token(access_token=access_token, token_type="bearer")
@classmethod
def _build_profile(cls, user: User | UserInDB | UserAccount) -> dict:
email = cls._normalize_email(getattr(user, "email", None))
md5_hash = hashlib.md5(email.encode("utf-8")).hexdigest()
gravatar_url = f"https://cn.cravatar.com/avatar/{md5_hash}?s=64&d=identicon"
return {
"id": user.id,
"username": user.username,
"email": getattr(user, "email", None),
"full_name": getattr(user, "full_name", None),
"gravatar_url": gravatar_url,
}
@classmethod
def get_profile(cls, user: User | UserInDB | UserAccount) -> dict:
return cls._build_profile(user)
@classmethod
async def update_me(cls, payload: UpdateMeRequest, current_user: User) -> dict:
db_user = await UserAccount.get_or_none(id=current_user.id)
if not db_user:
raise HTTPException(status_code=404, detail="用户不存在")
if payload.email is not None:
exists = (
await UserAccount.filter(email=payload.email)
.exclude(id=db_user.id)
.exists()
)
if exists:
raise HTTPException(status_code=400, detail="邮箱已被占用")
db_user.email = payload.email
if payload.full_name is not None:
db_user.full_name = payload.full_name
if payload.new_password:
if not payload.old_password:
raise HTTPException(status_code=400, detail="请提供原密码")
if not cls.verify_password(payload.old_password, db_user.hashed_password):
raise HTTPException(status_code=400, detail="原密码错误")
db_user.hashed_password = cls.get_password_hash(payload.new_password)
await db_user.save()
return cls._build_profile(db_user)
@classmethod
async def request_password_reset(cls, payload: PasswordResetRequest) -> bool:
normalized = cls._normalize_email(payload.email)
if not normalized:
return False
user = await UserAccount.get_or_none(email=normalized)
if not user or not user.email:
return False
token = await PasswordResetStore.create(user)
try:
await cls._send_password_reset_email(user, token)
except Exception as exc: # noqa: BLE001
await PasswordResetStore.mark_used(token)
await PasswordResetStore.invalidate_user(user.id)
raise HTTPException(status_code=500, detail="邮件发送失败") from exc
return True
@classmethod
async def verify_password_reset_token(cls, token: str) -> UserAccount:
record = await PasswordResetStore.get(token)
if not record:
raise HTTPException(status_code=400, detail="重置链接无效")
user = await UserAccount.get_or_none(id=record.user_id)
if not user:
raise HTTPException(status_code=400, detail="重置链接无效")
if record.expires_at < _now():
await PasswordResetStore.mark_used(token)
raise HTTPException(status_code=400, detail="重置链接已过期")
return user
@classmethod
async def reset_password_with_token(cls, payload: PasswordResetConfirm) -> None:
record = await PasswordResetStore.get(payload.token)
if not record:
raise HTTPException(status_code=400, detail="重置链接无效")
if record.expires_at < _now():
await PasswordResetStore.mark_used(payload.token)
raise HTTPException(status_code=400, detail="重置链接已过期")
user = await UserAccount.get_or_none(id=record.user_id)
if not user:
raise HTTPException(status_code=400, detail="重置链接无效")
user.hashed_password = cls.get_password_hash(payload.password)
await user.save(update_fields=["hashed_password"])
await PasswordResetStore.mark_used(payload.token)
await PasswordResetStore.invalidate_user(user.id)
@classmethod
async def get_current_user(cls, token: str):
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
secret_key = await cls.get_secret_key()
payload = jwt.decode(token, secret_key, algorithms=[cls.algorithm])
username = payload.get("sub")
if username is None:
raise credentials_exception
token_data = TokenData(username=username)
except InvalidTokenError:
raise credentials_exception
user = await cls.get_user_db(token_data.username)
if user is None:
raise credentials_exception
return user
@classmethod
async def get_current_active_user(cls, current_user: User):
if current_user.disabled:
raise HTTPException(status_code=400, detail="Inactive user")
return current_user
@classmethod
async def _send_password_reset_email(cls, user: UserAccount, token: str) -> None:
from domain.email.service import EmailService
app_domain = await ConfigService.get("APP_DOMAIN", None)
base_url = (app_domain or "http://localhost:5173").rstrip("/")
reset_link = f"{base_url}/reset-password?token={token}"
await EmailService.enqueue_email(
recipients=[user.email],
subject="Foxel 密码重置",
template="password_reset",
context={
"username": user.username,
"reset_link": reset_link,
"expire_minutes": cls.password_reset_token_expire_minutes,
},
)
async def _current_user_dep(token: Annotated[str, Depends(AuthService.oauth2_scheme)]):
return await AuthService.get_current_user(token)
async def _current_active_user_dep(
current_user: Annotated[User, Depends(_current_user_dep)],
):
return await AuthService.get_current_active_user(current_user)
# 方便依赖注入与外部使用
get_current_user = _current_user_dep
get_current_active_user = _current_active_user_dep
authenticate_user_db = AuthService.authenticate_user_db
create_access_token = AuthService.create_access_token
register_user = AuthService.register_user
request_password_reset = AuthService.request_password_reset
verify_password_reset_token = AuthService.verify_password_reset_token
reset_password_with_token = AuthService.reset_password_with_token
has_users = AuthService.has_users
verify_password = AuthService.verify_password
get_password_hash = AuthService.get_password_hash

45
domain/auth/types.py Normal file
View File

@@ -0,0 +1,45 @@
from pydantic import BaseModel
class Token(BaseModel):
access_token: str
token_type: str
class TokenData(BaseModel):
username: str | None = None
class User(BaseModel):
id: int
username: str
email: str | None = None
full_name: str | None = None
disabled: bool | None = None
class UserInDB(User):
hashed_password: str
class RegisterRequest(BaseModel):
username: str
password: str
email: str | None = None
full_name: str | None = None
class UpdateMeRequest(BaseModel):
email: str | None = None
full_name: str | None = None
old_password: str | None = None
new_password: str | None = None
class PasswordResetRequest(BaseModel):
email: str
class PasswordResetConfirm(BaseModel):
token: str
password: str

View File

@@ -0,0 +1 @@

30
domain/backup/api.py Normal file
View File

@@ -0,0 +1,30 @@
import datetime
from fastapi import APIRouter, Depends, File, Request, UploadFile
from fastapi.responses import JSONResponse
from domain.audit import AuditAction, audit
from domain.auth.service import get_current_active_user
from domain.backup.service import BackupService
router = APIRouter(
prefix="/api/backup",
tags=["Backup & Restore"],
dependencies=[Depends(get_current_active_user)],
)
@router.get("/export", summary="导出全站数据")
@audit(action=AuditAction.DOWNLOAD, description="导出备份")
async def export_backup(request: Request):
data = await BackupService.export_data()
timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
headers = {"Content-Disposition": f"attachment; filename=foxel_backup_{timestamp}.json"}
return JSONResponse(content=data.model_dump(), headers=headers)
@router.post("/import", summary="导入数据")
@audit(action=AuditAction.UPLOAD, description="导入备份")
async def import_backup(request: Request, file: UploadFile = File(...)):
await BackupService.import_from_bytes(file.filename, await file.read())
return {"message": "数据导入成功。"}

91
domain/backup/service.py Normal file
View File

@@ -0,0 +1,91 @@
import json
from fastapi import HTTPException
from tortoise.transactions import in_transaction
from domain.backup.types import BackupData
from domain.config.service import VERSION
from models.database import (
AutomationTask,
Configuration,
ShareLink,
StorageAdapter,
UserAccount,
)
class BackupService:
@classmethod
async def export_data(cls) -> BackupData:
async with in_transaction():
adapters = await StorageAdapter.all().values()
users = await UserAccount.all().values()
tasks = await AutomationTask.all().values()
shares = await ShareLink.all().values()
configs = await Configuration.all().values()
for share in shares:
share["created_at"] = (
share["created_at"].isoformat() if share.get("created_at") else None
)
share["expires_at"] = (
share["expires_at"].isoformat() if share.get("expires_at") else None
)
return BackupData(
version=VERSION,
storage_adapters=list(adapters),
user_accounts=list(users),
automation_tasks=list(tasks),
share_links=list(shares),
configurations=list(configs),
)
@classmethod
async def import_from_bytes(cls, filename: str, content: bytes) -> None:
if not filename.endswith(".json"):
raise HTTPException(status_code=400, detail="无效的文件类型, 请上传 .json 文件")
try:
raw_data = json.loads(content)
except Exception:
raise HTTPException(status_code=400, detail="无法解析JSON文件")
await cls.import_data(BackupData(**raw_data))
@classmethod
async def import_data(cls, payload: BackupData) -> None:
async with in_transaction() as conn:
await ShareLink.all().using_db(conn).delete()
await AutomationTask.all().using_db(conn).delete()
await StorageAdapter.all().using_db(conn).delete()
await UserAccount.all().using_db(conn).delete()
await Configuration.all().using_db(conn).delete()
if payload.configurations:
await Configuration.bulk_create(
[Configuration(**config) for config in payload.configurations],
using_db=conn,
)
if payload.user_accounts:
await UserAccount.bulk_create(
[UserAccount(**user) for user in payload.user_accounts],
using_db=conn,
)
if payload.storage_adapters:
await StorageAdapter.bulk_create(
[StorageAdapter(**adapter) for adapter in payload.storage_adapters],
using_db=conn,
)
if payload.automation_tasks:
await AutomationTask.bulk_create(
[AutomationTask(**task) for task in payload.automation_tasks],
using_db=conn,
)
if payload.share_links:
await ShareLink.bulk_create(
[ShareLink(**share) for share in payload.share_links],
using_db=conn,
)

12
domain/backup/types.py Normal file
View File

@@ -0,0 +1,12 @@
from typing import Any
from pydantic import BaseModel, Field
class BackupData(BaseModel):
version: str | None = None
storage_adapters: list[dict[str, Any]] = Field(default_factory=list)
user_accounts: list[dict[str, Any]] = Field(default_factory=list)
automation_tasks: list[dict[str, Any]] = Field(default_factory=list)
share_links: list[dict[str, Any]] = Field(default_factory=list)
configurations: list[dict[str, Any]] = Field(default_factory=list)

59
domain/config/api.py Normal file
View File

@@ -0,0 +1,59 @@
from typing import Annotated
from fastapi import APIRouter, Depends, Form, Request
from api.response import success
from domain.audit import AuditAction, audit
from domain.auth.service import get_current_active_user
from domain.auth.types import User
from domain.config.service import ConfigService
from domain.config.types import ConfigItem
router = APIRouter(prefix="/api/config", tags=["config"])
@router.get("/")
@audit(action=AuditAction.READ, description="获取配置")
async def get_config(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
key: str,
):
value = await ConfigService.get(key)
return success(ConfigItem(key=key, value=value).model_dump())
@router.post("/")
@audit(action=AuditAction.UPDATE, description="设置配置", body_fields=["key", "value"])
async def set_config(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
key: str = Form(...),
value: str = Form(...),
):
await ConfigService.set(key, value)
return success(ConfigItem(key=key, value=value).model_dump())
@router.get("/all")
@audit(action=AuditAction.READ, description="获取全部配置")
async def get_all_config(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
):
configs = await ConfigService.get_all()
return success(configs)
@router.get("/status")
@audit(action=AuditAction.READ, description="获取系统状态")
async def get_system_status(request: Request):
status_data = await ConfigService.get_system_status()
return success(status_data.model_dump())
@router.get("/latest-version")
@audit(action=AuditAction.READ, description="获取最新版本")
async def get_latest_version(request: Request):
info = await ConfigService.get_latest_version()
return success(info.model_dump())

111
domain/config/service.py Normal file
View File

@@ -0,0 +1,111 @@
import os
import time
from typing import Any, Dict, Optional
import httpx
from dotenv import load_dotenv
from domain.config.types import LatestVersionInfo, SystemStatus
from models.database import Configuration, UserAccount
load_dotenv(dotenv_path=".env")
VERSION = "v1.3.8"
class ConfigService:
_cache: Dict[str, Any] = {}
_latest_version_cache: Dict[str, Any] = {"timestamp": 0.0, "data": None}
@classmethod
async def get(cls, key: str, default: Optional[Any] = None) -> Any:
if key in cls._cache:
return cls._cache[key]
try:
config = await Configuration.get_or_none(key=key)
if config:
cls._cache[key] = config.value
return config.value
except Exception:
pass
env_value = os.getenv(key)
if env_value is not None:
cls._cache[key] = env_value
return env_value
return default
@classmethod
async def get_secret_key(cls, key: str, default: Optional[Any] = None) -> bytes:
value = await cls.get(key, default)
if isinstance(value, bytes):
return value
if isinstance(value, str):
return value.encode("utf-8")
if value is None:
raise ValueError(f"Secret key '{key}' not found in config or environment.")
return str(value).encode("utf-8")
@classmethod
async def set(cls, key: str, value: Any):
obj, _ = await Configuration.get_or_create(key=key, defaults={"value": value})
obj.value = value
await obj.save()
cls._cache[key] = value
@classmethod
async def get_all(cls) -> Dict[str, Any]:
try:
configs = await Configuration.all()
result = {}
for config in configs:
result[config.key] = config.value
cls._cache[config.key] = config.value
return result
except Exception:
return {}
@classmethod
def clear_cache(cls):
cls._cache.clear()
@classmethod
async def get_system_status(cls) -> SystemStatus:
logo = await cls.get("APP_LOGO", "/logo.svg")
favicon = await cls.get("APP_FAVICON", logo)
user_count = await UserAccount.all().count()
return SystemStatus(
version=VERSION,
title=await cls.get("APP_NAME", "Foxel"),
logo=logo,
favicon=favicon,
is_initialized=user_count > 0,
app_domain=await cls.get("APP_DOMAIN"),
file_domain=await cls.get("FILE_DOMAIN"),
)
@classmethod
async def get_latest_version(cls) -> LatestVersionInfo:
current_time = time.time()
cache = cls._latest_version_cache
if current_time - cache["timestamp"] < 3600 and cache["data"]:
return cache["data"]
try:
async with httpx.AsyncClient(timeout=10.0) as client:
resp = await client.get(
"https://api.github.com/repos/DrizzleTime/Foxel/releases/latest",
follow_redirects=True,
)
resp.raise_for_status()
data = resp.json()
version_info = LatestVersionInfo(
latest_version=data.get("tag_name"),
body=data.get("body"),
)
cache["timestamp"] = current_time
cache["data"] = version_info
return version_info
except httpx.RequestError:
if cache["data"]:
return cache["data"]
return LatestVersionInfo()

23
domain/config/types.py Normal file
View File

@@ -0,0 +1,23 @@
from typing import Any, Optional
from pydantic import BaseModel
class ConfigItem(BaseModel):
key: str
value: Optional[Any] = None
class SystemStatus(BaseModel):
version: str
title: str
logo: str
favicon: str
is_initialized: bool
app_domain: Optional[str] = None
file_domain: Optional[str] = None
class LatestVersionInfo(BaseModel):
latest_version: Optional[str] = None
body: Optional[str] = None

View File

@@ -1,20 +1,24 @@
from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, Depends, HTTPException, Request
from services.auth import User, get_current_active_user
from services.email import EmailService, EmailTemplateRenderer
from schemas.email import EmailTestRequest, EmailTemplateUpdate, EmailTemplatePreviewPayload
from api.response import success
from services.logging import LogService
router = APIRouter(
prefix="/api/email",
tags=["email"],
from domain.audit import AuditAction, audit
from domain.auth.service import get_current_active_user
from domain.auth.types import User
from domain.email.service import EmailService, EmailTemplateRenderer
from domain.email.types import (
EmailTemplatePreviewPayload,
EmailTemplateUpdate,
EmailTestRequest,
)
router = APIRouter(prefix="/api/email", tags=["email"])
@router.post("/test")
@audit(action=AuditAction.CREATE, description="发送测试邮件")
async def trigger_test_email(
request: Request,
payload: EmailTestRequest,
current_user: User = Depends(get_current_active_user),
):
@@ -27,17 +31,13 @@ async def trigger_test_email(
)
except Exception as exc:
raise HTTPException(status_code=400, detail=str(exc))
await LogService.action(
"route:email",
"Triggered email test",
details={"task_id": task.id, "template": payload.template, "to": str(payload.to)},
user_id=getattr(current_user, "id", None),
)
return success({"task_id": task.id})
@router.get("/templates")
@audit(action=AuditAction.READ, description="获取邮件模板列表")
async def list_email_templates(
request: Request,
current_user: User = Depends(get_current_active_user),
):
templates = await EmailTemplateRenderer.list_templates()
@@ -45,7 +45,9 @@ async def list_email_templates(
@router.get("/templates/{name}")
@audit(action=AuditAction.READ, description="查看邮件模板")
async def get_email_template(
request: Request,
name: str,
current_user: User = Depends(get_current_active_user),
):
@@ -59,7 +61,9 @@ async def get_email_template(
@router.post("/templates/{name}")
@audit(action=AuditAction.UPDATE, description="更新邮件模板")
async def update_email_template(
request: Request,
name: str,
payload: EmailTemplateUpdate,
current_user: User = Depends(get_current_active_user),
@@ -68,17 +72,13 @@ async def update_email_template(
await EmailTemplateRenderer.save(name, payload.content)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc))
await LogService.action(
"route:email",
"Updated email template",
details={"template": name},
user_id=getattr(current_user, "id", None),
)
return success({"name": name})
@router.post("/templates/{name}/preview")
@audit(action=AuditAction.READ, description="预览邮件模板")
async def preview_email_template(
request: Request,
name: str,
payload: EmailTemplatePreviewPayload,
current_user: User = Depends(get_current_active_user),

View File

@@ -1,42 +1,14 @@
import asyncio
import json
import re
import smtplib
from email.message import EmailMessage
from email.utils import formataddr
from enum import Enum
from pathlib import Path
from string import Template
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, EmailStr, Field, ValidationError
from services.config import ConfigCenter
from services.logging import LogService
class EmailSecurity(str, Enum):
NONE = "none"
SSL = "ssl"
STARTTLS = "starttls"
class EmailConfig(BaseModel):
host: str
port: int = Field(..., gt=0)
username: Optional[str] = None
password: Optional[str] = None
sender_email: EmailStr
sender_name: Optional[str] = None
security: EmailSecurity = EmailSecurity.NONE
timeout: float = Field(default=30.0, gt=0.0)
class EmailSendPayload(BaseModel):
recipients: List[EmailStr] = Field(..., min_length=1)
subject: str = Field(..., min_length=1)
template: str = Field(..., min_length=1)
context: Dict[str, Any] = Field(default_factory=dict)
from domain.config.service import ConfigService
from domain.email.types import EmailConfig, EmailSecurity, EmailSendPayload
class EmailTemplateRenderer:
@@ -52,9 +24,7 @@ class EmailTemplateRenderer:
async def list_templates(cls) -> list[str]:
cls.ROOT.mkdir(parents=True, exist_ok=True)
return sorted(
path.stem
for path in cls.ROOT.glob("*.html")
if path.is_file()
path.stem for path in cls.ROOT.glob("*.html") if path.is_file()
)
@classmethod
@@ -82,22 +52,8 @@ class EmailService:
@classmethod
async def _load_config(cls) -> EmailConfig:
raw_config = await ConfigCenter.get(cls.CONFIG_KEY)
if raw_config is None:
raise ValueError("Email configuration not found")
if isinstance(raw_config, str):
raw_config = raw_config.strip()
data: Any = json.loads(raw_config) if raw_config else {}
elif isinstance(raw_config, dict):
data = raw_config
else:
raise ValueError("Invalid email configuration format")
try:
return EmailConfig(**data)
except ValidationError as exc:
raise ValueError(f"Invalid email configuration: {exc}") from exc
raw_config = await ConfigService.get(cls.CONFIG_KEY)
return EmailConfig.parse_config(raw_config)
@staticmethod
def _html_to_text(html: str) -> str:
@@ -108,7 +64,9 @@ class EmailService:
async def _deliver(cls, config: EmailConfig, payload: EmailSendPayload, html_body: str):
message = EmailMessage()
message["Subject"] = payload.subject
message["From"] = formataddr((config.sender_name or str(config.sender_email), str(config.sender_email)))
message["From"] = formataddr(
(config.sender_name or str(config.sender_email), str(config.sender_email))
)
message["To"] = ", ".join([str(addr) for addr in payload.recipients])
plain_body = cls._html_to_text(html_body)
@@ -120,7 +78,9 @@ class EmailService:
@staticmethod
def _deliver_sync(config: EmailConfig, message: EmailMessage):
if config.security == EmailSecurity.SSL:
smtp: smtplib.SMTP = smtplib.SMTP_SSL(config.host, config.port, timeout=config.timeout)
smtp: smtplib.SMTP = smtplib.SMTP_SSL(
config.host, config.port, timeout=config.timeout
)
else:
smtp = smtplib.SMTP(config.host, config.port, timeout=config.timeout)
@@ -144,7 +104,7 @@ class EmailService:
template: str,
context: Optional[Dict[str, Any]] = None,
):
from services.task_queue import TaskProgress, task_queue_service
from domain.tasks.task_queue import TaskProgress, task_queue_service
payload = EmailSendPayload(
recipients=recipients,
@@ -162,16 +122,11 @@ class EmailService:
task.id,
TaskProgress(stage="queued", percent=0.0, detail="Waiting to send"),
)
await LogService.action(
"email_service",
"Email task enqueued",
details={"task_id": task.id, "subject": subject, "template": template},
)
return task
@classmethod
async def send_from_task(cls, task_id: str, data: Dict[str, Any]):
from services.task_queue import TaskProgress, task_queue_service
from domain.tasks.task_queue import TaskProgress, task_queue_service
payload = EmailSendPayload(**data)
@@ -194,8 +149,3 @@ class EmailService:
task_id,
TaskProgress(stage="completed", percent=100.0, detail="Email sent"),
)
await LogService.info(
"email_service",
"Email sent",
details={"task_id": task_id, "subject": payload.subject},
)

63
domain/email/types.py Normal file
View File

@@ -0,0 +1,63 @@
import json
from enum import Enum
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, EmailStr, Field, ValidationError
class EmailSecurity(str, Enum):
NONE = "none"
SSL = "ssl"
STARTTLS = "starttls"
class EmailConfig(BaseModel):
host: str
port: int = Field(..., gt=0)
username: Optional[str] = None
password: Optional[str] = None
sender_email: EmailStr
sender_name: Optional[str] = None
security: EmailSecurity = EmailSecurity.NONE
timeout: float = Field(default=30.0, gt=0.0)
@classmethod
def parse_config(cls, raw_config: Any) -> "EmailConfig":
"""接受字符串或 dict 配置并解析为 EmailConfig。"""
if raw_config is None:
raise ValueError("Email configuration not found")
if isinstance(raw_config, str):
raw_config = raw_config.strip()
data: Any = json.loads(raw_config) if raw_config else {}
elif isinstance(raw_config, dict):
data = raw_config
else:
raise ValueError("Invalid email configuration format")
try:
return cls(**data)
except ValidationError as exc:
raise ValueError(f"Invalid email configuration: {exc}") from exc
class EmailSendPayload(BaseModel):
recipients: List[EmailStr] = Field(..., min_length=1)
subject: str = Field(..., min_length=1)
template: str = Field(..., min_length=1)
context: Dict[str, Any] = Field(default_factory=dict)
class EmailTestRequest(BaseModel):
to: EmailStr
subject: str = Field(..., min_length=1)
template: str = Field(default="test", min_length=1)
context: Dict[str, Any] = Field(default_factory=dict)
class EmailTemplateUpdate(BaseModel):
content: str
class EmailTemplatePreviewPayload(BaseModel):
context: Dict[str, Any] = Field(default_factory=dict)

View File

@@ -0,0 +1,42 @@
from typing import Annotated
from fastapi import APIRouter, Depends, Request
from api.response import success
from domain.audit import AuditAction, audit
from domain.auth.service import get_current_active_user
from domain.auth.types import User
from domain.offline_downloads.service import OfflineDownloadService
from domain.offline_downloads.types import OfflineDownloadCreate
CurrentUser = Annotated[User, Depends(get_current_active_user)]
router = APIRouter(
prefix="/api/offline-downloads",
tags=["OfflineDownloads"],
)
@router.post("/")
@audit(
action=AuditAction.CREATE,
description="创建离线下载任务",
body_fields=["url", "dest_dir", "filename"],
)
async def create_offline_download(request: Request, payload: OfflineDownloadCreate, current_user: CurrentUser):
data = await OfflineDownloadService.create_download(payload, current_user)
return success(data)
@router.get("/")
@audit(action=AuditAction.READ, description="获取离线下载列表")
async def list_offline_downloads(request: Request, current_user: CurrentUser):
data = OfflineDownloadService.list_downloads()
return success(data)
@router.get("/{task_id}")
@audit(action=AuditAction.READ, description="获取离线下载详情")
async def get_offline_download(task_id: str, request: Request, current_user: CurrentUser):
data = OfflineDownloadService.get_download(task_id)
return success(data)

View File

@@ -0,0 +1,252 @@
import os
import time
from pathlib import Path
from typing import Annotated, AsyncIterator
import aiofiles
import aiohttp
from fastapi import Depends, HTTPException
from domain.auth.service import get_current_active_user
from domain.auth.types import User
from domain.offline_downloads.types import OfflineDownloadCreate
from domain.virtual_fs.service import VirtualFSService
from domain.tasks.task_queue import Task, TaskProgress, task_queue_service
class OfflineDownloadService:
current_user_dep = Annotated[User, Depends(get_current_active_user)]
temp_root = Path("data/tmp/offline_downloads")
@classmethod
async def create_download(cls, payload: OfflineDownloadCreate, current_user: User) -> dict:
await cls._ensure_destination(payload.dest_dir)
task = await task_queue_service.add_task(
"offline_http_download",
{
"url": str(payload.url),
"dest_dir": payload.dest_dir,
"filename": payload.filename,
},
)
await task_queue_service.update_progress(
task.id,
TaskProgress(
stage="queued",
percent=0.0,
bytes_total=None,
bytes_done=0,
detail="Waiting to start",
),
)
return {"task_id": task.id}
@classmethod
def list_downloads(cls) -> list[dict]:
tasks = [t for t in task_queue_service.get_all_tasks() if t.name == "offline_http_download"]
return [t.dict() for t in tasks]
@classmethod
def get_download(cls, task_id: str) -> dict:
task = task_queue_service.get_task(task_id)
if not task or task.name != "offline_http_download":
raise HTTPException(status_code=404, detail="Task not found")
return task.dict()
@classmethod
async def run_http_download(cls, task: Task):
params = task.task_info
url = params.get("url")
dest_dir = params.get("dest_dir")
filename = params.get("filename")
if not url or not dest_dir or not filename:
raise ValueError("Missing required parameters for offline download")
cls.temp_root.mkdir(parents=True, exist_ok=True)
temp_dir = cls.temp_root / task.id
temp_dir.mkdir(parents=True, exist_ok=True)
temp_file = temp_dir / "payload"
bytes_total: int | None = None
bytes_done = 0
last_update = time.monotonic()
await task_queue_service.update_progress(
task.id,
TaskProgress(
stage="downloading",
percent=0.0,
bytes_total=None,
bytes_done=0,
detail="HTTP downloading",
),
)
async def report_download(delta: int, total: int | None):
nonlocal bytes_done, bytes_total, last_update
if total is not None:
bytes_total = total
bytes_done += delta
now = time.monotonic()
if delta and now - last_update < 0.5:
return
last_update = now
percent = None
total_for_display = bytes_total if bytes_total is not None else None
if bytes_total:
percent = min(100.0, round(bytes_done / bytes_total * 100, 2))
await task_queue_service.update_progress(
task.id,
TaskProgress(
stage="downloading",
percent=percent,
bytes_total=total_for_display,
bytes_done=bytes_done,
detail="HTTP downloading",
),
)
timeout = aiohttp.ClientTimeout(total=None, connect=30)
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.get(url) as resp:
if resp.status != 200:
raise ValueError(f"HTTP {resp.status} for {url}")
content_length = resp.headers.get("Content-Length")
total_size = int(content_length) if content_length else None
bytes_done = 0
async with aiofiles.open(temp_file, "wb") as f:
async for chunk in resp.content.iter_chunked(512 * 1024):
if not chunk:
continue
await f.write(chunk)
await report_download(len(chunk), total_size)
await report_download(0, total_size)
file_size = os.path.getsize(temp_file)
bytes_done_transfer = 0
async def report_transfer(delta: int):
nonlocal bytes_done_transfer
bytes_done_transfer += delta
percent = min(100.0, round(bytes_done_transfer / file_size * 100, 2)) if file_size else None
await task_queue_service.update_progress(
task.id,
TaskProgress(
stage="transferring",
percent=percent,
bytes_total=file_size or None,
bytes_done=bytes_done_transfer,
detail="Saving to storage",
),
)
async def chunk_iter() -> AsyncIterator[bytes]:
async for chunk in cls._iter_file(temp_file, 512 * 1024, report_transfer):
yield chunk
final_path, resolved_name = await cls._allocate_destination(dest_dir, filename)
await task_queue_service.update_progress(
task.id,
TaskProgress(
stage="transferring",
percent=0.0,
bytes_total=file_size or None,
bytes_done=0,
detail="Saving to storage",
),
)
await VirtualFSService.write_file_stream(final_path, chunk_iter())
await task_queue_service.update_progress(
task.id,
TaskProgress(
stage="completed",
percent=100.0,
bytes_total=file_size or None,
bytes_done=file_size,
detail="Completed",
),
)
await task_queue_service.update_meta(task.id, {"final_path": final_path, "filename": resolved_name})
try:
os.remove(temp_file)
temp_dir.rmdir()
except Exception:
pass
return final_path
@classmethod
async def _ensure_destination(cls, dest_dir: str) -> None:
try:
is_dir = await VirtualFSService.path_is_directory(dest_dir)
except HTTPException:
is_dir = False
if not is_dir:
raise HTTPException(400, detail="Destination directory not found")
@staticmethod
def _normalize_path(path: str) -> str:
if not path:
return "/"
if not path.startswith("/"):
path = "/" + path
if len(path) > 1 and path.endswith("/"):
path = path.rstrip("/")
return path or "/"
@staticmethod
async def _path_exists(full_path: str) -> bool:
try:
await VirtualFSService.stat_file(full_path)
return True
except FileNotFoundError:
return False
except HTTPException as exc: # noqa: PERF203
if exc.status_code == 404:
return False
raise
@classmethod
async def _allocate_destination(cls, dest_dir: str, filename: str) -> tuple[str, str]:
dest_dir = cls._normalize_path(dest_dir)
stem, suffix = cls._split_filename(filename)
candidate = filename
base = "" if dest_dir == "/" else dest_dir
attempt = 0
while await cls._path_exists(f"{base}/{candidate}" if base else f"/{candidate}"):
attempt += 1
if stem:
candidate = f"{stem} ({attempt}){suffix}"
else:
candidate = f"file ({attempt}){suffix}" if suffix else f"file ({attempt})"
full_path = f"{base}/{candidate}" if base else f"/{candidate}"
return full_path, candidate
@staticmethod
def _split_filename(filename: str) -> tuple[str, str]:
if not filename:
return "", ""
if filename.startswith(".") and filename.count(".") == 1:
return filename, ""
if "." not in filename:
return filename, ""
stem, ext = filename.rsplit(".", 1)
return stem, f".{ext}"
@staticmethod
async def _iter_file(path: Path, chunk_size: int, report_cb) -> AsyncIterator[bytes]:
async with aiofiles.open(path, "rb") as f:
while True:
chunk = await f.read(chunk_size)
if not chunk:
break
await report_cb(len(chunk))
yield chunk

View File

@@ -0,0 +1 @@

66
domain/plugins/api.py Normal file
View File

@@ -0,0 +1,66 @@
from typing import List
from fastapi import APIRouter, Body, Request
from domain.audit import AuditAction, audit
from domain.plugins.service import PluginService
from domain.plugins.types import PluginCreate, PluginManifestUpdate, PluginOut
router = APIRouter(prefix="/api/plugins", tags=["plugins"])
@router.post("", response_model=PluginOut)
@audit(
action=AuditAction.CREATE,
description="创建插件",
body_fields=["url", "enabled"],
)
async def create_plugin(request: Request, payload: PluginCreate):
return await PluginService.create(payload)
@router.get("", response_model=List[PluginOut])
@audit(action=AuditAction.READ, description="获取插件列表")
async def list_plugins(request: Request):
return await PluginService.list_plugins()
@router.delete("/{plugin_id}")
@audit(action=AuditAction.DELETE, description="删除插件")
async def delete_plugin(request: Request, plugin_id: int):
await PluginService.delete(plugin_id)
return {"code": 0, "msg": "ok"}
@router.put("/{plugin_id}", response_model=PluginOut)
@audit(
action=AuditAction.UPDATE,
description="更新插件",
body_fields=["url", "enabled"],
)
async def update_plugin(request: Request, plugin_id: int, payload: PluginCreate):
return await PluginService.update(plugin_id, payload)
@router.post("/{plugin_id}/metadata", response_model=PluginOut)
@audit(
action=AuditAction.UPDATE,
description="更新插件 manifest",
body_fields=[
"key",
"name",
"version",
"supported_exts",
"default_bounds",
"default_maximized",
"icon",
"description",
"author",
"website",
"github",
],
)
async def update_manifest(
request: Request, plugin_id: int, manifest: PluginManifestUpdate = Body(...)
):
return await PluginService.update_manifest(plugin_id, manifest)

48
domain/plugins/service.py Normal file
View File

@@ -0,0 +1,48 @@
from fastapi import HTTPException
from domain.plugins.types import PluginCreate, PluginManifestUpdate, PluginOut
from models.database import Plugin
class PluginService:
@classmethod
async def create(cls, payload: PluginCreate) -> PluginOut:
rec = await Plugin.create(**payload.model_dump())
return PluginOut.model_validate(rec)
@classmethod
async def list_plugins(cls) -> list[PluginOut]:
rows = await Plugin.all().order_by("-id")
return [PluginOut.model_validate(r) for r in rows]
@classmethod
async def _get_or_404(cls, plugin_id: int) -> Plugin:
rec = await Plugin.get_or_none(id=plugin_id)
if not rec:
raise HTTPException(status_code=404, detail="Plugin not found")
return rec
@classmethod
async def delete(cls, plugin_id: int) -> None:
rec = await cls._get_or_404(plugin_id)
await rec.delete()
@classmethod
async def update(cls, plugin_id: int, payload: PluginCreate) -> PluginOut:
rec = await cls._get_or_404(plugin_id)
rec.url = payload.url
rec.enabled = payload.enabled
await rec.save()
return PluginOut.model_validate(rec)
@classmethod
async def update_manifest(
cls, plugin_id: int, manifest: PluginManifestUpdate
) -> PluginOut:
rec = await cls._get_or_404(plugin_id)
updates = manifest.model_dump(exclude_none=True)
if updates:
for key, value in updates.items():
setattr(rec, key, value)
await rec.save()
return PluginOut.model_validate(rec)

52
domain/plugins/types.py Normal file
View File

@@ -0,0 +1,52 @@
from typing import Any, Dict, List, Optional
from pydantic import AliasChoices, BaseModel, ConfigDict, Field
class PluginCreate(BaseModel):
url: str = Field(min_length=1)
enabled: bool = True
class PluginManifestUpdate(BaseModel):
model_config = ConfigDict(populate_by_name=True, extra="ignore")
key: Optional[str] = None
name: Optional[str] = None
version: Optional[str] = None
supported_exts: Optional[List[str]] = Field(
default=None,
validation_alias=AliasChoices("supported_exts", "supportedExts"),
)
default_bounds: Optional[Dict[str, Any]] = Field(
default=None,
validation_alias=AliasChoices("default_bounds", "defaultBounds"),
)
default_maximized: Optional[bool] = Field(
default=None,
validation_alias=AliasChoices("default_maximized", "defaultMaximized"),
)
icon: Optional[str] = None
description: Optional[str] = None
author: Optional[str] = None
website: Optional[str] = None
github: Optional[str] = None
class PluginOut(BaseModel):
id: int
url: str
enabled: bool
key: Optional[str] = None
name: Optional[str] = None
version: Optional[str] = None
supported_exts: Optional[List[str]] = None
default_bounds: Optional[Dict[str, Any]] = None
default_maximized: Optional[bool] = None
icon: Optional[str] = None
description: Optional[str] = None
author: Optional[str] = None
website: Optional[str] = None
github: Optional[str] = None
model_config = ConfigDict(from_attributes=True)

89
domain/processors/api.py Normal file
View File

@@ -0,0 +1,89 @@
from typing import Annotated
from fastapi import APIRouter, Body, Depends, Request
from api.response import success
from domain.audit import AuditAction, audit
from domain.auth.service import get_current_active_user
from domain.auth.types import User
from domain.processors.service import ProcessorService
from domain.processors.types import (
ProcessDirectoryRequest,
ProcessRequest,
UpdateSourceRequest,
)
router = APIRouter(prefix="/api/processors", tags=["processors"])
@router.get("")
@audit(action=AuditAction.READ, description="获取处理器列表")
async def list_processors(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
):
data = ProcessorService.list_processors()
return success(data)
@router.post("/process")
@audit(
action=AuditAction.CREATE,
description="处理单个文件",
body_fields=["path", "processor_type", "save_to", "overwrite"],
)
async def process_file_with_processor(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
req: ProcessRequest = Body(...),
):
data = await ProcessorService.process_file(req)
return success(data)
@router.post("/process-directory")
@audit(
action=AuditAction.CREATE,
description="批量处理目录",
body_fields=["path", "processor_type", "overwrite", "max_depth", "suffix"],
)
async def process_directory_with_processor(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
req: ProcessDirectoryRequest = Body(...),
):
data = await ProcessorService.process_directory(req)
return success(data)
@router.get("/source/{processor_type}")
@audit(action=AuditAction.READ, description="获取处理器源码")
async def get_processor_source(
request: Request,
processor_type: str,
current_user: Annotated[User, Depends(get_current_active_user)],
):
data = await ProcessorService.get_source(processor_type)
return success(data)
@router.put("/source/{processor_type}")
@audit(action=AuditAction.UPDATE, description="更新处理器源码")
async def update_processor_source(
request: Request,
processor_type: str,
req: UpdateSourceRequest,
current_user: Annotated[User, Depends(get_current_active_user)],
):
data = await ProcessorService.update_source(processor_type, req)
return success(data)
@router.post("/reload")
@audit(action=AuditAction.UPDATE, description="重载处理器模块")
async def reload_processor_modules(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
):
data = ProcessorService.reload()
return success(data)

View File

@@ -0,0 +1 @@
# 内置处理器包

View File

@@ -1,9 +1,11 @@
from .base import BaseProcessor
from typing import Dict, Any
from PIL import Image, ImageDraw, ImageFont
from io import BytesIO
from PIL import Image, ImageDraw, ImageFont
from fastapi.responses import Response
from services.logging import LogService
from ..base import BaseProcessor
class ImageWatermarkProcessor:
name = "图片水印"
@@ -26,10 +28,11 @@ class ImageWatermarkProcessor:
]
produces_file = True
async def process(self, input_bytes: bytes,path: str, config: Dict[str, Any]) -> Response:
async def process(self, input_bytes: bytes, path: str, config: Dict[str, Any]) -> Response:
text = config.get("text", "")
position = config.get("position", "bottom-right")
font_size = int(config.get("font_size", 24))
img = Image.open(BytesIO(input_bytes)).convert("RGBA")
watermark = Image.new("RGBA", img.size)
draw = ImageDraw.Draw(watermark)
@@ -37,29 +40,29 @@ class ImageWatermarkProcessor:
font = ImageFont.truetype("arial.ttf", font_size)
except Exception:
font = ImageFont.load_default()
w, h = img.size
try:
text_w, text_h = font.getsize(text)
except AttributeError:
bbox = draw.textbbox((0, 0), text, font=font)
text_w, text_h = bbox[2] - bbox[0], bbox[3] - bbox[1]
if position == "bottom-right":
xy = (w - text_w - 10, h - text_h - 10)
elif position == "top-left":
xy = (10, 10)
else:
xy = (w // 2 - text_w // 2, h // 2 - text_h // 2)
draw.text(xy, text, font=font, fill=(255, 255, 255, 128))
out = Image.alpha_composite(img, watermark)
buf = BytesIO()
out.convert("RGB").save(buf, format="JPEG")
await LogService.info(
"processor:image_watermark",
f"Watermarked image {path}",
details={"path": path, "config": config},
)
return Response(content=buf.getvalue(), media_type="image/jpeg")
PROCESSOR_TYPE = "image_watermark"
PROCESSOR_NAME = ImageWatermarkProcessor.name
SUPPORTED_EXTS = ImageWatermarkProcessor.supported_exts

View File

@@ -1,15 +1,15 @@
from typing import Dict, Any, List, Tuple
from fastapi.responses import Response
import base64
import mimetypes
import os
from io import BytesIO
from typing import Dict, Any, List, Tuple
from services.ai import describe_image_base64, get_text_embedding, provider_service
from services.vector_db import VectorDBService, DEFAULT_VECTOR_DIMENSION
from services.logging import LogService
from fastapi.responses import Response
from PIL import Image
from ..base import BaseProcessor
from domain.ai.inference import describe_image_base64, get_text_embedding, provider_service
from domain.ai.service import VectorDBService, DEFAULT_VECTOR_DIMENSION
CHUNK_SIZE = 800
@@ -116,11 +116,6 @@ class VectorIndexProcessor:
if action == "destroy":
await vector_db.delete_vector(collection_name, path)
await LogService.info(
"processor:vector_index",
f"Destroyed {index_type} index for {path}",
details={"path": path, "action": "destroy", "index_type": index_type},
)
return Response(content=f"文件 {path}{index_type} 索引已销毁", media_type="text/plain")
mime_type = _guess_mime(path)
@@ -136,11 +131,6 @@ class VectorIndexProcessor:
"type": "filename",
"name": os.path.basename(path),
})
await LogService.info(
"processor:vector_index",
f"Created simple index for {path}",
details={"path": path, "action": "create", "index_type": "simple"},
)
return Response(content=f"文件 {path} 的普通索引已创建", media_type="text/plain")
file_ext = path.split('.')[-1].lower()
@@ -177,11 +167,6 @@ class VectorIndexProcessor:
details["description"] = description
if compression:
details["image_compression"] = compression
await LogService.info(
"processor:vector_index",
f"Indexed image {path}",
details=details,
)
return Response(content=f"图片已索引,描述:{description}", media_type="text/plain")
if file_ext in ["txt", "md"]:
@@ -204,11 +189,6 @@ class VectorIndexProcessor:
"end_offset": len(text),
})
details["chunks"] = 1
await LogService.info(
"processor:vector_index",
f"Indexed text file {path}",
details=details,
)
return Response(content="文本文件已索引", media_type="text/plain")
chunk_count = 0
@@ -230,11 +210,6 @@ class VectorIndexProcessor:
details["chunks"] = chunk_count
sample = chunks[0][1]
details["sample"] = sample[:120]
await LogService.info(
"processor:vector_index",
f"Indexed text file {path}",
details=details,
)
return Response(content="文本文件已索引", media_type="text/plain")
# 其他类型暂未支持向量索引,回退为文件名索引
@@ -248,11 +223,6 @@ class VectorIndexProcessor:
"name": os.path.basename(path),
"embedding": [0.0] * vector_dim,
})
await LogService.info(
"processor:vector_index",
f"File type fallback to simple index for {path}",
details={"path": path, "action": "create", "index_type": "simple", "original_type": file_ext},
)
return Response(content="暂不支持该类型的向量索引,已创建文件名索引", media_type="text/plain")

View File

@@ -5,7 +5,7 @@ from pathlib import Path
from types import ModuleType
from typing import Callable, Dict, Optional
from .base import BaseProcessor
from domain.processors.base import BaseProcessor
ProcessorFactory = Callable[[], BaseProcessor]
TYPE_MAP: Dict[str, ProcessorFactory] = {}
@@ -15,10 +15,9 @@ LAST_DISCOVERY_ERRORS: list[str] = []
def discover_processors(force_reload: bool = False) -> list[str]:
"""Discover available processor modules and cache their metadata."""
import services.processors # 延迟导入以避免循环
"""扫描并缓存可用的处理器模块。"""
from domain.processors import builtin as processors_pkg
processors_pkg = services.processors
TYPE_MAP.clear()
CONFIG_SCHEMAS.clear()
MODULE_MAP.clear()
@@ -51,8 +50,10 @@ def discover_processors(force_reload: bool = False) -> list[str]:
if factory is None:
for attr in module.__dict__.values():
if inspect.isclass(attr) and attr.__name__.endswith("Processor"):
def _mk(cls=attr):
return lambda: cls()
factory = _mk()
break
@@ -114,7 +115,7 @@ def get_config_schema(processor_type: str):
return CONFIG_SCHEMAS.get(processor_type)
def get(processor_type: str) -> BaseProcessor:
def get(processor_type: str) -> BaseProcessor | None:
factory = TYPE_MAP.get(processor_type)
if factory:
return factory()

View File

@@ -0,0 +1,217 @@
from pathlib import Path
from typing import List, Tuple
from fastapi import HTTPException
from fastapi.concurrency import run_in_threadpool
from domain.processors.registry import (
get,
get_config_schema,
get_config_schemas,
get_module_path,
reload_processors,
)
from domain.processors.types import (
ProcessDirectoryRequest,
ProcessRequest,
UpdateSourceRequest,
)
from domain.virtual_fs.service import VirtualFSService
from domain.tasks.task_queue import task_queue_service
class ProcessorService:
@classmethod
def get_processor(cls, processor_type: str):
return get(processor_type)
@classmethod
def list_processors(cls):
schemas = get_config_schemas()
out = []
for t, meta in schemas.items():
out.append({
"type": meta["type"],
"name": meta["name"],
"supported_exts": meta.get("supported_exts", []),
"config_schema": meta["config_schema"],
"produces_file": meta.get("produces_file", False),
"module_path": meta.get("module_path"),
})
return out
@classmethod
async def process_file(cls, req: ProcessRequest):
is_dir = await VirtualFSService.path_is_directory(req.path)
if is_dir and not req.overwrite:
raise HTTPException(400, detail="Directory processing requires overwrite")
save_to = None if is_dir else (req.path if req.overwrite else req.save_to)
task = await task_queue_service.add_task(
"process_file",
{
"path": req.path,
"processor_type": req.processor_type,
"config": req.config,
"save_to": save_to,
"overwrite": req.overwrite,
},
)
return {"task_id": task.id}
@classmethod
async def process_directory(cls, req: ProcessDirectoryRequest):
if req.max_depth is not None and req.max_depth < 0:
raise HTTPException(400, detail="max_depth must be >= 0")
is_dir = await VirtualFSService.path_is_directory(req.path)
if not is_dir:
raise HTTPException(400, detail="Path must be a directory")
schema = get_config_schema(req.processor_type)
_processor = get(req.processor_type)
if not schema or not _processor:
raise HTTPException(404, detail="Processor not found")
produces_file = bool(schema.get("produces_file"))
raw_suffix = req.suffix if req.suffix is not None else None
if raw_suffix is not None and raw_suffix.strip() == "":
raw_suffix = None
suffix = raw_suffix
overwrite = req.overwrite
if produces_file:
if not overwrite and not suffix:
raise HTTPException(400, detail="Suffix is required when not overwriting files")
else:
overwrite = False
suffix = None
supported_exts = schema.get("supported_exts") or []
allowed_exts = {
ext.lower().lstrip('.')
for ext in supported_exts
if isinstance(ext, str)
}
def matches_extension(file_rel: str) -> bool:
if not allowed_exts:
return True
if '.' not in file_rel:
return '' in allowed_exts
ext = file_rel.rsplit('.', 1)[-1].lower()
return ext in allowed_exts or f'.{ext}' in allowed_exts
adapter_instance, adapter_model, root, rel = await VirtualFSService.resolve_adapter_and_rel(req.path)
rel = rel.rstrip('/')
list_dir = getattr(adapter_instance, "list_dir", None)
if not callable(list_dir):
raise HTTPException(501, detail="Adapter does not implement list_dir")
def build_absolute_path(mount_path: str, rel_path: str) -> str:
rel_norm = rel_path.lstrip('/')
mount_norm = mount_path.rstrip('/')
if not mount_norm:
return '/' + rel_norm if rel_norm else '/'
return f"{mount_norm}/{rel_norm}" if rel_norm else mount_norm
def apply_suffix(path_str: str, suffix_str: str) -> str:
path_obj = Path(path_str)
name = path_obj.name
if not name:
return path_str
if '.' in name:
base, ext = name.rsplit('.', 1)
new_name = f"{base}{suffix_str}.{ext}"
else:
new_name = f"{name}{suffix_str}"
return str(path_obj.with_name(new_name))
scheduled_tasks: List[str] = []
stack: List[Tuple[str, int]] = [(rel, 0)]
page_size = 200
while stack:
current_rel, depth = stack.pop()
page = 1
while True:
entries, total = await list_dir(root, current_rel, page, page_size, "name", "asc")
entries = entries or []
if not entries and (total or 0) == 0:
break
for entry in entries:
name = entry.get("name")
if not name:
continue
child_rel = f"{current_rel}/{name}" if current_rel else name
if entry.get("is_dir"):
if req.max_depth is None or depth < req.max_depth:
stack.append((child_rel.rstrip('/'), depth + 1))
continue
if not matches_extension(child_rel):
continue
absolute_path = build_absolute_path(adapter_model.path, child_rel)
save_to = None
if produces_file and not overwrite and suffix:
save_to = apply_suffix(absolute_path, suffix)
task = await task_queue_service.add_task(
"process_file",
{
"path": absolute_path,
"processor_type": req.processor_type,
"config": req.config,
"save_to": save_to,
"overwrite": overwrite,
},
)
scheduled_tasks.append(task.id)
if total is None or page * page_size >= total:
break
page += 1
return {
"task_ids": scheduled_tasks,
"scheduled": len(scheduled_tasks),
}
@classmethod
async def get_source(cls, processor_type: str):
module_path = get_module_path(processor_type)
if not module_path:
raise HTTPException(404, detail="Processor not found")
path_obj = Path(module_path)
if not path_obj.exists():
raise HTTPException(404, detail="Processor source not found")
try:
content = await run_in_threadpool(path_obj.read_text, encoding='utf-8')
except Exception as exc:
raise HTTPException(500, detail=f"Failed to read source: {exc}")
return {"source": content, "module_path": str(path_obj)}
@classmethod
async def update_source(cls, processor_type: str, req: UpdateSourceRequest):
module_path = get_module_path(processor_type)
if not module_path:
raise HTTPException(404, detail="Processor not found")
path_obj = Path(module_path)
if not path_obj.exists():
raise HTTPException(404, detail="Processor source not found")
try:
await run_in_threadpool(path_obj.write_text, req.source, encoding='utf-8')
except Exception as exc:
raise HTTPException(500, detail=f"Failed to write source: {exc}")
return True
@classmethod
def reload(cls):
errors = reload_processors()
if errors:
raise HTTPException(500, detail="; ".join(errors))
return True
get_processor = ProcessorService.get_processor
list_processors = ProcessorService.list_processors
reload_processor_modules = ProcessorService.reload

View File

@@ -0,0 +1,24 @@
from typing import Any, Dict, Optional
from pydantic import BaseModel
class ProcessRequest(BaseModel):
path: str
processor_type: str
config: Dict[str, Any]
save_to: Optional[str] = None
overwrite: bool = False
class ProcessDirectoryRequest(BaseModel):
path: str
processor_type: str
config: Dict[str, Any]
overwrite: bool = True
max_depth: Optional[int] = None
suffix: Optional[str] = None
class UpdateSourceRequest(BaseModel):
source: str

129
domain/share/api.py Normal file
View File

@@ -0,0 +1,129 @@
from typing import Annotated, List, Optional
from fastapi import APIRouter, Depends, Request
from api.response import success
from domain.audit import AuditAction, audit
from domain.auth.service import get_current_active_user
from domain.auth.types import User
from domain.share.service import ShareService
from domain.share.types import (
ShareCreate,
ShareInfo,
ShareInfoWithPassword,
SharePassword,
)
from models.database import UserAccount
public_router = APIRouter(prefix="/api/s", tags=["Share - Public"])
router = APIRouter(prefix="/api/shares", tags=["Share - Management"])
@router.post("", response_model=ShareInfoWithPassword)
@audit(
action=AuditAction.SHARE,
description="创建分享链接",
body_fields=["name", "paths", "expires_in_days", "access_type"],
)
async def create_share(
request: Request,
payload: ShareCreate,
current_user: Annotated[User, Depends(get_current_active_user)],
):
user_account = await UserAccount.get(id=current_user.id)
share = await ShareService.create_share_link(
user=user_account,
name=payload.name,
paths=payload.paths,
expires_in_days=payload.expires_in_days,
access_type=payload.access_type,
password=payload.password,
)
share_info = ShareInfo.from_orm(share).model_dump()
if payload.access_type == "password" and payload.password:
share_info["password"] = payload.password
return share_info
@router.get("", response_model=List[ShareInfo])
@audit(action=AuditAction.READ, description="获取我的分享列表")
async def get_my_shares(
request: Request, current_user: Annotated[User, Depends(get_current_active_user)]
):
user_account = await UserAccount.get(id=current_user.id)
shares = await ShareService.get_user_shares(user=user_account)
return [ShareInfo.from_orm(s) for s in shares]
@router.delete("/expired")
@audit(action=AuditAction.DELETE, description="删除已过期分享")
async def delete_expired_shares(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
):
user_account = await UserAccount.get(id=current_user.id)
deleted_count = await ShareService.delete_expired_shares(user=user_account)
return success({"deleted_count": deleted_count})
@router.delete("/{share_id}")
@audit(action=AuditAction.DELETE, description="删除分享链接")
async def delete_share(
share_id: int,
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
):
user_account = await UserAccount.get(id=current_user.id)
await ShareService.delete_share_link(user=user_account, share_id=share_id)
return success(msg="分享已取消")
@public_router.post("/{token}/verify")
@audit(
action=AuditAction.SHARE,
description="校验分享密码",
body_fields=["password"],
redact_fields=["password"],
)
async def verify_password(request: Request, token: str, payload: SharePassword):
await ShareService.verify_share_password(token, payload.password)
return success(msg="验证成功")
@public_router.get("/{token}/ls")
@audit(action=AuditAction.SHARE, description="浏览分享内容")
async def list_share_content(
request: Request, token: str, path: str = "/", password: Optional[str] = None
):
share = await ShareService.ensure_share_access(token, password)
content = await ShareService.get_shared_item_details(share, path)
return success(
{
"path": path,
"entries": content.get("items", []),
"pagination": {
"total": content.get("total", 0),
"page": content.get("page", 1),
"page_size": content.get("page_size", 1),
"pages": content.get("pages", 1),
},
}
)
@public_router.get("/{token}")
@audit(action=AuditAction.SHARE, description="获取分享信息")
async def get_share_info(request: Request, token: str):
share = await ShareService.get_share_by_token(token)
return success(ShareInfo.from_orm(share))
@public_router.get("/{token}/download")
@audit(action=AuditAction.DOWNLOAD, description="下载分享文件")
async def download_shared_file(
token: str,
path: str,
request: Request,
password: Optional[str] = None,
):
return await ShareService.stream_shared_file(token, path, request.headers.get("Range"), password)

187
domain/share/service.py Normal file
View File

@@ -0,0 +1,187 @@
import secrets
from datetime import datetime, timedelta, timezone
from typing import List, Optional
from urllib.parse import quote
import bcrypt
from fastapi import HTTPException, status
from fastapi.responses import Response
from domain.virtual_fs.service import VirtualFSService
from models.database import ShareLink, UserAccount
class ShareService:
@classmethod
def _hash_password(cls, password: str) -> str:
return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
@classmethod
def _verify_password(cls, plain_password: str, hashed_password: str) -> bool:
return bcrypt.checkpw(plain_password.encode("utf-8"), hashed_password.encode("utf-8"))
@classmethod
def _calc_expires_at(cls, expires_in_days: Optional[int]) -> Optional[datetime]:
if expires_in_days is None or expires_in_days <= 0:
return None
return datetime.now(timezone.utc) + timedelta(days=expires_in_days)
@classmethod
def _ensure_password_if_needed(cls, share: ShareLink, password: Optional[str]) -> None:
if share.access_type != "password":
return
if not password:
raise HTTPException(status_code=401, detail="需要密码")
if not share.hashed_password:
raise HTTPException(status_code=403, detail="密码错误")
if not cls._verify_password(password, share.hashed_password):
raise HTTPException(status_code=403, detail="密码错误")
@classmethod
async def create_share_link(
cls,
user: UserAccount,
name: str,
paths: List[str],
expires_in_days: Optional[int] = 7,
access_type: str = "public",
password: Optional[str] = None,
) -> ShareLink:
if not paths:
raise HTTPException(status_code=400, detail="分享路径不能为空")
if access_type == "password" and not password:
raise HTTPException(status_code=400, detail="密码不能为空")
token = secrets.token_urlsafe(16)
expires_at = cls._calc_expires_at(expires_in_days)
hashed_password = None
if access_type == "password" and password:
hashed_password = cls._hash_password(password)
share = await ShareLink.create(
token=token,
name=name,
paths=paths,
user=user,
expires_at=expires_at,
access_type=access_type,
hashed_password=hashed_password,
)
return share
@classmethod
async def get_share_by_token(cls, token: str) -> ShareLink:
share = await ShareLink.get_or_none(token=token).prefetch_related("user")
if not share:
raise HTTPException(status_code=404, detail="分享链接不存在")
if share.expires_at and share.expires_at < datetime.now(timezone.utc):
raise HTTPException(status_code=410, detail="分享链接已过期")
return share
@classmethod
async def verify_share_password(cls, token: str, password: str) -> ShareLink:
share = await cls.get_share_by_token(token)
if share.access_type != "password":
raise HTTPException(status_code=400, detail="此分享不需要密码")
cls._ensure_password_if_needed(share, password)
return share
@classmethod
async def ensure_share_access(cls, token: str, password: Optional[str]) -> ShareLink:
share = await cls.get_share_by_token(token)
cls._ensure_password_if_needed(share, password)
return share
@classmethod
async def get_user_shares(cls, user: UserAccount) -> List[ShareLink]:
return await ShareLink.filter(user=user).order_by("-created_at")
@classmethod
async def delete_share_link(cls, user: UserAccount, share_id: int) -> None:
share = await ShareLink.get_or_none(id=share_id, user_id=user.id)
if not share:
raise HTTPException(status_code=404, detail="分享链接不存在")
await share.delete()
@classmethod
async def delete_expired_shares(cls, user: UserAccount) -> int:
now = datetime.now(timezone.utc)
deleted_count = await ShareLink.filter(user=user, expires_at__lte=now).delete()
return deleted_count
@classmethod
async def get_shared_item_details(cls, share: ShareLink, sub_path: str = ""):
if not share.paths:
raise HTTPException(status_code=404, detail="分享内容为空")
base_shared_path = share.paths[0]
if sub_path and sub_path != "/":
full_path = f"{base_shared_path.rstrip('/')}/{sub_path.lstrip('/')}".rstrip("/")
if not full_path.startswith(base_shared_path):
raise HTTPException(status_code=403, detail="无权访问此路径")
try:
return await VirtualFSService.list_virtual_dir(full_path)
except FileNotFoundError:
raise HTTPException(status_code=404, detail="目录未找到")
try:
stat = await VirtualFSService.stat_file(base_shared_path)
if stat.get("is_dir"):
return await VirtualFSService.list_virtual_dir(base_shared_path)
stat["name"] = base_shared_path.split("/")[-1]
return {"items": [stat], "total": 1, "page": 1, "page_size": 1, "pages": 1}
except HTTPException as e:
if "Path is a directory" in str(e.detail) or "Not a file" in str(e.detail):
return await VirtualFSService.list_virtual_dir(base_shared_path)
raise e
@classmethod
async def stream_shared_file(
cls,
token: str,
path: str,
range_header: str | None,
password: Optional[str] = None,
) -> Response:
if not path or path == "/" or ".." in path.split("/"):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="无效的文件路径")
share = await cls.ensure_share_access(token, password)
if not share.paths:
raise HTTPException(status_code=404, detail="分享的源文件不存在")
base_shared_path = share.paths[0]
is_dir = False
try:
stat = await VirtualFSService.stat_file(base_shared_path)
if stat and stat.get("is_dir"):
is_dir = True
except HTTPException as e:
if "Path is a directory" in str(e.detail) or "Not a file" in str(e.detail):
is_dir = True
elif e.status_code == 404:
raise HTTPException(status_code=404, detail="分享的源文件不存在")
else:
raise
if is_dir:
full_virtual_path = f"{base_shared_path.rstrip('/')}/{path.lstrip('/')}"
if not full_virtual_path.startswith(base_shared_path):
raise HTTPException(status_code=403, detail="无权访问此路径")
else:
shared_filename = base_shared_path.split("/")[-1]
request_filename = path.lstrip("/")
if shared_filename != request_filename:
raise HTTPException(status_code=403, detail="无权访问此路径")
full_virtual_path = base_shared_path
response = await VirtualFSService.stream_file(full_virtual_path, range_header)
filename = full_virtual_path.split("/")[-1]
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{quote(filename)}"
return response

43
domain/share/types.py Normal file
View File

@@ -0,0 +1,43 @@
from typing import List, Optional
from pydantic import BaseModel
from models.database import ShareLink
class ShareCreate(BaseModel):
name: str
paths: List[str]
expires_in_days: Optional[int] = 7
access_type: str = "public"
password: Optional[str] = None
class SharePassword(BaseModel):
password: str
class ShareInfo(BaseModel):
id: int
token: str
name: str
paths: List[str]
created_at: str
expires_at: Optional[str] = None
access_type: str
@classmethod
def from_orm(cls, obj: ShareLink):
return cls(
id=obj.id,
token=obj.token,
name=obj.name,
paths=obj.paths,
created_at=obj.created_at.isoformat(),
expires_at=obj.expires_at.isoformat() if obj.expires_at else None,
access_type=obj.access_type,
)
class ShareInfoWithPassword(ShareInfo):
password: Optional[str] = None

112
domain/tasks/api.py Normal file
View File

@@ -0,0 +1,112 @@
from fastapi import APIRouter, Depends, Request
from api.response import success
from domain.audit import AuditAction, audit
from domain.auth.service import get_current_active_user
from domain.tasks.service import TaskService
from domain.tasks.types import (
AutomationTaskCreate,
AutomationTaskUpdate,
TaskQueueSettings,
)
CurrentUser = TaskService.current_user_dep
router = APIRouter(
prefix="/api/tasks",
tags=["Tasks"],
dependencies=[Depends(get_current_active_user)],
responses={404: {"description": "Not found"}},
)
@router.get("/queue")
@audit(action=AuditAction.READ, description="获取任务队列状态")
async def get_task_queue_status(request: Request, current_user: CurrentUser):
payload = TaskService.get_queue_tasks()
return success(payload)
@router.get("/queue/settings")
@audit(action=AuditAction.READ, description="获取任务队列设置")
async def get_task_queue_settings(request: Request, current_user: CurrentUser):
payload = TaskService.get_queue_settings()
return success(payload.model_dump())
@router.post("/queue/settings")
@audit(
action=AuditAction.UPDATE,
description="更新任务队列设置",
body_fields=["concurrency"],
)
async def update_task_queue_settings(request: Request, settings: TaskQueueSettings, current_user: CurrentUser):
payload = await TaskService.update_queue_settings(settings, getattr(current_user, "id", None))
return success(payload.model_dump())
@router.get("/queue/{task_id}")
@audit(action=AuditAction.READ, description="获取队列任务状态")
async def get_task_status(task_id: str, request: Request, current_user: CurrentUser):
payload = TaskService.get_queue_task(task_id)
return success(payload)
@router.post("/")
@audit(
action=AuditAction.CREATE,
description="创建自动化任务",
body_fields=[
"name",
"event",
"path_pattern",
"filename_regex",
"processor_type",
"processor_config",
"enabled",
],
user_kw="user",
)
async def create_task(request: Request, task_in: AutomationTaskCreate, user: CurrentUser):
task = await TaskService.create_task(task_in, user)
return success(task)
@router.get("/{task_id}")
@audit(action=AuditAction.READ, description="获取自动化任务详情")
async def get_task(task_id: int, request: Request, current_user: CurrentUser):
task = await TaskService.get_task(task_id)
return success(task)
@router.get("/")
@audit(action=AuditAction.READ, description="获取自动化任务列表")
async def list_tasks(request: Request, current_user: CurrentUser):
tasks = await TaskService.list_tasks()
return success(tasks)
@router.put("/{task_id}")
@audit(
action=AuditAction.UPDATE,
description="更新自动化任务",
body_fields=[
"name",
"event",
"path_pattern",
"filename_regex",
"processor_type",
"processor_config",
"enabled",
],
)
async def update_task(request: Request, current_user: CurrentUser, task_id: int, task_in: AutomationTaskUpdate):
task = await TaskService.update_task(task_id, task_in, current_user)
return success(task)
@router.delete("/{task_id}")
@audit(action=AuditAction.DELETE, description="删除自动化任务", user_kw="user")
async def delete_task(task_id: int, request: Request, user: CurrentUser):
await TaskService.delete_task(task_id, user)
return success(msg="Task deleted")

109
domain/tasks/service.py Normal file
View File

@@ -0,0 +1,109 @@
import re
from typing import Annotated, Any, Dict, Optional
from fastapi import Depends, HTTPException
from domain.auth.service import get_current_active_user
from domain.auth.types import User
from domain.config.service import ConfigService
from domain.tasks.types import (
AutomationTaskCreate,
AutomationTaskUpdate,
TaskQueueSettings,
TaskQueueSettingsResponse,
)
from models.database import AutomationTask
from domain.tasks.task_queue import task_queue_service
class TaskService:
current_user_dep = Annotated[User, Depends(get_current_active_user)]
@classmethod
def get_queue_tasks(cls) -> list[dict[str, Any]]:
tasks = task_queue_service.get_all_tasks()
return [task.dict() for task in tasks]
@classmethod
def get_queue_settings(cls) -> TaskQueueSettingsResponse:
return TaskQueueSettingsResponse(
concurrency=task_queue_service.get_concurrency(),
active_workers=task_queue_service.get_active_worker_count(),
)
@classmethod
async def update_queue_settings(cls, settings: TaskQueueSettings, user_id: Optional[int]) -> TaskQueueSettingsResponse:
await task_queue_service.set_concurrency(settings.concurrency)
await ConfigService.set("TASK_QUEUE_CONCURRENCY", str(task_queue_service.get_concurrency()))
return cls.get_queue_settings()
@classmethod
def get_queue_task(cls, task_id: str) -> dict[str, Any]:
task = task_queue_service.get_task(task_id)
if not task:
raise HTTPException(status_code=404, detail="Task not found")
return task.dict()
@classmethod
async def create_task(cls, payload: AutomationTaskCreate, user: Optional[User]) -> AutomationTask:
task = await AutomationTask.create(**payload.model_dump())
return task
@classmethod
async def get_task(cls, task_id: int) -> AutomationTask:
task = await AutomationTask.get_or_none(id=task_id)
if not task:
raise HTTPException(status_code=404, detail=f"Task {task_id} not found")
return task
@classmethod
async def list_tasks(cls) -> list[AutomationTask]:
tasks = await AutomationTask.all()
return tasks
@classmethod
async def update_task(cls, task_id: int, payload: AutomationTaskUpdate, current_user: User) -> AutomationTask:
task = await AutomationTask.get_or_none(id=task_id)
if not task:
raise HTTPException(status_code=404, detail=f"Task {task_id} not found")
update_data = payload.model_dump(exclude_unset=True)
for key, value in update_data.items():
setattr(task, key, value)
await task.save()
return task
@classmethod
async def delete_task(cls, task_id: int, user: Optional[User]) -> None:
deleted_count = await AutomationTask.filter(id=task_id).delete()
if not deleted_count:
raise HTTPException(status_code=404, detail=f"Task {task_id} not found")
@classmethod
async def trigger_tasks(cls, event: str, path: str):
tasks = await AutomationTask.filter(event=event, enabled=True)
for task in tasks:
if cls.match(task, path):
await cls.execute(task, path)
@classmethod
def match(cls, task: AutomationTask, path: str) -> bool:
if task.path_pattern and not path.startswith(task.path_pattern):
return False
if task.filename_regex:
filename = path.split("/")[-1]
if not re.match(task.filename_regex, filename):
return False
return True
@classmethod
async def execute(cls, task: AutomationTask, path: str):
await task_queue_service.add_task(
task.processor_type,
{
"task_id": task.id,
"path": path,
},
)
task_service = TaskService

View File

@@ -2,7 +2,6 @@ import asyncio
from typing import Dict, Any
from pydantic import BaseModel, Field
import uuid
from services.logging import LogService
from enum import Enum
@@ -47,7 +46,6 @@ class TaskQueueService:
task = Task(name=name, task_info=task_info)
self._tasks[task.id] = task
await self._queue.put(task)
await LogService.info("task_queue", f"Task {name} ({task.id}) enqueued", {"task_id": task.id, "name": name})
return task
def get_task(self, task_id: str) -> Task | None:
@@ -72,15 +70,15 @@ class TaskQueueService:
task.meta = (task.meta or {}) | meta
async def _execute_task(self, task: Task):
from services.virtual_fs import process_file
task.status = TaskStatus.RUNNING
await LogService.info("task_queue", f"Task {task.name} ({task.id}) started", {"task_id": task.id, "name": task.name})
try:
# Local import to avoid circular dependency during module load.
from domain.virtual_fs.service import VirtualFSService
if task.name == "process_file":
params = task.task_info
result = await process_file(
result = await VirtualFSService.process_file(
path=params["path"],
processor_type=params["processor_type"],
config=params["config"],
@@ -90,8 +88,7 @@ class TaskQueueService:
task.result = result
elif task.name == "automation_task" or self._is_processor_task(task.name):
from models.database import AutomationTask
from services.processors.registry import get as get_processor
from services.virtual_fs import read_file, write_file
from domain.processors.service import get_processor
params = task.task_info
auto_task = await AutomationTask.get(id=params["task_id"])
@@ -103,54 +100,45 @@ class TaskQueueService:
raise ValueError(f"Processor {processor_type} not found for task {auto_task.id}")
if processor_type != auto_task.processor_type:
await LogService.warning(
"task_queue",
"Processor type mismatch; falling back to stored type",
{"task_id": auto_task.id, "expected": auto_task.processor_type, "got": processor_type},
)
processor_type = auto_task.processor_type
processor = get_processor(processor_type)
if not processor:
raise ValueError(f"Processor {processor_type} not found for task {auto_task.id}")
file_content = await read_file(path)
file_content = await VirtualFSService.read_file(path)
result = await processor.process(file_content, path, auto_task.processor_config)
save_to = auto_task.processor_config.get("save_to")
if save_to and getattr(processor, "produces_file", False):
await write_file(save_to, result)
await VirtualFSService.write_file(save_to, result)
task.result = "Automation task completed"
elif task.name == "offline_http_download":
from services.offline_download import run_http_download
from domain.offline_downloads.service import OfflineDownloadService
result_path = await run_http_download(task)
result_path = await OfflineDownloadService.run_http_download(task)
task.result = {"path": result_path}
elif task.name == "cross_mount_transfer":
from services.virtual_fs import run_cross_mount_transfer_task
result = await run_cross_mount_transfer_task(task)
result = await VirtualFSService.run_cross_mount_transfer_task(task)
task.result = result
elif task.name == "send_email":
from services.email import EmailService
from domain.email.service import EmailService
await EmailService.send_from_task(task.id, task.task_info)
task.result = "Email sent"
else:
raise ValueError(f"Unknown task name: {task.name}")
task.status = TaskStatus.SUCCESS
await LogService.info("task_queue", f"Task {task.name} ({task.id}) succeeded", {"task_id": task.id, "name": task.name})
except Exception as e:
task.status = TaskStatus.FAILED
task.error = str(e)
await LogService.error("task_queue", f"Task {task.name} ({task.id}) failed: {e}", {"task_id": task.id, "name": task.name})
def _cleanup_workers(self):
self._worker_tasks = [task for task in self._worker_tasks if not task.done()]
def _is_processor_task(self, task_name: str) -> bool:
try:
from services.processors.registry import get as get_processor
from domain.processors.service import get_processor
return get_processor(task_name) is not None
except Exception:
@@ -165,15 +153,12 @@ class TaskQueueService:
worker_id = self._worker_seq
worker_task = asyncio.create_task(self._worker_loop(worker_id))
self._worker_tasks.append(worker_task)
await LogService.info("task_queue", "Task workers adjusted", {"active_workers": len(self._worker_tasks), "target": self._concurrency})
elif current > self._concurrency:
for _ in range(current - self._concurrency):
await self._queue.put(_SENTINEL)
await LogService.info("task_queue", "Task workers scaling down", {"active_workers": len(self._worker_tasks), "target": self._concurrency})
async def _worker_loop(self, worker_id: int):
current_task = asyncio.current_task()
await LogService.info("task_queue", f"Worker {worker_id} started")
try:
while True:
job = await self._queue.get()
@@ -183,23 +168,18 @@ class TaskQueueService:
try:
await self._execute_task(job)
except Exception as e:
await LogService.error(
"task_queue",
f"Error executing task {job.id}: {e}",
{"task_id": job.id, "name": job.name},
)
pass
finally:
self._queue.task_done()
finally:
if current_task in self._worker_tasks:
self._worker_tasks.remove(current_task) # type: ignore[arg-type]
await LogService.info("task_queue", f"Worker {worker_id} stopped")
async def start_worker(self, concurrency: int | None = None):
if concurrency is None:
from services.config import ConfigCenter
from domain.config.service import ConfigService
stored_value = await ConfigCenter.get("TASK_QUEUE_CONCURRENCY", self._concurrency)
stored_value = await ConfigService.get("TASK_QUEUE_CONCURRENCY", self._concurrency)
try:
concurrency = int(stored_value)
except (TypeError, ValueError):
@@ -219,7 +199,6 @@ class TaskQueueService:
if self._worker_tasks:
await asyncio.gather(*self._worker_tasks, return_exceptions=True)
self._worker_tasks.clear()
await LogService.info("task_queue", "Task workers have been stopped.")
def get_concurrency(self) -> int:
return self._concurrency

View File

@@ -1,5 +1,6 @@
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field
from typing import Optional, Dict, Any
class AutomationTaskBase(BaseModel):

189
domain/virtual_fs/api.py Normal file
View File

@@ -0,0 +1,189 @@
from typing import Annotated
from fastapi import APIRouter, Depends, File, Query, Request, UploadFile
from api.response import success
from domain.audit import AuditAction, audit
from domain.auth.service import get_current_active_user
from domain.auth.types import User
from domain.virtual_fs.service import VirtualFSService
from domain.virtual_fs.types import MkdirRequest, MoveRequest
router = APIRouter(prefix="/api/fs", tags=["virtual-fs"])
@router.get("/file/{full_path:path}")
@audit(action=AuditAction.DOWNLOAD, description="获取文件")
async def get_file(
full_path: str,
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
):
return await VirtualFSService.serve_file(full_path, request.headers.get("Range"))
@router.get("/thumb/{full_path:path}")
@audit(action=AuditAction.READ, description="获取缩略图")
async def get_thumb(
full_path: str,
request: Request,
w: int = Query(256, ge=8, le=1024),
h: int = Query(256, ge=8, le=1024),
fit: str = Query("cover"),
):
return await VirtualFSService.get_thumbnail(full_path, w, h, fit)
@router.get("/stream/{full_path:path}")
@audit(action=AuditAction.DOWNLOAD, description="流式读取文件")
async def stream_endpoint(
full_path: str,
request: Request,
):
return await VirtualFSService.stream_response(full_path, request.headers.get("Range"))
@router.get("/temp-link/{full_path:path}")
@audit(action=AuditAction.SHARE, description="创建临时链接")
async def get_temp_link(
full_path: str,
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
expires_in: int = Query(3600, description="有效时间(秒), 0或负数表示永久"),
):
data = await VirtualFSService.create_temp_link(full_path, expires_in)
return success(data)
@router.get("/public/{token}")
@audit(action=AuditAction.DOWNLOAD, description="访问临时链接文件")
async def access_public_file(
token: str,
request: Request,
):
return await VirtualFSService.access_public_file(token, request.headers.get("Range"))
@router.get("/stat/{full_path:path}")
@audit(action=AuditAction.READ, description="查看文件信息")
async def get_file_stat(
full_path: str,
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
):
stat = await VirtualFSService.stat(full_path)
return success(stat)
@router.post("/file/{full_path:path}")
@audit(action=AuditAction.UPLOAD, description="上传文件")
async def put_file(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
full_path: str,
file: UploadFile = File(...),
):
data = await file.read()
result = await VirtualFSService.write_uploaded_file(full_path, data)
return success(result)
@router.post("/mkdir")
@audit(action=AuditAction.CREATE, description="创建目录", body_fields=["path"])
async def api_mkdir(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
body: MkdirRequest,
):
result = await VirtualFSService.mkdir(body.path)
return success(result)
@router.post("/move")
@audit(action=AuditAction.UPDATE, description="移动路径", body_fields=["src", "dst"])
async def api_move(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
body: MoveRequest,
overwrite: bool = Query(False, description="是否允许覆盖已存在目标"),
):
result = await VirtualFSService.move(body.src, body.dst, overwrite)
return success(result)
@router.post("/rename")
@audit(action=AuditAction.UPDATE, description="重命名路径", body_fields=["src", "dst"])
async def api_rename(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
body: MoveRequest,
overwrite: bool = Query(False, description="是否允许覆盖已存在目标"),
):
result = await VirtualFSService.rename(body.src, body.dst, overwrite)
return success(result)
@router.post("/copy")
@audit(action=AuditAction.CREATE, description="复制路径", body_fields=["src", "dst"])
async def api_copy(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
body: MoveRequest,
overwrite: bool = Query(False, description="是否覆盖已存在目标"),
):
result = await VirtualFSService.copy(body.src, body.dst, overwrite)
return success(result)
@router.post("/upload/{full_path:path}")
@audit(action=AuditAction.UPLOAD, description="流式上传文件")
async def upload_stream(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
full_path: str,
file: UploadFile = File(...),
overwrite: bool = Query(True, description="是否覆盖已存在文件"),
chunk_size: int = Query(1024 * 1024, ge=8 * 1024, le=8 * 1024 * 1024, description="单次读取块大小"),
):
result = await VirtualFSService.upload_stream_from_upload_file(full_path, file, chunk_size, overwrite)
return success(result)
@router.get("/{full_path:path}")
@audit(action=AuditAction.READ, description="浏览目录")
async def browse_fs(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
full_path: str,
page_num: int = Query(1, alias="page", ge=1, description="页码"),
page_size: int = Query(50, ge=1, le=500, description="每页条数"),
sort_by: str = Query("name", description="按字段排序: name, size, mtime"),
sort_order: str = Query("asc", description="排序顺序: asc, desc"),
):
data = await VirtualFSService.list_directory(full_path, page_num, page_size, sort_by, sort_order)
return success(data)
@router.delete("/{full_path:path}")
@audit(action=AuditAction.DELETE, description="删除路径")
async def api_delete(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
full_path: str,
):
result = await VirtualFSService.delete(full_path)
return success(result)
@router.get("/")
@audit(action=AuditAction.READ, description="浏览根目录")
async def root_listing(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
page_num: int = Query(1, alias="page", ge=1, description="页码"),
page_size: int = Query(50, ge=1, le=500, description="每页条数"),
sort_by: str = Query("name", description="按字段排序: name, size, mtime"),
sort_order: str = Query("asc", description="排序顺序: asc, desc"),
):
data = await VirtualFSService.list_directory("/", page_num, page_size, sort_by, sort_order)
return success(data)

View File

@@ -10,14 +10,8 @@ from typing import Dict, Iterable, List, Optional, Tuple
from fastapi import APIRouter, Request, Response
from fastapi import HTTPException
from services.config import ConfigCenter
from services.virtual_fs import (
delete_path,
list_virtual_dir,
stat_file,
stream_file,
write_file_stream,
)
from domain.config.service import ConfigService
from domain.virtual_fs.service import VirtualFSService
router = APIRouter(prefix="/s3", tags=["s3"])
@@ -69,18 +63,18 @@ def _s3_error(code: str, message: str, resource: str = "", status: int = 400) ->
async def _ensure_enabled() -> Optional[Response]:
flag = await ConfigCenter.get("S3_MAPPING_ENABLED", "1")
flag = await ConfigService.get("S3_MAPPING_ENABLED", "1")
if str(flag).strip().lower() in FALSEY:
return _s3_error("ServiceUnavailable", "S3 mapping disabled", status=503)
return None
async def _get_settings() -> Tuple[Optional[S3Settings], Optional[Response]]:
bucket = (await ConfigCenter.get("S3_MAPPING_BUCKET", "foxel")) or "foxel"
region = (await ConfigCenter.get("S3_MAPPING_REGION", "us-east-1")) or "us-east-1"
base_path = (await ConfigCenter.get("S3_MAPPING_BASE_PATH", "/")) or "/"
access_key = (await ConfigCenter.get("S3_MAPPING_ACCESS_KEY")) or ""
secret_key = (await ConfigCenter.get("S3_MAPPING_SECRET_KEY")) or ""
bucket = (await ConfigService.get("S3_MAPPING_BUCKET", "foxel")) or "foxel"
region = (await ConfigService.get("S3_MAPPING_REGION", "us-east-1")) or "us-east-1"
base_path = (await ConfigService.get("S3_MAPPING_BASE_PATH", "/")) or "/"
access_key = (await ConfigService.get("S3_MAPPING_ACCESS_KEY")) or ""
secret_key = (await ConfigService.get("S3_MAPPING_SECRET_KEY")) or ""
if not access_key or not secret_key:
return None, _s3_error(
"InvalidAccessKeyId",
@@ -221,7 +215,7 @@ async def _list_dir_all(path: str) -> List[Dict]:
page_size = 1000
while True:
try:
res = await list_virtual_dir(path, page_num=page_num, page_size=page_size)
res = await VirtualFSService.list_virtual_dir(path, page_num=page_num, page_size=page_size)
except HTTPException as exc: # directory missing
if exc.status_code in (400, 404):
return []
@@ -376,7 +370,7 @@ async def list_objects(request: Request, bucket: str):
prefixes: List[str] = []
if prefix and not prefix.endswith("/"):
try:
info = await stat_file(_virtual_path(settings, prefix))
info = await VirtualFSService.stat_file(_virtual_path(settings, prefix))
if not info.get("is_dir"):
files = [(prefix, info)]
except HTTPException as exc:
@@ -473,7 +467,7 @@ def _object_headers(meta: Dict, key: str) -> Dict[str, str]:
async def _stat_object(settings: S3Settings, key: str) -> Tuple[Optional[Dict], Optional[Response]]:
try:
info = await stat_file(_virtual_path(settings, key))
info = await VirtualFSService.stat_file(_virtual_path(settings, key))
if info.get("is_dir"):
return None, _s3_error("NoSuchKey", "The specified key does not exist.", _resource_path(settings["bucket"], key), status=404)
return info, None
@@ -498,7 +492,7 @@ async def object_get_head(request: Request, bucket: str, object_path: str):
base_headers.update(_object_headers(meta, key))
if request.method == "HEAD":
return Response(status_code=200, headers=base_headers)
resp = await stream_file(_virtual_path(settings, key), request.headers.get("range"))
resp = await VirtualFSService.stream_file(_virtual_path(settings, key), request.headers.get("range"))
safe_merge_keys = {"ETag", "Last-Modified", "x-amz-version-id", "Accept-Ranges"}
for hk, hv in base_headers.items():
if hk in safe_merge_keys:
@@ -514,7 +508,7 @@ async def put_object(request: Request, bucket: str, object_path: str):
return error
assert settings
key = object_path.lstrip("/")
await write_file_stream(_virtual_path(settings, key), request.stream(), overwrite=True)
await VirtualFSService.write_file_stream(_virtual_path(settings, key), request.stream(), overwrite=True)
meta, err = await _stat_object(settings, key)
if err:
return err
@@ -535,7 +529,7 @@ async def delete_object(request: Request, bucket: str, object_path: str):
assert settings
key = object_path.lstrip("/")
try:
await delete_path(_virtual_path(settings, key))
await VirtualFSService.delete_path(_virtual_path(settings, key))
except HTTPException as exc:
if exc.status_code not in (400, 404):
raise

View File

@@ -0,0 +1,26 @@
from fastapi import APIRouter, Depends, Query
from domain.auth.service import get_current_active_user
from domain.auth.types import User
from domain.virtual_fs.search_service import VirtualFSSearchService
router = APIRouter(prefix="/api/fs/search", tags=["search"])
@router.get("")
async def search_files(
q: str = Query(..., description="搜索查询"),
top_k: int = Query(10, description="返回结果数量"),
mode: str = Query("vector", description="搜索模式: 'vector''filename'"),
page: int = Query(1, description="分页页码,仅在文件名搜索模式下生效"),
page_size: int = Query(10, description="分页大小,仅在文件名搜索模式下生效"),
user: User = Depends(get_current_active_user),
):
if not q.strip():
return {"items": [], "query": q}
top_k = max(top_k, 1)
page = max(page, 1)
page_size = max(min(page_size, 100), 1)
return await VirtualFSSearchService.search(q, top_k, mode, page, page_size)

View File

@@ -1,13 +1,8 @@
from typing import Any, Dict, List, Tuple
from fastapi import APIRouter, Depends, Query
from schemas.fs import SearchResultItem
from services.auth import get_current_active_user, User
from services.ai import get_text_embedding
from services.vector_db import VectorDBService
router = APIRouter(prefix="/api/search", tags=["search"])
from domain.virtual_fs.types import SearchResultItem
from domain.ai.inference import get_text_embedding
from domain.ai.service import VectorDBService
def _normalize_result(raw: Dict[str, Any], source: str, fallback_score: float = 0.0) -> SearchResultItem:
@@ -101,37 +96,23 @@ async def _filename_search(query: str, page: int, page_size: int) -> Tuple[List[
return page_items, has_more
@router.get("")
async def search_files(
q: str = Query(..., description="搜索查询"),
top_k: int = Query(10, description="返回结果数量"),
mode: str = Query("vector", description="搜索模式: 'vector''filename'"),
page: int = Query(1, description="分页页码,仅在文件名搜索模式下生效"),
page_size: int = Query(10, description="分页大小,仅在文件名搜索模式下生效"),
user: User = Depends(get_current_active_user),
):
if not q.strip():
return {"items": [], "query": q}
top_k = max(top_k, 1)
page = max(page, 1)
page_size = max(min(page_size, 100), 1)
if mode == "vector":
items = (await _vector_search(q, top_k))[:top_k]
elif mode == "filename":
items, has_more = await _filename_search(q, page, page_size)
return {
"items": items,
"query": q,
"mode": mode,
"pagination": {
"page": page,
"page_size": page_size,
"has_more": has_more,
},
}
else:
items = (await _vector_search(q, top_k))[:top_k]
return {"items": items, "query": q, "mode": mode}
class VirtualFSSearchService:
@staticmethod
async def search(query: str, top_k: int, mode: str, page: int, page_size: int):
if mode == "vector":
items = (await _vector_search(query, top_k))[:top_k]
return {"items": items, "query": query, "mode": mode}
if mode == "filename":
items, has_more = await _filename_search(query, page, page_size)
return {
"items": items,
"query": query,
"mode": mode,
"pagination": {
"page": page,
"page_size": page_size,
"has_more": has_more,
},
}
items = (await _vector_search(query, top_k))[:top_k]
return {"items": items, "query": query, "mode": mode}

1360
domain/virtual_fs/service.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +1,7 @@
from pydantic import BaseModel
from typing import List, Optional
from pydantic import BaseModel
class VfsEntry(BaseModel):
name: str

View File

@@ -9,25 +9,17 @@ from typing import Optional
from fastapi import APIRouter, Request, Response, HTTPException, Depends
import xml.etree.ElementTree as ET
from services.auth import authenticate_user_db, User, UserInDB
from services.virtual_fs import (
list_virtual_dir,
stat_file,
write_file_stream,
make_dir,
delete_path,
move_path,
copy_path,
stream_file,
)
from services.config import ConfigCenter
from domain.auth.service import AuthService
from domain.auth.types import User, UserInDB
from domain.virtual_fs.service import VirtualFSService
from domain.config.service import ConfigService
_WEBDAV_ENABLED_KEY = "WEBDAV_MAPPING_ENABLED"
async def _ensure_webdav_enabled() -> None:
enabled = await ConfigCenter.get(_WEBDAV_ENABLED_KEY, "1")
enabled = await ConfigService.get(_WEBDAV_ENABLED_KEY, "1")
if str(enabled).strip().lower() in ("0", "false", "off", "no"):
raise HTTPException(503, detail="WebDAV mapping disabled")
@@ -70,7 +62,7 @@ async def _get_basic_user(request: Request) -> User:
username, _, password = decoded.partition(":")
except Exception:
raise HTTPException(401, detail="Invalid Basic auth", headers={"WWW-Authenticate": "Basic realm=webdav"})
user_or_false: Optional[UserInDB] = await authenticate_user_db(username, password)
user_or_false: Optional[UserInDB] = await AuthService.authenticate_user_db(username, password)
if not user_or_false:
raise HTTPException(401, detail="Invalid credentials", headers={"WWW-Authenticate": "Basic realm=webdav"})
u: UserInDB = user_or_false
@@ -170,7 +162,7 @@ async def propfind(
# 先获取当前路径信息
try:
st = await stat_file(full_path)
st = await VirtualFSService.stat_file(full_path)
is_dir = bool(st.get("is_dir"))
name = st.get("name") or full_path.rsplit("/", 1)[-1] or "/"
size = None if is_dir else int(st.get("size", 0))
@@ -182,7 +174,7 @@ async def propfind(
if depth in ("1", "infinity"):
try:
listing = await list_virtual_dir(full_path, page_num=1, page_size=1000)
listing = await VirtualFSService.list_virtual_dir(full_path, page_num=1, page_size=1000)
for ent in listing["items"]:
is_dir = bool(ent.get("is_dir"))
name = ent.get("name")
@@ -210,7 +202,7 @@ async def dav_get(
):
full_path = _normalize_fs_path(path)
range_header = request.headers.get("Range")
return await stream_file(full_path, range_header)
return await VirtualFSService.stream_file(full_path, range_header)
@router.head("/{path:path}")
@@ -221,7 +213,7 @@ async def dav_head(
):
full_path = _normalize_fs_path(path)
try:
st = await stat_file(full_path)
st = await VirtualFSService.stat_file(full_path)
except FileNotFoundError:
raise HTTPException(404, detail="Not found")
is_dir = bool(st.get("is_dir"))
@@ -251,7 +243,7 @@ async def dav_put(
async for chunk in request.stream():
if chunk:
yield chunk
size = await write_file_stream(full_path, body_iter(), overwrite=True)
size = await VirtualFSService.write_file_stream(full_path, body_iter(), overwrite=True)
return Response(status_code=201, headers=_dav_headers({"Content-Length": "0"}))
@@ -262,7 +254,7 @@ async def dav_delete(
user: User = Depends(_get_basic_user),
):
full_path = _normalize_fs_path(path)
await delete_path(full_path)
await VirtualFSService.delete_path(full_path)
return Response(status_code=204, headers=_dav_headers())
@@ -273,7 +265,7 @@ async def dav_mkcol(
user: User = Depends(_get_basic_user),
):
full_path = _normalize_fs_path(path)
await make_dir(full_path)
await VirtualFSService.make_dir(full_path)
return Response(status_code=201, headers=_dav_headers())
@@ -295,7 +287,7 @@ async def dav_move(path: str, request: Request, user: User = Depends(_get_basic_
dest_header = request.headers.get("Destination")
dst = _parse_destination(dest_header or "")
overwrite = request.headers.get("Overwrite", "T").upper() != "F"
await move_path(full_src, dst, overwrite=overwrite)
await VirtualFSService.move_path(full_src, dst, overwrite=overwrite)
return Response(status_code=204, headers=_dav_headers())
@@ -305,5 +297,5 @@ async def dav_copy(path: str, request: Request, user: User = Depends(_get_basic_
dest_header = request.headers.get("Destination")
dst = _parse_destination(dest_header or "")
overwrite = request.headers.get("Overwrite", "T").upper() != "F"
await copy_path(full_src, dst, overwrite=overwrite)
await VirtualFSService.copy_path(full_src, dst, overwrite=overwrite)
return Response(status_code=201 if not overwrite else 204, headers=_dav_headers())

12
main.py
View File

@@ -1,15 +1,14 @@
import os
from services.config import VERSION, ConfigCenter
from services.adapters.registry import runtime_registry
from domain.config.service import ConfigService, VERSION
from domain.adapters.registry import runtime_registry
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
from db.session import close_db, init_db
from api.routers import include_routers
from fastapi import FastAPI
from services.middleware.logging_middleware import LoggingMiddleware
from services.middleware.exception_handler import global_exception_handler
from middleware.exception_handler import global_exception_handler
from dotenv import load_dotenv
from services.task_queue import task_queue_service
from domain.tasks.task_queue import task_queue_service
load_dotenv()
@@ -19,7 +18,7 @@ async def lifespan(app: FastAPI):
os.makedirs("data/db", exist_ok=True)
await init_db()
await runtime_registry.refresh()
await ConfigCenter.set("APP_VERSION", VERSION)
await ConfigService.set("APP_VERSION", VERSION)
await task_queue_service.start_worker()
try:
yield
@@ -35,7 +34,6 @@ def create_app() -> FastAPI:
lifespan=lifespan,
)
include_routers(app)
app.add_middleware(LoggingMiddleware)
app.add_exception_handler(Exception, global_exception_handler)
return app

1
middleware/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""Middleware package for FastAPI app."""

View File

@@ -0,0 +1,11 @@
from fastapi import Request, status
from fastapi.responses import JSONResponse
async def global_exception_handler(request: Request, exc: Exception):
"""
全局异常处理
"""
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={"error": "Internal Server Error", "detail": str(exc)},
)

View File

@@ -128,17 +128,25 @@ class AutomationTask(Model):
table = "automation_tasks"
class Log(Model):
class AuditLog(Model):
id = fields.IntField(pk=True)
timestamp = fields.DatetimeField(auto_now_add=True)
level = fields.CharField(max_length=50)
source = fields.CharField(max_length=100)
message = fields.TextField()
details = fields.JSONField(null=True)
created_at = fields.DatetimeField(auto_now_add=True)
action = fields.CharField(max_length=50)
description = fields.TextField(null=True)
user_id = fields.IntField(null=True)
username = fields.CharField(max_length=100, null=True)
client_ip = fields.CharField(max_length=64, null=True)
method = fields.CharField(max_length=10)
path = fields.CharField(max_length=1024)
status_code = fields.IntField()
duration_ms = fields.FloatField(null=True)
success = fields.BooleanField(default=True)
request_params = fields.JSONField(null=True)
request_body = fields.JSONField(null=True)
error = fields.TextField(null=True)
class Meta:
table = "logs"
table = "audit_logs"
class ShareLink(Model):

View File

@@ -1,12 +0,0 @@
from schemas.plugins import PluginCreate,PluginOut
from .adapters import AdapterCreate, AdapterOut
from .fs import MkdirRequest, MoveRequest
__all__ = [
"PluginOut"
"PluginCreate"
"AdapterCreate",
"AdapterOut",
"MkdirRequest",
"MoveRequest",
]

View File

@@ -1,18 +0,0 @@
from typing import Any, Dict
from pydantic import BaseModel, EmailStr, Field
class EmailTestRequest(BaseModel):
to: EmailStr
subject: str = Field(..., min_length=1)
template: str = Field(default="test", min_length=1)
context: Dict[str, Any] = Field(default_factory=dict)
class EmailTemplateUpdate(BaseModel):
content: str
class EmailTemplatePreviewPayload(BaseModel):
context: Dict[str, Any] = Field(default_factory=dict)

View File

@@ -1,27 +0,0 @@
from typing import List, Optional, Dict, Any
from pydantic import BaseModel, Field
class PluginCreate(BaseModel):
url: str = Field(min_length=1)
enabled: bool = True
class PluginOut(BaseModel):
id: int
url: str
enabled: bool
key: Optional[str]
name: Optional[str]
version: Optional[str]
supported_exts: Optional[List[str]]
default_bounds: Optional[Dict[str, Any]]
default_maximized: Optional[bool]
icon: Optional[str]
description: Optional[str]
author: Optional[str]
website: Optional[str]
github: Optional[str]
class Config:
from_attributes = True

View File

@@ -1,321 +0,0 @@
import asyncio
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from typing import Annotated
import secrets
import jwt
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from jwt.exceptions import InvalidTokenError
from passlib.context import CryptContext
from pydantic import BaseModel
from models.database import UserAccount
from services.config import ConfigCenter
from services.logging import LogService
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 * 365
PASSWORD_RESET_TOKEN_EXPIRE_MINUTES = 10
def _now() -> datetime:
return datetime.now(timezone.utc)
@dataclass
class PasswordResetEntry:
user_id: int
email: str
username: str
expires_at: datetime
used: bool = False
class PasswordResetStore:
_tokens: dict[str, PasswordResetEntry] = {}
_lock = asyncio.Lock()
@classmethod
def _cleanup(cls):
now = _now()
for token, record in list(cls._tokens.items()):
if record.used or record.expires_at < now:
cls._tokens.pop(token, None)
@classmethod
async def create(cls, user: UserAccount) -> str:
async with cls._lock:
cls._cleanup()
for key, record in list(cls._tokens.items()):
if record.user_id == user.id:
cls._tokens.pop(key, None)
token = secrets.token_urlsafe(32)
expires_at = _now() + timedelta(minutes=PASSWORD_RESET_TOKEN_EXPIRE_MINUTES)
cls._tokens[token] = PasswordResetEntry(
user_id=user.id,
email=user.email or "",
username=user.username,
expires_at=expires_at,
)
return token
@classmethod
async def get(cls, token: str) -> PasswordResetEntry | None:
async with cls._lock:
cls._cleanup()
record = cls._tokens.get(token)
if not record or record.used:
return None
return record
@classmethod
async def mark_used(cls, token: str) -> None:
async with cls._lock:
record = cls._tokens.get(token)
if record:
record.used = True
cls._cleanup()
@classmethod
async def invalidate_user(cls, user_id: int, except_token: str | None = None) -> None:
async with cls._lock:
for key, record in list(cls._tokens.items()):
if record.user_id == user_id and key != except_token:
cls._tokens.pop(key, None)
cls._cleanup()
async def get_secret_key():
return await ConfigCenter.get_secret_key(
"SECRET_KEY", None
)
class Token(BaseModel):
access_token: str
token_type: str
class TokenData(BaseModel):
username: str | None = None
class User(BaseModel):
id:int
username: str
email: str | None = None
full_name: str | None = None
disabled: bool | None = None
class UserInDB(User):
hashed_password: str
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login")
def verify_password(plain_password, hashed_password):
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password):
return pwd_context.hash(password)
def get_user(db, username: str):
if username in db:
user_dict = db[username]
return UserInDB(**user_dict)
async def get_user_db(username_or_email: str):
user = await UserAccount.get_or_none(username=username_or_email)
if not user:
user = await UserAccount.get_or_none(email=username_or_email)
if user:
return UserInDB(
id= user.id,
username=user.username,
email=user.email,
full_name=user.full_name,
disabled=user.disabled,
hashed_password=user.hashed_password,
)
def authenticate_user(fake_db, username: str, password: str):
user = get_user(fake_db, username)
if not user:
return False
if not verify_password(password, user.hashed_password):
return False
return user
async def authenticate_user_db(username_or_email: str, password: str):
user = await get_user_db(username_or_email)
if not user:
return False
if not verify_password(password, user.hashed_password):
return False
return user
async def register_user(username: str, password: str, email: str = None, full_name: str = None):
if await has_users():
raise HTTPException(status_code=403, detail="系统已初始化,不允许注册新用户")
exists = await UserAccount.get_or_none(username=username)
if exists:
raise HTTPException(status_code=400, detail="用户名已存在")
hashed = get_password_hash(password)
user = await UserAccount.create(
username=username,
email=email,
full_name=full_name,
hashed_password=hashed,
disabled=False,
)
return user
async def has_users() -> bool:
"""
检查数据库中是否存在任何用户
"""
user_count = await UserAccount.all().count()
return user_count > 0
async def create_access_token(data: dict, expires_delta: timedelta | None = None):
to_encode = data.copy()
if "sub" not in to_encode and "username" in to_encode:
to_encode["sub"] = to_encode["username"]
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(minutes=15)
to_encode.update({"exp": expire})
secret_key = await get_secret_key()
encoded_jwt = jwt.encode(to_encode, secret_key, algorithm=ALGORITHM)
return encoded_jwt
def _normalize_email(email: str | None) -> str:
return (email or "").strip().lower()
async def _send_password_reset_email(user: UserAccount, token: str) -> None:
from services.email import EmailService
app_domain = await ConfigCenter.get("APP_DOMAIN", None)
base_url = (app_domain or "http://localhost:5173").rstrip("/")
reset_link = f"{base_url}/reset-password?token={token}"
await EmailService.enqueue_email(
recipients=[user.email],
subject="Foxel 密码重置",
template="password_reset",
context={
"username": user.username,
"reset_link": reset_link,
"expire_minutes": PASSWORD_RESET_TOKEN_EXPIRE_MINUTES,
},
)
async def request_password_reset(email: str) -> bool:
normalized = _normalize_email(email)
if not normalized:
return False
user = await UserAccount.get_or_none(email=normalized)
if not user or not user.email:
return False
token = await PasswordResetStore.create(user)
try:
await _send_password_reset_email(user, token)
except Exception as exc: # noqa: BLE001
await PasswordResetStore.mark_used(token)
await PasswordResetStore.invalidate_user(user.id)
await LogService.error(
"auth",
f"Failed to enqueue password reset email: {exc}",
details={"user_id": user.id},
user_id=user.id,
)
raise HTTPException(status_code=500, detail="邮件发送失败") from exc
await LogService.action(
"auth",
"Password reset requested",
details={"user_id": user.id},
user_id=user.id,
)
return True
async def verify_password_reset_token(token: str) -> UserAccount:
record = await PasswordResetStore.get(token)
if not record:
raise HTTPException(status_code=400, detail="重置链接无效")
user = await UserAccount.get_or_none(id=record.user_id)
if not user:
raise HTTPException(status_code=400, detail="重置链接无效")
if record.expires_at < _now():
await PasswordResetStore.mark_used(token)
raise HTTPException(status_code=400, detail="重置链接已过期")
return user
async def reset_password_with_token(token: str, new_password: str) -> None:
record = await PasswordResetStore.get(token)
if not record:
raise HTTPException(status_code=400, detail="重置链接无效")
if record.expires_at < _now():
await PasswordResetStore.mark_used(token)
raise HTTPException(status_code=400, detail="重置链接已过期")
user = await UserAccount.get_or_none(id=record.user_id)
if not user:
raise HTTPException(status_code=400, detail="重置链接无效")
user.hashed_password = get_password_hash(new_password)
await user.save(update_fields=["hashed_password"])
await PasswordResetStore.mark_used(token)
await PasswordResetStore.invalidate_user(user.id)
await LogService.action(
"auth",
"Password reset via email",
details={"user_id": user.id},
user_id=user.id,
)
async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]):
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
secret_key = await get_secret_key()
payload = jwt.decode(token, secret_key, algorithms=[ALGORITHM])
username = payload.get("sub")
if username is None:
raise credentials_exception
token_data = TokenData(username=username)
except InvalidTokenError:
raise credentials_exception
user = await get_user_db(token_data.username)
if user is None:
raise credentials_exception
return user
async def get_current_active_user(
current_user: Annotated[User, Depends(get_current_user)],
):
if current_user.disabled:
raise HTTPException(status_code=400, detail="Inactive user")
return current_user

View File

@@ -1,80 +0,0 @@
from tortoise.transactions import in_transaction
from models.database import (
StorageAdapter,
UserAccount,
AutomationTask,
ShareLink,
Configuration,
)
from services.config import VERSION
class BackupService:
@staticmethod
async def export_data():
"""
导出所有相关数据到JSON格式。
"""
async with in_transaction() as conn:
adapters = await StorageAdapter.all().values()
users = await UserAccount.all().values()
tasks = await AutomationTask.all().values()
shares = await ShareLink.all().values()
configs = await Configuration.all().values()
for share in shares:
share["created_at"] = share["created_at"].isoformat(
) if share.get("created_at") else None
share["expires_at"] = share["expires_at"].isoformat(
) if share.get("expires_at") else None
return {
"version": VERSION,
"storage_adapters": list(adapters),
"user_accounts": list(users),
"automation_tasks": list(tasks),
"share_links": list(shares),
"configurations": list(configs),
}
@staticmethod
async def import_data(data: dict):
"""
从JSON数据导入到数据库。
"""
async with in_transaction() as conn:
await ShareLink.all().using_db(conn).delete()
await AutomationTask.all().using_db(conn).delete()
await StorageAdapter.all().using_db(conn).delete()
await UserAccount.all().using_db(conn).delete()
await Configuration.all().using_db(conn).delete()
if data.get("configurations"):
await Configuration.bulk_create(
[Configuration(**c) for c in data["configurations"]],
using_db=conn
)
if data.get("user_accounts"):
await UserAccount.bulk_create(
[UserAccount(**u) for u in data["user_accounts"]],
using_db=conn
)
if data.get("storage_adapters"):
await StorageAdapter.bulk_create(
[StorageAdapter(**a) for a in data["storage_adapters"]],
using_db=conn
)
if data.get("automation_tasks"):
await AutomationTask.bulk_create(
[AutomationTask(**t) for t in data["automation_tasks"]],
using_db=conn
)
if data.get("share_links"):
await ShareLink.bulk_create(
[ShareLink(**s) for s in data["share_links"]],
using_db=conn
)

Some files were not shown because too many files have changed in this diff Show More