Compare commits

...

23 Commits

Author SHA1 Message Date
shiyu
fc85f21aaa chore: update version to v1.4.0 2025-12-09 15:22:02 +08:00
shiyu
16283dea09 chore(deploy): serve SPA without nginx and slim docker image 2025-12-09 14:21:20 +08:00
shiyu
055c240079 refactor: split virtual fs mapping and search modules 2025-12-09 11:09:24 +08:00
shiyu
12a3bb8efc feat: extend BackupData model 2025-12-08 23:45:43 +08:00
shiyu
050577cf62 feat: improve log detail display 2025-12-08 22:21:06 +08:00
shiyu
394c2f7229 feat: enhance exception handling with additional handlers for HTTP and validation errors 2025-12-08 21:21:57 +08:00
shiyu
8f515aaaf4 refactor: optimize backend module 2025-12-08 17:46:45 +08:00
shiyu
cf8d10f71c chore: update version to v1.3.8 2025-11-27 18:07:18 +08:00
shiyu
5c4d3a625b feat: add endpoint to list objects with trailing slash in S3 router 2025-11-26 15:23:44 +08:00
shiyu
f0a51c3369 chore: consolidate .gitignore files 2025-11-20 14:11:55 +08:00
shiyu
3278896d4b feat: normalize adapter types and improve validation in adapters 2025-11-20 12:43:41 +08:00
Copilot
219f3e81b8 feat: add Google Drive storage adapter (#50)
* Initial plan

* Add Google Drive storage adapter implementation

Co-authored-by: DrizzleTime <169802108+DrizzleTime@users.noreply.github.com>

* Add optional methods for direct download and thumbnail support

Co-authored-by: DrizzleTime <169802108+DrizzleTime@users.noreply.github.com>

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: DrizzleTime <169802108+DrizzleTime@users.noreply.github.com>
2025-11-20 11:49:32 +08:00
shiyu
8ef0a34642 fix: handle missing size and modification time in WebDAV file info 2025-11-15 14:17:54 +08:00
shiyu
8aaa2900ef fix: adjust Flex component styles for consistent width in loading states 2025-11-11 12:05:59 +08:00
shiyu
e3e68f5397 chore: update version to v1.3.7 2025-11-10 11:02:45 +08:00
shiyu
78dfbac458 feat: enhance object retrieval and upload handling in S3 routes 2025-11-10 10:07:25 +08:00
shiyu
583db651a7 feat: add S3 endpoint to Nginx configuration 2025-11-08 22:44:09 +08:00
shiyu
3a15362422 feat: add S3 mapping configuration and API endpoints 2025-11-07 23:06:51 +08:00
shiyu
e55a09d84f feat: add WebDAV mapping configuration and UI in System Settings 2025-11-07 22:32:50 +08:00
shiyu
8957174e6f chore: update version to v1.3.6 2025-11-07 12:06:51 +08:00
shiyu
abb6b0ce22 feat: add workflow to clean dangling Docker images 2025-11-07 11:08:09 +08:00
shiyu
74df438053 style: update grid layout for fx-grid and fx-grid-item for better responsiveness 2025-11-07 09:57:47 +08:00
shiyu
f271a8bee5 feat: add User ID and Raw Log fields to log details in LogsPage 2025-11-06 16:36:35 +08:00
147 changed files with 8808 additions and 6384 deletions

51
.github/workflows/docker-clean.yml vendored Normal file
View File

@@ -0,0 +1,51 @@
name: Clean dangling Docker images
on:
workflow_dispatch:
jobs:
docker-clean:
runs-on: ubuntu-latest
permissions:
contents: read
packages: write
steps:
- name: Delete untagged GHCR versions
shell: bash
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
set -euo pipefail
OWNER="${GITHUB_REPOSITORY_OWNER}"
PACKAGE="$(echo "${GITHUB_REPOSITORY##*/}" | tr '[:upper:]' '[:lower:]')"
OWNER_TYPE="$(gh api "/users/${OWNER}" -q '.type')"
if [[ "${OWNER_TYPE}" == "Organization" ]]; then
SCOPE="orgs/${OWNER}"
else
SCOPE="users/${OWNER}"
fi
BASE_PATH="/${SCOPE}/packages/container/${PACKAGE}"
if ! gh api "${BASE_PATH}" >/dev/null 2>&1; then
echo "Package ghcr.io/${OWNER}/${PACKAGE} not found or accessible. Nothing to clean."
exit 0
fi
mapfile -t VERSION_IDS < <(gh api --paginate "${BASE_PATH}/versions?per_page=100" \
-q '.[] | select(.metadata.container.tags | length == 0) | .id')
if [[ ${#VERSION_IDS[@]} -eq 0 ]]; then
echo "No untagged versions to delete."
exit 0
fi
echo "Deleting ${#VERSION_IDS[@]} untagged versions from ghcr.io/${OWNER}/${PACKAGE}..."
for id in "${VERSION_IDS[@]}"; do
gh api -X DELETE "${BASE_PATH}/versions/${id}" >/dev/null
echo "Deleted version ${id}"
done
echo "Cleanup complete."

27
.gitignore vendored
View File

@@ -7,4 +7,29 @@ __pycache__/
data/
migrate/
.env
AGENTS.md
AGENTS.md
# Logs
/web/logs
*.log
npm-debug.log*
yarn-debug.log*
yarn-error.log*
pnpm-debug.log*
lerna-debug.log*
/web/node_modules
/web/dist
/web/dist-ssr
/web/*.local
# Editor directories and files
.vscode/*
!.vscode/extensions.json
.idea
.DS_Store
*.suo
*.ntvs*
*.njsproj
*.sln
*.sw?

View File

@@ -137,8 +137,8 @@ Install the following tooling first:
Storage adapters integrate new storage providers (for example S3, FTP, or Alist).
1. Create a new module under [`services/adapters/`](services/adapters/) (for example `my_new_adapter.py`).
2. Implement a class that inherits from [`services.adapters.base.BaseAdapter`](services/adapters/base.py) and provide concrete implementations for the abstract methods such as `list_dir`, `get_meta`, `upload`, and `download`.
1. Create a new module under [`domain/adapters/providers/`](domain/adapters/providers/) (for example `my_new_adapter.py`).
2. Implement a class that inherits from [`domain.adapters.providers.base.BaseAdapter`](domain/adapters/providers/base.py) and provide concrete implementations for the abstract methods such as `list_dir`, `get_meta`, `upload`, and `download`.
### Frontend Apps

View File

@@ -143,9 +143,9 @@
存储适配器是 Foxel 的核心扩展点,用于接入不同的存储后端 (如 S3, FTP, Alist 等)。
1. **创建适配器文件**: 在 [`services/adapters/`](services/adapters/) 目录下,创建一个新文件,例如 `my_new_adapter.py`。
1. **创建适配器文件**: 在 [`domain/adapters/providers/`](domain/adapters/providers/) 目录下,创建一个新文件,例如 `my_new_adapter.py`。
2. **实现适配器类**:
- 创建一个类,继承自 [`services.adapters.base.BaseAdapter`](services/adapters/base.py)。
- 创建一个类,继承自 [`domain.adapters.providers.base.BaseAdapter`](domain/adapters/providers/base.py)。
- 实现 `BaseAdapter` 中定义的所有抽象方法,如 `list_dir`, `get_meta`, `upload`, `download` 等。请仔细阅读基类中的文档注释以理解每个方法的作用和参数。
### 贡献前端应用 (App)

View File

@@ -14,23 +14,26 @@ FROM python:3.13-slim
WORKDIR /app
RUN apt-get update \
&& apt-get install -y --no-install-recommends nginx git ffmpeg \
&& apt-get install -y --no-install-recommends ffmpeg curl ca-certificates \
&& rm -rf /var/lib/apt/lists/*
RUN pip install uv
COPY pyproject.toml uv.lock ./
RUN uv pip install --system . gunicorn
RUN uv pip install --system . gunicorn \
&& rm -rf /root/.cache
RUN git clone https://github.com/DrizzleTime/FoxelUpgrade /app/migrate
RUN curl -L https://github.com/DrizzleTime/FoxelUpgrade/archive/refs/heads/main.tar.gz -o /tmp/migrate.tgz \
&& mkdir -p /app/migrate \
&& tar -xzf /tmp/migrate.tgz --strip-components=1 -C /app/migrate \
&& rm -rf /tmp/migrate.tgz
COPY --from=frontend-builder /app/web/dist /app/web/dist
COPY . .
COPY nginx.conf /etc/nginx/nginx.conf
RUN mkdir -p data/db data/mount && \
chmod 777 data/db data/mount
chmod 777 data/db data/mount && \
rm -rf /var/log/apt /var/cache/apt/archives
EXPOSE 80

View File

@@ -1,25 +1,38 @@
from fastapi import FastAPI
from .routes import adapters, virtual_fs, auth, config, processors, tasks, logs, share, backup, search, vector_db, offline_downloads, ai_providers, email
from .routes import webdav
from .routes import plugins
from domain.adapters import api as adapters
from domain.auth import api as auth
from domain.backup import api as backup
from domain.config import api as config
from domain.email import api as email
from domain.offline_downloads import api as offline_downloads
from domain.plugins import api as plugins
from domain.processors import api as processors
from domain.share import api as share
from domain.tasks import api as tasks
from domain.ai import api as ai
from domain.virtual_fs import api as virtual_fs
from domain.virtual_fs.mapping import s3_api, webdav_api
from domain.virtual_fs.search import search_api
from domain.audit import router as audit
def include_routers(app: FastAPI):
app.include_router(adapters.router)
app.include_router(virtual_fs.router)
app.include_router(search.router)
app.include_router(search_api.router)
app.include_router(auth.router)
app.include_router(config.router)
app.include_router(processors.router)
app.include_router(tasks.router)
app.include_router(logs.router)
app.include_router(share.router)
app.include_router(share.public_router)
app.include_router(backup.router)
app.include_router(vector_db.router)
app.include_router(ai_providers.router)
app.include_router(ai.router_vector_db)
app.include_router(ai.router_ai)
app.include_router(plugins.router)
app.include_router(webdav.router)
app.include_router(webdav_api.router)
app.include_router(s3_api.router)
app.include_router(offline_downloads.router)
app.include_router(email.router)
app.include_router(audit)

View File

@@ -1,149 +0,0 @@
from fastapi import APIRouter, HTTPException, Depends
from tortoise.transactions import in_transaction
from typing import Annotated
from models import StorageAdapter
from schemas import AdapterCreate, AdapterOut
from services.auth import get_current_active_user, User
from services.adapters.registry import runtime_registry, get_config_schemas
from api.response import success
from services.logging import LogService
router = APIRouter(prefix="/api/adapters", tags=["adapters"])
def validate_and_normalize_config(adapter_type: str, cfg):
schemas = get_config_schemas()
if not isinstance(cfg, dict):
raise HTTPException(400, detail="config 必须是对象")
schema = schemas.get(adapter_type)
if not schema:
raise HTTPException(400, detail=f"不支持的适配器类型: {adapter_type}")
out = {}
missing = []
for f in schema:
k = f["key"]
if k in cfg and cfg[k] not in (None, ""):
out[k] = cfg[k]
elif "default" in f:
out[k] = f["default"]
elif f.get("required"):
missing.append(k)
if missing:
raise HTTPException(400, detail="缺少必填配置字段: " + ", ".join(missing))
return out
@router.post("")
async def create_adapter(
data: AdapterCreate,
current_user: Annotated[User, Depends(get_current_active_user)]
):
norm_path = AdapterCreate.normalize_mount_path(data.path)
exists = await StorageAdapter.get_or_none(path=norm_path)
if exists:
raise HTTPException(400, detail="Mount path already exists")
adapter_fields = {
"name": data.name,
"type": data.type,
"config": validate_and_normalize_config(data.type, data.config or {}),
"enabled": data.enabled,
"path": norm_path,
"sub_path": data.sub_path,
}
rec = await StorageAdapter.create(**adapter_fields)
await runtime_registry.upsert(rec)
await LogService.action(
"route:adapters",
f"Created adapter {rec.name}",
details=adapter_fields,
user_id=current_user.id if hasattr(current_user, "id") else None,
)
return success(rec)
@router.get("")
async def list_adapters(
current_user: Annotated[User, Depends(get_current_active_user)]
):
adapters = await StorageAdapter.all()
out = [AdapterOut.model_validate(a) for a in adapters]
return success(out)
@router.get("/available")
async def available_adapter_types(
current_user: Annotated[User, Depends(get_current_active_user)]
):
data = []
for t, fields in get_config_schemas().items():
data.append({
"type": t,
"name": "本地文件系统" if t == "local" else ("WebDAV" if t == "webdav" else t),
"config_schema": fields,
})
return success(data)
@router.get("/{adapter_id}")
async def get_adapter(
adapter_id: int,
current_user: Annotated[User, Depends(get_current_active_user)]
):
rec = await StorageAdapter.get_or_none(id=adapter_id)
if not rec:
raise HTTPException(404, detail="Not found")
return success(AdapterOut.model_validate(rec))
@router.put("/{adapter_id}")
async def update_adapter(
adapter_id: int,
data: AdapterCreate,
current_user: Annotated[User, Depends(get_current_active_user)]
):
rec = await StorageAdapter.get_or_none(id=adapter_id)
if not rec:
raise HTTPException(404, detail="Not found")
norm_path = AdapterCreate.normalize_mount_path(data.path)
existing = await StorageAdapter.get_or_none(path=norm_path)
if existing and existing.id != adapter_id:
raise HTTPException(400, detail="Mount path already exists")
rec.name = data.name
rec.type = data.type
rec.config = validate_and_normalize_config(data.type, data.config or {})
rec.enabled = data.enabled
rec.path = norm_path
rec.sub_path = data.sub_path
await rec.save()
await runtime_registry.upsert(rec)
await LogService.action(
"route:adapters",
f"Updated adapter {rec.name}",
details=data.model_dump(),
user_id=current_user.id if hasattr(current_user, "id") else None,
)
return success(rec)
@router.delete("/{adapter_id}")
async def delete_adapter(
adapter_id: int,
current_user: Annotated[User, Depends(get_current_active_user)]
):
deleted = await StorageAdapter.filter(id=adapter_id).delete()
if not deleted:
raise HTTPException(404, detail="Not found")
runtime_registry.remove(adapter_id)
await LogService.action(
"route:adapters",
f"Deleted adapter {adapter_id}",
details={"adapter_id": adapter_id},
user_id=current_user.id if hasattr(current_user, "id") else None,
)
return success({"deleted": True})

View File

@@ -1,177 +0,0 @@
from typing import Annotated, Dict, Optional
import httpx
from fastapi import APIRouter, Depends, HTTPException, Path
from api.response import success
from schemas.ai import (
AIDefaultsUpdate,
AIModelCreate,
AIModelUpdate,
AIProviderCreate,
AIProviderUpdate,
)
from services.ai_providers import AIProviderService
from services.auth import User, get_current_active_user
from services.vector_db import VectorDBService
router = APIRouter(prefix="/api/ai", tags=["ai"])
service = AIProviderService()
@router.get("/providers")
async def list_providers(
current_user: Annotated[User, Depends(get_current_active_user)]
):
providers = await service.list_providers()
return success({"providers": providers})
@router.post("/providers")
async def create_provider(
payload: AIProviderCreate,
current_user: Annotated[User, Depends(get_current_active_user)]
):
provider = await service.create_provider(payload.dict())
return success(provider)
@router.get("/providers/{provider_id}")
async def get_provider(
provider_id: Annotated[int, Path(..., gt=0)],
current_user: Annotated[User, Depends(get_current_active_user)],
):
provider = await service.get_provider(provider_id, with_models=True)
return success(provider)
@router.put("/providers/{provider_id}")
async def update_provider(
provider_id: Annotated[int, Path(..., gt=0)],
payload: AIProviderUpdate,
current_user: Annotated[User, Depends(get_current_active_user)],
):
data = {k: v for k, v in payload.dict().items() if v is not None}
if not data:
raise HTTPException(status_code=400, detail="No fields to update")
provider = await service.update_provider(provider_id, data)
return success(provider)
@router.delete("/providers/{provider_id}")
async def delete_provider(
provider_id: Annotated[int, Path(..., gt=0)],
current_user: Annotated[User, Depends(get_current_active_user)],
):
await service.delete_provider(provider_id)
return success({"id": provider_id})
@router.post("/providers/{provider_id}/sync-models")
async def sync_models(
provider_id: Annotated[int, Path(..., gt=0)],
current_user: Annotated[User, Depends(get_current_active_user)],
):
try:
result = await service.sync_models(provider_id)
except (httpx.RequestError, httpx.HTTPStatusError) as exc:
raise HTTPException(status_code=502, detail=f"Failed to synchronize models: {exc}") from exc
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
return success(result)
@router.get("/providers/{provider_id}/remote-models")
async def fetch_remote_models(
provider_id: Annotated[int, Path(..., gt=0)],
current_user: Annotated[User, Depends(get_current_active_user)],
):
try:
models = await service.fetch_remote_models(provider_id)
except (httpx.RequestError, httpx.HTTPStatusError) as exc:
raise HTTPException(status_code=502, detail=f"Failed to pull models: {exc}") from exc
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
return success({"models": models})
@router.get("/providers/{provider_id}/models")
async def list_models(
provider_id: Annotated[int, Path(..., gt=0)],
current_user: Annotated[User, Depends(get_current_active_user)],
):
models = await service.list_models(provider_id)
return success({"models": models})
@router.post("/providers/{provider_id}/models")
async def create_model(
provider_id: Annotated[int, Path(..., gt=0)],
payload: AIModelCreate,
current_user: Annotated[User, Depends(get_current_active_user)],
):
model = await service.create_model(provider_id, payload.dict())
return success(model)
@router.put("/models/{model_id}")
async def update_model(
model_id: Annotated[int, Path(..., gt=0)],
payload: AIModelUpdate,
current_user: Annotated[User, Depends(get_current_active_user)],
):
data = {k: v for k, v in payload.dict().items() if v is not None}
if not data:
raise HTTPException(status_code=400, detail="No fields to update")
model = await service.update_model(model_id, data)
return success(model)
@router.delete("/models/{model_id}")
async def delete_model(
model_id: Annotated[int, Path(..., gt=0)],
current_user: Annotated[User, Depends(get_current_active_user)],
):
await service.delete_model(model_id)
return success({"id": model_id})
def _get_embedding_dimension(entry: Optional[Dict]) -> Optional[int]:
if not entry:
return None
value = entry.get("embedding_dimensions")
return int(value) if value is not None else None
@router.get("/defaults")
async def get_defaults(
current_user: Annotated[User, Depends(get_current_active_user)],
):
defaults = await service.get_default_models()
return success(defaults)
@router.put("/defaults")
async def update_defaults(
payload: AIDefaultsUpdate,
current_user: Annotated[User, Depends(get_current_active_user)],
):
previous = await service.get_default_models()
try:
updated = await service.set_default_models(payload.as_mapping())
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
prev_dim = _get_embedding_dimension(previous.get("embedding"))
next_dim = _get_embedding_dimension(updated.get("embedding"))
if prev_dim and next_dim and prev_dim != next_dim:
try:
await VectorDBService().clear_all_data()
except Exception as exc: # noqa: BLE001
raise HTTPException(status_code=500, detail=f"Failed to clear vector database: {exc}") from exc
return success(updated)

View File

@@ -1,155 +0,0 @@
from typing import Annotated
from fastapi import APIRouter, HTTPException, Depends, Form
import hashlib
from fastapi.security import OAuth2PasswordRequestForm
from services.auth import (
authenticate_user_db,
create_access_token,
ACCESS_TOKEN_EXPIRE_MINUTES,
register_user,
Token,
get_current_active_user,
User,
request_password_reset,
verify_password_reset_token,
reset_password_with_token,
)
from pydantic import BaseModel
from datetime import timedelta
from api.response import success
from models.database import UserAccount
from services.auth import verify_password, get_password_hash
router = APIRouter(prefix="/api/auth", tags=["auth"])
class RegisterRequest(BaseModel):
username: str
password: str
email: str | None = None
full_name: str | None = None
@router.post("/register", summary="注册第一个管理员用户")
async def register(data: RegisterRequest):
"""
仅当系统中没有用户时,才允许注册。
"""
user = await register_user(
username=data.username,
password=data.password,
email=data.email,
full_name=data.full_name,
)
return success({"username": user.username}, msg="初始用户注册成功")
@router.post("/login")
async def login_for_access_token(
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
) -> Token:
user = await authenticate_user_db(form_data.username, form_data.password)
if not user:
raise HTTPException(
status_code=401,
detail="用户名或密码错误",
headers={"WWW-Authenticate": "Bearer"},
)
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = await create_access_token(
data={"sub": user.username}, expires_delta=access_token_expires
)
return Token(access_token=access_token, token_type="bearer")
@router.get("/me", summary="获取当前登录用户信息")
async def get_me(current_user: Annotated[User, Depends(get_current_active_user)]):
"""
返回当前登录用户的基本信息,并附带 gravatar 头像链接。
"""
email = (current_user.email or "").strip().lower()
md5_hash = hashlib.md5(email.encode("utf-8")).hexdigest()
gravatar_url = f"https://cn.cravatar.com/avatar/{md5_hash}?s=64&d=identicon"
return success({
"id": current_user.id,
"username": current_user.username,
"email": current_user.email,
"full_name": current_user.full_name,
"gravatar_url": gravatar_url,
})
class UpdateMeRequest(BaseModel):
email: str | None = None
full_name: str | None = None
old_password: str | None = None
new_password: str | None = None
class PasswordResetRequest(BaseModel):
email: str
class PasswordResetConfirm(BaseModel):
token: str
password: str
@router.put("/me", summary="更新当前登录用户信息")
async def update_me(
payload: UpdateMeRequest,
current_user: Annotated[User, Depends(get_current_active_user)],
):
db_user = await UserAccount.get_or_none(id=current_user.id)
if not db_user:
raise HTTPException(status_code=404, detail="用户不存在")
if payload.email is not None:
exists = await UserAccount.filter(email=payload.email).exclude(id=db_user.id).exists()
if exists:
raise HTTPException(status_code=400, detail="邮箱已被占用")
db_user.email = payload.email
if payload.full_name is not None:
db_user.full_name = payload.full_name
if payload.new_password:
if not payload.old_password:
raise HTTPException(status_code=400, detail="请提供原密码")
if not verify_password(payload.old_password, db_user.hashed_password):
raise HTTPException(status_code=400, detail="原密码错误")
db_user.hashed_password = get_password_hash(payload.new_password)
await db_user.save()
email = (db_user.email or "").strip().lower()
md5_hash = hashlib.md5(email.encode("utf-8")).hexdigest()
gravatar_url = f"https://cn.cravatar.com/avatar/{md5_hash}?s=64&d=identicon"
return success({
"id": db_user.id,
"username": db_user.username,
"email": db_user.email,
"full_name": db_user.full_name,
"gravatar_url": gravatar_url,
})
@router.post("/password-reset/request", summary="请求密码重置邮件")
async def password_reset_request_endpoint(payload: PasswordResetRequest):
await request_password_reset(payload.email)
return success(msg="如果邮箱存在,将发送重置邮件")
@router.get("/password-reset/verify", summary="校验密码重置令牌")
async def password_reset_verify(token: str):
user = await verify_password_reset_token(token)
return success({
"username": user.username,
"email": user.email,
})
@router.post("/password-reset/confirm", summary="使用令牌重置密码")
async def password_reset_confirm(payload: PasswordResetConfirm):
await reset_password_with_token(payload.token, payload.password)
return success(msg="密码已重置")

View File

@@ -1,50 +0,0 @@
from fastapi import APIRouter, Depends, UploadFile, File, HTTPException
from fastapi.responses import JSONResponse
from services.auth import get_current_active_user
from services.backup import BackupService
from models.database import UserAccount
import json
import datetime
router = APIRouter(
prefix="/api/backup",
tags=["Backup & Restore"],
dependencies=[Depends(get_current_active_user)],
)
@router.get("/export", summary="导出全站数据")
async def export_backup():
"""
生成并下载一个包含所有关键数据的JSON文件。
"""
try:
data = await BackupService.export_data()
timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
headers = {
"Content-Disposition": f"attachment; filename=foxel_backup_{timestamp}.json"
}
return JSONResponse(content=data, headers=headers)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/import", summary="导入数据")
async def import_backup(file: UploadFile = File(...)):
"""
从上传的JSON文件恢复数据。
**警告**: 这将会覆盖所有现有数据!
"""
if not file.filename.endswith(".json"):
raise HTTPException(status_code=400, detail="无效的文件类型, 请上传 .json 文件")
try:
contents = await file.read()
data = json.loads(contents)
except Exception:
raise HTTPException(status_code=400, detail="无法解析JSON文件")
try:
await BackupService.import_data(data)
return {"message": "数据导入成功。"}
except Exception as e:
raise HTTPException(status_code=500, detail=f"导入失败: {e}")

View File

@@ -1,83 +0,0 @@
import httpx
import time
from fastapi import APIRouter, Depends, Form
from typing import Annotated
from services.config import ConfigCenter, VERSION
from services.auth import get_current_active_user, User, has_users
from api.response import success
router = APIRouter(prefix="/api/config", tags=["config"])
@router.get("/")
async def get_config(
current_user: Annotated[User, Depends(get_current_active_user)],
key: str
):
value = await ConfigCenter.get(key)
return success({"key": key, "value": value})
@router.post("/")
async def set_config(
current_user: Annotated[User, Depends(get_current_active_user)],
key: str = Form(...),
value: str = Form(...)
):
await ConfigCenter.set(key, value)
return success({"key": key, "value": value})
@router.get("/all")
async def get_all_config(
current_user: Annotated[User, Depends(get_current_active_user)]
):
configs = await ConfigCenter.get_all()
return success(configs)
@router.get("/status")
async def get_system_status():
logo = await ConfigCenter.get("APP_LOGO", "/logo.svg")
favicon = await ConfigCenter.get("APP_FAVICON", logo)
system_info = {
"version": VERSION,
"title": await ConfigCenter.get("APP_NAME", "Foxel"),
"logo": logo,
"favicon": favicon,
"is_initialized": await has_users(),
"app_domain": await ConfigCenter.get("APP_DOMAIN"),
"file_domain": await ConfigCenter.get("FILE_DOMAIN"),
}
return success(system_info)
latest_version_cache = {
"timestamp": 0,
"data": None
}
@router.get("/latest-version")
async def get_latest_version():
current_time = time.time()
if current_time - latest_version_cache["timestamp"] < 3600 and latest_version_cache["data"]:
return success(latest_version_cache["data"])
try:
async with httpx.AsyncClient(timeout=10.0) as client:
resp = await client.get(
"https://api.github.com/repos/DrizzleTime/Foxel/releases/latest",
follow_redirects=True,
)
resp.raise_for_status()
data = resp.json()
version_info = {
"latest_version": data.get("tag_name"),
"body": data.get("body")
}
latest_version_cache["timestamp"] = current_time
latest_version_cache["data"] = version_info
return success(version_info)
except httpx.RequestError as e:
if latest_version_cache["data"]:
return success(latest_version_cache["data"])
return success({"latest_version": None, "body": None})

View File

@@ -1,48 +0,0 @@
from typing import Optional
from fastapi import APIRouter, Query
from models.database import Log
from api.response import page, success
from tortoise.expressions import Q
from datetime import datetime
router = APIRouter(prefix="/api/logs", tags=["Logs"])
@router.get("")
async def get_logs(
page_num: int = Query(1, alias="page"),
page_size: int = Query(20, alias="page_size"),
level: Optional[str] = Query(None),
source: Optional[str] = Query(None),
start_time: Optional[datetime] = Query(None),
end_time: Optional[datetime] = Query(None),
):
"""获取日志列表,支持分页和筛选"""
query = Log.all()
if level:
query = query.filter(level=level)
if source:
query = query.filter(source__icontains=source)
if start_time:
query = query.filter(timestamp__gte=start_time)
if end_time:
query = query.filter(timestamp__lte=end_time)
total = await query.count()
logs = await query.order_by("-timestamp").offset((page_num - 1) * page_size).limit(page_size)
return success(page([log for log in logs], total, page_num, page_size))
@router.delete("")
async def clear_logs(
start_time: Optional[datetime] = Query(None),
end_time: Optional[datetime] = Query(None),
):
"""清理指定时间范围内的日志"""
query = Log.all()
if start_time:
query = query.filter(timestamp__gte=start_time)
if end_time:
query = query.filter(timestamp__lte=end_time)
deleted_count = await query.delete()
return success({"deleted_count": deleted_count})

View File

@@ -1,79 +0,0 @@
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException
from api.response import success
from schemas.offline_downloads import OfflineDownloadCreate
from services.auth import User, get_current_active_user
from services.logging import LogService
from services.task_queue import task_queue_service, TaskProgress
from services.virtual_fs import path_is_directory
router = APIRouter(
prefix="/api/offline-downloads",
tags=["OfflineDownloads"],
)
@router.post("/")
async def create_offline_download(
payload: OfflineDownloadCreate,
current_user: Annotated[User, Depends(get_current_active_user)],
):
dest_dir = payload.dest_dir
try:
is_dir = await path_is_directory(dest_dir)
except HTTPException:
is_dir = False
if not is_dir:
raise HTTPException(400, detail="Destination directory not found")
task = await task_queue_service.add_task(
"offline_http_download",
{
"url": str(payload.url),
"dest_dir": dest_dir,
"filename": payload.filename,
},
)
await task_queue_service.update_progress(
task.id,
TaskProgress(
stage="queued",
percent=0.0,
bytes_total=None,
bytes_done=0,
detail="Waiting to start",
),
)
await LogService.action(
"route:offline_downloads",
f"Offline download task created {task.id}",
details={"url": str(payload.url), "dest_dir": dest_dir, "filename": payload.filename},
user_id=current_user.id if hasattr(current_user, "id") else None,
)
return success({"task_id": task.id})
@router.get("/")
async def list_offline_downloads(
current_user: Annotated[User, Depends(get_current_active_user)],
):
tasks = [t for t in task_queue_service.get_all_tasks() if t.name == "offline_http_download"]
data = [t.dict() for t in tasks]
return success(data)
@router.get("/{task_id}")
async def get_offline_download(
task_id: str,
current_user: Annotated[User, Depends(get_current_active_user)],
):
task = task_queue_service.get_task(task_id)
if not task or task.name != "offline_http_download":
raise HTTPException(status_code=404, detail="Task not found")
return success(task.dict())

View File

@@ -1,73 +0,0 @@
from typing import List, Any, Dict
from fastapi import APIRouter, HTTPException, Body
from models import database
from schemas import PluginCreate, PluginOut
router = APIRouter(prefix="/api/plugins", tags=["plugins"])
@router.post("", response_model=PluginOut)
async def create_plugin(payload: PluginCreate):
rec = await database.Plugin.create(
url=payload.url,
enabled=payload.enabled,
)
return PluginOut.model_validate(rec)
@router.get("", response_model=List[PluginOut])
async def list_plugins():
rows = await database.Plugin.all().order_by("-id")
return [PluginOut.model_validate(r) for r in rows]
@router.delete("/{plugin_id}")
async def delete_plugin(plugin_id: int):
rec = await database.Plugin.get_or_none(id=plugin_id)
if not rec:
raise HTTPException(status_code=404, detail="Plugin not found")
await rec.delete()
return {"code": 0, "msg": "ok"}
@router.put("/{plugin_id}", response_model=PluginOut)
async def update_plugin(plugin_id: int, payload: PluginCreate):
rec = await database.Plugin.get_or_none(id=plugin_id)
if not rec:
raise HTTPException(status_code=404, detail="Plugin not found")
rec.url = payload.url
rec.enabled = payload.enabled
await rec.save()
return PluginOut.model_validate(rec)
@router.post("/{plugin_id}/metadata", response_model=PluginOut)
async def update_manifest(plugin_id: int, manifest: Dict[str, Any] = Body(...)):
rec = await database.Plugin.get_or_none(id=plugin_id)
if not rec:
raise HTTPException(status_code=404, detail="Plugin not found")
key_map = {
'key': 'key',
'name': 'name',
'version': 'version',
'supported_exts': 'supported_exts',
'supportedExts': 'supported_exts',
'default_bounds': 'default_bounds',
'defaultBounds': 'default_bounds',
'default_maximized': 'default_maximized',
'defaultMaximized': 'default_maximized',
'icon': 'icon',
'description': 'description',
'author': 'author',
'website': 'website',
'github': 'github',
}
for k, v in list(manifest.items()):
if v is None:
continue
attr = key_map.get(k)
if not attr:
continue
setattr(rec, attr, v)
await rec.save()
return PluginOut.model_validate(rec)

View File

@@ -1,250 +0,0 @@
from pathlib import Path
from fastapi import APIRouter, Depends, Body, HTTPException
from fastapi.concurrency import run_in_threadpool
from typing import Annotated
from services.processors.registry import (
get,
get_config_schema,
get_config_schemas,
get_module_path,
reload_processors,
)
from services.task_queue import task_queue_service
from services.auth import get_current_active_user, User
from api.response import success
from pydantic import BaseModel
from services.virtual_fs import path_is_directory, resolve_adapter_and_rel
from typing import List, Optional, Tuple
router = APIRouter(prefix="/api/processors", tags=["processors"])
@router.get("")
async def list_processors(
current_user: Annotated[User, Depends(get_current_active_user)]
):
schemas = get_config_schemas()
out = []
for t, meta in schemas.items():
out.append({
"type": meta["type"],
"name": meta["name"],
"supported_exts": meta.get("supported_exts", []),
"config_schema": meta["config_schema"],
"produces_file": meta.get("produces_file", False),
"module_path": meta.get("module_path"),
})
return success(out)
class ProcessRequest(BaseModel):
path: str
processor_type: str
config: dict
save_to: str | None = None
overwrite: bool = False
class ProcessDirectoryRequest(BaseModel):
path: str
processor_type: str
config: dict
overwrite: bool = True
max_depth: Optional[int] = None
suffix: Optional[str] = None
class UpdateSourceRequest(BaseModel):
source: str
@router.post("/process")
async def process_file_with_processor(
current_user: Annotated[User, Depends(get_current_active_user)],
req: ProcessRequest = Body(...)
):
is_dir = await path_is_directory(req.path)
if is_dir and not req.overwrite:
raise HTTPException(400, detail="Directory processing requires overwrite")
save_to = None if is_dir else (req.path if req.overwrite else req.save_to)
task = await task_queue_service.add_task(
"process_file",
{
"path": req.path,
"processor_type": req.processor_type,
"config": req.config,
"save_to": save_to,
"overwrite": req.overwrite,
},
)
return success({"task_id": task.id})
@router.post("/process-directory")
async def process_directory_with_processor(
current_user: Annotated[User, Depends(get_current_active_user)],
req: ProcessDirectoryRequest = Body(...)
):
if req.max_depth is not None and req.max_depth < 0:
raise HTTPException(400, detail="max_depth must be >= 0")
is_dir = await path_is_directory(req.path)
if not is_dir:
raise HTTPException(400, detail="Path must be a directory")
schema = get_config_schema(req.processor_type)
_processor = get(req.processor_type)
if not schema or not _processor:
raise HTTPException(404, detail="Processor not found")
produces_file = bool(schema.get("produces_file"))
raw_suffix = req.suffix if req.suffix is not None else None
if raw_suffix is not None and raw_suffix.strip() == "":
raw_suffix = None
suffix = raw_suffix
overwrite = req.overwrite
if produces_file:
if not overwrite and not suffix:
raise HTTPException(400, detail="Suffix is required when not overwriting files")
else:
overwrite = False
suffix = None
supported_exts = schema.get("supported_exts") or []
allowed_exts = {
ext.lower().lstrip('.')
for ext in supported_exts
if isinstance(ext, str)
}
def matches_extension(file_rel: str) -> bool:
if not allowed_exts:
return True
if '.' not in file_rel:
return '' in allowed_exts
ext = file_rel.rsplit('.', 1)[-1].lower()
return ext in allowed_exts or f'.{ext}' in allowed_exts
adapter_instance, adapter_model, root, rel = await resolve_adapter_and_rel(req.path)
rel = rel.rstrip('/')
list_dir = getattr(adapter_instance, "list_dir", None)
if not callable(list_dir):
raise HTTPException(501, detail="Adapter does not implement list_dir")
def build_absolute_path(mount_path: str, rel_path: str) -> str:
rel_norm = rel_path.lstrip('/')
mount_norm = mount_path.rstrip('/')
if not mount_norm:
return '/' + rel_norm if rel_norm else '/'
return f"{mount_norm}/{rel_norm}" if rel_norm else mount_norm
def apply_suffix(path_str: str, suffix_str: str) -> str:
path_obj = Path(path_str)
name = path_obj.name
if not name:
return path_str
if '.' in name:
base, ext = name.rsplit('.', 1)
new_name = f"{base}{suffix_str}.{ext}"
else:
new_name = f"{name}{suffix_str}"
return str(path_obj.with_name(new_name))
scheduled_tasks: List[str] = []
stack: List[Tuple[str, int]] = [(rel, 0)]
page_size = 200
while stack:
current_rel, depth = stack.pop()
page = 1
while True:
entries, total = await list_dir(root, current_rel, page, page_size, "name", "asc")
entries = entries or []
if not entries and (total or 0) == 0:
break
for entry in entries:
name = entry.get("name")
if not name:
continue
child_rel = f"{current_rel}/{name}" if current_rel else name
if entry.get("is_dir"):
if req.max_depth is None or depth < req.max_depth:
stack.append((child_rel.rstrip('/'), depth + 1))
continue
if not matches_extension(child_rel):
continue
absolute_path = build_absolute_path(adapter_model.path, child_rel)
save_to = None
if produces_file and not overwrite and suffix:
save_to = apply_suffix(absolute_path, suffix)
task = await task_queue_service.add_task(
"process_file",
{
"path": absolute_path,
"processor_type": req.processor_type,
"config": req.config,
"save_to": save_to,
"overwrite": overwrite,
},
)
scheduled_tasks.append(task.id)
if total is None or page * page_size >= total:
break
page += 1
return success({
"task_ids": scheduled_tasks,
"scheduled": len(scheduled_tasks),
})
@router.get("/source/{processor_type}")
async def get_processor_source(
processor_type: str,
current_user: Annotated[User, Depends(get_current_active_user)],
):
module_path = get_module_path(processor_type)
if not module_path:
raise HTTPException(404, detail="Processor not found")
path_obj = Path(module_path)
if not path_obj.exists():
raise HTTPException(404, detail="Processor source not found")
try:
content = await run_in_threadpool(path_obj.read_text, encoding='utf-8')
except Exception as exc:
raise HTTPException(500, detail=f"Failed to read source: {exc}")
return success({"source": content, "module_path": str(path_obj)})
@router.put("/source/{processor_type}")
async def update_processor_source(
processor_type: str,
req: UpdateSourceRequest,
current_user: Annotated[User, Depends(get_current_active_user)],
):
module_path = get_module_path(processor_type)
if not module_path:
raise HTTPException(404, detail="Processor not found")
path_obj = Path(module_path)
if not path_obj.exists():
raise HTTPException(404, detail="Processor source not found")
try:
await run_in_threadpool(path_obj.write_text, req.source, encoding='utf-8')
except Exception as exc:
raise HTTPException(500, detail=f"Failed to write source: {exc}")
return success(True)
@router.post("/reload")
async def reload_processor_modules(
current_user: Annotated[User, Depends(get_current_active_user)],
):
errors = reload_processors()
if errors:
raise HTTPException(500, detail="; ".join(errors))
return success(True)

View File

@@ -1,217 +0,0 @@
from typing import List, Optional
from urllib.parse import quote
from fastapi import APIRouter, Depends, HTTPException, Request
from pydantic import BaseModel
from api.response import success
from services.auth import User, get_current_active_user
from services.share import share_service
from services.virtual_fs import stream_file, stat_file
from models.database import ShareLink, UserAccount
public_router = APIRouter(prefix="/api/s", tags=["Share - Public"])
router = APIRouter(prefix="/api/shares", tags=["Share - Management"])
class ShareCreate(BaseModel):
name: str
paths: List[str]
expires_in_days: Optional[int] = 7
access_type: str = "public"
password: Optional[str] = None
class ShareInfo(BaseModel):
id: int
token: str
name: str
paths: List[str]
created_at: str
expires_at: Optional[str] = None
access_type: str
@classmethod
def from_orm(cls, obj: ShareLink):
return cls(
id=obj.id,
token=obj.token,
name=obj.name,
paths=obj.paths,
created_at=obj.created_at.isoformat(),
expires_at=obj.expires_at.isoformat() if obj.expires_at else None,
access_type=obj.access_type,
)
class ShareInfoWithPassword(ShareInfo):
password: Optional[str] = None
# --- Management Routes ---
@router.post("", response_model=ShareInfoWithPassword)
async def create_share(
payload: ShareCreate,
current_user: User = Depends(get_current_active_user),
):
"""
创建一个新的分享链接。
"""
user_account = await UserAccount.get(id=current_user.id)
share = await share_service.create_share_link(
user=user_account,
name=payload.name,
paths=payload.paths,
expires_in_days=payload.expires_in_days,
access_type=payload.access_type,
password=payload.password,
)
share_info_base = ShareInfo.from_orm(share)
response_data = share_info_base.model_dump()
if payload.access_type == "password" and payload.password:
response_data['password'] = payload.password
return response_data
@router.get("", response_model=List[ShareInfo])
async def get_my_shares(current_user: User = Depends(get_current_active_user)):
"""
获取当前用户的所有分享链接。
"""
user_account = await UserAccount.get(id=current_user.id)
shares = await share_service.get_user_shares(user=user_account)
return [ShareInfo.from_orm(s) for s in shares]
@router.delete("/expired")
async def delete_expired_shares(
current_user: User = Depends(get_current_active_user),
):
"""
删除当前用户的所有已过期分享。
"""
user_account = await UserAccount.get(id=current_user.id)
deleted_count = await share_service.delete_expired_shares(user=user_account)
return success({"deleted_count": deleted_count})
@router.delete("/{share_id}")
async def delete_share(
share_id: int,
current_user: User = Depends(get_current_active_user),
):
"""
删除一个分享链接。
"""
await share_service.delete_share_link(user=current_user, share_id=share_id)
return success(msg="分享已取消")
# --- Public Routes ---
class SharePassword(BaseModel):
password: str
@public_router.post("/{token}/verify")
async def verify_password(token: str, payload: SharePassword):
"""
验证分享链接的密码。
"""
share = await share_service.get_share_by_token(token)
if share.access_type != "password":
raise HTTPException(status_code=400, detail="此分享不需要密码")
if not share_service._verify_password(payload.password, share.hashed_password):
raise HTTPException(status_code=403, detail="密码错误")
# 在这里可以考虑返回一个有时效性的token用于后续访问但为了简单起见
# 我们让前端在每次请求时都带上密码或一个会话标识。
# 简单起见,我们只返回成功状态。
return success(msg="验证成功")
@public_router.get("/{token}/ls")
async def list_share_content(token: str, path: str = "/", password: Optional[str] = None):
"""
列出分享链接中的文件和目录。
"""
share = await share_service.get_share_by_token(token)
if share.access_type == "password":
if not password:
raise HTTPException(status_code=401, detail="需要密码")
if not share_service._verify_password(password, share.hashed_password):
raise HTTPException(status_code=403, detail="密码错误")
content = await share_service.get_shared_item_details(share, path)
return success({
"path": path,
"entries": content.get("items", []),
"pagination": {
"total": content.get("total", 0),
"page": content.get("page", 1),
"page_size": content.get("page_size", 1),
"pages": content.get("pages", 1),
}
})
@public_router.get("/{token}")
async def get_share_info(token: str):
"""
获取分享链接的元数据信息。
"""
share = await share_service.get_share_by_token(token)
return success(ShareInfo.from_orm(share))
@public_router.get("/{token}/download")
async def download_shared_file(token: str, path: str, request: Request, password: Optional[str] = None):
"""
下载分享链接中的单个文件。
"""
if not path or path == "/" or ".." in path.split('/'):
raise HTTPException(status_code=400, detail="无效的文件路径")
share = await share_service.get_share_by_token(token)
if share.access_type == "password":
if not password:
raise HTTPException(status_code=401, detail="需要密码")
if not share_service._verify_password(password, share.hashed_password):
raise HTTPException(status_code=403, detail="密码错误")
base_shared_path = share.paths[0]
# 判断分享的是文件还是目录
is_dir = False
try:
stat = await stat_file(base_shared_path)
if stat and stat.get("is_dir"):
is_dir = True
except HTTPException as e:
if "Path is a directory" in str(e.detail) or "Not a file" in str(e.detail):
is_dir = True
else:
# The shared path itself doesn't exist, which is an issue.
raise HTTPException(status_code=404, detail="分享的源文件不存在")
if is_dir:
# 目录分享:拼接路径
full_virtual_path = f"{base_shared_path.rstrip('/')}/{path.lstrip('/')}"
if not full_virtual_path.startswith(base_shared_path):
raise HTTPException(status_code=403, detail="无权访问此路径")
else:
# 文件分享:路径应为分享的根路径
shared_filename = base_shared_path.split('/')[-1]
request_filename = path.lstrip('/')
if shared_filename != request_filename:
raise HTTPException(status_code=403, detail="无权访问此路径")
full_virtual_path = base_shared_path
range_header = request.headers.get("Range")
response = await stream_file(full_virtual_path, range_header)
# 设置 Content-Disposition 头来强制下载
filename = full_virtual_path.split('/')[-1]
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{quote(filename)}"
return response

View File

@@ -1,141 +0,0 @@
from fastapi import APIRouter, Depends, HTTPException
from typing import Annotated
from models.database import AutomationTask
from schemas.tasks import (
AutomationTaskCreate,
AutomationTaskUpdate,
TaskQueueSettings,
TaskQueueSettingsResponse,
)
from api.response import success
from services.auth import get_current_active_user, User
from services.logging import LogService
from services.task_queue import task_queue_service
from services.config import ConfigCenter
router = APIRouter(
prefix="/api/tasks",
tags=["Tasks"],
dependencies=[Depends(get_current_active_user)],
responses={404: {"description": "Not found"}},
)
@router.get("/queue")
async def get_task_queue_status(
current_user: Annotated[User, Depends(get_current_active_user)],
):
tasks = task_queue_service.get_all_tasks()
return success([task.dict() for task in tasks])
@router.get("/queue/settings")
async def get_task_queue_settings(
current_user: Annotated[User, Depends(get_current_active_user)],
):
payload = TaskQueueSettingsResponse(
concurrency=task_queue_service.get_concurrency(),
active_workers=task_queue_service.get_active_worker_count(),
)
return success(payload.model_dump())
@router.post("/queue/settings")
async def update_task_queue_settings(
settings: TaskQueueSettings,
current_user: Annotated[User, Depends(get_current_active_user)],
):
await task_queue_service.set_concurrency(settings.concurrency)
await ConfigCenter.set("TASK_QUEUE_CONCURRENCY", str(task_queue_service.get_concurrency()))
await LogService.action(
"route:tasks",
"Updated task queue settings",
details={"concurrency": settings.concurrency},
user_id=getattr(current_user, "id", None),
)
payload = TaskQueueSettingsResponse(
concurrency=task_queue_service.get_concurrency(),
active_workers=task_queue_service.get_active_worker_count(),
)
return success(payload.model_dump())
@router.get("/queue/{task_id}")
async def get_task_status(
task_id: str,
current_user: Annotated[User, Depends(get_current_active_user)],
):
task = task_queue_service.get_task(task_id)
if not task:
raise HTTPException(status_code=404, detail="Task not found")
return success(task.dict())
@router.post("/")
async def create_task(
task_in: AutomationTaskCreate,
user: User = Depends(get_current_active_user)
):
task = await AutomationTask.create(**task_in.model_dump())
await LogService.action(
"route:tasks",
f"Created task {task.name}",
details=task_in.model_dump(),
user_id=user.id if hasattr(user, "id") else None,
)
return success(task)
@router.get("/{task_id}")
async def get_task(task_id: int):
task = await AutomationTask.get_or_none(id=task_id)
if not task:
raise HTTPException(
status_code=404, detail=f"Task {task_id} not found")
return success(task)
@router.get("/")
async def list_tasks():
tasks = await AutomationTask.all()
return success(tasks)
@router.put("/{task_id}")
async def update_task(
current_user: Annotated[User, Depends(get_current_active_user)],
task_id: int, task_in: AutomationTaskUpdate):
task = await AutomationTask.get_or_none(id=task_id)
if not task:
raise HTTPException(
status_code=404, detail=f"Task {task_id} not found")
update_data = task_in.model_dump(exclude_unset=True)
for key, value in update_data.items():
setattr(task, key, value)
await task.save()
await LogService.action(
"route:tasks",
f"Updated task {task.name}",
details=task_in.model_dump(),
user_id=current_user.id,
)
return success(task)
@router.delete("/{task_id}")
async def delete_task(
task_id: int,
user: User = Depends(get_current_active_user)
):
deleted_count = await AutomationTask.filter(id=task_id).delete()
if not deleted_count:
raise HTTPException(
status_code=404, detail=f"Task {task_id} not found")
await LogService.action(
"route:tasks",
f"Deleted task {task_id}",
details={"task_id": task_id},
user_id=user.id if hasattr(user, "id") else None,
)
return success(msg="Task deleted")

View File

@@ -1,91 +0,0 @@
from typing import Any, Dict
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field
from services.auth import get_current_active_user
from models.database import UserAccount
from services.vector_db import (
VectorDBService,
VectorDBConfigManager,
list_providers,
get_provider_entry,
)
from services.vector_db.providers import get_provider_class
from api.response import success
router = APIRouter(prefix="/api/vector-db", tags=["vector-db"])
class VectorDBConfigPayload(BaseModel):
type: str = Field(..., description="向量数据库提供者类型")
config: Dict[str, Any] = Field(default_factory=dict, description="提供者配置参数")
@router.post("/clear-all", summary="清空向量数据库")
async def clear_vector_db(user: UserAccount = Depends(get_current_active_user)):
try:
service = VectorDBService()
await service.clear_all_data()
return success(msg="向量数据库已清空")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/stats", summary="获取向量数据库统计")
async def get_vector_db_stats(user: UserAccount = Depends(get_current_active_user)):
try:
service = VectorDBService()
data = await service.get_all_stats()
return success(data=data)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/providers", summary="列出可用向量数据库提供者")
async def list_vector_providers(user: UserAccount = Depends(get_current_active_user)):
return success(list_providers())
@router.get("/config", summary="获取当前向量数据库配置")
async def get_vector_db_config(user: UserAccount = Depends(get_current_active_user)):
service = VectorDBService()
data = await service.current_provider()
return success(data)
@router.post("/config", summary="更新向量数据库配置")
async def update_vector_db_config(payload: VectorDBConfigPayload, user: UserAccount = Depends(get_current_active_user)):
entry = get_provider_entry(payload.type)
if not entry:
raise HTTPException(
status_code=400, detail=f"未知的向量数据库类型: {payload.type}")
if not entry.get("enabled", True):
raise HTTPException(status_code=400, detail="该向量数据库类型暂不可用")
provider_cls = get_provider_class(payload.type)
if not provider_cls:
raise HTTPException(
status_code=400, detail=f"未找到类型 {payload.type} 对应的实现")
# 先尝试建立连接,确保配置有效
test_provider = provider_cls(payload.config)
try:
await test_provider.initialize()
except Exception as exc:
raise HTTPException(status_code=400, detail=str(exc))
finally:
client = getattr(test_provider, "client", None)
close_fn = getattr(client, "close", None)
if callable(close_fn):
try:
close_fn()
except Exception:
pass
await VectorDBConfigManager.save_config(payload.type, payload.config)
service = VectorDBService()
await service.reload()
config_data = await service.current_provider()
stats = await service.get_all_stats()
return success({"config": config_data, "stats": stats})

View File

@@ -1,376 +0,0 @@
from fastapi import APIRouter, UploadFile, File, HTTPException, Response, Query, Request, Depends
import mimetypes
import re
from typing import Annotated
from services.auth import get_current_active_user, User
from services.virtual_fs import (
list_virtual_dir,
read_file,
write_file,
make_dir,
delete_path,
move_path,
resolve_adapter_and_rel,
stream_file,
generate_temp_link_token,
verify_temp_link_token,
maybe_redirect_download,
)
from services.thumbnail import is_image_filename, get_or_create_thumb, is_raw_filename, is_video_filename
from schemas import MkdirRequest, MoveRequest
from api.response import success
from services.config import ConfigCenter
router = APIRouter(prefix='/api/fs', tags=["virtual-fs"])
@router.get("/file/{full_path:path}")
async def get_file(
full_path: str,
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)]
):
full_path = '/' + full_path if not full_path.startswith('/') else full_path
if is_raw_filename(full_path):
import rawpy
from PIL import Image
import io
try:
raw_data = await read_file(full_path)
with rawpy.imread(io.BytesIO(raw_data)) as raw:
rgb = raw.postprocess(use_camera_wb=True, output_bps=8)
im = Image.fromarray(rgb)
buf = io.BytesIO()
im.save(buf, 'JPEG', quality=90)
content = buf.getvalue()
return Response(content=content, media_type='image/jpeg')
except FileNotFoundError:
raise HTTPException(404, detail="File not found")
except Exception as e:
raise HTTPException(500, detail=f"RAW file processing failed: {e}")
adapter_instance, adapter_model, root, rel = await resolve_adapter_and_rel(full_path)
redirect_response = await maybe_redirect_download(adapter_instance, adapter_model, root, rel)
if redirect_response is not None:
return redirect_response
try:
content = await read_file(full_path)
except FileNotFoundError:
raise HTTPException(404, detail="File not found")
if not isinstance(content, (bytes, bytearray)):
return Response(content=content, media_type="application/octet-stream")
content_length = len(content)
content_type = mimetypes.guess_type(
full_path)[0] or "application/octet-stream"
range_header = request.headers.get('Range')
if range_header:
range_match = re.match(r'bytes=(\d+)-(\d*)', range_header)
if range_match:
start = int(range_match.group(1))
end = int(range_match.group(2)) if range_match.group(
2) else content_length - 1
start = max(0, min(start, content_length - 1))
end = max(start, min(end, content_length - 1))
chunk = content[start:end + 1]
chunk_size = len(chunk)
headers = {
'Content-Range': f'bytes {start}-{end}/{content_length}',
'Accept-Ranges': 'bytes',
'Content-Length': str(chunk_size),
'Content-Type': content_type,
}
return Response(
content=chunk,
status_code=206,
headers=headers
)
headers = {
'Accept-Ranges': 'bytes',
'Content-Length': str(content_length),
'Content-Type': content_type,
}
if content_type.startswith('video/'):
headers['Cache-Control'] = 'public, max-age=3600'
return Response(content=content, headers=headers)
@router.get("/thumb/{full_path:path}")
async def get_thumb(
full_path: str,
w: int = Query(256, ge=8, le=1024),
h: int = Query(256, ge=8, le=1024),
fit: str = Query("cover"),
):
full_path = '/' + full_path if not full_path.startswith('/') else full_path
if fit not in ("cover", "contain"):
raise HTTPException(400, detail="fit must be cover|contain")
adapter, mount, root, rel = await resolve_adapter_and_rel(full_path)
if not rel or rel.endswith('/'):
raise HTTPException(400, detail="Not a file")
if not (is_image_filename(rel) or is_video_filename(rel)):
raise HTTPException(404, detail="Not an image or video")
# type: ignore
data, mime, key = await get_or_create_thumb(adapter, mount.id, root, rel, w, h, fit)
headers = {
'Cache-Control': 'public, max-age=3600',
'ETag': key,
}
return Response(content=data, media_type=mime, headers=headers)
@router.get("/stream/{full_path:path}")
async def stream_endpoint(
full_path: str,
request: Request,
):
"""支持 Range 的视频/大文件流式读取,优先使用底层适配器 Range 能力。"""
full_path = '/' + full_path if not full_path.startswith('/') else full_path
range_header = request.headers.get('Range')
try:
return await stream_file(full_path, range_header)
except HTTPException:
raise
except FileNotFoundError:
raise HTTPException(404, detail="File not found")
except Exception as e:
raise HTTPException(500, detail=f"Stream error: {e}")
@router.get("/temp-link/{full_path:path}")
async def get_temp_link(
full_path: str,
current_user: Annotated[User, Depends(get_current_active_user)],
expires_in: int = Query(3600, description="有效时间(秒), 0或负数表示永久")
):
"""获取文件的临时公开访问令牌"""
full_path = '/' + full_path if not full_path.startswith('/') else full_path
token = await generate_temp_link_token(full_path, expires_in=expires_in)
file_domain = await ConfigCenter.get("FILE_DOMAIN")
if file_domain:
file_domain = file_domain.rstrip('/')
url = f"{file_domain}/api/fs/public/{token}"
else:
url = f"/api/fs/public/{token}"
return success({"token": token, "path": full_path, "url": url})
@router.get("/public/{token}")
async def access_public_file(
token: str,
request: Request,
):
"""通过令牌公开访问文件,支持 Range 请求"""
try:
path = await verify_temp_link_token(token)
except HTTPException as e:
raise e
range_header = request.headers.get('Range')
try:
return await stream_file(path, range_header)
except FileNotFoundError:
raise HTTPException(404, detail="File not found via token")
except Exception as e:
raise HTTPException(500, detail=f"File access error: {e}")
@router.get("/stat/{full_path:path}")
async def get_file_stat(
full_path: str,
current_user: Annotated[User, Depends(get_current_active_user)]
):
full_path = '/' + full_path if not full_path.startswith('/') else full_path
from services.virtual_fs import stat_file
stat = await stat_file(full_path)
return success(stat)
@router.post("/file/{full_path:path}")
async def put_file(
current_user: Annotated[User, Depends(get_current_active_user)],
full_path: str,
file: UploadFile = File(...)
):
full_path = '/' + full_path if not full_path.startswith('/') else full_path
data = await file.read()
await write_file(full_path, data)
return success({"written": True, "path": full_path, "size": len(data)})
@router.post("/mkdir")
async def api_mkdir(
current_user: Annotated[User, Depends(get_current_active_user)],
body: MkdirRequest
):
path = body.path if body.path.startswith('/') else '/' + body.path
if not path or path == '/':
raise HTTPException(400, detail="Invalid path")
await make_dir(path)
return success({"created": True, "path": path})
@router.post("/move")
async def api_move(
current_user: Annotated[User, Depends(get_current_active_user)],
body: MoveRequest,
overwrite: bool = Query(False, description="是否允许覆盖已存在目标"),
):
src = body.src if body.src.startswith('/') else '/' + body.src
dst = body.dst if body.dst.startswith('/') else '/' + body.dst
debug_info = await move_path(src, dst, overwrite=overwrite, return_debug=True, allow_cross=True)
queued = bool(debug_info.get("queued"))
response = {
"moved": not queued,
"queued": queued,
"src": src,
"dst": dst,
"overwrite": overwrite,
}
if queued:
response["task_id"] = debug_info.get("task_id")
response["task_name"] = debug_info.get("task_name")
return success(response)
@router.post("/rename")
async def api_rename(
current_user: Annotated[User, Depends(get_current_active_user)],
body: MoveRequest,
overwrite: bool = Query(False, description="是否允许覆盖已存在目标")
):
src = body.src if body.src.startswith('/') else '/' + body.src
dst = body.dst if body.dst.startswith('/') else '/' + body.dst
from services.virtual_fs import rename_path
await rename_path(src, dst, overwrite=overwrite, return_debug=False)
return success({
"renamed": True,
"src": src,
"dst": dst,
"overwrite": overwrite,
})
@router.post("/copy")
async def api_copy(
current_user: Annotated[User, Depends(get_current_active_user)],
body: MoveRequest,
overwrite: bool = Query(False, description="是否覆盖已存在目标"),
):
from services.virtual_fs import copy_path
src = body.src if body.src.startswith('/') else '/' + body.src
dst = body.dst if body.dst.startswith('/') else '/' + body.dst
debug_info = await copy_path(src, dst, overwrite=overwrite, return_debug=True, allow_cross=True)
queued = bool(debug_info.get("queued"))
response = {
"copied": not queued,
"queued": queued,
"src": src,
"dst": dst,
"overwrite": overwrite,
}
if queued:
response["task_id"] = debug_info.get("task_id")
response["task_name"] = debug_info.get("task_name")
return success(response)
@router.post("/upload/{full_path:path}")
async def upload_stream(
current_user: Annotated[User, Depends(get_current_active_user)],
full_path: str,
file: UploadFile = File(...),
overwrite: bool = Query(True, description="是否覆盖已存在文件"),
chunk_size: int = Query(1024 * 1024, ge=8 * 1024,
le=8 * 1024 * 1024, description="单次读取块大小")
):
full_path = '/' + full_path if not full_path.startswith('/') else full_path
if full_path.endswith('/'):
raise HTTPException(400, detail="Path must be a file")
from services.virtual_fs import write_file_stream, resolve_adapter_and_rel
adapter, _m, root, rel = await resolve_adapter_and_rel(full_path)
exists_func = getattr(adapter, "exists", None)
if not overwrite and callable(exists_func):
try:
if await exists_func(root, rel):
raise HTTPException(409, detail="Destination exists")
except HTTPException:
raise
except Exception:
pass
async def gen():
while True:
chunk = await file.read(chunk_size)
if not chunk:
break
yield chunk
size = await write_file_stream(full_path, gen(), overwrite=overwrite)
return success({"uploaded": True, "path": full_path, "size": size, "overwrite": overwrite})
@router.get("/{full_path:path}")
async def browse_fs(
current_user: Annotated[User, Depends(get_current_active_user)],
full_path: str,
page_num: int = Query(1, alias="page", ge=1, description="页码"),
page_size: int = Query(50, ge=1, le=500, description="每页条数"),
sort_by: str = Query("name", description="按字段排序: name, size, mtime"),
sort_order: str = Query("asc", description="排序顺序: asc, desc")
):
full_path = '/' + full_path if not full_path.startswith('/') else full_path
result = await list_virtual_dir(full_path, page_num, page_size, sort_by, sort_order)
return success({
"path": full_path,
"entries": result["items"],
"pagination": {
"total": result["total"],
"page": result["page"],
"page_size": result["page_size"],
"pages": result["pages"]
}
})
@router.delete("/{full_path:path}")
async def api_delete(
current_user: Annotated[User, Depends(get_current_active_user)],
full_path: str
):
full_path = '/' + full_path if not full_path.startswith('/') else full_path
await delete_path(full_path)
return success({"deleted": True, "path": full_path})
@router.get("/")
async def root_listing(
current_user: Annotated[User, Depends(get_current_active_user)],
page_num: int = Query(1, alias="page", ge=1, description="页码"),
page_size: int = Query(50, ge=1, le=500, description="每页条数"),
sort_by: str = Query("name", description="按字段排序: name, size, mtime"),
sort_order: str = Query("asc", description="排序顺序: asc, desc")
):
result = await list_virtual_dir("/", page_num, page_size, sort_by, sort_order)
return success({
"path": "/",
"entries": result["items"],
"pagination": {
"total": result["total"],
"page": result["page"],
"page_size": result["page_size"],
"pages": result["pages"]
}
})

View File

@@ -1,6 +1,6 @@
from tortoise import Tortoise
from services.adapters.registry import runtime_registry
from domain.adapters.registry import runtime_registry
TORTOISE_ORM = {
"connections": {"default": "sqlite://data/db/db.sqlite3"},

View File

@@ -0,0 +1 @@

85
domain/adapters/api.py Normal file
View File

@@ -0,0 +1,85 @@
from typing import Annotated
from fastapi import APIRouter, Depends, Request
from api.response import success
from domain.audit import AuditAction, audit
from domain.adapters.service import AdapterService
from domain.adapters.types import AdapterCreate
from domain.auth.service import get_current_active_user
from domain.auth.types import User
router = APIRouter(prefix="/api/adapters", tags=["adapters"])
@router.post("")
@audit(
action=AuditAction.CREATE,
description="创建存储适配器",
body_fields=["name", "type", "path", "sub_path", "enabled"],
)
async def create_adapter(
request: Request,
data: AdapterCreate,
current_user: Annotated[User, Depends(get_current_active_user)]
):
adapter = await AdapterService.create_adapter(data, current_user)
return success(adapter)
@router.get("")
@audit(action=AuditAction.READ, description="获取适配器列表")
async def list_adapters(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)]
):
adapters = await AdapterService.list_adapters()
return success(adapters)
@router.get("/available")
@audit(action=AuditAction.READ, description="获取可用适配器类型")
async def available_adapter_types(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)]
):
data = await AdapterService.available_adapter_types()
return success(data)
@router.get("/{adapter_id}")
@audit(action=AuditAction.READ, description="获取适配器详情")
async def get_adapter(
request: Request,
adapter_id: int,
current_user: Annotated[User, Depends(get_current_active_user)]
):
adapter = await AdapterService.get_adapter(adapter_id)
return success(adapter)
@router.put("/{adapter_id}")
@audit(
action=AuditAction.UPDATE,
description="更新存储适配器",
body_fields=["name", "type", "path", "sub_path", "enabled"],
)
async def update_adapter(
request: Request,
adapter_id: int,
data: AdapterCreate,
current_user: Annotated[User, Depends(get_current_active_user)]
):
adapter = await AdapterService.update_adapter(adapter_id, data, current_user)
return success(adapter)
@router.delete("/{adapter_id}")
@audit(action=AuditAction.DELETE, description="删除存储适配器")
async def delete_adapter(
request: Request,
adapter_id: int,
current_user: Annotated[User, Depends(get_current_active_user)]
):
result = await AdapterService.delete_adapter(adapter_id, current_user)
return success(result)

View File

@@ -0,0 +1,3 @@
from .base import BaseAdapter
__all__ = ["BaseAdapter"]

View File

@@ -10,7 +10,6 @@ from ftplib import FTP, error_perm
import mimetypes
from models import StorageAdapter
from services.logging import LogService
def _join_remote(root: str, rel: str) -> str:
@@ -240,11 +239,6 @@ class FTPAdapter:
pass
await asyncio.to_thread(_do_write)
await LogService.info(
"adapter:ftp",
f"Wrote file to {rel}",
details={"adapter_id": self.record.id, "path": path, "size": len(data)},
)
async def write_file_stream(self, root: str, rel: str, data_iter: AsyncIterator[bytes]):
# KISS: 聚合后一次性写入
@@ -276,7 +270,6 @@ class FTPAdapter:
pass
await asyncio.to_thread(_do_mkdir)
await LogService.info("adapter:ftp", f"Created directory {rel}", details={"adapter_id": self.record.id, "path": path})
async def delete(self, root: str, rel: str):
path = _join_remote(root, rel)
@@ -340,7 +333,6 @@ class FTPAdapter:
pass
await asyncio.to_thread(_do_delete)
await LogService.info("adapter:ftp", f"Deleted {rel}", details={"adapter_id": self.record.id, "path": path})
async def move(self, root: str, src_rel: str, dst_rel: str):
src = _join_remote(root, src_rel)
@@ -367,7 +359,6 @@ class FTPAdapter:
pass
await asyncio.to_thread(_do_move)
await LogService.info("adapter:ftp", f"Moved {src_rel} to {dst_rel}", details={"adapter_id": self.record.id, "src": src, "dst": dst})
async def rename(self, root: str, src_rel: str, dst_rel: str):
await self.move(root, src_rel, dst_rel)
@@ -402,10 +393,6 @@ class FTPAdapter:
child_src = f"{src_rel.rstrip('/')}/{ent['name']}"
child_dst = f"{dst_rel.rstrip('/')}/{ent['name']}"
await self.copy(root, child_src, child_dst, overwrite)
await LogService.info(
"adapter:ftp", f"Copied directory {src_rel} to {dst_rel}",
details={"adapter_id": self.record.id, "src": src, "dst": dst}
)
return
# file
@@ -418,7 +405,6 @@ class FTPAdapter:
except FileNotFoundError:
pass
await self.write_file(root, dst_rel, data)
await LogService.info("adapter:ftp", f"Copied {src_rel} to {dst_rel}", details={"adapter_id": self.record.id, "src": src, "dst": dst})
async def stat_file(self, root: str, rel: str):
path = _join_remote(root, rel)

View File

@@ -0,0 +1,560 @@
from __future__ import annotations
from datetime import datetime, timezone, timedelta
from typing import List, Dict, Tuple, AsyncIterator
import httpx
from fastapi.responses import StreamingResponse, Response
from fastapi import HTTPException
from models import StorageAdapter
GOOGLE_OAUTH_URL = "https://oauth2.googleapis.com/token"
GOOGLE_DRIVE_API_URL = "https://www.googleapis.com/drive/v3"
class GoogleDriveAdapter:
"""Google Drive 存储适配器"""
def __init__(self, record: StorageAdapter):
self.record = record
cfg = record.config
self.client_id = cfg.get("client_id")
self.client_secret = cfg.get("client_secret")
self.refresh_token = cfg.get("refresh_token")
self.root_folder_id = cfg.get("root_folder_id", "root")
self.enable_redirect_307 = bool(cfg.get("enable_direct_download_307"))
if not all([self.client_id, self.client_secret, self.refresh_token]):
raise ValueError(
"Google Drive 适配器需要 client_id, client_secret, 和 refresh_token")
self._access_token: str | None = None
self._token_expiry: datetime | None = None
def get_effective_root(self, sub_path: str | None) -> str:
"""
获取有效根路径。
:param sub_path: 子路径。
:return: 完整的有效路径。
"""
if sub_path:
return f"{sub_path.strip('/')}".strip()
return ""
async def _get_access_token(self) -> str:
"""
获取或刷新 access token。
:return: access token。
"""
if self._access_token and self._token_expiry and datetime.now(timezone.utc) < self._token_expiry:
return self._access_token
data = {
"client_id": self.client_id,
"client_secret": self.client_secret,
"refresh_token": self.refresh_token,
"grant_type": "refresh_token",
}
async with httpx.AsyncClient(timeout=20.0) as client:
resp = await client.post(GOOGLE_OAUTH_URL, data=data)
resp.raise_for_status()
token_data = resp.json()
self._access_token = token_data["access_token"]
self._token_expiry = datetime.now(
timezone.utc) + timedelta(seconds=token_data["expires_in"] - 300)
return self._access_token
async def _request(self, method: str, endpoint: str, **kwargs):
"""
向 Google Drive API 发送请求。
:param method: HTTP 方法。
:param endpoint: API 端点。
:param kwargs: 其他请求参数。
:return: 响应对象。
"""
token = await self._get_access_token()
headers = {"Authorization": f"Bearer {token}"}
if "headers" in kwargs:
headers.update(kwargs.pop("headers"))
url = f"{GOOGLE_DRIVE_API_URL}{endpoint}"
async with httpx.AsyncClient(timeout=60.0) as client:
resp = await client.request(method, url, headers=headers, **kwargs)
if resp.status_code == 401:
self._access_token = None
token = await self._get_access_token()
headers["Authorization"] = f"Bearer {token}"
resp = await client.request(method, url, headers=headers, **kwargs)
return resp
async def _get_folder_id_by_path(self, path: str) -> str:
"""
通过路径获取文件夹 ID。
:param path: 路径。
:return: 文件夹 ID。
"""
if not path or path == "/":
return self.root_folder_id
parts = [p for p in path.strip("/").split("/") if p]
current_id = self.root_folder_id
for part in parts:
query = f"name='{part}' and '{current_id}' in parents and mimeType='application/vnd.google-apps.folder' and trashed=false"
params = {"q": query, "fields": "files(id, name)"}
resp = await self._request("GET", "/files", params=params)
resp.raise_for_status()
data = resp.json()
files = data.get("files", [])
if not files:
raise FileNotFoundError(f"文件夹不存在: {part}")
current_id = files[0]["id"]
return current_id
async def _get_file_id_by_path(self, path: str) -> str | None:
"""
通过路径获取文件 ID。
:param path: 文件路径。
:return: 文件 ID 或 None。
"""
if not path or path == "/":
return self.root_folder_id
parts = [p for p in path.strip("/").split("/") if p]
parent_id = self.root_folder_id
for i, part in enumerate(parts):
is_last = i == len(parts) - 1
mime_filter = "" if is_last else "and mimeType='application/vnd.google-apps.folder'"
query = f"name='{part}' and '{parent_id}' in parents {mime_filter} and trashed=false"
params = {"q": query, "fields": "files(id, name)"}
resp = await self._request("GET", "/files", params=params)
resp.raise_for_status()
data = resp.json()
files = data.get("files", [])
if not files:
return None
parent_id = files[0]["id"]
return parent_id
def _format_item(self, item: Dict) -> Dict:
"""
将 Google Drive API 返回的 item 格式化为统一的格式。
:param item: Google Drive API 返回的 item 字典。
:return: 格式化后的字典。
"""
is_dir = item["mimeType"] == "application/vnd.google-apps.folder"
mtime_str = item.get("modifiedTime", item.get("createdTime", ""))
try:
mtime = int(datetime.fromisoformat(mtime_str.replace("Z", "+00:00")).timestamp())
except:
mtime = 0
return {
"name": item["name"],
"is_dir": is_dir,
"size": 0 if is_dir else int(item.get("size", 0)),
"mtime": mtime,
"type": "dir" if is_dir else "file",
}
async def list_dir(self, root: str, rel: str, page_num: int = 1, page_size: int = 50, sort_by: str = "name", sort_order: str = "asc") -> Tuple[List[Dict], int]:
"""
列出目录内容。
:param root: 根路径。
:param rel: 相对路径。
:param page_num: 页码。
:param page_size: 每页大小。
:param sort_by: 排序字段
:param sort_order: 排序顺序
:return: 文件/目录列表和总数。
"""
try:
folder_id = await self._get_folder_id_by_path(rel)
except FileNotFoundError:
return [], 0
query = f"'{folder_id}' in parents and trashed=false"
params = {
"q": query,
"fields": "files(id, name, mimeType, size, modifiedTime, createdTime)",
"pageSize": 1000,
}
all_items = []
page_token = None
while True:
if page_token:
params["pageToken"] = page_token
resp = await self._request("GET", "/files", params=params)
if resp.status_code == 404:
return [], 0
resp.raise_for_status()
data = resp.json()
all_items.extend(data.get("files", []))
page_token = data.get("nextPageToken")
if not page_token:
break
formatted_items = [self._format_item(item) for item in all_items]
# 排序
reverse = sort_order.lower() == "desc"
def get_sort_key(item):
key = (not item["is_dir"],)
sort_field = sort_by.lower()
if sort_field == "name":
key += (item["name"].lower(),)
elif sort_field == "size":
key += (item["size"],)
elif sort_field == "mtime":
key += (item["mtime"],)
else:
key += (item["name"].lower(),)
return key
formatted_items.sort(key=get_sort_key, reverse=reverse)
total_count = len(formatted_items)
start_idx = (page_num - 1) * page_size
end_idx = start_idx + page_size
return formatted_items[start_idx:end_idx], total_count
async def read_file(self, root: str, rel: str) -> bytes:
"""
读取文件内容。
:param root: 根路径。
:param rel: 相对路径。
:return: 文件内容的字节流。
"""
file_id = await self._get_file_id_by_path(rel)
if not file_id:
raise FileNotFoundError(rel)
resp = await self._request("GET", f"/files/{file_id}", params={"alt": "media"})
if resp.status_code == 404:
raise FileNotFoundError(rel)
resp.raise_for_status()
return resp.content
async def write_file(self, root: str, rel: str, data: bytes):
"""
写入文件。
:param root: 根路径。
:param rel: 相对路径。
:param data: 文件内容的字节流。
"""
parent_path = "/".join(rel.strip("/").split("/")[:-1])
file_name = rel.strip("/").split("/")[-1]
parent_id = await self._get_folder_id_by_path(parent_path)
# 检查文件是否已存在
existing_id = await self._get_file_id_by_path(rel)
if existing_id:
# 更新现有文件
async with httpx.AsyncClient(timeout=60.0) as client:
token = await self._get_access_token()
headers = {"Authorization": f"Bearer {token}"}
url = f"https://www.googleapis.com/upload/drive/v3/files/{existing_id}?uploadType=media"
resp = await client.patch(url, headers=headers, content=data)
resp.raise_for_status()
else:
# 创建新文件
metadata = {
"name": file_name,
"parents": [parent_id]
}
async with httpx.AsyncClient(timeout=60.0) as client:
token = await self._get_access_token()
headers = {"Authorization": f"Bearer {token}"}
# 使用 multipart 上传
import json
boundary = "===============boundary==============="
headers["Content-Type"] = f"multipart/related; boundary={boundary}"
body = (
f"--{boundary}\r\n"
f"Content-Type: application/json; charset=UTF-8\r\n\r\n"
f"{json.dumps(metadata)}\r\n"
f"--{boundary}\r\n"
f"Content-Type: application/octet-stream\r\n\r\n"
).encode() + data + f"\r\n--{boundary}--".encode()
url = "https://www.googleapis.com/upload/drive/v3/files?uploadType=multipart"
resp = await client.post(url, headers=headers, content=body)
resp.raise_for_status()
async def write_file_stream(self, root: str, rel: str, data_iter: AsyncIterator[bytes]):
"""
以流式方式写入文件。
:param root: 根路径。
:param rel: 相对路径。
:param data_iter: 文件内容的异步迭代器。
:return: 文件大小。
"""
# 先收集所有数据
chunks = []
total_size = 0
async for chunk in data_iter:
chunks.append(chunk)
total_size += len(chunk)
data = b"".join(chunks)
await self.write_file(root, rel, data)
return total_size
async def mkdir(self, root: str, rel: str):
"""
创建目录。
:param root: 根路径。
:param rel: 相对路径。
"""
parent_path = "/".join(rel.strip("/").split("/")[:-1])
folder_name = rel.strip("/").split("/")[-1]
parent_id = await self._get_folder_id_by_path(parent_path)
metadata = {
"name": folder_name,
"mimeType": "application/vnd.google-apps.folder",
"parents": [parent_id]
}
resp = await self._request("POST", "/files", json=metadata)
resp.raise_for_status()
async def delete(self, root: str, rel: str):
"""
删除文件或目录。
:param root: 根路径。
:param rel: 相对路径。
"""
file_id = await self._get_file_id_by_path(rel)
if not file_id:
return
resp = await self._request("DELETE", f"/files/{file_id}")
if resp.status_code not in (204, 404):
resp.raise_for_status()
async def move(self, root: str, src_rel: str, dst_rel: str):
"""
移动或重命名文件/目录。
:param root: 根路径。
:param src_rel: 源相对路径。
:param dst_rel: 目标相对路径。
"""
file_id = await self._get_file_id_by_path(src_rel)
if not file_id:
raise FileNotFoundError(src_rel)
# 获取当前父文件夹
resp = await self._request("GET", f"/files/{file_id}", params={"fields": "parents"})
resp.raise_for_status()
current_parents = resp.json().get("parents", [])
# 获取目标父文件夹和新名称
dst_parent_path = "/".join(dst_rel.strip("/").split("/")[:-1])
dst_name = dst_rel.strip("/").split("/")[-1]
dst_parent_id = await self._get_folder_id_by_path(dst_parent_path)
# 更新文件
params = {
"addParents": dst_parent_id,
"removeParents": ",".join(current_parents) if current_parents else None,
}
metadata = {"name": dst_name}
resp = await self._request("PATCH", f"/files/{file_id}", params=params, json=metadata)
resp.raise_for_status()
async def rename(self, root: str, src_rel: str, dst_rel: str):
"""
重命名文件或目录。
"""
await self.move(root, src_rel, dst_rel)
async def copy(self, root: str, src_rel: str, dst_rel: str, overwrite: bool = False):
"""
复制文件或目录。
:param root: 根路径。
:param src_rel: 源相对路径。
:param dst_rel: 目标相对路径。
:param overwrite: 是否覆盖。
"""
file_id = await self._get_file_id_by_path(src_rel)
if not file_id:
raise FileNotFoundError(src_rel)
dst_parent_path = "/".join(dst_rel.strip("/").split("/")[:-1])
dst_name = dst_rel.strip("/").split("/")[-1]
dst_parent_id = await self._get_folder_id_by_path(dst_parent_path)
metadata = {
"name": dst_name,
"parents": [dst_parent_id]
}
resp = await self._request("POST", f"/files/{file_id}/copy", json=metadata)
resp.raise_for_status()
async def stream_file(self, root: str, rel: str, range_header: str | None):
"""
流式传输文件(支持范围请求)。
:param root: 根路径。
:param rel: 相对路径。
:param range_header: HTTP Range 头。
:return: FastAPI StreamingResponse 对象。
"""
file_id = await self._get_file_id_by_path(rel)
if not file_id:
raise FileNotFoundError(rel)
# 获取文件元数据
resp = await self._request("GET", f"/files/{file_id}", params={"fields": "name, size, mimeType"})
if resp.status_code == 404:
raise FileNotFoundError(rel)
resp.raise_for_status()
item_data = resp.json()
file_size = int(item_data.get("size", 0))
content_type = item_data.get("mimeType", "application/octet-stream")
start = 0
end = file_size - 1
status = 200
headers = {
"Accept-Ranges": "bytes",
"Content-Type": content_type,
"Content-Disposition": f"inline; filename=\"{item_data.get('name')}\""
}
if range_header and range_header.startswith("bytes="):
try:
part = range_header.removeprefix("bytes=")
s, e = part.split("-", 1)
if s.strip():
start = int(s)
if e.strip():
end = int(e)
if start >= file_size:
raise HTTPException(416, "Requested Range Not Satisfiable")
if end >= file_size:
end = file_size - 1
status = 206
except ValueError:
raise HTTPException(400, "Invalid Range header")
headers["Content-Range"] = f"bytes {start}-{end}/{file_size}"
headers["Content-Length"] = str(end - start + 1)
else:
headers["Content-Length"] = str(file_size)
async def file_iterator():
nonlocal start, end
token = await self._get_access_token()
async with httpx.AsyncClient(timeout=60.0) as client:
req_headers = {
'Authorization': f'Bearer {token}',
'Range': f'bytes={start}-{end}'
}
url = f"{GOOGLE_DRIVE_API_URL}/files/{file_id}?alt=media"
async with client.stream("GET", url, headers=req_headers) as stream_resp:
stream_resp.raise_for_status()
async for chunk in stream_resp.aiter_bytes():
yield chunk
return StreamingResponse(file_iterator(), status_code=status, headers=headers, media_type=content_type)
async def stat_file(self, root: str, rel: str):
"""
获取文件或目录的元数据。
:param root: 根路径。
:param rel: 相对路径。
:return: 格式化后的文件/目录信息。
"""
file_id = await self._get_file_id_by_path(rel)
if not file_id:
raise FileNotFoundError(rel)
resp = await self._request("GET", f"/files/{file_id}", params={"fields": "id, name, mimeType, size, modifiedTime, createdTime"})
if resp.status_code == 404:
raise FileNotFoundError(rel)
resp.raise_for_status()
return self._format_item(resp.json())
async def get_direct_download_response(self, root: str, rel: str):
"""
获取直接下载响应 (307 重定向)。
:param root: 根路径。
:param rel: 相对路径。
:return: 307 重定向响应或 None。
"""
if not self.enable_redirect_307:
return None
file_id = await self._get_file_id_by_path(rel)
if not file_id:
raise FileNotFoundError(rel)
# 获取文件的下载链接
resp = await self._request("GET", f"/files/{file_id}", params={"fields": "webContentLink"})
if resp.status_code == 404:
raise FileNotFoundError(rel)
resp.raise_for_status()
item_data = resp.json()
download_url = item_data.get("webContentLink")
if not download_url:
return None
return Response(status_code=307, headers={"Location": download_url})
async def get_thumbnail(self, root: str, rel: str, size: str = "medium"):
"""
获取文件的缩略图。
:param root: 根路径。
:param rel: 相对路径。
:param size: 缩略图大小 (暂未使用Google Drive 自动决定)。
:return: 缩略图内容的字节流,或在不支持时返回 None。
"""
file_id = await self._get_file_id_by_path(rel)
if not file_id:
return None
try:
resp = await self._request("GET", f"/files/{file_id}", params={"fields": "thumbnailLink"})
if resp.status_code == 200:
item_data = resp.json()
thumbnail_link = item_data.get("thumbnailLink")
if thumbnail_link:
async with httpx.AsyncClient(timeout=30.0) as client:
thumb_resp = await client.get(thumbnail_link)
thumb_resp.raise_for_status()
return thumb_resp.content
return None
except Exception:
return None
ADAPTER_TYPE = "googledrive"
CONFIG_SCHEMA = [
{"key": "client_id", "label": "Client ID", "type": "string", "required": True},
{"key": "client_secret", "label": "Client Secret",
"type": "password", "required": True},
{"key": "refresh_token", "label": "Refresh Token", "type": "password",
"required": True, "help_text": "可以通过 Google OAuth 2.0 Playground 获取"},
{"key": "root_folder_id", "label": "根文件夹 ID (Root Folder ID)", "type": "string",
"required": False, "placeholder": "默认为根目录 (root)", "default": "root"},
{"key": "enable_direct_download_307", "label": "Enable 307 redirect download", "type": "boolean", "default": False},
]
def ADAPTER_FACTORY(rec): return GoogleDriveAdapter(rec)

View File

@@ -10,7 +10,6 @@ import mimetypes
from fastapi import HTTPException
from fastapi.responses import StreamingResponse, Response
from models import StorageAdapter
from services.logging import LogService
def _safe_join(root: str, rel: str) -> Path:
@@ -115,11 +114,6 @@ class LocalAdapter:
await asyncio.to_thread(fp.write_bytes, data)
if not pre_exists:
await asyncio.to_thread(_apply_mode, fp, DEFAULT_FILE_MODE)
await LogService.info(
"adapter:local",
f"Wrote file to {rel}",
details={"adapter_id": self.record.id, "path": str(fp), "size": len(data)},
)
async def write_file_stream(self, root: str, rel: str, data_iter: AsyncIterator[bytes]):
fp = _safe_join(root, rel)
@@ -140,21 +134,11 @@ class LocalAdapter:
await asyncio.to_thread(f.close)
if not pre_exists:
await asyncio.to_thread(_apply_mode, fp, DEFAULT_FILE_MODE)
await LogService.info(
"adapter:local",
f"Wrote file stream to {rel}",
details={"adapter_id": self.record.id, "path": str(fp), "size": size},
)
return size
async def mkdir(self, root: str, rel: str):
fp = _safe_join(root, rel)
await asyncio.to_thread(os.makedirs, fp, mode=DEFAULT_DIR_MODE, exist_ok=True)
await LogService.info(
"adapter:local",
f"Created directory {rel}",
details={"adapter_id": self.record.id, "path": str(fp)},
)
async def delete(self, root: str, rel: str):
fp = _safe_join(root, rel)
@@ -164,11 +148,6 @@ class LocalAdapter:
await asyncio.to_thread(shutil.rmtree, fp)
else:
await asyncio.to_thread(fp.unlink)
await LogService.info(
"adapter:local",
f"Deleted {rel}",
details={"adapter_id": self.record.id, "path": str(fp)},
)
async def stat_path(self, root: str, rel: str):
"""新增: 返回路径状态调试信息"""
@@ -203,15 +182,6 @@ class LocalAdapter:
except OSError:
shutil.move(str(src), str(dst))
await asyncio.to_thread(_do_move)
await LogService.info(
"adapter:local",
f"Moved {src_rel} to {dst_rel}",
details={
"adapter_id": self.record.id,
"src": str(src),
"dst": str(dst),
},
)
async def rename(self, root: str, src_rel: str, dst_rel: str):
src = _safe_join(root, src_rel)
@@ -227,15 +197,6 @@ class LocalAdapter:
except OSError:
os.replace(src, dst)
await asyncio.to_thread(_do_rename)
await LogService.info(
"adapter:local",
f"Renamed {src_rel} to {dst_rel}",
details={
"adapter_id": self.record.id,
"src": str(src),
"dst": str(dst),
},
)
async def copy(self, root: str, src_rel: str, dst_rel: str, overwrite: bool = False):
src = _safe_join(root, src_rel)
@@ -258,15 +219,6 @@ class LocalAdapter:
else:
shutil.copy2(src, dst)
await asyncio.to_thread(_do)
await LogService.info(
"adapter:local",
f"Copied {src_rel} to {dst_rel}",
details={
"adapter_id": self.record.id,
"src": str(src),
"dst": str(dst),
},
)
async def stream_file(self, root: str, rel: str, range_header: str | None):
fp = _safe_join(root, rel)

View File

@@ -445,14 +445,14 @@ class OneDriveAdapter:
return self._format_item(resp.json())
ADAPTER_TYPE = "OneDrive"
ADAPTER_TYPE = "onedrive"
CONFIG_SCHEMA = [
{"key": "client_id", "label": "Client ID", "type": "string", "required": True},
{"key": "client_secret", "label": "Client Secret",
"type": "password", "required": True},
{"key": "refresh_token", "label": "Refresh Token", "type": "password",
"required": True, "help_text": "可以通过运行 'python -m services.adapters.onedrive' 获取"},
"required": True, "help_text": "可以通过运行 'python -m domain.adapters.providers.onedrive' 获取"},
{"key": "root", "label": "根目录 (Root Path)", "type": "string",
"required": False, "placeholder": "默认为根目录 /"},
{"key": "enable_direct_download_307", "label": "Enable 307 redirect download", "type": "boolean", "default": False},

View File

@@ -718,7 +718,7 @@ class QuarkAdapter:
return it["fid"]
ADAPTER_TYPE = "Quark"
ADAPTER_TYPE = "quark"
CONFIG_SCHEMA = [
{"key": "cookie", "label": "Cookie", "type": "password", "required": True, "placeholder": "从 pan.quark.cn 复制"},

View File

@@ -10,7 +10,6 @@ from botocore.exceptions import ClientError
from fastapi import HTTPException
from fastapi.responses import StreamingResponse
from models import StorageAdapter
from services.logging import LogService
class S3Adapter:
@@ -127,11 +126,6 @@ class S3Adapter:
key = self._get_s3_key(rel)
async with self._get_client() as s3:
await s3.put_object(Bucket=self.bucket_name, Key=key, Body=data)
await LogService.info(
"adapter:s3", f"Wrote file to {rel}",
details={"adapter_id": self.record.id,
"bucket": self.bucket_name, "key": key, "size": len(data)}
)
async def write_file_stream(self, root: str, rel: str, data_iter: AsyncIterator[bytes]):
key = self._get_s3_key(rel)
@@ -193,10 +187,6 @@ class S3Adapter:
)
raise IOError(f"S3 stream upload failed: {e}") from e
await LogService.info(
"adapter:s3", f"Wrote file stream to {rel}",
details={"adapter_id": self.record.id, "bucket": self.bucket_name, "key": key, "size": total_size}
)
return total_size
async def mkdir(self, root: str, rel: str):
@@ -205,11 +195,6 @@ class S3Adapter:
key += "/"
async with self._get_client() as s3:
await s3.put_object(Bucket=self.bucket_name, Key=key, Body=b"")
await LogService.info(
"adapter:s3", f"Created directory {rel}",
details={"adapter_id": self.record.id,
"bucket": self.bucket_name, "key": key}
)
async def delete(self, root: str, rel: str):
key = self._get_s3_key(rel)
@@ -237,20 +222,9 @@ class S3Adapter:
else:
await s3.delete_object(Bucket=self.bucket_name, Key=key)
await LogService.info(
"adapter:s3", f"Deleted {rel}",
details={"adapter_id": self.record.id,
"bucket": self.bucket_name, "key": key}
)
async def move(self, root: str, src_rel: str, dst_rel: str):
await self.copy(root, src_rel, dst_rel, overwrite=True)
await self.delete(root, src_rel)
await LogService.info(
"adapter:s3", f"Moved {src_rel} to {dst_rel}",
details={"adapter_id": self.record.id, "bucket": self.bucket_name,
"src_key": self._get_s3_key(src_rel), "dst_key": self._get_s3_key(dst_rel)}
)
async def rename(self, root: str, src_rel: str, dst_rel: str):
await self.move(root, src_rel, dst_rel)
@@ -270,11 +244,6 @@ class S3Adapter:
copy_source = {"Bucket": self.bucket_name, "Key": src_key}
await s3.copy_object(CopySource=copy_source, Bucket=self.bucket_name, Key=dst_key)
await LogService.info(
"adapter:s3", f"Copied {src_rel} to {dst_rel}",
details={"adapter_id": self.record.id, "bucket": self.bucket_name,
"src_key": src_key, "dst_key": dst_key}
)
async def stat_file(self, root: str, rel: str):
key = self._get_s3_key(rel)
@@ -353,13 +322,12 @@ class S3Adapter:
while chunk := await body.read(65536):
yield chunk
except Exception as e:
LogService.error(
"adapter:s3", f"Error streaming file {key}: {e}")
raise
return StreamingResponse(iterator(), status_code=status, headers=headers, media_type=content_type)
ADAPTER_TYPE = "S3"
ADAPTER_TYPE = "s3"
CONFIG_SCHEMA = [
{"key": "bucket_name", "label": "Bucket 名称",

View File

@@ -10,7 +10,6 @@ from fastapi.responses import StreamingResponse
import paramiko
from models import StorageAdapter
from services.logging import LogService
def _join_remote(root: str, rel: str) -> str:
@@ -159,7 +158,6 @@ class SFTPAdapter:
pass
await asyncio.to_thread(_do_write)
await LogService.info("adapter:sftp", f"Wrote file to {rel}", details={"adapter_id": self.record.id, "path": path, "size": len(data)})
async def write_file_stream(self, root: str, rel: str, data_iter: AsyncIterator[bytes]):
buf = bytearray()
@@ -190,7 +188,6 @@ class SFTPAdapter:
pass
await asyncio.to_thread(_do_mkdir)
await LogService.info("adapter:sftp", f"Created directory {rel}", details={"adapter_id": self.record.id, "path": path})
async def delete(self, root: str, rel: str):
path = _join_remote(root, rel)
@@ -228,7 +225,6 @@ class SFTPAdapter:
pass
await asyncio.to_thread(_do_delete)
await LogService.info("adapter:sftp", f"Deleted {rel}", details={"adapter_id": self.record.id, "path": path})
async def move(self, root: str, src_rel: str, dst_rel: str):
src = _join_remote(root, src_rel)
@@ -255,7 +251,6 @@ class SFTPAdapter:
pass
await asyncio.to_thread(_do_move)
await LogService.info("adapter:sftp", f"Moved {src_rel} to {dst_rel}", details={"adapter_id": self.record.id, "src": src, "dst": dst})
async def rename(self, root: str, src_rel: str, dst_rel: str):
await self.move(root, src_rel, dst_rel)
@@ -283,7 +278,6 @@ class SFTPAdapter:
child_src = f"{src_rel.rstrip('/')}/{ent['name']}"
child_dst = f"{dst_rel.rstrip('/')}/{ent['name']}"
await self.copy(root, child_src, child_dst, overwrite)
await LogService.info("adapter:sftp", f"Copied directory {src_rel} to {dst_rel}", details={"adapter_id": self.record.id, "src": src, "dst": dst})
return
# file copy
@@ -295,7 +289,6 @@ class SFTPAdapter:
except FileNotFoundError:
pass
await self.write_file(root, dst_rel, data)
await LogService.info("adapter:sftp", f"Copied {src_rel} to {dst_rel}", details={"adapter_id": self.record.id, "src": src, "dst": dst})
async def stat_file(self, root: str, rel: str):
path = _join_remote(root, rel)

View File

@@ -8,7 +8,7 @@ from telethon.sessions import StringSession
import socks
# 适配器类型标识
ADAPTER_TYPE = "Telegram"
ADAPTER_TYPE = "telegram"
# 适配器配置项定义
CONFIG_SCHEMA = [

View File

@@ -9,7 +9,6 @@ import mimetypes
import logging
from fastapi import HTTPException
from fastapi.responses import StreamingResponse, Response
from services.logging import LogService
NS = {"d": "DAV:"}
@@ -148,15 +147,6 @@ class WebDAVAdapter:
async with self._client() as client:
resp = await client.put(url, content=data)
resp.raise_for_status()
await LogService.info(
"adapter:webdav",
f"Wrote file to {rel}",
details={
"adapter_id": self.record.id,
"url": url,
"size": len(data),
},
)
async def mkdir(self, root: str, rel: str):
url = self._build_url(rel.rstrip('/') + '/')
@@ -164,11 +154,6 @@ class WebDAVAdapter:
resp = await client.request("MKCOL", url)
if resp.status_code not in (201, 405):
resp.raise_for_status()
await LogService.info(
"adapter:webdav",
f"Created directory {rel}",
details={"adapter_id": self.record.id, "url": url},
)
async def delete(self, root: str, rel: str):
url = self._build_url(rel)
@@ -176,11 +161,6 @@ class WebDAVAdapter:
resp = await client.delete(url)
if resp.status_code not in (204, 200, 404):
resp.raise_for_status()
await LogService.info(
"adapter:webdav",
f"Deleted {rel}",
details={"adapter_id": self.record.id, "url": url},
)
async def move(self, root: str, src_rel: str, dst_rel: str):
src_url = self._build_url(src_rel)
@@ -188,15 +168,6 @@ class WebDAVAdapter:
async with self._client() as client:
resp = await client.request("MOVE", src_url, headers={"Destination": dst_url})
resp.raise_for_status()
await LogService.info(
"adapter:webdav",
f"Moved {src_rel} to {dst_rel}",
details={
"adapter_id": self.record.id,
"src_url": src_url,
"dst_url": dst_url,
},
)
async def rename(self, root: str, src_rel: str, dst_rel: str):
src_url = self._build_url(src_rel)
@@ -204,15 +175,6 @@ class WebDAVAdapter:
async with self._client() as client:
resp = await client.request("MOVE", src_url, headers={"Destination": dst_url})
resp.raise_for_status()
await LogService.info(
"adapter:webdav",
f"Renamed {src_rel} to {dst_rel}",
details={
"adapter_id": self.record.id,
"src_url": src_url,
"dst_url": dst_url,
},
)
async def get_file_size(self, root: str, rel: str) -> int:
"""获取文件大小"""
@@ -455,8 +417,16 @@ class WebDAVAdapter:
info["type"] = "dir" if is_dir else "file"
if size_el is not None and size_el.text and size_el.text.isdigit():
info["size"] = int(size_el.text)
elif info["size"] is None:
info["size"] = 0
if lm_el is not None and lm_el.text:
info["mtime"] = lm_el.text
from email.utils import parsedate_to_datetime
try:
info["mtime"] = int(parsedate_to_datetime(lm_el.text).timestamp())
except Exception:
info["mtime"] = 0
elif info["mtime"] is None:
info["mtime"] = 0
# exif信息
exif = None
if not info["is_dir"]:
@@ -510,15 +480,6 @@ class WebDAVAdapter:
if resp.status_code == 404:
raise FileNotFoundError(src_rel)
resp.raise_for_status()
await LogService.info(
"adapter:webdav",
f"Copied {src_rel} to {dst_rel}",
details={
"adapter_id": self.record.id,
"src_url": src_url,
"dst_url": dst_url,
},
)
ADAPTER_TYPE = "webdav"
CONFIG_SCHEMA = [

View File

@@ -1,20 +1,28 @@
from typing import Dict, Callable
import pkgutil
import inspect
import pkgutil
from importlib import import_module
from typing import Callable, Dict
from .base import BaseAdapter
from models import StorageAdapter
from domain.adapters.providers.base import BaseAdapter
AdapterFactory = Callable[[StorageAdapter], object]
AdapterFactory = Callable[[StorageAdapter], BaseAdapter]
TYPE_MAP: Dict[str, AdapterFactory] = {}
CONFIG_SCHEMAS: Dict[str, list] = {}
def normalize_adapter_type(value: str | None) -> str | None:
if value is None:
return None
normalized = str(value).strip().lower()
return normalized or None
def discover_adapters():
"""扫描 services.adapters 包, 自动注册适配器类型、工厂与配置 schema。"""
from .. import adapters as adapters_pkg
"""扫描 domain.adapters.providers 包, 自动注册适配器类型、工厂与配置 schema。"""
from domain.adapters import providers as adapters_pkg
TYPE_MAP.clear()
CONFIG_SCHEMAS.clear()
for modinfo in pkgutil.iter_modules(adapters_pkg.__path__):
@@ -25,7 +33,7 @@ def discover_adapters():
module = import_module(full_name)
except Exception:
continue
adapter_type = getattr(module, "ADAPTER_TYPE", None)
adapter_type = normalize_adapter_type(getattr(module, "ADAPTER_TYPE", None))
schema = getattr(module, "CONFIG_SCHEMA", None)
factory = getattr(module, "ADAPTER_FACTORY", None)
@@ -57,22 +65,31 @@ def get_config_schema(adapter_type: str):
class RuntimeRegistry:
def __init__(self):
self._instances: Dict[int, object] = {}
self._instances: Dict[int, BaseAdapter] = {}
async def refresh(self):
discover_adapters()
self._instances.clear()
adapters = await StorageAdapter.filter(enabled=True)
for rec in adapters:
factory = TYPE_MAP.get(rec.type)
normalized_type = normalize_adapter_type(rec.type)
if not normalized_type:
continue
if normalized_type != rec.type:
rec.type = normalized_type
try:
await rec.save(update_fields=["type"])
except Exception:
continue
factory = TYPE_MAP.get(normalized_type)
if not factory:
continue
try:
self._instances[rec.id] = factory(rec)
except Exception:
continue
continue
def get(self, adapter_id: int):
def get(self, adapter_id: int) -> BaseAdapter | None:
return self._instances.get(adapter_id)
def snapshot(self) -> Dict[int, BaseAdapter]:
@@ -88,11 +105,22 @@ class RuntimeRegistry:
if not rec.enabled:
self.remove(rec.id)
return
factory = TYPE_MAP.get(rec.type)
normalized_type = normalize_adapter_type(rec.type)
if not normalized_type:
self.remove(rec.id)
return
if normalized_type != rec.type:
rec.type = normalized_type
try:
await rec.save(update_fields=["type"])
except Exception:
pass
factory = TYPE_MAP.get(normalized_type)
if not factory:
discover_adapters()
factory = TYPE_MAP.get(rec.type)
factory = TYPE_MAP.get(normalized_type)
if not factory:
return

111
domain/adapters/service.py Normal file
View File

@@ -0,0 +1,111 @@
from typing import Optional
from fastapi import HTTPException
from domain.adapters.registry import (
get_config_schemas,
normalize_adapter_type,
runtime_registry,
)
from domain.adapters.types import AdapterCreate, AdapterOut
from domain.auth.types import User
from models import StorageAdapter
class AdapterService:
@classmethod
def _validate_and_normalize_config(cls, adapter_type: str, cfg):
schemas = get_config_schemas()
adapter_type = normalize_adapter_type(adapter_type)
if not adapter_type:
raise HTTPException(400, detail="不支持的适配器类型")
if not isinstance(cfg, dict):
raise HTTPException(400, detail="config 必须是对象")
schema = schemas.get(adapter_type)
if not schema:
raise HTTPException(400, detail=f"不支持的适配器类型: {adapter_type}")
out = {}
missing = []
for f in schema:
k = f["key"]
if k in cfg and cfg[k] not in (None, ""):
out[k] = cfg[k]
elif "default" in f:
out[k] = f["default"]
elif f.get("required"):
missing.append(k)
if missing:
raise HTTPException(400, detail="缺少必填配置字段: " + ", ".join(missing))
return out
@classmethod
async def create_adapter(cls, data: AdapterCreate, current_user: Optional[User]):
norm_path = AdapterCreate.normalize_mount_path(data.path)
exists = await StorageAdapter.get_or_none(path=norm_path)
if exists:
raise HTTPException(400, detail="Mount path already exists")
adapter_fields = {
"name": data.name,
"type": data.type,
"config": cls._validate_and_normalize_config(data.type, data.config or {}),
"enabled": data.enabled,
"path": norm_path,
"sub_path": data.sub_path,
}
rec = await StorageAdapter.create(**adapter_fields)
await runtime_registry.upsert(rec)
return AdapterOut.model_validate(rec)
@classmethod
async def list_adapters(cls):
adapters = await StorageAdapter.all()
return [AdapterOut.model_validate(a) for a in adapters]
@classmethod
async def available_adapter_types(cls):
data = []
for adapter_type, fields in get_config_schemas().items():
data.append({
"type": adapter_type,
"config_schema": fields,
})
return data
@classmethod
async def get_adapter(cls, adapter_id: int):
rec = await StorageAdapter.get_or_none(id=adapter_id)
if not rec:
raise HTTPException(404, detail="Not found")
return AdapterOut.model_validate(rec)
@classmethod
async def update_adapter(cls, adapter_id: int, data: AdapterCreate, current_user: Optional[User]):
rec = await StorageAdapter.get_or_none(id=adapter_id)
if not rec:
raise HTTPException(404, detail="Not found")
norm_path = AdapterCreate.normalize_mount_path(data.path)
existing = await StorageAdapter.get_or_none(path=norm_path)
if existing and existing.id != adapter_id:
raise HTTPException(400, detail="Mount path already exists")
rec.name = data.name
rec.type = data.type
rec.config = cls._validate_and_normalize_config(data.type, data.config or {})
rec.enabled = data.enabled
rec.path = norm_path
rec.sub_path = data.sub_path
await rec.save()
await runtime_registry.upsert(rec)
return AdapterOut.model_validate(rec)
@classmethod
async def delete_adapter(cls, adapter_id: int, current_user: Optional[User]):
deleted = await StorageAdapter.filter(id=adapter_id).delete()
if not deleted:
raise HTTPException(404, detail="Not found")
runtime_registry.remove(adapter_id)
return {"deleted": True}

View File

@@ -1,15 +1,29 @@
import re
from typing import Dict, Optional
from pydantic import BaseModel, Field, field_validator
class AdapterBase(BaseModel):
name: str
type: str = Field(pattern=r"^[a-zA-Z0-9_]+$")
type: str = Field(pattern=r"^[a-z0-9_]+$")
config: Dict = Field(default_factory=dict)
enabled: bool = True
path: str = None
sub_path: Optional[str] = None
@field_validator("type", mode="before")
@classmethod
def _normalize_type(cls, v: str):
if not isinstance(v, str):
raise ValueError("type required")
normalized = v.strip().lower()
if not normalized:
raise ValueError("type required")
if not re.fullmatch(r"[a-z0-9_]+", normalized):
raise ValueError("type must be lowercase alphanumeric or underscore")
return normalized
class AdapterCreate(AdapterBase):
@staticmethod

34
domain/ai/__init__.py Normal file
View File

@@ -0,0 +1,34 @@
from .api import router_ai, router_vector_db
from .service import (
AIProviderService,
VectorDBConfigManager,
VectorDBService,
DEFAULT_VECTOR_DIMENSION,
ABILITIES,
normalize_capabilities,
)
from .types import (
AIDefaultsUpdate,
AIModelCreate,
AIModelUpdate,
AIProviderCreate,
AIProviderUpdate,
VectorDBConfigPayload,
)
__all__ = [
"router_ai",
"router_vector_db",
"AIProviderService",
"VectorDBService",
"VectorDBConfigManager",
"DEFAULT_VECTOR_DIMENSION",
"ABILITIES",
"normalize_capabilities",
"AIDefaultsUpdate",
"AIModelCreate",
"AIModelUpdate",
"AIProviderCreate",
"AIProviderUpdate",
"VectorDBConfigPayload",
]

305
domain/ai/api.py Normal file
View File

@@ -0,0 +1,305 @@
from typing import Annotated, Dict, Optional
import httpx
from fastapi import APIRouter, Depends, HTTPException, Path, Request
from api.response import success
from domain.audit import AuditAction, audit
from domain.ai.service import AIProviderService, VectorDBConfigManager, VectorDBService
from domain.ai.types import (
AIDefaultsUpdate,
AIModelCreate,
AIModelUpdate,
AIProviderCreate,
AIProviderUpdate,
VectorDBConfigPayload,
)
from domain.ai.vector_providers import get_provider_class, get_provider_entry, list_providers
from domain.auth.service import get_current_active_user
from domain.auth.types import User
router_ai = APIRouter(prefix="/api/ai", tags=["ai"])
router_vector_db = APIRouter(prefix="/api/vector-db", tags=["vector-db"])
@audit(action=AuditAction.READ, description="获取 AI 提供商列表")
@router_ai.get("/providers")
async def list_providers_endpoint(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)]
):
providers = await AIProviderService.list_providers()
return success({"providers": providers})
@audit(
action=AuditAction.CREATE,
description="创建 AI 提供商",
body_fields=["name", "identifier", "provider_type", "api_format", "base_url", "logo_url"],
redact_fields=["api_key"],
)
@router_ai.post("/providers")
async def create_provider(
request: Request,
payload: AIProviderCreate,
current_user: Annotated[User, Depends(get_current_active_user)]
):
provider = await AIProviderService.create_provider(payload.dict())
return success(provider)
@audit(action=AuditAction.READ, description="获取 AI 提供商详情")
@router_ai.get("/providers/{provider_id}")
async def get_provider(
request: Request,
provider_id: Annotated[int, Path(..., gt=0)],
current_user: Annotated[User, Depends(get_current_active_user)],
):
provider = await AIProviderService.get_provider(provider_id, with_models=True)
return success(provider)
@audit(
action=AuditAction.UPDATE,
description="更新 AI 提供商",
body_fields=["name", "provider_type", "api_format", "base_url", "logo_url", "api_key"],
redact_fields=["api_key"],
)
@router_ai.put("/providers/{provider_id}")
async def update_provider(
request: Request,
provider_id: Annotated[int, Path(..., gt=0)],
payload: AIProviderUpdate,
current_user: Annotated[User, Depends(get_current_active_user)],
):
data = {k: v for k, v in payload.dict().items() if v is not None}
if not data:
raise HTTPException(status_code=400, detail="No fields to update")
provider = await AIProviderService.update_provider(provider_id, data)
return success(provider)
@audit(action=AuditAction.DELETE, description="删除 AI 提供商")
@router_ai.delete("/providers/{provider_id}")
async def delete_provider(
request: Request,
provider_id: Annotated[int, Path(..., gt=0)],
current_user: Annotated[User, Depends(get_current_active_user)],
):
await AIProviderService.delete_provider(provider_id)
return success({"id": provider_id})
@audit(action=AuditAction.UPDATE, description="同步模型列表")
@router_ai.post("/providers/{provider_id}/sync-models")
async def sync_models(
request: Request,
provider_id: Annotated[int, Path(..., gt=0)],
current_user: Annotated[User, Depends(get_current_active_user)],
):
try:
result = await AIProviderService.sync_models(provider_id)
except (httpx.RequestError, httpx.HTTPStatusError) as exc:
raise HTTPException(status_code=502, detail=f"Failed to synchronize models: {exc}") from exc
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
return success(result)
@audit(action=AuditAction.READ, description="获取远程模型列表")
@router_ai.get("/providers/{provider_id}/remote-models")
async def fetch_remote_models(
request: Request,
provider_id: Annotated[int, Path(..., gt=0)],
current_user: Annotated[User, Depends(get_current_active_user)],
):
try:
models = await AIProviderService.fetch_remote_models(provider_id)
except (httpx.RequestError, httpx.HTTPStatusError) as exc:
raise HTTPException(status_code=502, detail=f"Failed to pull models: {exc}") from exc
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
return success({"models": models})
@audit(action=AuditAction.READ, description="获取模型列表")
@router_ai.get("/providers/{provider_id}/models")
async def list_models(
request: Request,
provider_id: Annotated[int, Path(..., gt=0)],
current_user: Annotated[User, Depends(get_current_active_user)],
):
models = await AIProviderService.list_models(provider_id)
return success({"models": models})
@audit(
action=AuditAction.CREATE,
description="创建模型",
body_fields=["name", "display_name", "capabilities", "context_window", "embedding_dimensions"],
)
@router_ai.post("/providers/{provider_id}/models")
async def create_model(
request: Request,
provider_id: Annotated[int, Path(..., gt=0)],
payload: AIModelCreate,
current_user: Annotated[User, Depends(get_current_active_user)],
):
model = await AIProviderService.create_model(provider_id, payload.dict())
return success(model)
@audit(
action=AuditAction.UPDATE,
description="更新模型",
body_fields=["display_name", "description", "capabilities", "context_window", "embedding_dimensions"],
)
@router_ai.put("/models/{model_id}")
async def update_model(
request: Request,
model_id: Annotated[int, Path(..., gt=0)],
payload: AIModelUpdate,
current_user: Annotated[User, Depends(get_current_active_user)],
):
data = {k: v for k, v in payload.dict().items() if v is not None}
if not data:
raise HTTPException(status_code=400, detail="No fields to update")
model = await AIProviderService.update_model(model_id, data)
return success(model)
@audit(action=AuditAction.DELETE, description="删除模型")
@router_ai.delete("/models/{model_id}")
async def delete_model(
request: Request,
model_id: Annotated[int, Path(..., gt=0)],
current_user: Annotated[User, Depends(get_current_active_user)],
):
await AIProviderService.delete_model(model_id)
return success({"id": model_id})
def _get_embedding_dimension(entry: Optional[Dict]) -> Optional[int]:
if not entry:
return None
value = entry.get("embedding_dimensions")
return int(value) if value is not None else None
@audit(action=AuditAction.READ, description="获取默认模型")
@router_ai.get("/defaults")
async def get_defaults(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
):
defaults = await AIProviderService.get_default_models()
return success(defaults)
@audit(
action=AuditAction.UPDATE,
description="更新默认模型",
body_fields=["chat", "vision", "embedding", "rerank", "voice", "tools"],
)
@router_ai.put("/defaults")
async def update_defaults(
request: Request,
payload: AIDefaultsUpdate,
current_user: Annotated[User, Depends(get_current_active_user)],
):
previous = await AIProviderService.get_default_models()
try:
updated = await AIProviderService.set_default_models(payload.as_mapping())
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
prev_dim = _get_embedding_dimension(previous.get("embedding"))
next_dim = _get_embedding_dimension(updated.get("embedding"))
if prev_dim and next_dim and prev_dim != next_dim:
try:
await VectorDBService().clear_all_data()
except Exception as exc: # noqa: BLE001
raise HTTPException(status_code=500, detail=f"Failed to clear vector database: {exc}") from exc
return success(updated)
@audit(action=AuditAction.UPDATE, description="清空向量数据库")
@router_vector_db.post("/clear-all", summary="清空向量数据库")
async def clear_vector_db(request: Request, user: User = Depends(get_current_active_user)):
try:
service = VectorDBService()
await service.clear_all_data()
return success(msg="向量数据库已清空")
except Exception as e: # noqa: BLE001
raise HTTPException(status_code=500, detail=str(e))
@audit(action=AuditAction.READ, description="获取向量数据库统计")
@router_vector_db.get("/stats", summary="获取向量数据库统计")
async def get_vector_db_stats(request: Request, user: User = Depends(get_current_active_user)):
try:
service = VectorDBService()
data = await service.get_all_stats()
return success(data=data)
except Exception as e: # noqa: BLE001
raise HTTPException(status_code=500, detail=str(e))
@audit(action=AuditAction.READ, description="获取向量数据库提供者列表")
@router_vector_db.get("/providers", summary="列出可用向量数据库提供者")
async def list_vector_providers(request: Request, user: User = Depends(get_current_active_user)):
return success(list_providers())
@audit(action=AuditAction.READ, description="获取向量数据库配置")
@router_vector_db.get("/config", summary="获取当前向量数据库配置")
async def get_vector_db_config(request: Request, user: User = Depends(get_current_active_user)):
service = VectorDBService()
data = await service.current_provider()
return success(data)
@audit(action=AuditAction.UPDATE, description="更新向量数据库配置", body_fields=["type"])
@router_vector_db.post("/config", summary="更新向量数据库配置")
async def update_vector_db_config(
request: Request, payload: VectorDBConfigPayload, user: User = Depends(get_current_active_user)
):
entry = get_provider_entry(payload.type)
if not entry:
raise HTTPException(
status_code=400, detail=f"未知的向量数据库类型: {payload.type}")
if not entry.get("enabled", True):
raise HTTPException(status_code=400, detail="该向量数据库类型暂不可用")
provider_cls = get_provider_class(payload.type)
if not provider_cls:
raise HTTPException(
status_code=400, detail=f"未找到类型 {payload.type} 对应的实现")
test_provider = provider_cls(payload.config)
try:
await test_provider.initialize()
except Exception as exc:
raise HTTPException(status_code=400, detail=str(exc))
finally:
client = getattr(test_provider, "client", None)
close_fn = getattr(client, "close", None)
if callable(close_fn):
try:
close_fn()
except Exception:
pass
await VectorDBConfigManager.save_config(payload.type, payload.config)
service = VectorDBService()
await service.reload()
config_data = await service.current_provider()
stats = await service.get_all_stats()
return success({"config": config_data, "stats": stats})
__all__ = ["router_ai", "router_vector_db"]

View File

@@ -4,10 +4,10 @@ import httpx
from typing import List, Sequence, Tuple
from models.database import AIModel, AIProvider
from services.ai_providers import AIProviderService
from domain.ai.service import AIProviderService
provider_service = AIProviderService()
provider_service = AIProviderService
class MissingModelError(RuntimeError):

View File

@@ -1,5 +1,7 @@
from __future__ import annotations
import asyncio
import json
from collections.abc import Iterable
from typing import Any, Dict, List, Optional, Tuple
@@ -7,10 +9,18 @@ import httpx
from tortoise.exceptions import DoesNotExist
from tortoise.transactions import in_transaction
from domain.config.service import ConfigService
from models.database import AIDefaultModel, AIModel, AIProvider
from .types import ABILITIES, normalize_capabilities
from .vector_providers import (
BaseVectorProvider,
get_provider_class,
get_provider_entry,
list_providers,
)
ABILITIES = ["chat", "vision", "embedding", "rerank", "voice", "tools"]
DEFAULT_VECTOR_DIMENSION = 4096
OPENAI_EMBEDDING_DIMS = {
"text-embedding-3-large": 3072,
@@ -19,6 +29,43 @@ OPENAI_EMBEDDING_DIMS = {
}
class VectorDBConfigManager:
TYPE_KEY = "VECTOR_DB_TYPE"
CONFIG_KEY = "VECTOR_DB_CONFIG"
DEFAULT_TYPE = "milvus_lite"
@classmethod
async def load_config(cls) -> Tuple[str, Dict[str, Any]]:
raw_type = await ConfigService.get(cls.TYPE_KEY, cls.DEFAULT_TYPE)
provider_type = str(raw_type or cls.DEFAULT_TYPE)
raw_config = await ConfigService.get(cls.CONFIG_KEY)
config_dict: Dict[str, Any] = {}
if isinstance(raw_config, str) and raw_config:
try:
config_dict = json.loads(raw_config)
except json.JSONDecodeError:
config_dict = {}
elif isinstance(raw_config, dict):
config_dict = raw_config
return provider_type, config_dict
@classmethod
async def save_config(cls, provider_type: str, config: Dict[str, Any]) -> None:
await ConfigService.set(cls.TYPE_KEY, provider_type)
await ConfigService.set(cls.CONFIG_KEY, json.dumps(config or {}))
@classmethod
async def get_type(cls) -> str:
provider_type, _ = await cls.load_config()
return provider_type
@classmethod
async def get_config(cls) -> Dict[str, Any]:
_, config = await cls.load_config()
return config
def _normalize_embedding_dim(value: Any) -> Optional[int]:
if value is None:
return None
@@ -47,17 +94,6 @@ def _apply_embedding_dim_to_metadata(
return data
def normalize_capabilities(items: Optional[Iterable[str]]) -> List[str]:
if not items:
return []
normalized = []
for cap in items:
key = str(cap).strip().lower()
if key in ABILITIES and key not in normalized:
normalized.append(key)
return normalized
def infer_openai_capabilities(model_id: str) -> Tuple[List[str], Optional[int]]:
lower = model_id.lower()
caps = set()
@@ -139,40 +175,46 @@ def provider_to_dict(provider: AIProvider, models: Optional[List[AIModel]] = Non
class AIProviderService:
async def list_providers(self) -> List[Dict[str, Any]]:
@classmethod
async def list_providers(cls) -> List[Dict[str, Any]]:
providers = await AIProvider.all().order_by("id").prefetch_related("models")
return [provider_to_dict(p, models=list(p.models)) for p in providers]
async def get_provider(self, provider_id: int, with_models: bool = False) -> Dict[str, Any]:
@classmethod
async def get_provider(cls, provider_id: int, with_models: bool = False) -> Dict[str, Any]:
if with_models:
provider = await AIProvider.get(id=provider_id)
models = await provider.models.all()
return provider_to_dict(provider, models=models)
else:
provider = await AIProvider.get(id=provider_id)
return provider_to_dict(provider)
provider = await AIProvider.get(id=provider_id)
return provider_to_dict(provider)
async def create_provider(self, payload: Dict[str, Any]) -> Dict[str, Any]:
@classmethod
async def create_provider(cls, payload: Dict[str, Any]) -> Dict[str, Any]:
data = payload.copy()
data.setdefault("extra_config", {})
provider = await AIProvider.create(**data)
return provider_to_dict(provider)
async def update_provider(self, provider_id: int, payload: Dict[str, Any]) -> Dict[str, Any]:
@classmethod
async def update_provider(cls, provider_id: int, payload: Dict[str, Any]) -> Dict[str, Any]:
provider = await AIProvider.get(id=provider_id)
for field, value in payload.items():
setattr(provider, field, value)
await provider.save()
return provider_to_dict(provider)
async def delete_provider(self, provider_id: int) -> None:
@classmethod
async def delete_provider(cls, provider_id: int) -> None:
await AIProvider.filter(id=provider_id).delete()
async def list_models(self, provider_id: int) -> List[Dict[str, Any]]:
@classmethod
async def list_models(cls, provider_id: int) -> List[Dict[str, Any]]:
models = await AIModel.filter(provider_id=provider_id).order_by("id").prefetch_related("provider")
return [model_to_dict(m) for m in models]
async def create_model(self, provider_id: int, payload: Dict[str, Any]) -> Dict[str, Any]:
@classmethod
async def create_model(cls, provider_id: int, payload: Dict[str, Any]) -> Dict[str, Any]:
data = payload.copy()
data["provider_id"] = provider_id
data["capabilities"] = normalize_capabilities(data.get("capabilities"))
@@ -182,7 +224,8 @@ class AIProviderService:
await model.fetch_related("provider")
return model_to_dict(model)
async def update_model(self, model_id: int, payload: Dict[str, Any]) -> Dict[str, Any]:
@classmethod
async def update_model(cls, model_id: int, payload: Dict[str, Any]) -> Dict[str, Any]:
model = await AIModel.get(id=model_id)
data = payload.copy()
if "capabilities" in data:
@@ -199,14 +242,17 @@ class AIProviderService:
await model.fetch_related("provider")
return model_to_dict(model)
async def delete_model(self, model_id: int) -> None:
@classmethod
async def delete_model(cls, model_id: int) -> None:
await AIModel.filter(id=model_id).delete()
async def fetch_remote_models(self, provider_id: int) -> List[Dict[str, Any]]:
@classmethod
async def fetch_remote_models(cls, provider_id: int) -> List[Dict[str, Any]]:
provider = await AIProvider.get(id=provider_id)
return await self._get_remote_models(provider)
return await cls._get_remote_models(provider)
async def _get_remote_models(self, provider: AIProvider) -> List[Dict[str, Any]]:
@classmethod
async def _get_remote_models(cls, provider: AIProvider) -> List[Dict[str, Any]]:
if not provider.base_url:
raise ValueError("Provider base_url is required for syncing models")
@@ -215,12 +261,13 @@ class AIProviderService:
raise ValueError(f"Unsupported api_format '{provider.api_format}' for syncing models")
if fmt == "openai":
return await self._fetch_openai_models(provider)
return await self._fetch_gemini_models(provider)
return await cls._fetch_openai_models(provider)
return await cls._fetch_gemini_models(provider)
async def sync_models(self, provider_id: int) -> Dict[str, int]:
@classmethod
async def sync_models(cls, provider_id: int) -> Dict[str, int]:
provider = await AIProvider.get(id=provider_id)
remote_models = await self._get_remote_models(provider)
remote_models = await cls._get_remote_models(provider)
created = 0
updated = 0
@@ -247,14 +294,16 @@ class AIProviderService:
return {"created": created, "updated": updated}
async def get_default_models(self) -> Dict[str, Optional[Dict[str, Any]]]:
@classmethod
async def get_default_models(cls) -> Dict[str, Optional[Dict[str, Any]]]:
defaults = await AIDefaultModel.all().prefetch_related("model__provider")
result: Dict[str, Optional[Dict[str, Any]]] = {ability: None for ability in ABILITIES}
for item in defaults:
result[item.ability] = model_to_dict(item.model, provider=item.model.provider) # type: ignore[attr-defined]
return result
async def set_default_models(self, mapping: Dict[str, Optional[int]]) -> Dict[str, Optional[Dict[str, Any]]]:
@classmethod
async def set_default_models(cls, mapping: Dict[str, Optional[int]]) -> Dict[str, Optional[Dict[str, Any]]]:
normalized = {ability: mapping.get(ability) for ability in ABILITIES}
async with in_transaction() as connection:
for ability, model_id in normalized.items():
@@ -271,9 +320,10 @@ class AIProviderService:
await AIDefaultModel.create(ability=ability, model_id=model_id)
elif record:
await record.delete(using_db=connection)
return await self.get_default_models()
return await cls.get_default_models()
async def get_default_model(self, ability: str) -> Optional[AIModel]:
@classmethod
async def get_default_model(cls, ability: str) -> Optional[AIModel]:
ability_key = ability.lower()
if ability_key not in ABILITIES:
return None
@@ -285,7 +335,8 @@ class AIProviderService:
await model.fetch_related("provider")
return model
async def _fetch_openai_models(self, provider: AIProvider) -> List[Dict[str, Any]]:
@classmethod
async def _fetch_openai_models(cls, provider: AIProvider) -> List[Dict[str, Any]]:
base_url = provider.base_url.rstrip("/")
url = f"{base_url}/models"
headers = {}
@@ -315,7 +366,8 @@ class AIProviderService:
})
return entries
async def _fetch_gemini_models(self, provider: AIProvider) -> List[Dict[str, Any]]:
@classmethod
async def _fetch_gemini_models(cls, provider: AIProvider) -> List[Dict[str, Any]]:
base_url = provider.base_url.rstrip("/")
suffix = "/models"
if provider.api_key:
@@ -345,3 +397,105 @@ class AIProviderService:
"metadata": item,
})
return entries
class VectorDBService:
_instance: "VectorDBService" | None = None
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if not hasattr(self, "_provider"):
self._provider: Optional[BaseVectorProvider] = None
self._provider_type: Optional[str] = None
self._provider_config: Dict[str, Any] | None = None
self._lock = asyncio.Lock()
async def _ensure_provider(self) -> BaseVectorProvider:
if self._provider is None:
await self.reload()
assert self._provider is not None
return self._provider
async def reload(self) -> BaseVectorProvider:
async with self._lock:
provider_type, provider_config = await VectorDBConfigManager.load_config()
normalized_config = dict(provider_config or {})
if (
self._provider
and self._provider_type == provider_type
and self._provider_config == normalized_config
):
return self._provider
entry = get_provider_entry(provider_type)
if not entry:
raise RuntimeError(f"Unknown vector database provider: {provider_type}")
if not entry.get("enabled", True):
raise RuntimeError(f"Vector database provider '{provider_type}' is disabled")
provider_cls = get_provider_class(provider_type)
if not provider_cls:
raise RuntimeError(f"Provider class not found for '{provider_type}'")
provider = provider_cls(provider_config)
await provider.initialize()
self._provider = provider
self._provider_type = provider_type
self._provider_config = normalized_config
return provider
async def ensure_collection(self, collection_name: str, vector: bool = True, dim: int = DEFAULT_VECTOR_DIMENSION) -> None:
provider = await self._ensure_provider()
provider.ensure_collection(collection_name, vector, dim)
async def upsert_vector(self, collection_name: str, data: Dict[str, Any]) -> None:
provider = await self._ensure_provider()
provider.upsert_vector(collection_name, data)
async def delete_vector(self, collection_name: str, path: str) -> None:
provider = await self._ensure_provider()
provider.delete_vector(collection_name, path)
async def search_vectors(self, collection_name: str, query_embedding, top_k: int = 5):
provider = await self._ensure_provider()
return provider.search_vectors(collection_name, query_embedding, top_k)
async def search_by_path(self, collection_name: str, query_path: str, top_k: int = 20):
provider = await self._ensure_provider()
return provider.search_by_path(collection_name, query_path, top_k)
async def get_all_stats(self) -> Dict[str, Any]:
provider = await self._ensure_provider()
return provider.get_all_stats()
async def clear_all_data(self) -> None:
provider = await self._ensure_provider()
provider.clear_all_data()
async def current_provider(self) -> Dict[str, Any]:
provider_type, provider_config = await VectorDBConfigManager.load_config()
entry = get_provider_entry(provider_type) or {}
return {
"type": provider_type,
"config": provider_config,
"label": entry.get("label"),
"enabled": entry.get("enabled", True),
}
__all__ = [
"AIProviderService",
"VectorDBService",
"VectorDBConfigManager",
"DEFAULT_VECTOR_DIMENSION",
"list_providers",
"get_provider_entry",
"get_provider_class",
"normalize_capabilities",
"ABILITIES",
]

View File

@@ -1,8 +1,19 @@
from typing import List, Optional
from typing import Any, Dict, Iterable, List, Optional
from pydantic import BaseModel, Field, field_validator
from services.ai_providers import ABILITIES, normalize_capabilities
ABILITIES = ["chat", "vision", "embedding", "rerank", "voice", "tools"]
def normalize_capabilities(items: Optional[Iterable[str]]) -> List[str]:
if not items:
return []
normalized: List[str] = []
for cap in items:
key = str(cap).strip().lower()
if key in ABILITIES and key not in normalized:
normalized.append(key)
return normalized
class AIProviderBase(BaseModel):
@@ -16,6 +27,7 @@ class AIProviderBase(BaseModel):
extra_config: Optional[dict] = None
@field_validator("api_format")
@classmethod
def normalize_format(cls, value: str) -> str:
fmt = value.lower()
if fmt not in {"openai", "gemini"}:
@@ -37,6 +49,7 @@ class AIProviderUpdate(BaseModel):
extra_config: Optional[dict] = None
@field_validator("api_format")
@classmethod
def normalize_format(cls, value: Optional[str]) -> Optional[str]:
if value is None:
return value
@@ -56,6 +69,7 @@ class AIModelBase(BaseModel):
metadata: Optional[dict] = None
@field_validator("capabilities")
@classmethod
def validate_capabilities(cls, items: Optional[List[str]]) -> Optional[List[str]]:
if items is None:
return None
@@ -79,6 +93,7 @@ class AIModelUpdate(BaseModel):
metadata: Optional[dict] = None
@field_validator("capabilities")
@classmethod
def validate_capabilities(cls, items: Optional[List[str]]) -> Optional[List[str]]:
if items is None:
return None
@@ -97,5 +112,10 @@ class AIDefaultsUpdate(BaseModel):
voice: Optional[int] = None
tools: Optional[int] = None
def as_mapping(self) -> dict:
def as_mapping(self) -> Dict[str, Optional[int]]:
return {ability: getattr(self, ability) for ability in ABILITIES}
class VectorDBConfigPayload(BaseModel):
type: str = Field(..., description="向量数据库提供者类型")
config: Dict[str, Any] = Field(default_factory=dict, description="提供者配置参数")

View File

@@ -54,3 +54,14 @@ def get_provider_class(provider_type: str) -> Type[BaseVectorProvider] | None:
if not entry:
return None
return entry.get("class") # type: ignore[return-value]
__all__ = [
"BaseVectorProvider",
"MilvusLiteProvider",
"MilvusServerProvider",
"QdrantProvider",
"list_providers",
"get_provider_entry",
"get_provider_class",
]

View File

@@ -155,7 +155,7 @@ class MilvusLiteProvider(BaseVectorProvider):
def search_by_path(self, collection_name: str, query_path: str, top_k: int):
if query_path:
escaped = query_path.replace('"', '\\"')
filter_expr = f'source_path like "%{escaped}%"'
filter_expr = f'source_path like \"%{escaped}%\"'
else:
filter_expr = "source_path like '%%'"
results = self._get_client().query(
@@ -232,7 +232,7 @@ class MilvusLiteProvider(BaseVectorProvider):
for index_name in index_names:
try:
detail = client.describe_index(name, index_name) or {}
detail = client.describe_index(name) or {}
except Exception:
detail = {}
indexes.append(

View File

@@ -162,7 +162,7 @@ class MilvusServerProvider(BaseVectorProvider):
def search_by_path(self, collection_name: str, query_path: str, top_k: int):
if query_path:
escaped = query_path.replace('"', '\\"')
filter_expr = f'source_path like "%{escaped}%"'
filter_expr = f'source_path like \"%{escaped}%\"'
else:
filter_expr = "source_path like '%%'"
results = self._get_client().query(
@@ -239,7 +239,7 @@ class MilvusServerProvider(BaseVectorProvider):
for index_name in index_names:
try:
detail = client.describe_index(name, index_name) or {}
detail = client.describe_index(name) or {}
except Exception:
detail = {}
indexes.append(

View File

@@ -42,7 +42,6 @@ class QdrantProvider(BaseVectorProvider):
api_key = (self.config.get("api_key") or None) or None
try:
client = QdrantClient(url=url, api_key=api_key)
# 简单连通性校验
client.get_collections()
self.client = client
except Exception as exc: # pragma: no cover - 依赖外部服务
@@ -70,7 +69,6 @@ class QdrantProvider(BaseVectorProvider):
message = str(exc).lower()
if "already exists" in message or "index exists" in message:
continue
# 旧版本 qdrant 可能返回带状态码的异常,这里容忍重复创建
raise
def ensure_collection(self, collection_name: str, vector: bool, dim: int) -> None:

5
domain/audit/__init__.py Normal file
View File

@@ -0,0 +1,5 @@
from domain.audit.decorator import audit
from domain.audit.types import AuditAction
from domain.audit.api import router
__all__ = ["audit", "AuditAction", "router"]

68
domain/audit/api.py Normal file
View File

@@ -0,0 +1,68 @@
from datetime import datetime, timezone
from typing import Annotated, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from api import response
from domain.audit.service import AuditService
from domain.audit.types import AuditAction
from domain.auth.service import get_current_active_user
from domain.auth.types import User
CurrentUser = Annotated[User, Depends(get_current_active_user)]
router = APIRouter(prefix="/api/audit", tags=["Audit"])
def _parse_iso(value: Optional[str], field: str):
if not value:
return None
try:
normalized = value.replace("Z", "+00:00")
dt = datetime.fromisoformat(normalized)
if dt.tzinfo:
dt = dt.astimezone(timezone.utc).replace(tzinfo=None)
return dt
except ValueError as exc: # noqa: BLE001
raise HTTPException(status_code=400, detail=f"invalid {field}") from exc
@router.get("/logs")
async def list_audit_logs(
current_user: CurrentUser,
page_num: int = Query(1, ge=1, alias="page", description="页码"),
page_size: int = Query(20, ge=1, le=200, description="每页条数"),
action: AuditAction | None = Query(None, description="操作类型"),
success: bool | None = Query(None, description="是否成功"),
username: str | None = Query(None, description="用户名模糊匹配"),
path: str | None = Query(None, description="路径模糊匹配"),
start_time: str | None = Query(None, description="开始时间 (ISO 8601)"),
end_time: str | None = Query(None, description="结束时间 (ISO 8601)"),
):
start_dt = _parse_iso(start_time, "start_time")
end_dt = _parse_iso(end_time, "end_time")
items, total = await AuditService.list_logs(
page=page_num,
page_size=page_size,
action=str(action) if action else None,
success=success,
username=username,
path=path,
start_time=start_dt,
end_time=end_dt,
)
return response.success(response.page(items, total, page_num, page_size))
@router.delete("/logs")
async def clear_audit_logs(
current_user: CurrentUser,
start_time: str | None = Query(None, description="开始时间 (ISO 8601)"),
end_time: str | None = Query(None, description="结束时间 (ISO 8601)"),
):
start_dt = _parse_iso(start_time, "start_time")
end_dt = _parse_iso(end_time, "end_time")
if start_dt is None and end_dt is None:
raise HTTPException(status_code=400, detail="start_time 或 end_time 至少提供一个")
deleted_count = await AuditService.clear_logs(start_time=start_dt, end_time=end_dt)
return response.success({"deleted_count": deleted_count})

182
domain/audit/decorator.py Normal file
View File

@@ -0,0 +1,182 @@
import inspect
import time
from functools import wraps
from typing import Any, Dict, Mapping, Optional
import jwt
from fastapi import Request
from jwt.exceptions import InvalidTokenError
from domain.audit.service import AuditService
from domain.audit.types import AuditAction
from domain.auth.service import ALGORITHM
from domain.config.service import ConfigService
from models.database import UserAccount
def _extract_request(bound_args: Mapping[str, Any]) -> Request | None:
for value in bound_args.values():
if isinstance(value, Request):
return value
return None
async def _resolve_user(request: Request | None, user_obj: Any | None) -> tuple[Optional[int], Optional[str]]:
user_id: int | None = None
username: str | None = None
if request:
auth_header = request.headers.get("authorization") or request.headers.get("Authorization")
if auth_header and auth_header.lower().startswith("bearer "):
token = auth_header.split(" ", 1)[1]
try:
payload = jwt.decode(token, await ConfigService.get_secret_key("SECRET_KEY"), algorithms=[ALGORITHM])
username = payload.get("sub") or payload.get("username")
if username:
user = await UserAccount.get_or_none(username=username)
user_id = user.id if user else None
except (InvalidTokenError, Exception):
pass
if user_id is None and username is None and user_obj is not None:
user_id = getattr(user_obj, "id", None) or getattr(user_obj, "user_id", None)
username = getattr(user_obj, "username", None) or getattr(user_obj, "name", None)
if isinstance(user_obj, dict):
user_id = user_obj.get("id", user_obj.get("user_id", user_id))
username = user_obj.get("username", user_obj.get("name", username))
return user_id, username
def _extract_body_fields(bound_args: Mapping[str, Any], body_fields: list[str] | None, redact_fields: list[str] | None):
if not body_fields:
return None
body: Dict[str, Any] = {}
redacts = set(redact_fields or [])
for value in bound_args.values():
data: Optional[Dict[str, Any]] = None
if hasattr(value, "model_dump"):
try:
data = value.model_dump()
except Exception:
data = None
elif hasattr(value, "dict"):
try:
data = value.dict()
except Exception:
data = None
elif isinstance(value, dict):
data = value
elif hasattr(value, "__dict__"):
data = dict(value.__dict__)
if not isinstance(data, dict):
continue
for field in body_fields:
if field in data and field not in body:
body[field] = data[field]
if not body:
return None
for field in redacts:
if field in body:
body[field] = "<redacted>"
return body
def _build_request_params(request: Request | None) -> Dict[str, Any] | None:
if not request:
return None
params: Dict[str, Any] = {}
query = dict(request.query_params)
if query:
params["query"] = query
path_params = dict(request.path_params or {})
if path_params:
params["path"] = path_params
return params or None
def _status_code_from_response(response: Any) -> int:
if hasattr(response, "status_code"):
try:
return int(getattr(response, "status_code"))
except Exception:
pass
return 200
def audit(
*,
action: AuditAction,
description: str | None = None,
body_fields: list[str] | None = None,
redact_fields: list[str] | None = None,
user_kw: str = "current_user",
):
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
bound = inspect.signature(func).bind_partial(*args, **kwargs)
bound.apply_defaults()
request = _extract_request(bound.arguments)
start = time.perf_counter()
user_info = bound.arguments.get(user_kw)
user_id, username = await _resolve_user(request, user_info)
request_params = _build_request_params(request)
request_body = _extract_body_fields(bound.arguments, body_fields, redact_fields)
try:
result = func(*args, **kwargs)
if inspect.isawaitable(result):
result = await result
status_code = _status_code_from_response(result)
success = True
error = None
except Exception as exc: # noqa: BLE001
status_code = getattr(exc, "status_code", 500)
success = False
error = str(exc)
duration_ms = round((time.perf_counter() - start) * 1000, 2)
try:
await AuditService.log(
action=action,
description=description,
user_id=user_id,
username=username,
client_ip=request.client.host if request and request.client else None,
method=request.method if request else "",
path=request.url.path if request else func.__name__,
status_code=status_code,
duration_ms=duration_ms,
success=success,
request_params=request_params,
request_body=request_body,
error=error,
)
except Exception:
pass
raise
duration_ms = round((time.perf_counter() - start) * 1000, 2)
try:
await AuditService.log(
action=action,
description=description,
user_id=user_id,
username=username,
client_ip=request.client.host if request and request.client else None,
method=request.method if request else "",
path=request.url.path if request else func.__name__,
status_code=status_code,
duration_ms=duration_ms,
success=success,
request_params=request_params,
request_body=request_body,
error=error,
)
except Exception:
pass
return result
return wrapper
return decorator

124
domain/audit/service.py Normal file
View File

@@ -0,0 +1,124 @@
from typing import Any, Dict, Optional
from models.database import AuditLog
from domain.audit.types import AuditAction
class AuditService:
@classmethod
async def log(
cls,
*,
action: AuditAction | str,
description: Optional[str],
user_id: Optional[int],
username: Optional[str],
client_ip: Optional[str],
method: str,
path: str,
status_code: int,
duration_ms: Optional[float],
success: bool,
request_params: Optional[Dict[str, Any]] = None,
request_body: Optional[Dict[str, Any]] = None,
error: Optional[str] = None,
) -> None:
await AuditLog.create(
action=str(action),
description=description,
user_id=user_id,
username=username,
client_ip=client_ip,
method=method,
path=path,
status_code=status_code,
duration_ms=duration_ms,
success=success,
request_params=request_params,
request_body=request_body,
error=error,
)
@classmethod
def _serialize(cls, log: AuditLog) -> Dict[str, Any]:
return {
"id": log.id,
"created_at": log.created_at.isoformat() if log.created_at else None,
"action": log.action,
"description": log.description,
"user_id": log.user_id,
"username": log.username,
"client_ip": log.client_ip,
"method": log.method,
"path": log.path,
"status_code": log.status_code,
"duration_ms": log.duration_ms,
"success": log.success,
"request_params": log.request_params,
"request_body": log.request_body,
"error": log.error,
}
@classmethod
def _apply_filters(
cls,
*,
action: str | None = None,
success: bool | None = None,
username: str | None = None,
path: str | None = None,
start_time=None,
end_time=None,
):
qs = AuditLog.all()
if action:
qs = qs.filter(action=action)
if success is not None:
qs = qs.filter(success=success)
if username:
qs = qs.filter(username__icontains=username)
if path:
qs = qs.filter(path__icontains=path)
if start_time:
qs = qs.filter(created_at__gte=start_time)
if end_time:
qs = qs.filter(created_at__lte=end_time)
return qs
@classmethod
async def list_logs(
cls,
*,
page: int,
page_size: int,
action: str | None = None,
success: bool | None = None,
username: str | None = None,
path: str | None = None,
start_time=None,
end_time=None,
):
qs = cls._apply_filters(
action=action,
success=success,
username=username,
path=path,
start_time=start_time,
end_time=end_time,
)
total = await qs.count()
offset = (page - 1) * page_size
items = await qs.order_by("-created_at").offset(offset).limit(page_size)
return [cls._serialize(log) for log in items], total
@classmethod
async def clear_logs(
cls,
*,
start_time=None,
end_time=None,
) -> int:
qs = cls._apply_filters(start_time=start_time, end_time=end_time)
deleted_count = await qs.delete()
return deleted_count

16
domain/audit/types.py Normal file
View File

@@ -0,0 +1,16 @@
from enum import StrEnum
class AuditAction(StrEnum):
LOGIN = "login"
LOGOUT = "logout"
REGISTER = "register"
READ = "read"
CREATE = "create"
UPDATE = "update"
DELETE = "delete"
RESET_PASSWORD = "reset_password"
SHARE = "share"
DOWNLOAD = "download"
UPLOAD = "upload"
OTHER = "other"

90
domain/auth/api.py Normal file
View File

@@ -0,0 +1,90 @@
from typing import Annotated
from fastapi import APIRouter, Depends, Request
from fastapi.security import OAuth2PasswordRequestForm
from api.response import success
from domain.audit import AuditAction, audit
from domain.auth.service import AuthService, get_current_active_user
from domain.auth.types import (
PasswordResetConfirm,
PasswordResetRequest,
RegisterRequest,
Token,
UpdateMeRequest,
User,
)
router = APIRouter(prefix="/api/auth", tags=["auth"])
@router.post("/register", summary="注册第一个管理员用户")
@audit(
action=AuditAction.REGISTER,
description="注册管理员",
body_fields=["username", "email", "full_name"],
redact_fields=["password"],
)
async def register(request: Request, data: RegisterRequest):
user = await AuthService.register_user(data)
return success({"username": user.username}, msg="初始用户注册成功")
@router.post("/login")
@audit(action=AuditAction.LOGIN, description="用户登录", body_fields=["username"], redact_fields=["password"])
async def login_for_access_token(
request: Request,
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
) -> Token:
return await AuthService.login(form_data)
@router.get("/me", summary="获取当前登录用户信息")
@audit(action=AuditAction.READ, description="获取当前用户信息")
async def get_me(
request: Request, current_user: Annotated[User, Depends(get_current_active_user)]
):
profile = AuthService.get_profile(current_user)
return success(profile)
@router.put("/me", summary="更新当前登录用户信息")
@audit(
action=AuditAction.UPDATE,
description="更新当前用户信息",
body_fields=["email", "full_name"],
redact_fields=["old_password", "new_password"],
)
async def update_me(
request: Request,
payload: UpdateMeRequest,
current_user: Annotated[User, Depends(get_current_active_user)],
):
profile = await AuthService.update_me(payload, current_user)
return success(profile)
@router.post("/password-reset/request", summary="请求密码重置邮件")
@audit(action=AuditAction.RESET_PASSWORD, description="请求密码重置邮件", body_fields=["email"])
async def password_reset_request_endpoint(request: Request, payload: PasswordResetRequest):
await AuthService.request_password_reset(payload)
return success(msg="如果邮箱存在,将发送重置邮件")
@router.get("/password-reset/verify", summary="校验密码重置令牌")
@audit(action=AuditAction.RESET_PASSWORD, description="校验密码重置令牌", redact_fields=["token"])
async def password_reset_verify(request: Request, token: str):
user = await AuthService.verify_password_reset_token(token)
return success({"username": user.username, "email": user.email})
@router.post("/password-reset/confirm", summary="使用令牌重置密码")
@audit(
action=AuditAction.RESET_PASSWORD,
description="重置密码",
body_fields=["token"],
redact_fields=["token", "password"],
)
async def password_reset_confirm(request: Request, payload: PasswordResetConfirm):
await AuthService.reset_password_with_token(payload)
return success(msg="密码已重置")

356
domain/auth/service.py Normal file
View File

@@ -0,0 +1,356 @@
import asyncio
import hashlib
import secrets
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from typing import Annotated
import jwt
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from jwt.exceptions import InvalidTokenError
from passlib.context import CryptContext
from domain.auth.types import (
PasswordResetConfirm,
PasswordResetRequest,
RegisterRequest,
Token,
TokenData,
UpdateMeRequest,
User,
UserInDB,
)
from models.database import UserAccount
from domain.config.service import ConfigService
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 * 365
PASSWORD_RESET_TOKEN_EXPIRE_MINUTES = 10
def _now() -> datetime:
return datetime.now(timezone.utc)
@dataclass
class PasswordResetEntry:
user_id: int
email: str
username: str
expires_at: datetime
used: bool = False
class PasswordResetStore:
_tokens: dict[str, PasswordResetEntry] = {}
_lock = asyncio.Lock()
@classmethod
def _cleanup(cls):
now = _now()
for token, record in list(cls._tokens.items()):
if record.used or record.expires_at < now:
cls._tokens.pop(token, None)
@classmethod
async def create(cls, user: UserAccount) -> str:
async with cls._lock:
cls._cleanup()
for key, record in list(cls._tokens.items()):
if record.user_id == user.id:
cls._tokens.pop(key, None)
token = secrets.token_urlsafe(32)
expires_at = _now() + timedelta(minutes=PASSWORD_RESET_TOKEN_EXPIRE_MINUTES)
cls._tokens[token] = PasswordResetEntry(
user_id=user.id,
email=user.email or "",
username=user.username,
expires_at=expires_at,
)
return token
@classmethod
async def get(cls, token: str) -> PasswordResetEntry | None:
async with cls._lock:
cls._cleanup()
record = cls._tokens.get(token)
if not record or record.used:
return None
return record
@classmethod
async def mark_used(cls, token: str) -> None:
async with cls._lock:
record = cls._tokens.get(token)
if record:
record.used = True
cls._cleanup()
@classmethod
async def invalidate_user(cls, user_id: int, except_token: str | None = None) -> None:
async with cls._lock:
for key, record in list(cls._tokens.items()):
if record.user_id == user_id and key != except_token:
cls._tokens.pop(key, None)
cls._cleanup()
class AuthService:
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login")
algorithm = ALGORITHM
access_token_expire_minutes = ACCESS_TOKEN_EXPIRE_MINUTES
password_reset_token_expire_minutes = PASSWORD_RESET_TOKEN_EXPIRE_MINUTES
@classmethod
async def get_secret_key(cls) -> str:
return await ConfigService.get_secret_key("SECRET_KEY", None)
@classmethod
def _normalize_email(cls, email: str | None) -> str:
return (email or "").strip().lower()
@classmethod
def verify_password(cls, plain_password: str, hashed_password: str) -> bool:
return cls.pwd_context.verify(plain_password, hashed_password)
@classmethod
def get_password_hash(cls, password: str) -> str:
return cls.pwd_context.hash(password)
@classmethod
async def get_user_db(cls, username_or_email: str) -> UserInDB | None:
user = await UserAccount.get_or_none(username=username_or_email)
if not user:
user = await UserAccount.get_or_none(email=username_or_email)
if user:
return UserInDB(
id=user.id,
username=user.username,
email=user.email,
full_name=user.full_name,
disabled=user.disabled,
hashed_password=user.hashed_password,
)
return None
@classmethod
async def authenticate_user_db(cls, username_or_email: str, password: str) -> UserInDB | None:
user = await cls.get_user_db(username_or_email)
if not user:
return None
if not cls.verify_password(password, user.hashed_password):
return None
return user
@classmethod
async def has_users(cls) -> bool:
user_count = await UserAccount.all().count()
return user_count > 0
@classmethod
async def register_user(cls, payload: RegisterRequest):
if await cls.has_users():
raise HTTPException(status_code=403, detail="系统已初始化,不允许注册新用户")
exists = await UserAccount.get_or_none(username=payload.username)
if exists:
raise HTTPException(status_code=400, detail="用户名已存在")
hashed = cls.get_password_hash(payload.password)
user = await UserAccount.create(
username=payload.username,
email=payload.email,
full_name=payload.full_name,
hashed_password=hashed,
disabled=False,
)
return user
@classmethod
async def create_access_token(cls, data: dict, expires_delta: timedelta | None = None):
to_encode = data.copy()
if "sub" not in to_encode and "username" in to_encode:
to_encode["sub"] = to_encode["username"]
expire = _now() + (expires_delta or timedelta(minutes=15))
to_encode.update({"exp": expire})
secret_key = await cls.get_secret_key()
encoded_jwt = jwt.encode(to_encode, secret_key, algorithm=cls.algorithm)
return encoded_jwt
@classmethod
async def login(cls, form: OAuth2PasswordRequestForm) -> Token:
user = await cls.authenticate_user_db(form.username, form.password)
if not user:
raise HTTPException(
status_code=401,
detail="用户名或密码错误",
headers={"WWW-Authenticate": "Bearer"},
)
access_token_expires = timedelta(minutes=cls.access_token_expire_minutes)
access_token = await cls.create_access_token(
data={"sub": user.username}, expires_delta=access_token_expires
)
return Token(access_token=access_token, token_type="bearer")
@classmethod
def _build_profile(cls, user: User | UserInDB | UserAccount) -> dict:
email = cls._normalize_email(getattr(user, "email", None))
md5_hash = hashlib.md5(email.encode("utf-8")).hexdigest()
gravatar_url = f"https://cn.cravatar.com/avatar/{md5_hash}?s=64&d=identicon"
return {
"id": user.id,
"username": user.username,
"email": getattr(user, "email", None),
"full_name": getattr(user, "full_name", None),
"gravatar_url": gravatar_url,
}
@classmethod
def get_profile(cls, user: User | UserInDB | UserAccount) -> dict:
return cls._build_profile(user)
@classmethod
async def update_me(cls, payload: UpdateMeRequest, current_user: User) -> dict:
db_user = await UserAccount.get_or_none(id=current_user.id)
if not db_user:
raise HTTPException(status_code=404, detail="用户不存在")
if payload.email is not None:
exists = (
await UserAccount.filter(email=payload.email)
.exclude(id=db_user.id)
.exists()
)
if exists:
raise HTTPException(status_code=400, detail="邮箱已被占用")
db_user.email = payload.email
if payload.full_name is not None:
db_user.full_name = payload.full_name
if payload.new_password:
if not payload.old_password:
raise HTTPException(status_code=400, detail="请提供原密码")
if not cls.verify_password(payload.old_password, db_user.hashed_password):
raise HTTPException(status_code=400, detail="原密码错误")
db_user.hashed_password = cls.get_password_hash(payload.new_password)
await db_user.save()
return cls._build_profile(db_user)
@classmethod
async def request_password_reset(cls, payload: PasswordResetRequest) -> bool:
normalized = cls._normalize_email(payload.email)
if not normalized:
return False
user = await UserAccount.get_or_none(email=normalized)
if not user or not user.email:
return False
token = await PasswordResetStore.create(user)
try:
await cls._send_password_reset_email(user, token)
except Exception as exc: # noqa: BLE001
await PasswordResetStore.mark_used(token)
await PasswordResetStore.invalidate_user(user.id)
raise HTTPException(status_code=500, detail="邮件发送失败") from exc
return True
@classmethod
async def verify_password_reset_token(cls, token: str) -> UserAccount:
record = await PasswordResetStore.get(token)
if not record:
raise HTTPException(status_code=400, detail="重置链接无效")
user = await UserAccount.get_or_none(id=record.user_id)
if not user:
raise HTTPException(status_code=400, detail="重置链接无效")
if record.expires_at < _now():
await PasswordResetStore.mark_used(token)
raise HTTPException(status_code=400, detail="重置链接已过期")
return user
@classmethod
async def reset_password_with_token(cls, payload: PasswordResetConfirm) -> None:
record = await PasswordResetStore.get(payload.token)
if not record:
raise HTTPException(status_code=400, detail="重置链接无效")
if record.expires_at < _now():
await PasswordResetStore.mark_used(payload.token)
raise HTTPException(status_code=400, detail="重置链接已过期")
user = await UserAccount.get_or_none(id=record.user_id)
if not user:
raise HTTPException(status_code=400, detail="重置链接无效")
user.hashed_password = cls.get_password_hash(payload.password)
await user.save(update_fields=["hashed_password"])
await PasswordResetStore.mark_used(payload.token)
await PasswordResetStore.invalidate_user(user.id)
@classmethod
async def get_current_user(cls, token: str):
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
secret_key = await cls.get_secret_key()
payload = jwt.decode(token, secret_key, algorithms=[cls.algorithm])
username = payload.get("sub")
if username is None:
raise credentials_exception
token_data = TokenData(username=username)
except InvalidTokenError:
raise credentials_exception
user = await cls.get_user_db(token_data.username)
if user is None:
raise credentials_exception
return user
@classmethod
async def get_current_active_user(cls, current_user: User):
if current_user.disabled:
raise HTTPException(status_code=400, detail="Inactive user")
return current_user
@classmethod
async def _send_password_reset_email(cls, user: UserAccount, token: str) -> None:
from domain.email.service import EmailService
app_domain = await ConfigService.get("APP_DOMAIN", None)
base_url = (app_domain or "http://localhost:5173").rstrip("/")
reset_link = f"{base_url}/reset-password?token={token}"
await EmailService.enqueue_email(
recipients=[user.email],
subject="Foxel 密码重置",
template="password_reset",
context={
"username": user.username,
"reset_link": reset_link,
"expire_minutes": cls.password_reset_token_expire_minutes,
},
)
async def _current_user_dep(token: Annotated[str, Depends(AuthService.oauth2_scheme)]):
return await AuthService.get_current_user(token)
async def _current_active_user_dep(
current_user: Annotated[User, Depends(_current_user_dep)],
):
return await AuthService.get_current_active_user(current_user)
# 方便依赖注入与外部使用
get_current_user = _current_user_dep
get_current_active_user = _current_active_user_dep
authenticate_user_db = AuthService.authenticate_user_db
create_access_token = AuthService.create_access_token
register_user = AuthService.register_user
request_password_reset = AuthService.request_password_reset
verify_password_reset_token = AuthService.verify_password_reset_token
reset_password_with_token = AuthService.reset_password_with_token
has_users = AuthService.has_users
verify_password = AuthService.verify_password
get_password_hash = AuthService.get_password_hash

45
domain/auth/types.py Normal file
View File

@@ -0,0 +1,45 @@
from pydantic import BaseModel
class Token(BaseModel):
access_token: str
token_type: str
class TokenData(BaseModel):
username: str | None = None
class User(BaseModel):
id: int
username: str
email: str | None = None
full_name: str | None = None
disabled: bool | None = None
class UserInDB(User):
hashed_password: str
class RegisterRequest(BaseModel):
username: str
password: str
email: str | None = None
full_name: str | None = None
class UpdateMeRequest(BaseModel):
email: str | None = None
full_name: str | None = None
old_password: str | None = None
new_password: str | None = None
class PasswordResetRequest(BaseModel):
email: str
class PasswordResetConfirm(BaseModel):
token: str
password: str

View File

@@ -0,0 +1 @@

30
domain/backup/api.py Normal file
View File

@@ -0,0 +1,30 @@
import datetime
from fastapi import APIRouter, Depends, File, Request, UploadFile
from fastapi.responses import JSONResponse
from domain.audit import AuditAction, audit
from domain.auth.service import get_current_active_user
from domain.backup.service import BackupService
router = APIRouter(
prefix="/api/backup",
tags=["Backup & Restore"],
dependencies=[Depends(get_current_active_user)],
)
@router.get("/export", summary="导出全站数据")
@audit(action=AuditAction.DOWNLOAD, description="导出备份")
async def export_backup(request: Request):
data = await BackupService.export_data()
timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
headers = {"Content-Disposition": f"attachment; filename=foxel_backup_{timestamp}.json"}
return JSONResponse(content=data.model_dump(), headers=headers)
@router.post("/import", summary="导入数据")
@audit(action=AuditAction.UPLOAD, description="导入备份")
async def import_backup(request: Request, file: UploadFile = File(...)):
await BackupService.import_from_bytes(file.filename, await file.read())
return {"message": "数据导入成功。"}

203
domain/backup/service.py Normal file
View File

@@ -0,0 +1,203 @@
import json
from datetime import datetime
from fastapi import HTTPException
from tortoise.transactions import in_transaction
from domain.backup.types import BackupData
from domain.config.service import VERSION
from models.database import (
AIDefaultModel,
AIModel,
AIProvider,
AutomationTask,
Configuration,
Plugin,
ShareLink,
StorageAdapter,
UserAccount,
)
class BackupService:
@classmethod
async def export_data(cls) -> BackupData:
async with in_transaction():
adapters = await StorageAdapter.all().values()
users = await UserAccount.all().values()
tasks = await AutomationTask.all().values()
shares = await ShareLink.all().values()
configs = await Configuration.all().values()
providers = await AIProvider.all().values()
models = await AIModel.all().values()
default_models = await AIDefaultModel.all().values()
plugins = await Plugin.all().values()
share_links = cls._serialize_datetime_fields(
shares, ["created_at", "expires_at"]
)
ai_providers = cls._serialize_datetime_fields(
providers, ["created_at", "updated_at"]
)
ai_models = cls._serialize_datetime_fields(
models, ["created_at", "updated_at"]
)
ai_default_models = cls._serialize_datetime_fields(
default_models, ["created_at", "updated_at"]
)
plugin_items = cls._serialize_datetime_fields(
plugins, ["created_at", "updated_at"]
)
return BackupData(
version=VERSION,
storage_adapters=list(adapters),
user_accounts=list(users),
automation_tasks=list(tasks),
share_links=share_links,
configurations=list(configs),
ai_providers=ai_providers,
ai_models=ai_models,
ai_default_models=ai_default_models,
plugins=plugin_items,
)
@classmethod
async def import_from_bytes(cls, filename: str, content: bytes) -> None:
if not filename.endswith(".json"):
raise HTTPException(status_code=400, detail="无效的文件类型, 请上传 .json 文件")
try:
raw_data = json.loads(content)
except Exception:
raise HTTPException(status_code=400, detail="无法解析JSON文件")
await cls.import_data(BackupData(**raw_data))
@classmethod
async def import_data(cls, payload: BackupData) -> None:
async with in_transaction() as conn:
await ShareLink.all().using_db(conn).delete()
await AutomationTask.all().using_db(conn).delete()
await StorageAdapter.all().using_db(conn).delete()
await UserAccount.all().using_db(conn).delete()
await Configuration.all().using_db(conn).delete()
await AIDefaultModel.all().using_db(conn).delete()
await AIModel.all().using_db(conn).delete()
await AIProvider.all().using_db(conn).delete()
await Plugin.all().using_db(conn).delete()
if payload.configurations:
await Configuration.bulk_create(
[Configuration(**config) for config in payload.configurations],
using_db=conn,
)
if payload.user_accounts:
await UserAccount.bulk_create(
[UserAccount(**user) for user in payload.user_accounts],
using_db=conn,
)
if payload.storage_adapters:
await StorageAdapter.bulk_create(
[StorageAdapter(**adapter) for adapter in payload.storage_adapters],
using_db=conn,
)
if payload.automation_tasks:
await AutomationTask.bulk_create(
[AutomationTask(**task) for task in payload.automation_tasks],
using_db=conn,
)
if payload.share_links:
await ShareLink.bulk_create(
[
ShareLink(**share)
for share in cls._parse_datetime_fields(
payload.share_links, ["created_at", "expires_at"]
)
],
using_db=conn,
)
if payload.ai_providers:
await AIProvider.bulk_create(
[
AIProvider(**item)
for item in cls._parse_datetime_fields(
payload.ai_providers, ["created_at", "updated_at"]
)
],
using_db=conn,
)
if payload.ai_models:
await AIModel.bulk_create(
[
AIModel(**item)
for item in cls._parse_datetime_fields(
payload.ai_models, ["created_at", "updated_at"]
)
],
using_db=conn,
)
if payload.ai_default_models:
await AIDefaultModel.bulk_create(
[
AIDefaultModel(**item)
for item in cls._parse_datetime_fields(
payload.ai_default_models, ["created_at", "updated_at"]
)
],
using_db=conn,
)
if payload.plugins:
await Plugin.bulk_create(
[
Plugin(**item)
for item in cls._parse_datetime_fields(
payload.plugins, ["created_at", "updated_at"]
)
],
using_db=conn,
)
@staticmethod
def _serialize_datetime_fields(
records: list[dict], fields: list[str]
) -> list[dict]:
serialized: list[dict] = []
for record in records:
item = dict(record)
for field in fields:
value = item.get(field)
if isinstance(value, datetime):
item[field] = value.isoformat()
serialized.append(item)
return serialized
@staticmethod
def _parse_datetime_fields(
records: list[dict], fields: list[str]
) -> list[dict]:
parsed: list[dict] = []
for record in records:
item = dict(record)
for field in fields:
value = item.get(field)
if isinstance(value, str):
item[field] = BackupService._from_iso(value)
parsed.append(item)
return parsed
@staticmethod
def _from_iso(value: str) -> datetime | None:
if not value:
return None
normalized = value.replace("Z", "+00:00")
try:
return datetime.fromisoformat(normalized)
except ValueError as exc: # noqa: BLE001
raise HTTPException(status_code=400, detail="无效的日期格式") from exc

16
domain/backup/types.py Normal file
View File

@@ -0,0 +1,16 @@
from typing import Any
from pydantic import BaseModel, Field
class BackupData(BaseModel):
version: str | None = None
storage_adapters: list[dict[str, Any]] = Field(default_factory=list)
user_accounts: list[dict[str, Any]] = Field(default_factory=list)
automation_tasks: list[dict[str, Any]] = Field(default_factory=list)
share_links: list[dict[str, Any]] = Field(default_factory=list)
configurations: list[dict[str, Any]] = Field(default_factory=list)
ai_providers: list[dict[str, Any]] = Field(default_factory=list)
ai_models: list[dict[str, Any]] = Field(default_factory=list)
ai_default_models: list[dict[str, Any]] = Field(default_factory=list)
plugins: list[dict[str, Any]] = Field(default_factory=list)

59
domain/config/api.py Normal file
View File

@@ -0,0 +1,59 @@
from typing import Annotated
from fastapi import APIRouter, Depends, Form, Request
from api.response import success
from domain.audit import AuditAction, audit
from domain.auth.service import get_current_active_user
from domain.auth.types import User
from domain.config.service import ConfigService
from domain.config.types import ConfigItem
router = APIRouter(prefix="/api/config", tags=["config"])
@router.get("/")
@audit(action=AuditAction.READ, description="获取配置")
async def get_config(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
key: str,
):
value = await ConfigService.get(key)
return success(ConfigItem(key=key, value=value).model_dump())
@router.post("/")
@audit(action=AuditAction.UPDATE, description="设置配置", body_fields=["key", "value"])
async def set_config(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
key: str = Form(...),
value: str = Form(...),
):
await ConfigService.set(key, value)
return success(ConfigItem(key=key, value=value).model_dump())
@router.get("/all")
@audit(action=AuditAction.READ, description="获取全部配置")
async def get_all_config(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
):
configs = await ConfigService.get_all()
return success(configs)
@router.get("/status")
@audit(action=AuditAction.READ, description="获取系统状态")
async def get_system_status(request: Request):
status_data = await ConfigService.get_system_status()
return success(status_data.model_dump())
@router.get("/latest-version")
@audit(action=AuditAction.READ, description="获取最新版本")
async def get_latest_version(request: Request):
info = await ConfigService.get_latest_version()
return success(info.model_dump())

111
domain/config/service.py Normal file
View File

@@ -0,0 +1,111 @@
import os
import time
from typing import Any, Dict, Optional
import httpx
from dotenv import load_dotenv
from domain.config.types import LatestVersionInfo, SystemStatus
from models.database import Configuration, UserAccount
load_dotenv(dotenv_path=".env")
VERSION = "v1.4.0"
class ConfigService:
_cache: Dict[str, Any] = {}
_latest_version_cache: Dict[str, Any] = {"timestamp": 0.0, "data": None}
@classmethod
async def get(cls, key: str, default: Optional[Any] = None) -> Any:
if key in cls._cache:
return cls._cache[key]
try:
config = await Configuration.get_or_none(key=key)
if config:
cls._cache[key] = config.value
return config.value
except Exception:
pass
env_value = os.getenv(key)
if env_value is not None:
cls._cache[key] = env_value
return env_value
return default
@classmethod
async def get_secret_key(cls, key: str, default: Optional[Any] = None) -> bytes:
value = await cls.get(key, default)
if isinstance(value, bytes):
return value
if isinstance(value, str):
return value.encode("utf-8")
if value is None:
raise ValueError(f"Secret key '{key}' not found in config or environment.")
return str(value).encode("utf-8")
@classmethod
async def set(cls, key: str, value: Any):
obj, _ = await Configuration.get_or_create(key=key, defaults={"value": value})
obj.value = value
await obj.save()
cls._cache[key] = value
@classmethod
async def get_all(cls) -> Dict[str, Any]:
try:
configs = await Configuration.all()
result = {}
for config in configs:
result[config.key] = config.value
cls._cache[config.key] = config.value
return result
except Exception:
return {}
@classmethod
def clear_cache(cls):
cls._cache.clear()
@classmethod
async def get_system_status(cls) -> SystemStatus:
logo = await cls.get("APP_LOGO", "/logo.svg")
favicon = await cls.get("APP_FAVICON", logo)
user_count = await UserAccount.all().count()
return SystemStatus(
version=VERSION,
title=await cls.get("APP_NAME", "Foxel"),
logo=logo,
favicon=favicon,
is_initialized=user_count > 0,
app_domain=await cls.get("APP_DOMAIN"),
file_domain=await cls.get("FILE_DOMAIN"),
)
@classmethod
async def get_latest_version(cls) -> LatestVersionInfo:
current_time = time.time()
cache = cls._latest_version_cache
if current_time - cache["timestamp"] < 3600 and cache["data"]:
return cache["data"]
try:
async with httpx.AsyncClient(timeout=10.0) as client:
resp = await client.get(
"https://api.github.com/repos/DrizzleTime/Foxel/releases/latest",
follow_redirects=True,
)
resp.raise_for_status()
data = resp.json()
version_info = LatestVersionInfo(
latest_version=data.get("tag_name"),
body=data.get("body"),
)
cache["timestamp"] = current_time
cache["data"] = version_info
return version_info
except httpx.RequestError:
if cache["data"]:
return cache["data"]
return LatestVersionInfo()

23
domain/config/types.py Normal file
View File

@@ -0,0 +1,23 @@
from typing import Any, Optional
from pydantic import BaseModel
class ConfigItem(BaseModel):
key: str
value: Optional[Any] = None
class SystemStatus(BaseModel):
version: str
title: str
logo: str
favicon: str
is_initialized: bool
app_domain: Optional[str] = None
file_domain: Optional[str] = None
class LatestVersionInfo(BaseModel):
latest_version: Optional[str] = None
body: Optional[str] = None

View File

@@ -1,20 +1,24 @@
from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, Depends, HTTPException, Request
from services.auth import User, get_current_active_user
from services.email import EmailService, EmailTemplateRenderer
from schemas.email import EmailTestRequest, EmailTemplateUpdate, EmailTemplatePreviewPayload
from api.response import success
from services.logging import LogService
router = APIRouter(
prefix="/api/email",
tags=["email"],
from domain.audit import AuditAction, audit
from domain.auth.service import get_current_active_user
from domain.auth.types import User
from domain.email.service import EmailService, EmailTemplateRenderer
from domain.email.types import (
EmailTemplatePreviewPayload,
EmailTemplateUpdate,
EmailTestRequest,
)
router = APIRouter(prefix="/api/email", tags=["email"])
@router.post("/test")
@audit(action=AuditAction.CREATE, description="发送测试邮件")
async def trigger_test_email(
request: Request,
payload: EmailTestRequest,
current_user: User = Depends(get_current_active_user),
):
@@ -27,17 +31,13 @@ async def trigger_test_email(
)
except Exception as exc:
raise HTTPException(status_code=400, detail=str(exc))
await LogService.action(
"route:email",
"Triggered email test",
details={"task_id": task.id, "template": payload.template, "to": str(payload.to)},
user_id=getattr(current_user, "id", None),
)
return success({"task_id": task.id})
@router.get("/templates")
@audit(action=AuditAction.READ, description="获取邮件模板列表")
async def list_email_templates(
request: Request,
current_user: User = Depends(get_current_active_user),
):
templates = await EmailTemplateRenderer.list_templates()
@@ -45,7 +45,9 @@ async def list_email_templates(
@router.get("/templates/{name}")
@audit(action=AuditAction.READ, description="查看邮件模板")
async def get_email_template(
request: Request,
name: str,
current_user: User = Depends(get_current_active_user),
):
@@ -59,7 +61,9 @@ async def get_email_template(
@router.post("/templates/{name}")
@audit(action=AuditAction.UPDATE, description="更新邮件模板")
async def update_email_template(
request: Request,
name: str,
payload: EmailTemplateUpdate,
current_user: User = Depends(get_current_active_user),
@@ -68,17 +72,13 @@ async def update_email_template(
await EmailTemplateRenderer.save(name, payload.content)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc))
await LogService.action(
"route:email",
"Updated email template",
details={"template": name},
user_id=getattr(current_user, "id", None),
)
return success({"name": name})
@router.post("/templates/{name}/preview")
@audit(action=AuditAction.READ, description="预览邮件模板")
async def preview_email_template(
request: Request,
name: str,
payload: EmailTemplatePreviewPayload,
current_user: User = Depends(get_current_active_user),

View File

@@ -1,42 +1,14 @@
import asyncio
import json
import re
import smtplib
from email.message import EmailMessage
from email.utils import formataddr
from enum import Enum
from pathlib import Path
from string import Template
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, EmailStr, Field, ValidationError
from services.config import ConfigCenter
from services.logging import LogService
class EmailSecurity(str, Enum):
NONE = "none"
SSL = "ssl"
STARTTLS = "starttls"
class EmailConfig(BaseModel):
host: str
port: int = Field(..., gt=0)
username: Optional[str] = None
password: Optional[str] = None
sender_email: EmailStr
sender_name: Optional[str] = None
security: EmailSecurity = EmailSecurity.NONE
timeout: float = Field(default=30.0, gt=0.0)
class EmailSendPayload(BaseModel):
recipients: List[EmailStr] = Field(..., min_length=1)
subject: str = Field(..., min_length=1)
template: str = Field(..., min_length=1)
context: Dict[str, Any] = Field(default_factory=dict)
from domain.config.service import ConfigService
from domain.email.types import EmailConfig, EmailSecurity, EmailSendPayload
class EmailTemplateRenderer:
@@ -52,9 +24,7 @@ class EmailTemplateRenderer:
async def list_templates(cls) -> list[str]:
cls.ROOT.mkdir(parents=True, exist_ok=True)
return sorted(
path.stem
for path in cls.ROOT.glob("*.html")
if path.is_file()
path.stem for path in cls.ROOT.glob("*.html") if path.is_file()
)
@classmethod
@@ -82,22 +52,8 @@ class EmailService:
@classmethod
async def _load_config(cls) -> EmailConfig:
raw_config = await ConfigCenter.get(cls.CONFIG_KEY)
if raw_config is None:
raise ValueError("Email configuration not found")
if isinstance(raw_config, str):
raw_config = raw_config.strip()
data: Any = json.loads(raw_config) if raw_config else {}
elif isinstance(raw_config, dict):
data = raw_config
else:
raise ValueError("Invalid email configuration format")
try:
return EmailConfig(**data)
except ValidationError as exc:
raise ValueError(f"Invalid email configuration: {exc}") from exc
raw_config = await ConfigService.get(cls.CONFIG_KEY)
return EmailConfig.parse_config(raw_config)
@staticmethod
def _html_to_text(html: str) -> str:
@@ -108,7 +64,9 @@ class EmailService:
async def _deliver(cls, config: EmailConfig, payload: EmailSendPayload, html_body: str):
message = EmailMessage()
message["Subject"] = payload.subject
message["From"] = formataddr((config.sender_name or str(config.sender_email), str(config.sender_email)))
message["From"] = formataddr(
(config.sender_name or str(config.sender_email), str(config.sender_email))
)
message["To"] = ", ".join([str(addr) for addr in payload.recipients])
plain_body = cls._html_to_text(html_body)
@@ -120,7 +78,9 @@ class EmailService:
@staticmethod
def _deliver_sync(config: EmailConfig, message: EmailMessage):
if config.security == EmailSecurity.SSL:
smtp: smtplib.SMTP = smtplib.SMTP_SSL(config.host, config.port, timeout=config.timeout)
smtp: smtplib.SMTP = smtplib.SMTP_SSL(
config.host, config.port, timeout=config.timeout
)
else:
smtp = smtplib.SMTP(config.host, config.port, timeout=config.timeout)
@@ -144,7 +104,7 @@ class EmailService:
template: str,
context: Optional[Dict[str, Any]] = None,
):
from services.task_queue import TaskProgress, task_queue_service
from domain.tasks.task_queue import TaskProgress, task_queue_service
payload = EmailSendPayload(
recipients=recipients,
@@ -162,16 +122,11 @@ class EmailService:
task.id,
TaskProgress(stage="queued", percent=0.0, detail="Waiting to send"),
)
await LogService.action(
"email_service",
"Email task enqueued",
details={"task_id": task.id, "subject": subject, "template": template},
)
return task
@classmethod
async def send_from_task(cls, task_id: str, data: Dict[str, Any]):
from services.task_queue import TaskProgress, task_queue_service
from domain.tasks.task_queue import TaskProgress, task_queue_service
payload = EmailSendPayload(**data)
@@ -194,8 +149,3 @@ class EmailService:
task_id,
TaskProgress(stage="completed", percent=100.0, detail="Email sent"),
)
await LogService.info(
"email_service",
"Email sent",
details={"task_id": task_id, "subject": payload.subject},
)

63
domain/email/types.py Normal file
View File

@@ -0,0 +1,63 @@
import json
from enum import Enum
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, EmailStr, Field, ValidationError
class EmailSecurity(str, Enum):
NONE = "none"
SSL = "ssl"
STARTTLS = "starttls"
class EmailConfig(BaseModel):
host: str
port: int = Field(..., gt=0)
username: Optional[str] = None
password: Optional[str] = None
sender_email: EmailStr
sender_name: Optional[str] = None
security: EmailSecurity = EmailSecurity.NONE
timeout: float = Field(default=30.0, gt=0.0)
@classmethod
def parse_config(cls, raw_config: Any) -> "EmailConfig":
"""接受字符串或 dict 配置并解析为 EmailConfig。"""
if raw_config is None:
raise ValueError("Email configuration not found")
if isinstance(raw_config, str):
raw_config = raw_config.strip()
data: Any = json.loads(raw_config) if raw_config else {}
elif isinstance(raw_config, dict):
data = raw_config
else:
raise ValueError("Invalid email configuration format")
try:
return cls(**data)
except ValidationError as exc:
raise ValueError(f"Invalid email configuration: {exc}") from exc
class EmailSendPayload(BaseModel):
recipients: List[EmailStr] = Field(..., min_length=1)
subject: str = Field(..., min_length=1)
template: str = Field(..., min_length=1)
context: Dict[str, Any] = Field(default_factory=dict)
class EmailTestRequest(BaseModel):
to: EmailStr
subject: str = Field(..., min_length=1)
template: str = Field(default="test", min_length=1)
context: Dict[str, Any] = Field(default_factory=dict)
class EmailTemplateUpdate(BaseModel):
content: str
class EmailTemplatePreviewPayload(BaseModel):
context: Dict[str, Any] = Field(default_factory=dict)

View File

@@ -0,0 +1,42 @@
from typing import Annotated
from fastapi import APIRouter, Depends, Request
from api.response import success
from domain.audit import AuditAction, audit
from domain.auth.service import get_current_active_user
from domain.auth.types import User
from domain.offline_downloads.service import OfflineDownloadService
from domain.offline_downloads.types import OfflineDownloadCreate
CurrentUser = Annotated[User, Depends(get_current_active_user)]
router = APIRouter(
prefix="/api/offline-downloads",
tags=["OfflineDownloads"],
)
@router.post("/")
@audit(
action=AuditAction.CREATE,
description="创建离线下载任务",
body_fields=["url", "dest_dir", "filename"],
)
async def create_offline_download(request: Request, payload: OfflineDownloadCreate, current_user: CurrentUser):
data = await OfflineDownloadService.create_download(payload, current_user)
return success(data)
@router.get("/")
@audit(action=AuditAction.READ, description="获取离线下载列表")
async def list_offline_downloads(request: Request, current_user: CurrentUser):
data = OfflineDownloadService.list_downloads()
return success(data)
@router.get("/{task_id}")
@audit(action=AuditAction.READ, description="获取离线下载详情")
async def get_offline_download(task_id: str, request: Request, current_user: CurrentUser):
data = OfflineDownloadService.get_download(task_id)
return success(data)

View File

@@ -0,0 +1,252 @@
import os
import time
from pathlib import Path
from typing import Annotated, AsyncIterator
import aiofiles
import aiohttp
from fastapi import Depends, HTTPException
from domain.auth.service import get_current_active_user
from domain.auth.types import User
from domain.offline_downloads.types import OfflineDownloadCreate
from domain.virtual_fs.service import VirtualFSService
from domain.tasks.task_queue import Task, TaskProgress, task_queue_service
class OfflineDownloadService:
current_user_dep = Annotated[User, Depends(get_current_active_user)]
temp_root = Path("data/tmp/offline_downloads")
@classmethod
async def create_download(cls, payload: OfflineDownloadCreate, current_user: User) -> dict:
await cls._ensure_destination(payload.dest_dir)
task = await task_queue_service.add_task(
"offline_http_download",
{
"url": str(payload.url),
"dest_dir": payload.dest_dir,
"filename": payload.filename,
},
)
await task_queue_service.update_progress(
task.id,
TaskProgress(
stage="queued",
percent=0.0,
bytes_total=None,
bytes_done=0,
detail="Waiting to start",
),
)
return {"task_id": task.id}
@classmethod
def list_downloads(cls) -> list[dict]:
tasks = [t for t in task_queue_service.get_all_tasks() if t.name == "offline_http_download"]
return [t.dict() for t in tasks]
@classmethod
def get_download(cls, task_id: str) -> dict:
task = task_queue_service.get_task(task_id)
if not task or task.name != "offline_http_download":
raise HTTPException(status_code=404, detail="Task not found")
return task.dict()
@classmethod
async def run_http_download(cls, task: Task):
params = task.task_info
url = params.get("url")
dest_dir = params.get("dest_dir")
filename = params.get("filename")
if not url or not dest_dir or not filename:
raise ValueError("Missing required parameters for offline download")
cls.temp_root.mkdir(parents=True, exist_ok=True)
temp_dir = cls.temp_root / task.id
temp_dir.mkdir(parents=True, exist_ok=True)
temp_file = temp_dir / "payload"
bytes_total: int | None = None
bytes_done = 0
last_update = time.monotonic()
await task_queue_service.update_progress(
task.id,
TaskProgress(
stage="downloading",
percent=0.0,
bytes_total=None,
bytes_done=0,
detail="HTTP downloading",
),
)
async def report_download(delta: int, total: int | None):
nonlocal bytes_done, bytes_total, last_update
if total is not None:
bytes_total = total
bytes_done += delta
now = time.monotonic()
if delta and now - last_update < 0.5:
return
last_update = now
percent = None
total_for_display = bytes_total if bytes_total is not None else None
if bytes_total:
percent = min(100.0, round(bytes_done / bytes_total * 100, 2))
await task_queue_service.update_progress(
task.id,
TaskProgress(
stage="downloading",
percent=percent,
bytes_total=total_for_display,
bytes_done=bytes_done,
detail="HTTP downloading",
),
)
timeout = aiohttp.ClientTimeout(total=None, connect=30)
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.get(url) as resp:
if resp.status != 200:
raise ValueError(f"HTTP {resp.status} for {url}")
content_length = resp.headers.get("Content-Length")
total_size = int(content_length) if content_length else None
bytes_done = 0
async with aiofiles.open(temp_file, "wb") as f:
async for chunk in resp.content.iter_chunked(512 * 1024):
if not chunk:
continue
await f.write(chunk)
await report_download(len(chunk), total_size)
await report_download(0, total_size)
file_size = os.path.getsize(temp_file)
bytes_done_transfer = 0
async def report_transfer(delta: int):
nonlocal bytes_done_transfer
bytes_done_transfer += delta
percent = min(100.0, round(bytes_done_transfer / file_size * 100, 2)) if file_size else None
await task_queue_service.update_progress(
task.id,
TaskProgress(
stage="transferring",
percent=percent,
bytes_total=file_size or None,
bytes_done=bytes_done_transfer,
detail="Saving to storage",
),
)
async def chunk_iter() -> AsyncIterator[bytes]:
async for chunk in cls._iter_file(temp_file, 512 * 1024, report_transfer):
yield chunk
final_path, resolved_name = await cls._allocate_destination(dest_dir, filename)
await task_queue_service.update_progress(
task.id,
TaskProgress(
stage="transferring",
percent=0.0,
bytes_total=file_size or None,
bytes_done=0,
detail="Saving to storage",
),
)
await VirtualFSService.write_file_stream(final_path, chunk_iter())
await task_queue_service.update_progress(
task.id,
TaskProgress(
stage="completed",
percent=100.0,
bytes_total=file_size or None,
bytes_done=file_size,
detail="Completed",
),
)
await task_queue_service.update_meta(task.id, {"final_path": final_path, "filename": resolved_name})
try:
os.remove(temp_file)
temp_dir.rmdir()
except Exception:
pass
return final_path
@classmethod
async def _ensure_destination(cls, dest_dir: str) -> None:
try:
is_dir = await VirtualFSService.path_is_directory(dest_dir)
except HTTPException:
is_dir = False
if not is_dir:
raise HTTPException(400, detail="Destination directory not found")
@staticmethod
def _normalize_path(path: str) -> str:
if not path:
return "/"
if not path.startswith("/"):
path = "/" + path
if len(path) > 1 and path.endswith("/"):
path = path.rstrip("/")
return path or "/"
@staticmethod
async def _path_exists(full_path: str) -> bool:
try:
await VirtualFSService.stat_file(full_path)
return True
except FileNotFoundError:
return False
except HTTPException as exc: # noqa: PERF203
if exc.status_code == 404:
return False
raise
@classmethod
async def _allocate_destination(cls, dest_dir: str, filename: str) -> tuple[str, str]:
dest_dir = cls._normalize_path(dest_dir)
stem, suffix = cls._split_filename(filename)
candidate = filename
base = "" if dest_dir == "/" else dest_dir
attempt = 0
while await cls._path_exists(f"{base}/{candidate}" if base else f"/{candidate}"):
attempt += 1
if stem:
candidate = f"{stem} ({attempt}){suffix}"
else:
candidate = f"file ({attempt}){suffix}" if suffix else f"file ({attempt})"
full_path = f"{base}/{candidate}" if base else f"/{candidate}"
return full_path, candidate
@staticmethod
def _split_filename(filename: str) -> tuple[str, str]:
if not filename:
return "", ""
if filename.startswith(".") and filename.count(".") == 1:
return filename, ""
if "." not in filename:
return filename, ""
stem, ext = filename.rsplit(".", 1)
return stem, f".{ext}"
@staticmethod
async def _iter_file(path: Path, chunk_size: int, report_cb) -> AsyncIterator[bytes]:
async with aiofiles.open(path, "rb") as f:
while True:
chunk = await f.read(chunk_size)
if not chunk:
break
await report_cb(len(chunk))
yield chunk

View File

@@ -0,0 +1 @@

66
domain/plugins/api.py Normal file
View File

@@ -0,0 +1,66 @@
from typing import List
from fastapi import APIRouter, Body, Request
from domain.audit import AuditAction, audit
from domain.plugins.service import PluginService
from domain.plugins.types import PluginCreate, PluginManifestUpdate, PluginOut
router = APIRouter(prefix="/api/plugins", tags=["plugins"])
@router.post("", response_model=PluginOut)
@audit(
action=AuditAction.CREATE,
description="创建插件",
body_fields=["url", "enabled"],
)
async def create_plugin(request: Request, payload: PluginCreate):
return await PluginService.create(payload)
@router.get("", response_model=List[PluginOut])
@audit(action=AuditAction.READ, description="获取插件列表")
async def list_plugins(request: Request):
return await PluginService.list_plugins()
@router.delete("/{plugin_id}")
@audit(action=AuditAction.DELETE, description="删除插件")
async def delete_plugin(request: Request, plugin_id: int):
await PluginService.delete(plugin_id)
return {"code": 0, "msg": "ok"}
@router.put("/{plugin_id}", response_model=PluginOut)
@audit(
action=AuditAction.UPDATE,
description="更新插件",
body_fields=["url", "enabled"],
)
async def update_plugin(request: Request, plugin_id: int, payload: PluginCreate):
return await PluginService.update(plugin_id, payload)
@router.post("/{plugin_id}/metadata", response_model=PluginOut)
@audit(
action=AuditAction.UPDATE,
description="更新插件 manifest",
body_fields=[
"key",
"name",
"version",
"supported_exts",
"default_bounds",
"default_maximized",
"icon",
"description",
"author",
"website",
"github",
],
)
async def update_manifest(
request: Request, plugin_id: int, manifest: PluginManifestUpdate = Body(...)
):
return await PluginService.update_manifest(plugin_id, manifest)

48
domain/plugins/service.py Normal file
View File

@@ -0,0 +1,48 @@
from fastapi import HTTPException
from domain.plugins.types import PluginCreate, PluginManifestUpdate, PluginOut
from models.database import Plugin
class PluginService:
@classmethod
async def create(cls, payload: PluginCreate) -> PluginOut:
rec = await Plugin.create(**payload.model_dump())
return PluginOut.model_validate(rec)
@classmethod
async def list_plugins(cls) -> list[PluginOut]:
rows = await Plugin.all().order_by("-id")
return [PluginOut.model_validate(r) for r in rows]
@classmethod
async def _get_or_404(cls, plugin_id: int) -> Plugin:
rec = await Plugin.get_or_none(id=plugin_id)
if not rec:
raise HTTPException(status_code=404, detail="Plugin not found")
return rec
@classmethod
async def delete(cls, plugin_id: int) -> None:
rec = await cls._get_or_404(plugin_id)
await rec.delete()
@classmethod
async def update(cls, plugin_id: int, payload: PluginCreate) -> PluginOut:
rec = await cls._get_or_404(plugin_id)
rec.url = payload.url
rec.enabled = payload.enabled
await rec.save()
return PluginOut.model_validate(rec)
@classmethod
async def update_manifest(
cls, plugin_id: int, manifest: PluginManifestUpdate
) -> PluginOut:
rec = await cls._get_or_404(plugin_id)
updates = manifest.model_dump(exclude_none=True)
if updates:
for key, value in updates.items():
setattr(rec, key, value)
await rec.save()
return PluginOut.model_validate(rec)

52
domain/plugins/types.py Normal file
View File

@@ -0,0 +1,52 @@
from typing import Any, Dict, List, Optional
from pydantic import AliasChoices, BaseModel, ConfigDict, Field
class PluginCreate(BaseModel):
url: str = Field(min_length=1)
enabled: bool = True
class PluginManifestUpdate(BaseModel):
model_config = ConfigDict(populate_by_name=True, extra="ignore")
key: Optional[str] = None
name: Optional[str] = None
version: Optional[str] = None
supported_exts: Optional[List[str]] = Field(
default=None,
validation_alias=AliasChoices("supported_exts", "supportedExts"),
)
default_bounds: Optional[Dict[str, Any]] = Field(
default=None,
validation_alias=AliasChoices("default_bounds", "defaultBounds"),
)
default_maximized: Optional[bool] = Field(
default=None,
validation_alias=AliasChoices("default_maximized", "defaultMaximized"),
)
icon: Optional[str] = None
description: Optional[str] = None
author: Optional[str] = None
website: Optional[str] = None
github: Optional[str] = None
class PluginOut(BaseModel):
id: int
url: str
enabled: bool
key: Optional[str] = None
name: Optional[str] = None
version: Optional[str] = None
supported_exts: Optional[List[str]] = None
default_bounds: Optional[Dict[str, Any]] = None
default_maximized: Optional[bool] = None
icon: Optional[str] = None
description: Optional[str] = None
author: Optional[str] = None
website: Optional[str] = None
github: Optional[str] = None
model_config = ConfigDict(from_attributes=True)

89
domain/processors/api.py Normal file
View File

@@ -0,0 +1,89 @@
from typing import Annotated
from fastapi import APIRouter, Body, Depends, Request
from api.response import success
from domain.audit import AuditAction, audit
from domain.auth.service import get_current_active_user
from domain.auth.types import User
from domain.processors.service import ProcessorService
from domain.processors.types import (
ProcessDirectoryRequest,
ProcessRequest,
UpdateSourceRequest,
)
router = APIRouter(prefix="/api/processors", tags=["processors"])
@router.get("")
@audit(action=AuditAction.READ, description="获取处理器列表")
async def list_processors(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
):
data = ProcessorService.list_processors()
return success(data)
@router.post("/process")
@audit(
action=AuditAction.CREATE,
description="处理单个文件",
body_fields=["path", "processor_type", "save_to", "overwrite"],
)
async def process_file_with_processor(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
req: ProcessRequest = Body(...),
):
data = await ProcessorService.process_file(req)
return success(data)
@router.post("/process-directory")
@audit(
action=AuditAction.CREATE,
description="批量处理目录",
body_fields=["path", "processor_type", "overwrite", "max_depth", "suffix"],
)
async def process_directory_with_processor(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
req: ProcessDirectoryRequest = Body(...),
):
data = await ProcessorService.process_directory(req)
return success(data)
@router.get("/source/{processor_type}")
@audit(action=AuditAction.READ, description="获取处理器源码")
async def get_processor_source(
request: Request,
processor_type: str,
current_user: Annotated[User, Depends(get_current_active_user)],
):
data = await ProcessorService.get_source(processor_type)
return success(data)
@router.put("/source/{processor_type}")
@audit(action=AuditAction.UPDATE, description="更新处理器源码")
async def update_processor_source(
request: Request,
processor_type: str,
req: UpdateSourceRequest,
current_user: Annotated[User, Depends(get_current_active_user)],
):
data = await ProcessorService.update_source(processor_type, req)
return success(data)
@router.post("/reload")
@audit(action=AuditAction.UPDATE, description="重载处理器模块")
async def reload_processor_modules(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
):
data = ProcessorService.reload()
return success(data)

View File

@@ -0,0 +1 @@
# 内置处理器包

View File

@@ -1,9 +1,11 @@
from .base import BaseProcessor
from typing import Dict, Any
from PIL import Image, ImageDraw, ImageFont
from io import BytesIO
from PIL import Image, ImageDraw, ImageFont
from fastapi.responses import Response
from services.logging import LogService
from ..base import BaseProcessor
class ImageWatermarkProcessor:
name = "图片水印"
@@ -26,10 +28,11 @@ class ImageWatermarkProcessor:
]
produces_file = True
async def process(self, input_bytes: bytes,path: str, config: Dict[str, Any]) -> Response:
async def process(self, input_bytes: bytes, path: str, config: Dict[str, Any]) -> Response:
text = config.get("text", "")
position = config.get("position", "bottom-right")
font_size = int(config.get("font_size", 24))
img = Image.open(BytesIO(input_bytes)).convert("RGBA")
watermark = Image.new("RGBA", img.size)
draw = ImageDraw.Draw(watermark)
@@ -37,29 +40,29 @@ class ImageWatermarkProcessor:
font = ImageFont.truetype("arial.ttf", font_size)
except Exception:
font = ImageFont.load_default()
w, h = img.size
try:
text_w, text_h = font.getsize(text)
except AttributeError:
bbox = draw.textbbox((0, 0), text, font=font)
text_w, text_h = bbox[2] - bbox[0], bbox[3] - bbox[1]
if position == "bottom-right":
xy = (w - text_w - 10, h - text_h - 10)
elif position == "top-left":
xy = (10, 10)
else:
xy = (w // 2 - text_w // 2, h // 2 - text_h // 2)
draw.text(xy, text, font=font, fill=(255, 255, 255, 128))
out = Image.alpha_composite(img, watermark)
buf = BytesIO()
out.convert("RGB").save(buf, format="JPEG")
await LogService.info(
"processor:image_watermark",
f"Watermarked image {path}",
details={"path": path, "config": config},
)
return Response(content=buf.getvalue(), media_type="image/jpeg")
PROCESSOR_TYPE = "image_watermark"
PROCESSOR_NAME = ImageWatermarkProcessor.name
SUPPORTED_EXTS = ImageWatermarkProcessor.supported_exts

View File

@@ -1,15 +1,15 @@
from typing import Dict, Any, List, Tuple
from fastapi.responses import Response
import base64
import mimetypes
import os
from io import BytesIO
from typing import Dict, Any, List, Tuple
from services.ai import describe_image_base64, get_text_embedding, provider_service
from services.vector_db import VectorDBService, DEFAULT_VECTOR_DIMENSION
from services.logging import LogService
from fastapi.responses import Response
from PIL import Image
from ..base import BaseProcessor
from domain.ai.inference import describe_image_base64, get_text_embedding, provider_service
from domain.ai.service import VectorDBService, DEFAULT_VECTOR_DIMENSION
CHUNK_SIZE = 800
@@ -116,11 +116,6 @@ class VectorIndexProcessor:
if action == "destroy":
await vector_db.delete_vector(collection_name, path)
await LogService.info(
"processor:vector_index",
f"Destroyed {index_type} index for {path}",
details={"path": path, "action": "destroy", "index_type": index_type},
)
return Response(content=f"文件 {path}{index_type} 索引已销毁", media_type="text/plain")
mime_type = _guess_mime(path)
@@ -136,11 +131,6 @@ class VectorIndexProcessor:
"type": "filename",
"name": os.path.basename(path),
})
await LogService.info(
"processor:vector_index",
f"Created simple index for {path}",
details={"path": path, "action": "create", "index_type": "simple"},
)
return Response(content=f"文件 {path} 的普通索引已创建", media_type="text/plain")
file_ext = path.split('.')[-1].lower()
@@ -177,11 +167,6 @@ class VectorIndexProcessor:
details["description"] = description
if compression:
details["image_compression"] = compression
await LogService.info(
"processor:vector_index",
f"Indexed image {path}",
details=details,
)
return Response(content=f"图片已索引,描述:{description}", media_type="text/plain")
if file_ext in ["txt", "md"]:
@@ -204,11 +189,6 @@ class VectorIndexProcessor:
"end_offset": len(text),
})
details["chunks"] = 1
await LogService.info(
"processor:vector_index",
f"Indexed text file {path}",
details=details,
)
return Response(content="文本文件已索引", media_type="text/plain")
chunk_count = 0
@@ -230,11 +210,6 @@ class VectorIndexProcessor:
details["chunks"] = chunk_count
sample = chunks[0][1]
details["sample"] = sample[:120]
await LogService.info(
"processor:vector_index",
f"Indexed text file {path}",
details=details,
)
return Response(content="文本文件已索引", media_type="text/plain")
# 其他类型暂未支持向量索引,回退为文件名索引
@@ -248,11 +223,6 @@ class VectorIndexProcessor:
"name": os.path.basename(path),
"embedding": [0.0] * vector_dim,
})
await LogService.info(
"processor:vector_index",
f"File type fallback to simple index for {path}",
details={"path": path, "action": "create", "index_type": "simple", "original_type": file_ext},
)
return Response(content="暂不支持该类型的向量索引,已创建文件名索引", media_type="text/plain")

View File

@@ -5,7 +5,7 @@ from pathlib import Path
from types import ModuleType
from typing import Callable, Dict, Optional
from .base import BaseProcessor
from domain.processors.base import BaseProcessor
ProcessorFactory = Callable[[], BaseProcessor]
TYPE_MAP: Dict[str, ProcessorFactory] = {}
@@ -15,10 +15,9 @@ LAST_DISCOVERY_ERRORS: list[str] = []
def discover_processors(force_reload: bool = False) -> list[str]:
"""Discover available processor modules and cache their metadata."""
import services.processors # 延迟导入以避免循环
"""扫描并缓存可用的处理器模块。"""
from domain.processors import builtin as processors_pkg
processors_pkg = services.processors
TYPE_MAP.clear()
CONFIG_SCHEMAS.clear()
MODULE_MAP.clear()
@@ -51,8 +50,10 @@ def discover_processors(force_reload: bool = False) -> list[str]:
if factory is None:
for attr in module.__dict__.values():
if inspect.isclass(attr) and attr.__name__.endswith("Processor"):
def _mk(cls=attr):
return lambda: cls()
factory = _mk()
break
@@ -114,7 +115,7 @@ def get_config_schema(processor_type: str):
return CONFIG_SCHEMAS.get(processor_type)
def get(processor_type: str) -> BaseProcessor:
def get(processor_type: str) -> BaseProcessor | None:
factory = TYPE_MAP.get(processor_type)
if factory:
return factory()

View File

@@ -0,0 +1,217 @@
from pathlib import Path
from typing import List, Tuple
from fastapi import HTTPException
from fastapi.concurrency import run_in_threadpool
from domain.processors.registry import (
get,
get_config_schema,
get_config_schemas,
get_module_path,
reload_processors,
)
from domain.processors.types import (
ProcessDirectoryRequest,
ProcessRequest,
UpdateSourceRequest,
)
from domain.virtual_fs.service import VirtualFSService
from domain.tasks.task_queue import task_queue_service
class ProcessorService:
@classmethod
def get_processor(cls, processor_type: str):
return get(processor_type)
@classmethod
def list_processors(cls):
schemas = get_config_schemas()
out = []
for t, meta in schemas.items():
out.append({
"type": meta["type"],
"name": meta["name"],
"supported_exts": meta.get("supported_exts", []),
"config_schema": meta["config_schema"],
"produces_file": meta.get("produces_file", False),
"module_path": meta.get("module_path"),
})
return out
@classmethod
async def process_file(cls, req: ProcessRequest):
is_dir = await VirtualFSService.path_is_directory(req.path)
if is_dir and not req.overwrite:
raise HTTPException(400, detail="Directory processing requires overwrite")
save_to = None if is_dir else (req.path if req.overwrite else req.save_to)
task = await task_queue_service.add_task(
"process_file",
{
"path": req.path,
"processor_type": req.processor_type,
"config": req.config,
"save_to": save_to,
"overwrite": req.overwrite,
},
)
return {"task_id": task.id}
@classmethod
async def process_directory(cls, req: ProcessDirectoryRequest):
if req.max_depth is not None and req.max_depth < 0:
raise HTTPException(400, detail="max_depth must be >= 0")
is_dir = await VirtualFSService.path_is_directory(req.path)
if not is_dir:
raise HTTPException(400, detail="Path must be a directory")
schema = get_config_schema(req.processor_type)
_processor = get(req.processor_type)
if not schema or not _processor:
raise HTTPException(404, detail="Processor not found")
produces_file = bool(schema.get("produces_file"))
raw_suffix = req.suffix if req.suffix is not None else None
if raw_suffix is not None and raw_suffix.strip() == "":
raw_suffix = None
suffix = raw_suffix
overwrite = req.overwrite
if produces_file:
if not overwrite and not suffix:
raise HTTPException(400, detail="Suffix is required when not overwriting files")
else:
overwrite = False
suffix = None
supported_exts = schema.get("supported_exts") or []
allowed_exts = {
ext.lower().lstrip('.')
for ext in supported_exts
if isinstance(ext, str)
}
def matches_extension(file_rel: str) -> bool:
if not allowed_exts:
return True
if '.' not in file_rel:
return '' in allowed_exts
ext = file_rel.rsplit('.', 1)[-1].lower()
return ext in allowed_exts or f'.{ext}' in allowed_exts
adapter_instance, adapter_model, root, rel = await VirtualFSService.resolve_adapter_and_rel(req.path)
rel = rel.rstrip('/')
list_dir = getattr(adapter_instance, "list_dir", None)
if not callable(list_dir):
raise HTTPException(501, detail="Adapter does not implement list_dir")
def build_absolute_path(mount_path: str, rel_path: str) -> str:
rel_norm = rel_path.lstrip('/')
mount_norm = mount_path.rstrip('/')
if not mount_norm:
return '/' + rel_norm if rel_norm else '/'
return f"{mount_norm}/{rel_norm}" if rel_norm else mount_norm
def apply_suffix(path_str: str, suffix_str: str) -> str:
path_obj = Path(path_str)
name = path_obj.name
if not name:
return path_str
if '.' in name:
base, ext = name.rsplit('.', 1)
new_name = f"{base}{suffix_str}.{ext}"
else:
new_name = f"{name}{suffix_str}"
return str(path_obj.with_name(new_name))
scheduled_tasks: List[str] = []
stack: List[Tuple[str, int]] = [(rel, 0)]
page_size = 200
while stack:
current_rel, depth = stack.pop()
page = 1
while True:
entries, total = await list_dir(root, current_rel, page, page_size, "name", "asc")
entries = entries or []
if not entries and (total or 0) == 0:
break
for entry in entries:
name = entry.get("name")
if not name:
continue
child_rel = f"{current_rel}/{name}" if current_rel else name
if entry.get("is_dir"):
if req.max_depth is None or depth < req.max_depth:
stack.append((child_rel.rstrip('/'), depth + 1))
continue
if not matches_extension(child_rel):
continue
absolute_path = build_absolute_path(adapter_model.path, child_rel)
save_to = None
if produces_file and not overwrite and suffix:
save_to = apply_suffix(absolute_path, suffix)
task = await task_queue_service.add_task(
"process_file",
{
"path": absolute_path,
"processor_type": req.processor_type,
"config": req.config,
"save_to": save_to,
"overwrite": overwrite,
},
)
scheduled_tasks.append(task.id)
if total is None or page * page_size >= total:
break
page += 1
return {
"task_ids": scheduled_tasks,
"scheduled": len(scheduled_tasks),
}
@classmethod
async def get_source(cls, processor_type: str):
module_path = get_module_path(processor_type)
if not module_path:
raise HTTPException(404, detail="Processor not found")
path_obj = Path(module_path)
if not path_obj.exists():
raise HTTPException(404, detail="Processor source not found")
try:
content = await run_in_threadpool(path_obj.read_text, encoding='utf-8')
except Exception as exc:
raise HTTPException(500, detail=f"Failed to read source: {exc}")
return {"source": content, "module_path": str(path_obj)}
@classmethod
async def update_source(cls, processor_type: str, req: UpdateSourceRequest):
module_path = get_module_path(processor_type)
if not module_path:
raise HTTPException(404, detail="Processor not found")
path_obj = Path(module_path)
if not path_obj.exists():
raise HTTPException(404, detail="Processor source not found")
try:
await run_in_threadpool(path_obj.write_text, req.source, encoding='utf-8')
except Exception as exc:
raise HTTPException(500, detail=f"Failed to write source: {exc}")
return True
@classmethod
def reload(cls):
errors = reload_processors()
if errors:
raise HTTPException(500, detail="; ".join(errors))
return True
get_processor = ProcessorService.get_processor
list_processors = ProcessorService.list_processors
reload_processor_modules = ProcessorService.reload

View File

@@ -0,0 +1,24 @@
from typing import Any, Dict, Optional
from pydantic import BaseModel
class ProcessRequest(BaseModel):
path: str
processor_type: str
config: Dict[str, Any]
save_to: Optional[str] = None
overwrite: bool = False
class ProcessDirectoryRequest(BaseModel):
path: str
processor_type: str
config: Dict[str, Any]
overwrite: bool = True
max_depth: Optional[int] = None
suffix: Optional[str] = None
class UpdateSourceRequest(BaseModel):
source: str

129
domain/share/api.py Normal file
View File

@@ -0,0 +1,129 @@
from typing import Annotated, List, Optional
from fastapi import APIRouter, Depends, Request
from api.response import success
from domain.audit import AuditAction, audit
from domain.auth.service import get_current_active_user
from domain.auth.types import User
from domain.share.service import ShareService
from domain.share.types import (
ShareCreate,
ShareInfo,
ShareInfoWithPassword,
SharePassword,
)
from models.database import UserAccount
public_router = APIRouter(prefix="/api/s", tags=["Share - Public"])
router = APIRouter(prefix="/api/shares", tags=["Share - Management"])
@router.post("", response_model=ShareInfoWithPassword)
@audit(
action=AuditAction.SHARE,
description="创建分享链接",
body_fields=["name", "paths", "expires_in_days", "access_type"],
)
async def create_share(
request: Request,
payload: ShareCreate,
current_user: Annotated[User, Depends(get_current_active_user)],
):
user_account = await UserAccount.get(id=current_user.id)
share = await ShareService.create_share_link(
user=user_account,
name=payload.name,
paths=payload.paths,
expires_in_days=payload.expires_in_days,
access_type=payload.access_type,
password=payload.password,
)
share_info = ShareInfo.from_orm(share).model_dump()
if payload.access_type == "password" and payload.password:
share_info["password"] = payload.password
return share_info
@router.get("", response_model=List[ShareInfo])
@audit(action=AuditAction.READ, description="获取我的分享列表")
async def get_my_shares(
request: Request, current_user: Annotated[User, Depends(get_current_active_user)]
):
user_account = await UserAccount.get(id=current_user.id)
shares = await ShareService.get_user_shares(user=user_account)
return [ShareInfo.from_orm(s) for s in shares]
@router.delete("/expired")
@audit(action=AuditAction.DELETE, description="删除已过期分享")
async def delete_expired_shares(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
):
user_account = await UserAccount.get(id=current_user.id)
deleted_count = await ShareService.delete_expired_shares(user=user_account)
return success({"deleted_count": deleted_count})
@router.delete("/{share_id}")
@audit(action=AuditAction.DELETE, description="删除分享链接")
async def delete_share(
share_id: int,
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
):
user_account = await UserAccount.get(id=current_user.id)
await ShareService.delete_share_link(user=user_account, share_id=share_id)
return success(msg="分享已取消")
@public_router.post("/{token}/verify")
@audit(
action=AuditAction.SHARE,
description="校验分享密码",
body_fields=["password"],
redact_fields=["password"],
)
async def verify_password(request: Request, token: str, payload: SharePassword):
await ShareService.verify_share_password(token, payload.password)
return success(msg="验证成功")
@public_router.get("/{token}/ls")
@audit(action=AuditAction.SHARE, description="浏览分享内容")
async def list_share_content(
request: Request, token: str, path: str = "/", password: Optional[str] = None
):
share = await ShareService.ensure_share_access(token, password)
content = await ShareService.get_shared_item_details(share, path)
return success(
{
"path": path,
"entries": content.get("items", []),
"pagination": {
"total": content.get("total", 0),
"page": content.get("page", 1),
"page_size": content.get("page_size", 1),
"pages": content.get("pages", 1),
},
}
)
@public_router.get("/{token}")
@audit(action=AuditAction.SHARE, description="获取分享信息")
async def get_share_info(request: Request, token: str):
share = await ShareService.get_share_by_token(token)
return success(ShareInfo.from_orm(share))
@public_router.get("/{token}/download")
@audit(action=AuditAction.DOWNLOAD, description="下载分享文件")
async def download_shared_file(
token: str,
path: str,
request: Request,
password: Optional[str] = None,
):
return await ShareService.stream_shared_file(token, path, request.headers.get("Range"), password)

187
domain/share/service.py Normal file
View File

@@ -0,0 +1,187 @@
import secrets
from datetime import datetime, timedelta, timezone
from typing import List, Optional
from urllib.parse import quote
import bcrypt
from fastapi import HTTPException, status
from fastapi.responses import Response
from domain.virtual_fs.service import VirtualFSService
from models.database import ShareLink, UserAccount
class ShareService:
@classmethod
def _hash_password(cls, password: str) -> str:
return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
@classmethod
def _verify_password(cls, plain_password: str, hashed_password: str) -> bool:
return bcrypt.checkpw(plain_password.encode("utf-8"), hashed_password.encode("utf-8"))
@classmethod
def _calc_expires_at(cls, expires_in_days: Optional[int]) -> Optional[datetime]:
if expires_in_days is None or expires_in_days <= 0:
return None
return datetime.now(timezone.utc) + timedelta(days=expires_in_days)
@classmethod
def _ensure_password_if_needed(cls, share: ShareLink, password: Optional[str]) -> None:
if share.access_type != "password":
return
if not password:
raise HTTPException(status_code=401, detail="需要密码")
if not share.hashed_password:
raise HTTPException(status_code=403, detail="密码错误")
if not cls._verify_password(password, share.hashed_password):
raise HTTPException(status_code=403, detail="密码错误")
@classmethod
async def create_share_link(
cls,
user: UserAccount,
name: str,
paths: List[str],
expires_in_days: Optional[int] = 7,
access_type: str = "public",
password: Optional[str] = None,
) -> ShareLink:
if not paths:
raise HTTPException(status_code=400, detail="分享路径不能为空")
if access_type == "password" and not password:
raise HTTPException(status_code=400, detail="密码不能为空")
token = secrets.token_urlsafe(16)
expires_at = cls._calc_expires_at(expires_in_days)
hashed_password = None
if access_type == "password" and password:
hashed_password = cls._hash_password(password)
share = await ShareLink.create(
token=token,
name=name,
paths=paths,
user=user,
expires_at=expires_at,
access_type=access_type,
hashed_password=hashed_password,
)
return share
@classmethod
async def get_share_by_token(cls, token: str) -> ShareLink:
share = await ShareLink.get_or_none(token=token).prefetch_related("user")
if not share:
raise HTTPException(status_code=404, detail="分享链接不存在")
if share.expires_at and share.expires_at < datetime.now(timezone.utc):
raise HTTPException(status_code=410, detail="分享链接已过期")
return share
@classmethod
async def verify_share_password(cls, token: str, password: str) -> ShareLink:
share = await cls.get_share_by_token(token)
if share.access_type != "password":
raise HTTPException(status_code=400, detail="此分享不需要密码")
cls._ensure_password_if_needed(share, password)
return share
@classmethod
async def ensure_share_access(cls, token: str, password: Optional[str]) -> ShareLink:
share = await cls.get_share_by_token(token)
cls._ensure_password_if_needed(share, password)
return share
@classmethod
async def get_user_shares(cls, user: UserAccount) -> List[ShareLink]:
return await ShareLink.filter(user=user).order_by("-created_at")
@classmethod
async def delete_share_link(cls, user: UserAccount, share_id: int) -> None:
share = await ShareLink.get_or_none(id=share_id, user_id=user.id)
if not share:
raise HTTPException(status_code=404, detail="分享链接不存在")
await share.delete()
@classmethod
async def delete_expired_shares(cls, user: UserAccount) -> int:
now = datetime.now(timezone.utc)
deleted_count = await ShareLink.filter(user=user, expires_at__lte=now).delete()
return deleted_count
@classmethod
async def get_shared_item_details(cls, share: ShareLink, sub_path: str = ""):
if not share.paths:
raise HTTPException(status_code=404, detail="分享内容为空")
base_shared_path = share.paths[0]
if sub_path and sub_path != "/":
full_path = f"{base_shared_path.rstrip('/')}/{sub_path.lstrip('/')}".rstrip("/")
if not full_path.startswith(base_shared_path):
raise HTTPException(status_code=403, detail="无权访问此路径")
try:
return await VirtualFSService.list_virtual_dir(full_path)
except FileNotFoundError:
raise HTTPException(status_code=404, detail="目录未找到")
try:
stat = await VirtualFSService.stat_file(base_shared_path)
if stat.get("is_dir"):
return await VirtualFSService.list_virtual_dir(base_shared_path)
stat["name"] = base_shared_path.split("/")[-1]
return {"items": [stat], "total": 1, "page": 1, "page_size": 1, "pages": 1}
except HTTPException as e:
if "Path is a directory" in str(e.detail) or "Not a file" in str(e.detail):
return await VirtualFSService.list_virtual_dir(base_shared_path)
raise e
@classmethod
async def stream_shared_file(
cls,
token: str,
path: str,
range_header: str | None,
password: Optional[str] = None,
) -> Response:
if not path or path == "/" or ".." in path.split("/"):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="无效的文件路径")
share = await cls.ensure_share_access(token, password)
if not share.paths:
raise HTTPException(status_code=404, detail="分享的源文件不存在")
base_shared_path = share.paths[0]
is_dir = False
try:
stat = await VirtualFSService.stat_file(base_shared_path)
if stat and stat.get("is_dir"):
is_dir = True
except HTTPException as e:
if "Path is a directory" in str(e.detail) or "Not a file" in str(e.detail):
is_dir = True
elif e.status_code == 404:
raise HTTPException(status_code=404, detail="分享的源文件不存在")
else:
raise
if is_dir:
full_virtual_path = f"{base_shared_path.rstrip('/')}/{path.lstrip('/')}"
if not full_virtual_path.startswith(base_shared_path):
raise HTTPException(status_code=403, detail="无权访问此路径")
else:
shared_filename = base_shared_path.split("/")[-1]
request_filename = path.lstrip("/")
if shared_filename != request_filename:
raise HTTPException(status_code=403, detail="无权访问此路径")
full_virtual_path = base_shared_path
response = await VirtualFSService.stream_file(full_virtual_path, range_header)
filename = full_virtual_path.split("/")[-1]
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{quote(filename)}"
return response

43
domain/share/types.py Normal file
View File

@@ -0,0 +1,43 @@
from typing import List, Optional
from pydantic import BaseModel
from models.database import ShareLink
class ShareCreate(BaseModel):
name: str
paths: List[str]
expires_in_days: Optional[int] = 7
access_type: str = "public"
password: Optional[str] = None
class SharePassword(BaseModel):
password: str
class ShareInfo(BaseModel):
id: int
token: str
name: str
paths: List[str]
created_at: str
expires_at: Optional[str] = None
access_type: str
@classmethod
def from_orm(cls, obj: ShareLink):
return cls(
id=obj.id,
token=obj.token,
name=obj.name,
paths=obj.paths,
created_at=obj.created_at.isoformat(),
expires_at=obj.expires_at.isoformat() if obj.expires_at else None,
access_type=obj.access_type,
)
class ShareInfoWithPassword(ShareInfo):
password: Optional[str] = None

112
domain/tasks/api.py Normal file
View File

@@ -0,0 +1,112 @@
from fastapi import APIRouter, Depends, Request
from api.response import success
from domain.audit import AuditAction, audit
from domain.auth.service import get_current_active_user
from domain.tasks.service import TaskService
from domain.tasks.types import (
AutomationTaskCreate,
AutomationTaskUpdate,
TaskQueueSettings,
)
CurrentUser = TaskService.current_user_dep
router = APIRouter(
prefix="/api/tasks",
tags=["Tasks"],
dependencies=[Depends(get_current_active_user)],
responses={404: {"description": "Not found"}},
)
@router.get("/queue")
@audit(action=AuditAction.READ, description="获取任务队列状态")
async def get_task_queue_status(request: Request, current_user: CurrentUser):
payload = TaskService.get_queue_tasks()
return success(payload)
@router.get("/queue/settings")
@audit(action=AuditAction.READ, description="获取任务队列设置")
async def get_task_queue_settings(request: Request, current_user: CurrentUser):
payload = TaskService.get_queue_settings()
return success(payload.model_dump())
@router.post("/queue/settings")
@audit(
action=AuditAction.UPDATE,
description="更新任务队列设置",
body_fields=["concurrency"],
)
async def update_task_queue_settings(request: Request, settings: TaskQueueSettings, current_user: CurrentUser):
payload = await TaskService.update_queue_settings(settings, getattr(current_user, "id", None))
return success(payload.model_dump())
@router.get("/queue/{task_id}")
@audit(action=AuditAction.READ, description="获取队列任务状态")
async def get_task_status(task_id: str, request: Request, current_user: CurrentUser):
payload = TaskService.get_queue_task(task_id)
return success(payload)
@router.post("/")
@audit(
action=AuditAction.CREATE,
description="创建自动化任务",
body_fields=[
"name",
"event",
"path_pattern",
"filename_regex",
"processor_type",
"processor_config",
"enabled",
],
user_kw="user",
)
async def create_task(request: Request, task_in: AutomationTaskCreate, user: CurrentUser):
task = await TaskService.create_task(task_in, user)
return success(task)
@router.get("/{task_id}")
@audit(action=AuditAction.READ, description="获取自动化任务详情")
async def get_task(task_id: int, request: Request, current_user: CurrentUser):
task = await TaskService.get_task(task_id)
return success(task)
@router.get("/")
@audit(action=AuditAction.READ, description="获取自动化任务列表")
async def list_tasks(request: Request, current_user: CurrentUser):
tasks = await TaskService.list_tasks()
return success(tasks)
@router.put("/{task_id}")
@audit(
action=AuditAction.UPDATE,
description="更新自动化任务",
body_fields=[
"name",
"event",
"path_pattern",
"filename_regex",
"processor_type",
"processor_config",
"enabled",
],
)
async def update_task(request: Request, current_user: CurrentUser, task_id: int, task_in: AutomationTaskUpdate):
task = await TaskService.update_task(task_id, task_in, current_user)
return success(task)
@router.delete("/{task_id}")
@audit(action=AuditAction.DELETE, description="删除自动化任务", user_kw="user")
async def delete_task(task_id: int, request: Request, user: CurrentUser):
await TaskService.delete_task(task_id, user)
return success(msg="Task deleted")

109
domain/tasks/service.py Normal file
View File

@@ -0,0 +1,109 @@
import re
from typing import Annotated, Any, Dict, Optional
from fastapi import Depends, HTTPException
from domain.auth.service import get_current_active_user
from domain.auth.types import User
from domain.config.service import ConfigService
from domain.tasks.types import (
AutomationTaskCreate,
AutomationTaskUpdate,
TaskQueueSettings,
TaskQueueSettingsResponse,
)
from models.database import AutomationTask
from domain.tasks.task_queue import task_queue_service
class TaskService:
current_user_dep = Annotated[User, Depends(get_current_active_user)]
@classmethod
def get_queue_tasks(cls) -> list[dict[str, Any]]:
tasks = task_queue_service.get_all_tasks()
return [task.dict() for task in tasks]
@classmethod
def get_queue_settings(cls) -> TaskQueueSettingsResponse:
return TaskQueueSettingsResponse(
concurrency=task_queue_service.get_concurrency(),
active_workers=task_queue_service.get_active_worker_count(),
)
@classmethod
async def update_queue_settings(cls, settings: TaskQueueSettings, user_id: Optional[int]) -> TaskQueueSettingsResponse:
await task_queue_service.set_concurrency(settings.concurrency)
await ConfigService.set("TASK_QUEUE_CONCURRENCY", str(task_queue_service.get_concurrency()))
return cls.get_queue_settings()
@classmethod
def get_queue_task(cls, task_id: str) -> dict[str, Any]:
task = task_queue_service.get_task(task_id)
if not task:
raise HTTPException(status_code=404, detail="Task not found")
return task.dict()
@classmethod
async def create_task(cls, payload: AutomationTaskCreate, user: Optional[User]) -> AutomationTask:
task = await AutomationTask.create(**payload.model_dump())
return task
@classmethod
async def get_task(cls, task_id: int) -> AutomationTask:
task = await AutomationTask.get_or_none(id=task_id)
if not task:
raise HTTPException(status_code=404, detail=f"Task {task_id} not found")
return task
@classmethod
async def list_tasks(cls) -> list[AutomationTask]:
tasks = await AutomationTask.all()
return tasks
@classmethod
async def update_task(cls, task_id: int, payload: AutomationTaskUpdate, current_user: User) -> AutomationTask:
task = await AutomationTask.get_or_none(id=task_id)
if not task:
raise HTTPException(status_code=404, detail=f"Task {task_id} not found")
update_data = payload.model_dump(exclude_unset=True)
for key, value in update_data.items():
setattr(task, key, value)
await task.save()
return task
@classmethod
async def delete_task(cls, task_id: int, user: Optional[User]) -> None:
deleted_count = await AutomationTask.filter(id=task_id).delete()
if not deleted_count:
raise HTTPException(status_code=404, detail=f"Task {task_id} not found")
@classmethod
async def trigger_tasks(cls, event: str, path: str):
tasks = await AutomationTask.filter(event=event, enabled=True)
for task in tasks:
if cls.match(task, path):
await cls.execute(task, path)
@classmethod
def match(cls, task: AutomationTask, path: str) -> bool:
if task.path_pattern and not path.startswith(task.path_pattern):
return False
if task.filename_regex:
filename = path.split("/")[-1]
if not re.match(task.filename_regex, filename):
return False
return True
@classmethod
async def execute(cls, task: AutomationTask, path: str):
await task_queue_service.add_task(
task.processor_type,
{
"task_id": task.id,
"path": path,
},
)
task_service = TaskService

View File

@@ -2,7 +2,6 @@ import asyncio
from typing import Dict, Any
from pydantic import BaseModel, Field
import uuid
from services.logging import LogService
from enum import Enum
@@ -47,7 +46,6 @@ class TaskQueueService:
task = Task(name=name, task_info=task_info)
self._tasks[task.id] = task
await self._queue.put(task)
await LogService.info("task_queue", f"Task {name} ({task.id}) enqueued", {"task_id": task.id, "name": name})
return task
def get_task(self, task_id: str) -> Task | None:
@@ -72,15 +70,15 @@ class TaskQueueService:
task.meta = (task.meta or {}) | meta
async def _execute_task(self, task: Task):
from services.virtual_fs import process_file
task.status = TaskStatus.RUNNING
await LogService.info("task_queue", f"Task {task.name} ({task.id}) started", {"task_id": task.id, "name": task.name})
try:
# Local import to avoid circular dependency during module load.
from domain.virtual_fs.service import VirtualFSService
if task.name == "process_file":
params = task.task_info
result = await process_file(
result = await VirtualFSService.process_file(
path=params["path"],
processor_type=params["processor_type"],
config=params["config"],
@@ -90,8 +88,7 @@ class TaskQueueService:
task.result = result
elif task.name == "automation_task" or self._is_processor_task(task.name):
from models.database import AutomationTask
from services.processors.registry import get as get_processor
from services.virtual_fs import read_file, write_file
from domain.processors.service import get_processor
params = task.task_info
auto_task = await AutomationTask.get(id=params["task_id"])
@@ -103,54 +100,45 @@ class TaskQueueService:
raise ValueError(f"Processor {processor_type} not found for task {auto_task.id}")
if processor_type != auto_task.processor_type:
await LogService.warning(
"task_queue",
"Processor type mismatch; falling back to stored type",
{"task_id": auto_task.id, "expected": auto_task.processor_type, "got": processor_type},
)
processor_type = auto_task.processor_type
processor = get_processor(processor_type)
if not processor:
raise ValueError(f"Processor {processor_type} not found for task {auto_task.id}")
file_content = await read_file(path)
file_content = await VirtualFSService.read_file(path)
result = await processor.process(file_content, path, auto_task.processor_config)
save_to = auto_task.processor_config.get("save_to")
if save_to and getattr(processor, "produces_file", False):
await write_file(save_to, result)
await VirtualFSService.write_file(save_to, result)
task.result = "Automation task completed"
elif task.name == "offline_http_download":
from services.offline_download import run_http_download
from domain.offline_downloads.service import OfflineDownloadService
result_path = await run_http_download(task)
result_path = await OfflineDownloadService.run_http_download(task)
task.result = {"path": result_path}
elif task.name == "cross_mount_transfer":
from services.virtual_fs import run_cross_mount_transfer_task
result = await run_cross_mount_transfer_task(task)
result = await VirtualFSService.run_cross_mount_transfer_task(task)
task.result = result
elif task.name == "send_email":
from services.email import EmailService
from domain.email.service import EmailService
await EmailService.send_from_task(task.id, task.task_info)
task.result = "Email sent"
else:
raise ValueError(f"Unknown task name: {task.name}")
task.status = TaskStatus.SUCCESS
await LogService.info("task_queue", f"Task {task.name} ({task.id}) succeeded", {"task_id": task.id, "name": task.name})
except Exception as e:
task.status = TaskStatus.FAILED
task.error = str(e)
await LogService.error("task_queue", f"Task {task.name} ({task.id}) failed: {e}", {"task_id": task.id, "name": task.name})
def _cleanup_workers(self):
self._worker_tasks = [task for task in self._worker_tasks if not task.done()]
def _is_processor_task(self, task_name: str) -> bool:
try:
from services.processors.registry import get as get_processor
from domain.processors.service import get_processor
return get_processor(task_name) is not None
except Exception:
@@ -165,15 +153,12 @@ class TaskQueueService:
worker_id = self._worker_seq
worker_task = asyncio.create_task(self._worker_loop(worker_id))
self._worker_tasks.append(worker_task)
await LogService.info("task_queue", "Task workers adjusted", {"active_workers": len(self._worker_tasks), "target": self._concurrency})
elif current > self._concurrency:
for _ in range(current - self._concurrency):
await self._queue.put(_SENTINEL)
await LogService.info("task_queue", "Task workers scaling down", {"active_workers": len(self._worker_tasks), "target": self._concurrency})
async def _worker_loop(self, worker_id: int):
current_task = asyncio.current_task()
await LogService.info("task_queue", f"Worker {worker_id} started")
try:
while True:
job = await self._queue.get()
@@ -183,23 +168,18 @@ class TaskQueueService:
try:
await self._execute_task(job)
except Exception as e:
await LogService.error(
"task_queue",
f"Error executing task {job.id}: {e}",
{"task_id": job.id, "name": job.name},
)
pass
finally:
self._queue.task_done()
finally:
if current_task in self._worker_tasks:
self._worker_tasks.remove(current_task) # type: ignore[arg-type]
await LogService.info("task_queue", f"Worker {worker_id} stopped")
async def start_worker(self, concurrency: int | None = None):
if concurrency is None:
from services.config import ConfigCenter
from domain.config.service import ConfigService
stored_value = await ConfigCenter.get("TASK_QUEUE_CONCURRENCY", self._concurrency)
stored_value = await ConfigService.get("TASK_QUEUE_CONCURRENCY", self._concurrency)
try:
concurrency = int(stored_value)
except (TypeError, ValueError):
@@ -219,7 +199,6 @@ class TaskQueueService:
if self._worker_tasks:
await asyncio.gather(*self._worker_tasks, return_exceptions=True)
self._worker_tasks.clear()
await LogService.info("task_queue", "Task workers have been stopped.")
def get_concurrency(self) -> int:
return self._concurrency

View File

@@ -1,5 +1,6 @@
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field
from typing import Optional, Dict, Any
class AutomationTaskBase(BaseModel):

189
domain/virtual_fs/api.py Normal file
View File

@@ -0,0 +1,189 @@
from typing import Annotated
from fastapi import APIRouter, Depends, File, Query, Request, UploadFile
from api.response import success
from domain.audit import AuditAction, audit
from domain.auth.service import get_current_active_user
from domain.auth.types import User
from domain.virtual_fs.service import VirtualFSService
from domain.virtual_fs.types import MkdirRequest, MoveRequest
router = APIRouter(prefix="/api/fs", tags=["virtual-fs"])
@router.get("/file/{full_path:path}")
@audit(action=AuditAction.DOWNLOAD, description="获取文件")
async def get_file(
full_path: str,
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
):
return await VirtualFSService.serve_file(full_path, request.headers.get("Range"))
@router.get("/thumb/{full_path:path}")
@audit(action=AuditAction.READ, description="获取缩略图")
async def get_thumb(
full_path: str,
request: Request,
w: int = Query(256, ge=8, le=1024),
h: int = Query(256, ge=8, le=1024),
fit: str = Query("cover"),
):
return await VirtualFSService.get_thumbnail(full_path, w, h, fit)
@router.get("/stream/{full_path:path}")
@audit(action=AuditAction.DOWNLOAD, description="流式读取文件")
async def stream_endpoint(
full_path: str,
request: Request,
):
return await VirtualFSService.stream_response(full_path, request.headers.get("Range"))
@router.get("/temp-link/{full_path:path}")
@audit(action=AuditAction.SHARE, description="创建临时链接")
async def get_temp_link(
full_path: str,
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
expires_in: int = Query(3600, description="有效时间(秒), 0或负数表示永久"),
):
data = await VirtualFSService.create_temp_link(full_path, expires_in)
return success(data)
@router.get("/public/{token}")
@audit(action=AuditAction.DOWNLOAD, description="访问临时链接文件")
async def access_public_file(
token: str,
request: Request,
):
return await VirtualFSService.access_public_file(token, request.headers.get("Range"))
@router.get("/stat/{full_path:path}")
@audit(action=AuditAction.READ, description="查看文件信息")
async def get_file_stat(
full_path: str,
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
):
stat = await VirtualFSService.stat(full_path)
return success(stat)
@router.post("/file/{full_path:path}")
@audit(action=AuditAction.UPLOAD, description="上传文件")
async def put_file(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
full_path: str,
file: UploadFile = File(...),
):
data = await file.read()
result = await VirtualFSService.write_uploaded_file(full_path, data)
return success(result)
@router.post("/mkdir")
@audit(action=AuditAction.CREATE, description="创建目录", body_fields=["path"])
async def api_mkdir(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
body: MkdirRequest,
):
result = await VirtualFSService.mkdir(body.path)
return success(result)
@router.post("/move")
@audit(action=AuditAction.UPDATE, description="移动路径", body_fields=["src", "dst"])
async def api_move(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
body: MoveRequest,
overwrite: bool = Query(False, description="是否允许覆盖已存在目标"),
):
result = await VirtualFSService.move(body.src, body.dst, overwrite)
return success(result)
@router.post("/rename")
@audit(action=AuditAction.UPDATE, description="重命名路径", body_fields=["src", "dst"])
async def api_rename(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
body: MoveRequest,
overwrite: bool = Query(False, description="是否允许覆盖已存在目标"),
):
result = await VirtualFSService.rename(body.src, body.dst, overwrite)
return success(result)
@router.post("/copy")
@audit(action=AuditAction.CREATE, description="复制路径", body_fields=["src", "dst"])
async def api_copy(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
body: MoveRequest,
overwrite: bool = Query(False, description="是否覆盖已存在目标"),
):
result = await VirtualFSService.copy(body.src, body.dst, overwrite)
return success(result)
@router.post("/upload/{full_path:path}")
@audit(action=AuditAction.UPLOAD, description="流式上传文件")
async def upload_stream(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
full_path: str,
file: UploadFile = File(...),
overwrite: bool = Query(True, description="是否覆盖已存在文件"),
chunk_size: int = Query(1024 * 1024, ge=8 * 1024, le=8 * 1024 * 1024, description="单次读取块大小"),
):
result = await VirtualFSService.upload_stream_from_upload_file(full_path, file, chunk_size, overwrite)
return success(result)
@router.get("/{full_path:path}")
@audit(action=AuditAction.READ, description="浏览目录")
async def browse_fs(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
full_path: str,
page_num: int = Query(1, alias="page", ge=1, description="页码"),
page_size: int = Query(50, ge=1, le=500, description="每页条数"),
sort_by: str = Query("name", description="按字段排序: name, size, mtime"),
sort_order: str = Query("asc", description="排序顺序: asc, desc"),
):
data = await VirtualFSService.list_directory(full_path, page_num, page_size, sort_by, sort_order)
return success(data)
@router.delete("/{full_path:path}")
@audit(action=AuditAction.DELETE, description="删除路径")
async def api_delete(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
full_path: str,
):
result = await VirtualFSService.delete(full_path)
return success(result)
@router.get("/")
@audit(action=AuditAction.READ, description="浏览根目录")
async def root_listing(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
page_num: int = Query(1, alias="page", ge=1, description="页码"),
page_size: int = Query(50, ge=1, le=500, description="每页条数"),
sort_by: str = Query("name", description="按字段排序: name, size, mtime"),
sort_order: str = Query("asc", description="排序顺序: asc, desc"),
):
data = await VirtualFSService.list_directory("/", page_num, page_size, sort_by, sort_order)
return success(data)

View File

@@ -0,0 +1,44 @@
from __future__ import annotations
from pathlib import Path
from typing import Any
from fastapi import HTTPException
class VirtualFSCommonMixin:
CROSS_TRANSFER_TEMP_ROOT = Path("data/tmp/cross_transfer")
DIRECT_REDIRECT_CONFIG_KEY = "enable_direct_download_307"
@staticmethod
def _normalize_path(path: str) -> str:
return path if path.startswith("/") else f"/{path}"
@staticmethod
def _build_absolute_path(mount_path: str, rel_path: str) -> str:
rel_norm = rel_path.lstrip("/")
mount_norm = mount_path.rstrip("/")
if not mount_norm:
return "/" + rel_norm if rel_norm else "/"
return f"{mount_norm}/{rel_norm}" if rel_norm else mount_norm
@staticmethod
def _join_rel(base: str, name: str) -> str:
if not base:
return name.lstrip("/")
if not name:
return base
return f"{base.rstrip('/')}/{name.lstrip('/')}"
@staticmethod
def _parent_rel(rel: str) -> str:
if not rel or "/" not in rel:
return ""
return rel.rsplit("/", 1)[0]
@staticmethod
async def _ensure_method(adapter: Any, method: str):
func = getattr(adapter, method, None)
if not callable(func):
raise HTTPException(501, detail=f"Adapter does not implement {method}")
return func

View File

@@ -0,0 +1,126 @@
from __future__ import annotations
import mimetypes
from typing import Any, AsyncIterator, Union
from fastapi import HTTPException
from fastapi.responses import Response
from domain.tasks.service import TaskService
from domain.virtual_fs.thumbnail import is_raw_filename
from .listing import VirtualFSListingMixin
class VirtualFSFileOpsMixin(VirtualFSListingMixin):
@classmethod
async def read_file(cls, path: str) -> Union[bytes, Any]:
adapter_instance, _, root, rel = await cls.resolve_adapter_and_rel(path)
if rel.endswith("/") or rel == "":
raise HTTPException(400, detail="Path is a directory")
read_func = await cls._ensure_method(adapter_instance, "read_file")
return await read_func(root, rel)
@classmethod
async def write_file(cls, path: str, data: bytes):
adapter_instance, _, root, rel = await cls.resolve_adapter_and_rel(path)
if rel.endswith("/"):
raise HTTPException(400, detail="Invalid file path")
write_func = await cls._ensure_method(adapter_instance, "write_file")
await write_func(root, rel, data)
await TaskService.trigger_tasks("file_written", path)
@classmethod
async def write_file_stream(cls, path: str, data_iter: AsyncIterator[bytes], overwrite: bool = True):
adapter_instance, _, root, rel = await cls.resolve_adapter_and_rel(path)
if rel.endswith("/"):
raise HTTPException(400, detail="Invalid file path")
exists_func = getattr(adapter_instance, "exists", None)
if not overwrite and callable(exists_func):
try:
if await exists_func(root, rel):
raise HTTPException(409, detail="Destination exists")
except HTTPException:
raise
except Exception:
pass
size = 0
stream_func = getattr(adapter_instance, "write_file_stream", None)
if callable(stream_func):
size = await stream_func(root, rel, data_iter)
else:
buf = bytearray()
async for chunk in data_iter:
if chunk:
buf.extend(chunk)
write_func = await cls._ensure_method(adapter_instance, "write_file")
await write_func(root, rel, bytes(buf))
size = len(buf)
await TaskService.trigger_tasks("file_written", path)
return size
@classmethod
async def make_dir(cls, path: str):
adapter_instance, _, root, rel = await cls.resolve_adapter_and_rel(path)
if not rel:
return
mkdir_func = await cls._ensure_method(adapter_instance, "mkdir")
await mkdir_func(root, rel)
@classmethod
async def delete_path(cls, path: str):
adapter_instance, _, root, rel = await cls.resolve_adapter_and_rel(path)
if not rel:
raise HTTPException(400, detail="Cannot delete root")
delete_func = await cls._ensure_method(adapter_instance, "delete")
await delete_func(root, rel)
await TaskService.trigger_tasks("file_deleted", path)
@classmethod
async def stream_file(cls, path: str, range_header: str | None):
adapter_instance, adapter_model, root, rel = await cls.resolve_adapter_and_rel(path)
if not rel or rel.endswith("/"):
raise HTTPException(400, detail="Path is a directory")
if is_raw_filename(rel):
import io
import rawpy
from PIL import Image
try:
raw_data = await cls.read_file(path)
try:
with rawpy.imread(io.BytesIO(raw_data)) as raw:
try:
thumb = raw.extract_thumb()
except rawpy.LibRawNoThumbnailError:
thumb = None
if thumb is not None and thumb.format in [rawpy.ThumbFormat.JPEG, rawpy.ThumbFormat.BITMAP]:
im = Image.open(io.BytesIO(thumb.data))
else:
rgb = raw.postprocess(use_camera_wb=False, use_auto_wb=True, output_bps=8)
im = Image.fromarray(rgb)
except Exception as exc:
print(f"rawpy processing failed: {exc}")
raise exc
buf = io.BytesIO()
im.save(buf, "JPEG", quality=90)
content = buf.getvalue()
return Response(content=content, media_type="image/jpeg")
except Exception as exc:
raise HTTPException(500, detail=f"RAW file processing failed: {exc}")
redirect_response = await cls.maybe_redirect_download(adapter_instance, adapter_model, root, rel)
if redirect_response is not None:
return redirect_response
stream_impl = getattr(adapter_instance, "stream_file", None)
if callable(stream_impl):
return await stream_impl(root, rel, range_header)
data = await cls.read_file(path)
mime, _ = mimetypes.guess_type(rel)
return Response(content=data, media_type=mime or "application/octet-stream")

View File

@@ -0,0 +1,237 @@
from __future__ import annotations
from typing import Any, Dict, List, Tuple
from fastapi import HTTPException
from api.response import page
from domain.adapters.registry import runtime_registry
from domain.ai.service import VectorDBService
from domain.virtual_fs.thumbnail import is_image_filename, is_video_filename
from models import StorageAdapter
from .resolver import VirtualFSResolverMixin
class VirtualFSListingMixin(VirtualFSResolverMixin):
@classmethod
async def path_is_directory(cls, path: str) -> bool:
adapter_instance, _, root, rel = await cls.resolve_adapter_and_rel(path)
rel = rel.rstrip("/")
if rel == "":
return True
stat_func = getattr(adapter_instance, "stat_file", None)
if not callable(stat_func):
raise HTTPException(501, detail="Adapter does not implement stat_file")
try:
info = await stat_func(root, rel)
except FileNotFoundError:
raise HTTPException(404, detail="Path not found")
if isinstance(info, dict):
return bool(info.get("is_dir"))
return False
@classmethod
async def list_virtual_dir(
cls,
path: str,
page_num: int = 1,
page_size: int = 50,
sort_by: str = "name",
sort_order: str = "asc",
) -> Dict:
norm = cls._normalize_path(path).rstrip("/") or "/"
adapters = await StorageAdapter.filter(enabled=True)
child_mount_entries: List[str] = []
norm_prefix = norm.rstrip("/")
for adapter in adapters:
if adapter.path == norm:
continue
if adapter.path.startswith(norm_prefix + "/"):
tail = adapter.path[len(norm_prefix) :].lstrip("/")
if "/" not in tail:
child_mount_entries.append(tail)
child_mount_entries = sorted(set(child_mount_entries))
sort_field = sort_by.lower()
reverse = sort_order.lower() == "desc"
def build_sort_key(item: Dict) -> Tuple:
key = (not bool(item.get("is_dir")),)
if sort_field == "name":
key += (str(item.get("name", "")).lower(),)
elif sort_field == "size":
key += (int(item.get("size", 0)),)
elif sort_field == "mtime":
key += (int(item.get("mtime", 0)),)
else:
key += (str(item.get("name", "")).lower(),)
return key
def annotate_entry(entry: Dict) -> None:
if not entry.get("is_dir"):
name = entry.get("name", "")
entry["has_thumbnail"] = bool(is_image_filename(name) or is_video_filename(name))
else:
entry["has_thumbnail"] = False
try:
adapter_model, rel = await cls.resolve_adapter_by_path(norm)
adapter_instance = runtime_registry.get(adapter_model.id)
if not adapter_instance:
await runtime_registry.refresh()
adapter_instance = runtime_registry.get(adapter_model.id)
if adapter_instance:
effective_root = adapter_instance.get_effective_root(adapter_model.sub_path)
else:
adapter_model = None
effective_root = ""
rel = ""
except HTTPException:
adapter_model = None
adapter_instance = None
effective_root = ""
rel = ""
adapter_entries_for_merge: List[Dict] = []
adapter_entries_page: List[Dict] | None = None
adapter_total: int | None = None
if adapter_model and adapter_instance:
list_dir = getattr(adapter_instance, "list_dir", None)
if callable(list_dir):
adapter_entries_page, adapter_total = await list_dir(
effective_root, rel, page_num, page_size, sort_by, sort_order
)
if rel:
parent_rel = cls._parent_rel(rel)
if rel:
stat_file = getattr(adapter_instance, "stat_file", None)
if callable(stat_file):
try:
parent_info = await stat_file(effective_root, rel)
if isinstance(parent_info, dict):
parent_info.setdefault("name", rel.split("/")[-1])
parent_info["is_dir"] = bool(parent_info.get("is_dir", True))
adapter_entries_for_merge.append(parent_info)
except Exception:
pass
if parent_rel:
stat_file = getattr(adapter_instance, "stat_file", None)
if callable(stat_file):
try:
parent_info = await stat_file(effective_root, parent_rel)
if isinstance(parent_info, dict):
parent_info.setdefault("name", parent_rel.split("/")[-1])
parent_info["is_dir"] = bool(parent_info.get("is_dir", True))
adapter_entries_for_merge.append(parent_info)
except Exception:
pass
if adapter_entries_page:
adapter_entries_for_merge.extend(adapter_entries_page)
covered = set()
if adapter_entries_for_merge:
for item in adapter_entries_for_merge:
covered.add(item["name"])
mount_entries = []
for name in child_mount_entries:
if name not in covered:
mount_entries.append(
{"name": name, "is_dir": True, "size": 0, "mtime": 0, "type": "mount", "has_thumbnail": False}
)
if mount_entries:
for ent in adapter_entries_for_merge:
annotate_entry(ent)
combined_entries = adapter_entries_for_merge + [{**ent, "has_thumbnail": False} for ent in mount_entries]
combined_entries.sort(key=build_sort_key, reverse=reverse)
total_entries = len(combined_entries)
start_idx = (page_num - 1) * page_size
end_idx = start_idx + page_size
page_entries = combined_entries[start_idx:end_idx]
return page(page_entries, total_entries, page_num, page_size)
annotate_entry_list = adapter_entries_page or []
for ent in annotate_entry_list:
annotate_entry(ent)
return page(adapter_entries_page, adapter_total, page_num, page_size)
@classmethod
async def _gather_vector_index(cls, full_path: str, limit: int = 20):
vector_db = VectorDBService()
try:
raw_results = await vector_db.search_by_path("vector_collection", full_path, max(limit * 2, 20))
except Exception:
return None
matched = []
if raw_results:
buckets = raw_results if isinstance(raw_results, list) else [raw_results]
for bucket in buckets:
if not bucket:
continue
for record in bucket:
entity = dict((record or {}).get("entity") or {})
source_path = entity.get("source_path") or entity.get("path") or ""
if source_path != full_path:
continue
entry = {
"chunk_id": str(entity.get("chunk_id")) if entity.get("chunk_id") is not None else None,
"type": entity.get("type"),
"mime": entity.get("mime"),
"name": entity.get("name"),
"start_offset": entity.get("start_offset"),
"end_offset": entity.get("end_offset"),
"vector_id": entity.get("vector_id"),
}
text = entity.get("text") or entity.get("description")
if text:
preview_limit = 400
entry["preview"] = text[:preview_limit]
entry["preview_truncated"] = len(text) > preview_limit
matched.append(entry)
if not matched:
return {"total": 0, "entries": [], "by_type": {}, "has_more": False}
type_counts: Dict[str, int] = {}
for item in matched:
key = item.get("type") or "unknown"
type_counts[key] = type_counts.get(key, 0) + 1
has_more = len(matched) > limit
return {
"total": len(matched),
"entries": matched[:limit],
"by_type": type_counts,
"has_more": has_more,
"limit": limit,
}
@classmethod
async def stat_file(cls, path: str):
adapter_instance, _, root, rel = await cls.resolve_adapter_and_rel(path)
stat_func = getattr(adapter_instance, "stat_file", None)
if not callable(stat_func):
raise HTTPException(501, detail="Adapter does not implement stat_file")
info = await stat_func(root, rel)
if isinstance(info, dict):
info.setdefault("path", path)
try:
is_dir = bool(info.get("is_dir"))
except Exception:
is_dir = False
rel_name = rel.rstrip("/").split("/")[-1] if rel else path.rstrip("/").split("/")[-1]
name_hint = str(info.get("name") or rel_name or "")
info["has_thumbnail"] = bool(not is_dir and (is_image_filename(name_hint) or is_video_filename(name_hint)))
if not is_dir:
vector_index = await cls._gather_vector_index(path)
if vector_index is not None:
info["vector_index"] = vector_index
return info

View File

View File

@@ -0,0 +1,537 @@
from __future__ import annotations
import base64
import datetime as dt
import hashlib
import hmac
import uuid
from typing import Dict, Iterable, List, Optional, Tuple
from fastapi import APIRouter, Request, Response
from fastapi import HTTPException
from domain.config.service import ConfigService
from domain.virtual_fs.service import VirtualFSService
router = APIRouter(prefix="/s3", tags=["s3"])
FALSEY = {"0", "false", "off", "no"}
_XML_NS = "http://s3.amazonaws.com/doc/2006-03-01/"
class S3Settings(Dict[str, str]):
bucket: str
region: str
base_path: str
access_key: str
secret_key: str
def _now_iso() -> str:
return dt.datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%S.000Z")
def _etag(key: str, size: Optional[int], mtime: Optional[int]) -> str:
raw = f"{key}|{size or 0}|{mtime or 0}".encode("utf-8")
return '"' + hashlib.md5(raw).hexdigest() + '"'
def _meta_headers() -> Tuple[str, Dict[str, str]]:
req_id = uuid.uuid4().hex
headers = {
"x-amz-request-id": req_id,
"x-amz-id-2": uuid.uuid4().hex,
"Server": "FoxelS3",
}
return req_id, headers
def _s3_error(code: str, message: str, resource: str = "", status: int = 400) -> Response:
req_id, headers = _meta_headers()
xml = (
f"<?xml version=\"1.0\" encoding=\"UTF-8\"?>"
f"<Error>"
f"<Code>{code}</Code>"
f"<Message>{message}</Message>"
f"<Resource>{resource}</Resource>"
f"<RequestId>{req_id}</RequestId>"
f"</Error>"
)
return Response(content=xml, status_code=status, media_type="application/xml", headers=headers)
async def _ensure_enabled() -> Optional[Response]:
flag = await ConfigService.get("S3_MAPPING_ENABLED", "1")
if str(flag).strip().lower() in FALSEY:
return _s3_error("ServiceUnavailable", "S3 mapping disabled", status=503)
return None
async def _get_settings() -> Tuple[Optional[S3Settings], Optional[Response]]:
bucket = (await ConfigService.get("S3_MAPPING_BUCKET", "foxel")) or "foxel"
region = (await ConfigService.get("S3_MAPPING_REGION", "us-east-1")) or "us-east-1"
base_path = (await ConfigService.get("S3_MAPPING_BASE_PATH", "/")) or "/"
access_key = (await ConfigService.get("S3_MAPPING_ACCESS_KEY")) or ""
secret_key = (await ConfigService.get("S3_MAPPING_SECRET_KEY")) or ""
if not access_key or not secret_key:
return None, _s3_error(
"InvalidAccessKeyId",
"S3 mapping access key/secret are not configured.",
status=403,
)
settings: S3Settings = {
"bucket": bucket,
"region": region,
"base_path": base_path,
"access_key": access_key,
"secret_key": secret_key,
}
return settings, None
def _canonical_uri(path: str) -> str:
from urllib.parse import quote
if not path:
return "/"
return quote(path, safe="/-_.~")
def _canonical_query(params: Iterable[Tuple[str, str]]) -> str:
from urllib.parse import quote
encoded = []
for key, value in params:
enc_key = quote(key, safe="-_.~")
enc_val = quote(value or "", safe="-_.~")
encoded.append((enc_key, enc_val))
encoded.sort()
return "&".join(f"{k}={v}" for k, v in encoded)
def _normalize_ws(value: str) -> str:
return " ".join(value.strip().split())
def _sign(key: bytes, msg: str) -> bytes:
return hmac.new(key, msg.encode("utf-8"), hashlib.sha256).digest()
async def _authorize_sigv4(request: Request, settings: S3Settings) -> Optional[Response]:
auth = request.headers.get("authorization")
if not auth:
return _s3_error("AccessDenied", "Missing Authorization header", status=403)
scheme = "AWS4-HMAC-SHA256"
if not auth.startswith(scheme + " "):
return _s3_error("InvalidRequest", "Signature Version 4 is required", status=400)
parts: Dict[str, str] = {}
for segment in auth[len(scheme) + 1 :].split(","):
k, _, v = segment.strip().partition("=")
parts[k] = v
credential = parts.get("Credential")
signed_headers = parts.get("SignedHeaders")
signature = parts.get("Signature")
if not credential or not signed_headers or not signature:
return _s3_error("InvalidRequest", "Authorization header is malformed", status=400)
cred_parts = credential.split("/")
if len(cred_parts) != 5 or cred_parts[-1] != "aws4_request":
return _s3_error("InvalidRequest", "Credential scope is invalid", status=400)
access_key, datestamp, region, service, _ = cred_parts
if access_key != settings["access_key"]:
return _s3_error("InvalidAccessKeyId", "The AWS Access Key Id you provided does not exist in our records.", status=403)
if service != "s3":
return _s3_error("InvalidRequest", "Only service 's3' is supported", status=400)
if region != settings["region"]:
return _s3_error("AuthorizationHeaderMalformed", f"Region '{region}' is invalid", status=400)
amz_date = request.headers.get("x-amz-date")
if not amz_date or not amz_date.startswith(datestamp):
return _s3_error("AuthorizationHeaderMalformed", "x-amz-date does not match credential scope", status=400)
payload_hash = request.headers.get("x-amz-content-sha256")
if not payload_hash:
return _s3_error("AuthorizationHeaderMalformed", "Missing x-amz-content-sha256", status=400)
if payload_hash.upper().startswith("STREAMING-AWS4-HMAC-SHA256"):
return _s3_error("NotImplemented", "Chunked uploads are not supported", status=400)
signed_header_names = [h.strip().lower() for h in signed_headers.split(";") if h.strip()]
headers = {k.lower(): v for k, v in request.headers.items()}
canonical_headers = []
for name in signed_header_names:
value = headers.get(name)
if value is None:
return _s3_error("AuthorizationHeaderMalformed", f"Signed header '{name}' missing", status=400)
canonical_headers.append(f"{name}:{_normalize_ws(value)}\n")
canonical_request = "\n".join(
[
request.method,
_canonical_uri(request.url.path),
_canonical_query(request.query_params.multi_items()),
"".join(canonical_headers),
";".join(signed_header_names),
payload_hash,
]
)
hashed_request = hashlib.sha256(canonical_request.encode("utf-8")).hexdigest()
scope = "/".join([datestamp, region, "s3", "aws4_request"])
string_to_sign = "\n".join([scheme, amz_date, scope, hashed_request])
k_date = _sign(("AWS4" + settings["secret_key"]).encode("utf-8"), datestamp)
k_region = hmac.new(k_date, region.encode("utf-8"), hashlib.sha256).digest()
k_service = hmac.new(k_region, b"s3", hashlib.sha256).digest()
k_signing = hmac.new(k_service, b"aws4_request", hashlib.sha256).digest()
expected = hmac.new(k_signing, string_to_sign.encode("utf-8"), hashlib.sha256).hexdigest()
if expected != signature:
return _s3_error("SignatureDoesNotMatch", "The request signature we calculated does not match the signature you provided.", status=403)
return None
def _virtual_path(settings: S3Settings, key: str) -> str:
key_norm = key.strip("/")
base_norm = settings["base_path"].strip("/")
segments = [seg for seg in [base_norm, key_norm] if seg]
if not segments:
return "/"
return "/" + "/".join(segments)
def _join_virtual(base: str, name: str) -> str:
if not base or base == "/":
return "/" + name.strip("/")
return base.rstrip("/") + "/" + name.strip("/")
async def _list_dir_all(path: str) -> List[Dict]:
items: List[Dict] = []
page_num = 1
page_size = 1000
while True:
try:
res = await VirtualFSService.list_virtual_dir(path, page_num=page_num, page_size=page_size)
except HTTPException as exc: # directory missing
if exc.status_code in (400, 404):
return []
raise
chunk = res.get("items", [])
items.extend(chunk)
total = int(res.get("total", len(items)))
if len(items) >= total or not chunk or len(chunk) < page_size:
break
page_num += 1
return items
async def _collect_objects(path: str, key_prefix: str, recursive: bool, collect_prefixes: bool) -> Tuple[List[Tuple[str, Dict]], List[str]]:
entries = await _list_dir_all(path)
files: List[Tuple[str, Dict]] = []
prefixes: List[str] = []
for entry in entries:
name = entry.get("name")
if not name:
continue
if entry.get("is_dir"):
dir_key = f"{key_prefix}{name.strip('/')}/"
if collect_prefixes:
prefixes.append(dir_key)
if recursive:
sub_path = _join_virtual(path, name)
sub_files, _ = await _collect_objects(sub_path, dir_key, True, False)
files.extend(sub_files)
else:
key = f"{key_prefix}{name}"
files.append((key, entry))
files.sort(key=lambda item: item[0])
prefixes.sort()
return files, prefixes
def _encode_token(key: str) -> str:
raw = base64.urlsafe_b64encode(key.encode("utf-8")).decode("ascii")
return raw.rstrip("=")
def _decode_token(token: str) -> Optional[str]:
if not token:
return None
padding = "=" * (-len(token) % 4)
try:
return base64.urlsafe_b64decode(token + padding).decode("utf-8")
except Exception:
return None
def _apply_pagination(entries: List[Tuple[str, Dict]], prefixes: List[str], max_keys: int, start_after: Optional[str], continuation_token: Optional[str]) -> Tuple[List[Tuple[str, Dict]], List[str], bool, Optional[str]]:
combined = [(key, data, True) for key, data in entries] + [(prefix, None, False) for prefix in prefixes]
combined.sort(key=lambda item: item[0])
start_key = start_after or _decode_token(continuation_token or "")
if start_key:
combined = [item for item in combined if item[0] > start_key]
is_truncated = len(combined) > max_keys
sliced = combined[:max_keys]
next_token = _encode_token(sliced[-1][0]) if is_truncated and sliced else None
contents = [(key, data) for key, data, is_file in sliced if is_file]
next_prefixes = [key for key, _, is_file in sliced if not is_file]
return contents, next_prefixes, is_truncated, next_token
def _format_contents(entries: List[Tuple[str, Dict]]) -> str:
blocks = []
for key, meta in entries:
size = int(meta.get("size", 0))
mtime = meta.get("mtime")
if mtime is not None:
try:
mtime_val = int(mtime)
except Exception:
mtime_val = 0
else:
mtime_val = 0
last_modified = dt.datetime.utcfromtimestamp(mtime_val or dt.datetime.utcnow().timestamp()).strftime("%Y-%m-%dT%H:%M:%S.000Z")
etag = _etag(key, size, mtime_val)
blocks.append(
f"<Contents><Key>{key}</Key><LastModified>{last_modified}</LastModified><ETag>{etag}</ETag><Size>{size}</Size><StorageClass>STANDARD</StorageClass></Contents>"
)
return "".join(blocks)
def _format_common_prefixes(prefixes: List[str]) -> str:
return "".join(f"<CommonPrefixes><Prefix>{p}</Prefix></CommonPrefixes>" for p in prefixes)
def _resource_path(bucket: str, key: Optional[str] = None) -> str:
if key:
return f"/s3/{bucket}/{key}"
return f"/s3/{bucket}"
@router.get("")
async def list_buckets(request: Request):
if (resp := await _ensure_enabled()) is not None:
return resp
settings, err = await _get_settings()
if err:
return err
assert settings
if (auth := await _authorize_sigv4(request, settings)) is not None:
return auth
req_id, headers = _meta_headers()
xml = (
f"<?xml version=\"1.0\" encoding=\"UTF-8\"?>"
f"<ListAllMyBucketsResult xmlns=\"{_XML_NS}\">"
f"<Owner><ID>{settings['access_key']}</ID><DisplayName>Foxel</DisplayName></Owner>"
f"<Buckets><Bucket><Name>{settings['bucket']}</Name><CreationDate>{_now_iso()}</CreationDate></Bucket></Buckets>"
f"</ListAllMyBucketsResult>"
)
headers.update({"Content-Type": "application/xml"})
return Response(content=xml, media_type="application/xml", headers=headers)
@router.get("/{bucket}")
async def list_objects(request: Request, bucket: str):
if (resp := await _ensure_enabled()) is not None:
return resp
settings, err = await _get_settings()
if err:
return err
assert settings
if bucket != settings["bucket"]:
return _s3_error("NoSuchBucket", "The specified bucket does not exist.", _resource_path(bucket), status=404)
if (auth := await _authorize_sigv4(request, settings)) is not None:
return auth
params = request.query_params
if params.get("list-type", "2") != "2":
return _s3_error("InvalidArgument", "Only ListObjectsV2 (list-type=2) is supported.", _resource_path(bucket), status=400)
prefix = (params.get("prefix") or "").lstrip("/")
delimiter = params.get("delimiter")
recursive = not delimiter
max_keys_raw = params.get("max-keys", "1000")
try:
max_keys = max(1, min(1000, int(max_keys_raw)))
except ValueError:
max_keys = 1000
start_after = (params.get("start-after") or "").lstrip("/") or None
continuation = params.get("continuation-token")
# Exact file match if prefix is non-empty and does not end with '/'
files: List[Tuple[str, Dict]] = []
prefixes: List[str] = []
if prefix and not prefix.endswith("/"):
try:
info = await VirtualFSService.stat_file(_virtual_path(settings, prefix))
if not info.get("is_dir"):
files = [(prefix, info)]
except HTTPException as exc:
if exc.status_code not in (400, 404):
raise
if files:
contents, next_prefixes, is_truncated, next_token = _apply_pagination(files, [], max_keys, start_after, continuation)
xml = _build_list_result(bucket, prefix, delimiter, contents, next_prefixes, max_keys, is_truncated, continuation, next_token, start_after)
return xml
dir_prefix = prefix if not prefix or prefix.endswith("/") else prefix + "/"
virtual_dir = _virtual_path(settings, dir_prefix)
files, prefixes = await _collect_objects(virtual_dir, dir_prefix, recursive, bool(delimiter))
contents, next_prefixes, is_truncated, next_token = _apply_pagination(files, prefixes if delimiter else [], max_keys, start_after, continuation)
return _build_list_result(bucket, prefix, delimiter, contents, next_prefixes if delimiter else [], max_keys, is_truncated, continuation, next_token, start_after)
@router.get("/{bucket}/", include_in_schema=False)
async def list_objects_with_slash(request: Request, bucket: str):
return await list_objects(request, bucket)
def _build_list_result(
bucket: str,
prefix: str,
delimiter: Optional[str],
contents: List[Tuple[str, Dict]],
prefixes: List[str],
max_keys: int,
is_truncated: bool,
continuation: Optional[str],
next_token: Optional[str],
start_after: Optional[str],
):
req_id, headers = _meta_headers()
body = [f"<?xml version=\"1.0\" encoding=\"UTF-8\"?>", f"<ListBucketResult xmlns=\"{_XML_NS}\">"]
body.append(f"<Name>{bucket}</Name>")
body.append(f"<Prefix>{prefix}</Prefix>")
if delimiter:
body.append(f"<Delimiter>{delimiter}</Delimiter>")
if continuation:
body.append(f"<ContinuationToken>{continuation}</ContinuationToken>")
if start_after:
body.append(f"<StartAfter>{start_after}</StartAfter>")
body.append(f"<MaxKeys>{max_keys}</MaxKeys>")
body.append(f"<KeyCount>{len(contents) + len(prefixes)}</KeyCount>")
body.append(f"<IsTruncated>{str(is_truncated).lower()}</IsTruncated>")
if next_token:
body.append(f"<NextContinuationToken>{next_token}</NextContinuationToken>")
body.append(_format_contents(contents))
if prefixes:
body.append(_format_common_prefixes(prefixes))
body.append("</ListBucketResult>")
xml = "".join(body)
headers.update({"Content-Type": "application/xml"})
return Response(content=xml, media_type="application/xml", headers=headers)
async def _ensure_bucket_and_auth(request: Request, bucket: str) -> Tuple[Optional[S3Settings], Optional[Response]]:
if (resp := await _ensure_enabled()) is not None:
return None, resp
settings, err = await _get_settings()
if err:
return None, err
assert settings
if bucket != settings["bucket"]:
return None, _s3_error("NoSuchBucket", "The specified bucket does not exist.", _resource_path(bucket), status=404)
if (auth := await _authorize_sigv4(request, settings)) is not None:
return None, auth
return settings, None
def _object_headers(meta: Dict, key: str) -> Dict[str, str]:
size = int(meta.get("size", 0))
mtime = meta.get("mtime")
if mtime is not None:
try:
mtime_val = int(mtime)
except Exception:
mtime_val = 0
else:
mtime_val = 0
last_modified = dt.datetime.utcfromtimestamp(mtime_val or dt.datetime.utcnow().timestamp()).strftime("%a, %d %b %Y %H:%M:%S GMT")
headers = {
"Content-Length": str(size),
"ETag": _etag(key, size, mtime_val),
"Last-Modified": last_modified,
"Accept-Ranges": "bytes",
"x-amz-version-id": "null",
}
return headers
async def _stat_object(settings: S3Settings, key: str) -> Tuple[Optional[Dict], Optional[Response]]:
try:
info = await VirtualFSService.stat_file(_virtual_path(settings, key))
if info.get("is_dir"):
return None, _s3_error("NoSuchKey", "The specified key does not exist.", _resource_path(settings["bucket"], key), status=404)
return info, None
except HTTPException as exc:
if exc.status_code == 404:
return None, _s3_error("NoSuchKey", "The specified key does not exist.", _resource_path(settings["bucket"], key), status=404)
raise
@router.api_route("/{bucket}/{object_path:path}", methods=["GET", "HEAD"])
async def object_get_head(request: Request, bucket: str, object_path: str):
settings, error = await _ensure_bucket_and_auth(request, bucket)
if error:
return error
assert settings
key = object_path.lstrip("/")
meta, err = await _stat_object(settings, key)
if err:
return err
assert meta
_, base_headers = _meta_headers()
base_headers.update(_object_headers(meta, key))
if request.method == "HEAD":
return Response(status_code=200, headers=base_headers)
resp = await VirtualFSService.stream_file(_virtual_path(settings, key), request.headers.get("range"))
safe_merge_keys = {"ETag", "Last-Modified", "x-amz-version-id", "Accept-Ranges"}
for hk, hv in base_headers.items():
if hk in safe_merge_keys:
resp.headers.setdefault(hk, hv)
resp.headers.setdefault("Content-Type", meta.get("mime") or "application/octet-stream")
return resp
@router.put("/{bucket}/{object_path:path}")
async def put_object(request: Request, bucket: str, object_path: str):
settings, error = await _ensure_bucket_and_auth(request, bucket)
if error:
return error
assert settings
key = object_path.lstrip("/")
await VirtualFSService.write_file_stream(_virtual_path(settings, key), request.stream(), overwrite=True)
meta, err = await _stat_object(settings, key)
if err:
return err
headers = _object_headers(meta, key)
headers.pop("Content-Length", None)
headers.pop("Accept-Ranges", None)
headers["Content-Length"] = "0"
_, extra = _meta_headers()
headers.update(extra)
return Response(status_code=200, headers=headers)
@router.delete("/{bucket}/{object_path:path}")
async def delete_object(request: Request, bucket: str, object_path: str):
settings, error = await _ensure_bucket_and_auth(request, bucket)
if error:
return error
assert settings
key = object_path.lstrip("/")
try:
await VirtualFSService.delete_path(_virtual_path(settings, key))
except HTTPException as exc:
if exc.status_code not in (400, 404):
raise
_, headers = _meta_headers()
return Response(status_code=204, headers=headers)

View File

@@ -9,17 +9,19 @@ from typing import Optional
from fastapi import APIRouter, Request, Response, HTTPException, Depends
import xml.etree.ElementTree as ET
from services.auth import authenticate_user_db, User, UserInDB
from services.virtual_fs import (
list_virtual_dir,
stat_file,
write_file_stream,
make_dir,
delete_path,
move_path,
copy_path,
stream_file,
)
from domain.auth.service import AuthService
from domain.auth.types import User, UserInDB
from domain.virtual_fs.service import VirtualFSService
from domain.config.service import ConfigService
_WEBDAV_ENABLED_KEY = "WEBDAV_MAPPING_ENABLED"
async def _ensure_webdav_enabled() -> None:
enabled = await ConfigService.get(_WEBDAV_ENABLED_KEY, "1")
if str(enabled).strip().lower() in ("0", "false", "off", "no"):
raise HTTPException(503, detail="WebDAV mapping disabled")
router = APIRouter(prefix="/webdav", tags=["webdav"])
@@ -60,7 +62,7 @@ async def _get_basic_user(request: Request) -> User:
username, _, password = decoded.partition(":")
except Exception:
raise HTTPException(401, detail="Invalid Basic auth", headers={"WWW-Authenticate": "Basic realm=webdav"})
user_or_false: Optional[UserInDB] = await authenticate_user_db(username, password)
user_or_false: Optional[UserInDB] = await AuthService.authenticate_user_db(username, password)
if not user_or_false:
raise HTTPException(401, detail="Invalid credentials", headers={"WWW-Authenticate": "Basic realm=webdav"})
u: UserInDB = user_or_false
@@ -140,12 +142,17 @@ def _normalize_fs_path(path: str) -> str:
@router.options("/{path:path}")
async def options_root(path: str = ""):
async def options_root(path: str = "", _enabled: None = Depends(_ensure_webdav_enabled)):
return Response(status_code=200, headers=_dav_headers())
@router.api_route("/{path:path}", methods=["PROPFIND"])
async def propfind(request: Request, path: str, user: User = Depends(_get_basic_user)):
async def propfind(
request: Request,
path: str,
_enabled: None = Depends(_ensure_webdav_enabled),
user: User = Depends(_get_basic_user),
):
full_path = _normalize_fs_path(path)
depth = request.headers.get("Depth", "1").lower()
if depth not in ("0", "1", "infinity"):
@@ -155,7 +162,7 @@ async def propfind(request: Request, path: str, user: User = Depends(_get_basic_
# 先获取当前路径信息
try:
st = await stat_file(full_path)
st = await VirtualFSService.stat_file(full_path)
is_dir = bool(st.get("is_dir"))
name = st.get("name") or full_path.rsplit("/", 1)[-1] or "/"
size = None if is_dir else int(st.get("size", 0))
@@ -167,7 +174,7 @@ async def propfind(request: Request, path: str, user: User = Depends(_get_basic_
if depth in ("1", "infinity"):
try:
listing = await list_virtual_dir(full_path, page_num=1, page_size=1000)
listing = await VirtualFSService.list_virtual_dir(full_path, page_num=1, page_size=1000)
for ent in listing["items"]:
is_dir = bool(ent.get("is_dir"))
name = ent.get("name")
@@ -187,17 +194,26 @@ async def propfind(request: Request, path: str, user: User = Depends(_get_basic_
@router.get("/{path:path}")
async def dav_get(path: str, request: Request, user: User = Depends(_get_basic_user)):
async def dav_get(
path: str,
request: Request,
_enabled: None = Depends(_ensure_webdav_enabled),
user: User = Depends(_get_basic_user),
):
full_path = _normalize_fs_path(path)
range_header = request.headers.get("Range")
return await stream_file(full_path, range_header)
return await VirtualFSService.stream_file(full_path, range_header)
@router.head("/{path:path}")
async def dav_head(path: str, user: User = Depends(_get_basic_user)):
async def dav_head(
path: str,
_enabled: None = Depends(_ensure_webdav_enabled),
user: User = Depends(_get_basic_user),
):
full_path = _normalize_fs_path(path)
try:
st = await stat_file(full_path)
st = await VirtualFSService.stat_file(full_path)
except FileNotFoundError:
raise HTTPException(404, detail="Not found")
is_dir = bool(st.get("is_dir"))
@@ -216,27 +232,40 @@ async def dav_head(path: str, user: User = Depends(_get_basic_user)):
@router.api_route("/{path:path}", methods=["PUT"])
async def dav_put(path: str, request: Request, user: User = Depends(_get_basic_user)):
async def dav_put(
path: str,
request: Request,
_enabled: None = Depends(_ensure_webdav_enabled),
user: User = Depends(_get_basic_user),
):
full_path = _normalize_fs_path(path)
async def body_iter():
async for chunk in request.stream():
if chunk:
yield chunk
size = await write_file_stream(full_path, body_iter(), overwrite=True)
size = await VirtualFSService.write_file_stream(full_path, body_iter(), overwrite=True)
return Response(status_code=201, headers=_dav_headers({"Content-Length": "0"}))
@router.api_route("/{path:path}", methods=["DELETE"])
async def dav_delete(path: str, user: User = Depends(_get_basic_user)):
async def dav_delete(
path: str,
_enabled: None = Depends(_ensure_webdav_enabled),
user: User = Depends(_get_basic_user),
):
full_path = _normalize_fs_path(path)
await delete_path(full_path)
await VirtualFSService.delete_path(full_path)
return Response(status_code=204, headers=_dav_headers())
@router.api_route("/{path:path}", methods=["MKCOL"])
async def dav_mkcol(path: str, user: User = Depends(_get_basic_user)):
async def dav_mkcol(
path: str,
_enabled: None = Depends(_ensure_webdav_enabled),
user: User = Depends(_get_basic_user),
):
full_path = _normalize_fs_path(path)
await make_dir(full_path)
await VirtualFSService.make_dir(full_path)
return Response(status_code=201, headers=_dav_headers())
@@ -258,7 +287,7 @@ async def dav_move(path: str, request: Request, user: User = Depends(_get_basic_
dest_header = request.headers.get("Destination")
dst = _parse_destination(dest_header or "")
overwrite = request.headers.get("Overwrite", "T").upper() != "F"
await move_path(full_src, dst, overwrite=overwrite)
await VirtualFSService.move_path(full_src, dst, overwrite=overwrite)
return Response(status_code=204, headers=_dav_headers())
@@ -268,6 +297,5 @@ async def dav_copy(path: str, request: Request, user: User = Depends(_get_basic_
dest_header = request.headers.get("Destination")
dst = _parse_destination(dest_header or "")
overwrite = request.headers.get("Overwrite", "T").upper() != "F"
await copy_path(full_src, dst, overwrite=overwrite)
await VirtualFSService.copy_path(full_src, dst, overwrite=overwrite)
return Response(status_code=201 if not overwrite else 204, headers=_dav_headers())

View File

@@ -0,0 +1,106 @@
from __future__ import annotations
from typing import Any
from fastapi import HTTPException
from fastapi.responses import Response
from .transfer import VirtualFSTransferMixin
class VirtualFSProcessingMixin(VirtualFSTransferMixin):
@classmethod
async def process_file(
cls,
path: str,
processor_type: str,
config: dict,
save_to: str | None = None,
overwrite: bool = False,
) -> Any:
from domain.processors.service import get_processor
processor = get_processor(processor_type)
if not processor:
raise HTTPException(400, detail=f"Processor {processor_type} not found")
actual_is_dir = await cls.path_is_directory(path)
supported_exts = getattr(processor, "supported_exts", None) or []
allowed_exts = {str(ext).lower().lstrip(".") for ext in supported_exts if isinstance(ext, str)}
def matches_extension(rel_path: str) -> bool:
if not allowed_exts:
return True
if "." not in rel_path:
return "" in allowed_exts
ext = rel_path.rsplit(".", 1)[-1].lower()
return ext in allowed_exts or f".{ext}" in allowed_exts
def coerce_result_bytes(result: Any) -> bytes:
if isinstance(result, Response):
return result.body
if isinstance(result, (bytes, bytearray)):
return bytes(result)
if isinstance(result, str):
return result.encode("utf-8")
raise HTTPException(500, detail="Processor must return bytes/Response when produces_file=True")
if actual_is_dir:
if save_to:
raise HTTPException(400, detail="Directory processing does not support custom save_to path")
if not overwrite:
raise HTTPException(400, detail="Directory processing requires overwrite")
adapter_instance, adapter_model, root, rel = await cls.resolve_adapter_and_rel(path)
rel = rel.rstrip("/")
list_dir = await cls._ensure_method(adapter_instance, "list_dir")
processed_count = 0
stack: list[str] = [rel]
page_size = 200
while stack:
current = stack.pop()
page = 1
while True:
entries, total = await list_dir(root, current, page, page_size, "name", "asc")
if not entries and (total or 0) == 0:
break
for entry in entries:
name = entry.get("name")
if not name:
continue
child_rel = f"{current}/{name}" if current else name
if entry.get("is_dir"):
stack.append(child_rel)
continue
if not matches_extension(child_rel):
continue
absolute_path = cls._build_absolute_path(adapter_model.path, child_rel)
data = await cls.read_file(absolute_path)
result = await processor.process(data, absolute_path, config)
if getattr(processor, "produces_file", False):
result_bytes = coerce_result_bytes(result)
await cls.write_file(absolute_path, result_bytes)
processed_count += 1
if total is None or page * page_size >= total:
break
page += 1
return {"processed_files": processed_count}
data = await cls.read_file(path)
result = await processor.process(data, path, config)
target_path = save_to
if overwrite and not target_path:
target_path = path
if target_path and getattr(processor, "produces_file", False):
result_bytes = coerce_result_bytes(result)
await cls.write_file(target_path, result_bytes)
return {"saved_to": target_path}
return result

View File

@@ -0,0 +1,66 @@
from __future__ import annotations
from typing import Tuple
from fastapi import HTTPException
from fastapi.responses import Response
from domain.adapters.registry import runtime_registry
from models import StorageAdapter
from .common import VirtualFSCommonMixin
class VirtualFSResolverMixin(VirtualFSCommonMixin):
@classmethod
async def resolve_adapter_by_path(cls, path: str) -> Tuple[StorageAdapter, str]:
norm = cls._normalize_path(path)
adapters = await StorageAdapter.filter(enabled=True)
best = None
for adapter in adapters:
if norm == adapter.path or norm.startswith(adapter.path.rstrip("/") + "/"):
if best is None or len(adapter.path) > len(best.path):
best = adapter
if not best:
raise HTTPException(404, detail="No storage adapter for path")
rel = norm[len(best.path) :].lstrip("/")
return best, rel
@classmethod
async def resolve_adapter_and_rel(cls, path: str):
norm = cls._normalize_path(path)
adapter_model, rel = await cls.resolve_adapter_by_path(norm)
adapter_instance = runtime_registry.get(adapter_model.id)
if not adapter_instance:
await runtime_registry.refresh()
adapter_instance = runtime_registry.get(adapter_model.id)
if not adapter_instance:
raise HTTPException(
404, detail=f"Adapter instance for ID {adapter_model.id} not found or failed to load."
)
effective_root = adapter_instance.get_effective_root(adapter_model.sub_path)
return adapter_instance, adapter_model, effective_root, rel
@classmethod
async def maybe_redirect_download(cls, adapter_instance, adapter_model, root: str, rel: str):
if not rel or rel.endswith("/"):
return None
config = getattr(adapter_model, "config", {}) or {}
if not config.get(cls.DIRECT_REDIRECT_CONFIG_KEY):
return None
handler = getattr(adapter_instance, "get_direct_download_response", None)
if not callable(handler):
return None
try:
response = await handler(root, rel)
except FileNotFoundError:
raise
except Exception:
return None
if isinstance(response, Response):
return response
return None

250
domain/virtual_fs/routes.py Normal file
View File

@@ -0,0 +1,250 @@
from __future__ import annotations
import mimetypes
import re
from fastapi import HTTPException, UploadFile
from fastapi.responses import Response
from domain.config.service import ConfigService
from domain.virtual_fs.thumbnail import get_or_create_thumb, is_image_filename, is_raw_filename, is_video_filename
from .temp_link import VirtualFSTempLinkMixin
class VirtualFSRouteMixin(VirtualFSTempLinkMixin):
@classmethod
async def serve_file(cls, full_path: str, range_header: str | None) -> Response:
full_path = cls._normalize_path(full_path)
if is_raw_filename(full_path):
import io
import rawpy
from PIL import Image
try:
raw_data = await cls.read_file(full_path)
with rawpy.imread(io.BytesIO(raw_data)) as raw:
rgb = raw.postprocess(use_camera_wb=True, output_bps=8)
im = Image.fromarray(rgb)
buf = io.BytesIO()
im.save(buf, "JPEG", quality=90)
content = buf.getvalue()
return Response(content=content, media_type="image/jpeg")
except FileNotFoundError:
raise HTTPException(404, detail="File not found")
except Exception as exc:
raise HTTPException(500, detail=f"RAW file processing failed: {exc}")
adapter_instance, adapter_model, root, rel = await cls.resolve_adapter_and_rel(full_path)
redirect_response = await cls.maybe_redirect_download(adapter_instance, adapter_model, root, rel)
if redirect_response is not None:
return redirect_response
try:
content = await cls.read_file(full_path)
except FileNotFoundError:
raise HTTPException(404, detail="File not found")
if not isinstance(content, (bytes, bytearray)):
return Response(content=content, media_type="application/octet-stream")
content_length = len(content)
content_type = mimetypes.guess_type(full_path)[0] or "application/octet-stream"
if range_header:
range_match = re.match(r"bytes=(\\d+)-(\\d*)", range_header)
if range_match:
start = int(range_match.group(1))
end = int(range_match.group(2)) if range_match.group(2) else content_length - 1
start = max(0, min(start, content_length - 1))
end = max(start, min(end, content_length - 1))
chunk = content[start : end + 1]
chunk_size = len(chunk)
headers = {
"Content-Range": f"bytes {start}-{end}/{content_length}",
"Accept-Ranges": "bytes",
"Content-Length": str(chunk_size),
"Content-Type": content_type,
}
return Response(content=chunk, status_code=206, headers=headers)
headers = {
"Accept-Ranges": "bytes",
"Content-Length": str(content_length),
"Content-Type": content_type,
}
if content_type.startswith("video/"):
headers["Cache-Control"] = "public, max-age=3600"
return Response(content=content, headers=headers)
@classmethod
async def get_thumbnail(cls, full_path: str, w: int, h: int, fit: str) -> Response:
full_path = cls._normalize_path(full_path)
if fit not in ("cover", "contain"):
raise HTTPException(400, detail="fit must be cover|contain")
adapter, mount, root, rel = await cls.resolve_adapter_and_rel(full_path)
if not rel or rel.endswith("/"):
raise HTTPException(400, detail="Not a file")
if not (is_image_filename(rel) or is_video_filename(rel)):
raise HTTPException(404, detail="Not an image or video")
data, mime, key = await get_or_create_thumb(adapter, mount.id, root, rel, w, h, fit) # type: ignore
headers = {
"Cache-Control": "public, max-age=3600",
"ETag": key,
}
return Response(content=data, media_type=mime, headers=headers)
@classmethod
async def stream_response(cls, full_path: str, range_header: str | None):
full_path = cls._normalize_path(full_path)
try:
return await cls.stream_file(full_path, range_header)
except HTTPException:
raise
except FileNotFoundError:
raise HTTPException(404, detail="File not found")
except Exception as exc:
raise HTTPException(500, detail=f"Stream error: {exc}")
@classmethod
async def create_temp_link(cls, full_path: str, expires_in: int):
full_path = cls._normalize_path(full_path)
token = await cls.generate_temp_link_token(full_path, expires_in=expires_in)
file_domain = await ConfigService.get("FILE_DOMAIN")
if file_domain:
file_domain = file_domain.rstrip("/")
url = f"{file_domain}/api/fs/public/{token}"
else:
url = f"/api/fs/public/{token}"
return {"token": token, "path": full_path, "url": url}
@classmethod
async def access_public_file(cls, token: str, range_header: str | None):
try:
path = await cls.verify_temp_link_token(token)
except HTTPException as exc:
raise exc
try:
return await cls.stream_file(path, range_header)
except FileNotFoundError:
raise HTTPException(404, detail="File not found via token")
except Exception as exc:
raise HTTPException(500, detail=f"File access error: {exc}")
@classmethod
async def stat(cls, full_path: str):
full_path = cls._normalize_path(full_path)
return await cls.stat_file(full_path)
@classmethod
async def write_uploaded_file(cls, full_path: str, data: bytes):
full_path = cls._normalize_path(full_path)
await cls.write_file(full_path, data)
return {"written": True, "path": full_path, "size": len(data)}
@classmethod
async def mkdir(cls, path: str):
path = cls._normalize_path(path)
if not path or path == "/":
raise HTTPException(400, detail="Invalid path")
await cls.make_dir(path)
return {"created": True, "path": path}
@classmethod
async def move(cls, src: str, dst: str, overwrite: bool):
src = cls._normalize_path(src)
dst = cls._normalize_path(dst)
debug_info = await cls.move_path(src, dst, overwrite=overwrite, return_debug=True, allow_cross=True)
queued = bool(debug_info.get("queued"))
response = {
"moved": not queued,
"queued": queued,
"src": src,
"dst": dst,
"overwrite": overwrite,
}
if queued:
response["task_id"] = debug_info.get("task_id")
response["task_name"] = debug_info.get("task_name")
return response
@classmethod
async def rename(cls, src: str, dst: str, overwrite: bool):
src = cls._normalize_path(src)
dst = cls._normalize_path(dst)
await cls.rename_path(src, dst, overwrite=overwrite, return_debug=False)
return {"renamed": True, "src": src, "dst": dst, "overwrite": overwrite}
@classmethod
async def copy(cls, src: str, dst: str, overwrite: bool):
src = cls._normalize_path(src)
dst = cls._normalize_path(dst)
debug_info = await cls.copy_path(src, dst, overwrite=overwrite, return_debug=True, allow_cross=True)
queued = bool(debug_info.get("queued"))
response = {
"copied": not queued,
"queued": queued,
"src": src,
"dst": dst,
"overwrite": overwrite,
}
if queued:
response["task_id"] = debug_info.get("task_id")
response["task_name"] = debug_info.get("task_name")
return response
@classmethod
async def upload_stream_from_upload_file(cls, full_path: str, file: UploadFile, chunk_size: int, overwrite: bool):
full_path = cls._normalize_path(full_path)
if full_path.endswith("/"):
raise HTTPException(400, detail="Path must be a file")
adapter, _m, root, rel = await cls.resolve_adapter_and_rel(full_path)
exists_func = getattr(adapter, "exists", None)
if not overwrite and callable(exists_func):
try:
if await exists_func(root, rel):
raise HTTPException(409, detail="Destination exists")
except HTTPException:
raise
except Exception:
pass
async def gen():
while True:
chunk = await file.read(chunk_size)
if not chunk:
break
yield chunk
size = await cls.write_file_stream(full_path, gen(), overwrite=overwrite)
return {"uploaded": True, "path": full_path, "size": size, "overwrite": overwrite}
@classmethod
async def list_directory(cls, full_path: str, page_num: int, page_size: int, sort_by: str, sort_order: str):
full_path = cls._normalize_path(full_path)
result = await cls.list_virtual_dir(full_path, page_num, page_size, sort_by, sort_order)
return {
"path": full_path,
"entries": result["items"],
"pagination": {
"total": result["total"],
"page": result["page"],
"page_size": result["page_size"],
"pages": result["pages"],
},
}
@classmethod
async def delete(cls, full_path: str):
full_path = cls._normalize_path(full_path)
await cls.delete_path(full_path)
return {"deleted": True, "path": full_path}

View File

View File

@@ -0,0 +1,26 @@
from fastapi import APIRouter, Depends, Query
from domain.auth.service import get_current_active_user
from domain.auth.types import User
from domain.virtual_fs.search.search_service import VirtualFSSearchService
router = APIRouter(prefix="/api/fs/search", tags=["search"])
@router.get("")
async def search_files(
q: str = Query(..., description="搜索查询"),
top_k: int = Query(10, description="返回结果数量"),
mode: str = Query("vector", description="搜索模式: 'vector''filename'"),
page: int = Query(1, description="分页页码,仅在文件名搜索模式下生效"),
page_size: int = Query(10, description="分页大小,仅在文件名搜索模式下生效"),
user: User = Depends(get_current_active_user),
):
if not q.strip():
return {"items": [], "query": q}
top_k = max(top_k, 1)
page = max(page, 1)
page_size = max(min(page_size, 100), 1)
return await VirtualFSSearchService.search(q, top_k, mode, page, page_size)

View File

@@ -1,13 +1,8 @@
from typing import Any, Dict, List, Tuple
from fastapi import APIRouter, Depends, Query
from schemas.fs import SearchResultItem
from services.auth import get_current_active_user, User
from services.ai import get_text_embedding
from services.vector_db import VectorDBService
router = APIRouter(prefix="/api/search", tags=["search"])
from domain.virtual_fs.types import SearchResultItem
from domain.ai.inference import get_text_embedding
from domain.ai.service import VectorDBService
def _normalize_result(raw: Dict[str, Any], source: str, fallback_score: float = 0.0) -> SearchResultItem:
@@ -101,37 +96,23 @@ async def _filename_search(query: str, page: int, page_size: int) -> Tuple[List[
return page_items, has_more
@router.get("")
async def search_files(
q: str = Query(..., description="搜索查询"),
top_k: int = Query(10, description="返回结果数量"),
mode: str = Query("vector", description="搜索模式: 'vector''filename'"),
page: int = Query(1, description="分页页码,仅在文件名搜索模式下生效"),
page_size: int = Query(10, description="分页大小,仅在文件名搜索模式下生效"),
user: User = Depends(get_current_active_user),
):
if not q.strip():
return {"items": [], "query": q}
top_k = max(top_k, 1)
page = max(page, 1)
page_size = max(min(page_size, 100), 1)
if mode == "vector":
items = (await _vector_search(q, top_k))[:top_k]
elif mode == "filename":
items, has_more = await _filename_search(q, page, page_size)
return {
"items": items,
"query": q,
"mode": mode,
"pagination": {
"page": page,
"page_size": page_size,
"has_more": has_more,
},
}
else:
items = (await _vector_search(q, top_k))[:top_k]
return {"items": items, "query": q, "mode": mode}
class VirtualFSSearchService:
@staticmethod
async def search(query: str, top_k: int, mode: str, page: int, page_size: int):
if mode == "vector":
items = (await _vector_search(query, top_k))[:top_k]
return {"items": items, "query": query, "mode": mode}
if mode == "filename":
items, has_more = await _filename_search(query, page, page_size)
return {
"items": items,
"query": query,
"mode": mode,
"pagination": {
"page": page,
"page_size": page_size,
"has_more": has_more,
},
}
items = (await _vector_search(query, top_k))[:top_k]
return {"items": items, "query": query, "mode": mode}

View File

@@ -0,0 +1,23 @@
from __future__ import annotations
from .common import VirtualFSCommonMixin
from .resolver import VirtualFSResolverMixin
from .listing import VirtualFSListingMixin
from .file_ops import VirtualFSFileOpsMixin
from .transfer import VirtualFSTransferMixin
from .processing import VirtualFSProcessingMixin
from .temp_link import VirtualFSTempLinkMixin
from .routes import VirtualFSRouteMixin
class VirtualFSService(
VirtualFSRouteMixin,
VirtualFSTempLinkMixin,
VirtualFSProcessingMixin,
VirtualFSTransferMixin,
VirtualFSFileOpsMixin,
VirtualFSListingMixin,
VirtualFSResolverMixin,
VirtualFSCommonMixin,
):
pass

Some files were not shown because too many files have changed in this diff Show More