mirror of
https://github.com/DrizzleTime/Foxel.git
synced 2026-05-08 21:03:18 +08:00
Compare commits
60 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e7cf8dbdb8 | ||
|
|
e7eafdee97 | ||
|
|
051b49d3f6 | ||
|
|
b059b0eb44 | ||
|
|
59ad2cb622 | ||
|
|
6b2ada0b42 | ||
|
|
a727e77341 | ||
|
|
4638356a45 | ||
|
|
e51344b43e | ||
|
|
b7685db0e8 | ||
|
|
4e16de973c | ||
|
|
4dd0a4b1d6 | ||
|
|
5703825c31 | ||
|
|
24255744df | ||
|
|
31d97b2968 | ||
|
|
35abd080be | ||
|
|
2fa93a1eeb | ||
|
|
ff7eb13187 | ||
|
|
ed9090c3d0 | ||
|
|
d430254868 | ||
|
|
a8870f80da | ||
|
|
14ef2a4ccc | ||
|
|
dd41941b04 | ||
|
|
01a259bae0 | ||
|
|
ef5ef2730c | ||
|
|
8b8772b064 | ||
|
|
5393a973eb | ||
|
|
cc1f130099 | ||
|
|
c8b3817805 | ||
|
|
b1ea181f96 | ||
|
|
078709b871 | ||
|
|
d788bde44f | ||
|
|
28ede26801 | ||
|
|
53130383c1 | ||
|
|
036eeb92c2 | ||
|
|
5701a13f4f | ||
|
|
184997deed | ||
|
|
1d5824d498 | ||
|
|
91ff1860b7 | ||
|
|
56f947d0bf | ||
|
|
ad016baaf9 | ||
|
|
ad2e2858da | ||
|
|
a69d6c21a6 | ||
|
|
2a4a3c44b9 | ||
|
|
cdb8543370 | ||
|
|
2dabe9255f | ||
|
|
239216e574 | ||
|
|
09c65bffb7 | ||
|
|
ff1c06ad18 | ||
|
|
d88e95a9af | ||
|
|
ae80a751a8 | ||
|
|
b40e700a64 | ||
|
|
040d8346b3 | ||
|
|
55d062f0a7 | ||
|
|
cfaaff8a8c | ||
|
|
d6d41333fd | ||
|
|
a4efba94d5 | ||
|
|
00e6419b12 | ||
|
|
bbe8465aa0 | ||
|
|
baadaa70a7 |
2
.github/FUNDING.yml
vendored
2
.github/FUNDING.yml
vendored
@@ -1 +1 @@
|
||||
custom: https://foxel.cc/sponsor.html
|
||||
custom: https://foxel.cc/sponsor
|
||||
|
||||
6
.github/dependabot.yml
vendored
6
.github/dependabot.yml
vendored
@@ -3,14 +3,14 @@ updates:
|
||||
- package-ecosystem: "github-actions"
|
||||
directory: "/"
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
interval: "monthly"
|
||||
|
||||
- package-ecosystem: "bun"
|
||||
directory: "/web"
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
interval: "monthly"
|
||||
|
||||
- package-ecosystem: "uv"
|
||||
directory: "/"
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
interval: "monthly"
|
||||
|
||||
@@ -1 +1 @@
|
||||
3.13
|
||||
3.14
|
||||
|
||||
@@ -9,7 +9,7 @@ COPY web/ ./
|
||||
|
||||
RUN bun run build
|
||||
|
||||
FROM python:3.13-slim
|
||||
FROM python:3.14-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
|
||||
46
README.md
46
README.md
@@ -8,16 +8,17 @@
|
||||
|
||||
**A highly extensible private cloud storage solution for individuals and teams, featuring AI-powered semantic search.**
|
||||
|
||||

|
||||

|
||||

|
||||

|
||||
|
||||

|
||||
|
||||
---
|
||||
<blockquote>
|
||||
<em><strong>The ocean of data is boundless, let the eye of insight guide the voyage, yet its intricate connections lie deep, not fully discernible from the surface.</strong></em>
|
||||
</blockquote>
|
||||
<img src="https://foxel.cc/image/ad-min.png" alt="UI Screenshot">
|
||||
<img src="https://foxel.cc/image/ad-min-en.png" alt="UI Screenshot">
|
||||
</div>
|
||||
|
||||
## 👀 Online Demo
|
||||
@@ -39,36 +40,37 @@
|
||||
|
||||
Using Docker Compose is the most recommended way to start Foxel.
|
||||
|
||||
1. **Create Data Directories**:
|
||||
Create a `data` folder for persistent data:
|
||||
1. **Create Data Directories**
|
||||
|
||||
```bash
|
||||
mkdir -p data/db
|
||||
mkdir -p data/mount
|
||||
chmod 777 data/db data/mount
|
||||
```
|
||||
Create a `data` folder for persistent data:
|
||||
|
||||
2. **Download Docker Compose File**:
|
||||
```bash
|
||||
mkdir -p data/db
|
||||
mkdir -p data/mount
|
||||
chmod 777 data/db data/mount
|
||||
```
|
||||
|
||||
```bash
|
||||
curl -L -O https://github.com/DrizzleTime/Foxel/raw/main/compose.yaml
|
||||
```
|
||||
2. **Download Docker Compose File**
|
||||
|
||||
After downloading, it is **strongly recommended** to modify the environment variables in the `compose.yaml` file to ensure security:
|
||||
```bash
|
||||
curl -L -O https://github.com/DrizzleTime/Foxel/raw/main/compose.yaml
|
||||
```
|
||||
|
||||
- Modify `SECRET_KEY` and `TEMP_LINK_SECRET_KEY`: Replace the default keys with randomly generated strong keys.
|
||||
After downloading, it is **strongly recommended** to modify the environment variables in the `compose.yaml` file to ensure security:
|
||||
|
||||
3. **Start the Services**:
|
||||
- Modify `SECRET_KEY` and `TEMP_LINK_SECRET_KEY`: Replace the default keys with randomly generated strong keys.
|
||||
|
||||
```bash
|
||||
docker-compose up -d
|
||||
```
|
||||
3. **Start the Services**
|
||||
|
||||
4. **Access the Application**:
|
||||
```bash
|
||||
docker-compose up -d
|
||||
```
|
||||
|
||||
Once the services are running, open the page in your browser.
|
||||
4. **Access the Application**
|
||||
|
||||
> On the first launch, please follow the setup guide to initialize the administrator account.
|
||||
Once the services are running, open the page in your browser.
|
||||
|
||||
> On the first launch, please follow the setup guide to initialize the administrator account.
|
||||
|
||||
## 🤝 How to Contribute
|
||||
|
||||
|
||||
47
README_zh.md
47
README_zh.md
@@ -8,17 +8,17 @@
|
||||
|
||||
**一个面向个人和团队的、高度可扩展的私有云盘解决方案,支持 AI 语义搜索。**
|
||||
|
||||

|
||||

|
||||

|
||||

|
||||
|
||||

|
||||
|
||||
---
|
||||
<blockquote>
|
||||
<em><strong>数据之洋浩瀚无涯,当以洞察之目引航,然其脉络深隐,非表象所能尽窥。</strong></em><br>
|
||||
<em><strong>The ocean of data is boundless, let the eye of insight guide the voyage, yet its intricate connections lie deep, not fully discernible from the surface.</strong></em>
|
||||
</blockquote>
|
||||
<img src="https://foxel.cc/image/ad-min.png" alt="UI Screenshot">
|
||||
<img src="https://foxel.cc/image/ad-min-zh.png" alt="UI Screenshot">
|
||||
</div>
|
||||
|
||||
## 👀 在线体验
|
||||
@@ -40,36 +40,37 @@
|
||||
|
||||
使用 Docker Compose 是启动 Foxel 最推荐的方式。
|
||||
|
||||
1. **创建数据目录**:
|
||||
新建 `data` 文件夹用于持久化数据:
|
||||
1. **创建数据目录**
|
||||
|
||||
```bash
|
||||
mkdir -p data/db
|
||||
mkdir -p data/mount
|
||||
chmod 777 data/db data/mount
|
||||
```
|
||||
新建 `data` 文件夹用于持久化数据:
|
||||
|
||||
2. **下载 Docker Compose 文件**:
|
||||
```bash
|
||||
mkdir -p data/db
|
||||
mkdir -p data/mount
|
||||
chmod 777 data/db data/mount
|
||||
```
|
||||
|
||||
```bash
|
||||
curl -L -O https://github.com/DrizzleTime/Foxel/raw/main/compose.yaml
|
||||
```
|
||||
2. **下载 Docker Compose 文件**
|
||||
|
||||
下载完成后,**强烈建议**修改 `compose.yaml` 文件中的环境变量以确保安全:
|
||||
```bash
|
||||
curl -L -O https://github.com/DrizzleTime/Foxel/raw/main/compose.yaml
|
||||
```
|
||||
|
||||
- 修改 `SECRET_KEY` 和 `TEMP_LINK_SECRET_KEY`:将默认的密钥替换为随机生成的强密钥
|
||||
下载完成后,**强烈建议**修改 `compose.yaml` 文件中的环境变量以确保安全:
|
||||
|
||||
3. **启动服务**:
|
||||
- 修改 `SECRET_KEY` 和 `TEMP_LINK_SECRET_KEY`:将默认的密钥替换为随机生成的强密钥
|
||||
|
||||
```bash
|
||||
docker-compose up -d
|
||||
```
|
||||
3. **启动服务**
|
||||
|
||||
4. **访问应用**:
|
||||
```bash
|
||||
docker-compose up -d
|
||||
```
|
||||
|
||||
服务启动后,在浏览器中打开页面。
|
||||
4. **访问应用**
|
||||
|
||||
> 首次启动,请根据引导页面完成管理员账号的初始化设置。
|
||||
服务启动后,在浏览器中打开页面。
|
||||
|
||||
> 首次启动,请根据引导页面完成管理员账号的初始化设置。
|
||||
|
||||
## 🤝 如何贡献
|
||||
|
||||
|
||||
@@ -11,16 +11,17 @@ from domain.processors import api as processors
|
||||
from domain.share import api as share
|
||||
from domain.tasks import api as tasks
|
||||
from domain.ai import api as ai
|
||||
from domain.agent import api as agent
|
||||
from domain.virtual_fs import api as virtual_fs
|
||||
from domain.virtual_fs.mapping import s3_api, webdav_api
|
||||
from domain.virtual_fs.search import search_api
|
||||
from domain.audit import router as audit
|
||||
from domain.audit import api as audit
|
||||
|
||||
|
||||
def include_routers(app: FastAPI):
|
||||
app.include_router(adapters.router)
|
||||
app.include_router(virtual_fs.router)
|
||||
app.include_router(search_api.router)
|
||||
app.include_router(virtual_fs.router)
|
||||
app.include_router(auth.router)
|
||||
app.include_router(config.router)
|
||||
app.include_router(processors.router)
|
||||
@@ -30,9 +31,10 @@ def include_routers(app: FastAPI):
|
||||
app.include_router(backup.router)
|
||||
app.include_router(ai.router_vector_db)
|
||||
app.include_router(ai.router_ai)
|
||||
app.include_router(agent.router)
|
||||
app.include_router(plugins.router)
|
||||
app.include_router(webdav_api.router)
|
||||
app.include_router(s3_api.router)
|
||||
app.include_router(offline_downloads.router)
|
||||
app.include_router(email.router)
|
||||
app.include_router(audit)
|
||||
app.include_router(audit.router)
|
||||
|
||||
@@ -5,9 +5,10 @@ services:
|
||||
container_name: foxel
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
- "8088:80"
|
||||
- "${FOXEL_HOST_PORT:-8088}:${FOXEL_PORT:-80}"
|
||||
environment:
|
||||
- TZ=Asia/Shanghai
|
||||
- FOXEL_PORT=${FOXEL_PORT:-80}
|
||||
- SECRET_KEY=EnsRhL9NFPxgFVc+7t96/y70DIOR+9SpntcIqQa90TU=
|
||||
- TEMP_LINK_SECRET_KEY=EnsRhL9NFPxgFVc+7t96/y70DIOR+9SpntcIqQa90TU=
|
||||
volumes:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from tortoise import Tortoise
|
||||
|
||||
from domain.adapters.registry import runtime_registry
|
||||
from domain.adapters import runtime_registry
|
||||
|
||||
TORTOISE_ORM = {
|
||||
"connections": {"default": "sqlite://data/db/db.sqlite3"},
|
||||
@@ -12,22 +12,7 @@ TORTOISE_ORM = {
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def patch_aiosqlite_for_tortoise() -> None:
|
||||
import aiosqlite
|
||||
|
||||
if hasattr(aiosqlite.Connection, "start"):
|
||||
return
|
||||
|
||||
def start(self) -> None: # type: ignore[no-redef]
|
||||
if not self._thread.is_alive():
|
||||
self._thread.start()
|
||||
|
||||
aiosqlite.Connection.start = start # type: ignore[attr-defined]
|
||||
|
||||
|
||||
async def init_db():
|
||||
patch_aiosqlite_for_tortoise()
|
||||
await Tortoise.init(config=TORTOISE_ORM)
|
||||
await Tortoise.generate_schemas()
|
||||
await runtime_registry.refresh()
|
||||
|
||||
7
domain/__init__.py
Normal file
7
domain/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
domain:业务域层
|
||||
|
||||
约定:跨包只从各子包 `__init__.py` 导入公开 API。
|
||||
"""
|
||||
|
||||
__all__: list[str] = []
|
||||
@@ -1 +1,24 @@
|
||||
from .providers import BaseAdapter
|
||||
from .registry import (
|
||||
RuntimeRegistry,
|
||||
discover_adapters,
|
||||
get_config_schema,
|
||||
get_config_schemas,
|
||||
normalize_adapter_type,
|
||||
runtime_registry,
|
||||
)
|
||||
from .service import AdapterService
|
||||
from .types import AdapterCreate, AdapterOut
|
||||
|
||||
__all__ = [
|
||||
"BaseAdapter",
|
||||
"RuntimeRegistry",
|
||||
"discover_adapters",
|
||||
"get_config_schema",
|
||||
"get_config_schemas",
|
||||
"normalize_adapter_type",
|
||||
"runtime_registry",
|
||||
"AdapterService",
|
||||
"AdapterCreate",
|
||||
"AdapterOut",
|
||||
]
|
||||
|
||||
@@ -4,10 +4,9 @@ from fastapi import APIRouter, Depends, Request
|
||||
|
||||
from api.response import success
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.adapters.service import AdapterService
|
||||
from domain.adapters.types import AdapterCreate
|
||||
from domain.auth.service import get_current_active_user
|
||||
from domain.auth.types import User
|
||||
from domain.auth import User, get_current_active_user
|
||||
from .service import AdapterService
|
||||
from .types import AdapterCreate
|
||||
|
||||
router = APIRouter(prefix="/api/adapters", tags=["adapters"])
|
||||
|
||||
|
||||
411
domain/adapters/providers/foxel.py
Normal file
411
domain/adapters/providers/foxel.py
Normal file
@@ -0,0 +1,411 @@
|
||||
import asyncio
|
||||
import mimetypes
|
||||
import re
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any, AsyncIterator, Dict, List, Tuple
|
||||
from urllib.parse import quote
|
||||
|
||||
import httpx
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from models import StorageAdapter
|
||||
|
||||
|
||||
def _normalize_fs_path(path: str) -> str:
|
||||
path = (path or "").replace("\\", "/").strip()
|
||||
if not path or path == "/":
|
||||
return "/"
|
||||
if not path.startswith("/"):
|
||||
path = "/" + path
|
||||
path = re.sub(r"/{2,}", "/", path)
|
||||
if path != "/" and path.endswith("/"):
|
||||
path = path.rstrip("/")
|
||||
return path or "/"
|
||||
|
||||
|
||||
def _join_fs_path(base: str, rel: str | None) -> str:
|
||||
base = _normalize_fs_path(base)
|
||||
rel_norm = (rel or "").replace("\\", "/").strip().lstrip("/")
|
||||
if not rel_norm:
|
||||
return base
|
||||
if base == "/":
|
||||
return "/" + rel_norm
|
||||
return f"{base}/{rel_norm}"
|
||||
|
||||
|
||||
def _unwrap_success(payload: Any, *, context: str) -> Any:
|
||||
if not isinstance(payload, dict):
|
||||
return payload
|
||||
if "data" not in payload:
|
||||
return payload
|
||||
code = payload.get("code")
|
||||
if code not in (None, 0, 200):
|
||||
msg = payload.get("msg") or payload.get("message") or ""
|
||||
raise HTTPException(502, detail=f"Foxel 上游错误({context}): {msg}")
|
||||
return payload.get("data")
|
||||
|
||||
|
||||
class FoxelAdapter:
|
||||
def __init__(self, record: StorageAdapter):
|
||||
self.record = record
|
||||
cfg = record.config or {}
|
||||
|
||||
self.base_url: str = str(cfg.get("base_url", "")).rstrip("/")
|
||||
if not self.base_url.startswith("http"):
|
||||
raise ValueError("foxel requires base_url http/https")
|
||||
|
||||
self.username: str = str(cfg.get("username") or "")
|
||||
self.password: str = str(cfg.get("password") or "")
|
||||
if not self.username or not self.password:
|
||||
raise ValueError("foxel requires username and password")
|
||||
|
||||
self.timeout: float = float(cfg.get("timeout", 15))
|
||||
self.root_path: str = _normalize_fs_path(str(cfg.get("root") or "/"))
|
||||
|
||||
self._token: str | None = None
|
||||
self._login_lock = asyncio.Lock()
|
||||
|
||||
def get_effective_root(self, sub_path: str | None) -> str:
|
||||
base = _normalize_fs_path(self.root_path)
|
||||
if sub_path:
|
||||
return _join_fs_path(base, sub_path)
|
||||
return base
|
||||
|
||||
async def _login(self) -> str:
|
||||
url = self.base_url + "/api/auth/login"
|
||||
body = {"username": self.username, "password": self.password}
|
||||
async with httpx.AsyncClient(timeout=self.timeout, follow_redirects=True) as client:
|
||||
resp = await client.post(url, data=body)
|
||||
resp.raise_for_status()
|
||||
payload = resp.json()
|
||||
if not isinstance(payload, dict):
|
||||
raise HTTPException(502, detail="Foxel 登录响应异常")
|
||||
token = payload.get("access_token")
|
||||
if not token:
|
||||
raise HTTPException(502, detail="Foxel 登录失败: 缺少 access_token")
|
||||
return str(token)
|
||||
|
||||
async def _ensure_token(self) -> str:
|
||||
if self._token:
|
||||
return self._token
|
||||
async with self._login_lock:
|
||||
if self._token:
|
||||
return self._token
|
||||
self._token = await self._login()
|
||||
return self._token
|
||||
|
||||
async def _request_json(self, method: str, path: str, *, params: dict | None = None, json: Any = None) -> Any:
|
||||
url = self.base_url + path
|
||||
for attempt in range(2):
|
||||
token = await self._ensure_token()
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
async with httpx.AsyncClient(timeout=self.timeout, follow_redirects=True) as client:
|
||||
resp = await client.request(method, url, headers=headers, params=params, json=json)
|
||||
if resp.status_code == 401 and attempt == 0:
|
||||
self._token = None
|
||||
continue
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
raise HTTPException(502, detail="Foxel 上游请求失败")
|
||||
|
||||
@staticmethod
|
||||
def _encode_path(full_path: str) -> str:
|
||||
return quote(full_path.lstrip("/"), safe="/")
|
||||
|
||||
def _browse_path(self, full_path: str) -> str:
|
||||
full_path = _normalize_fs_path(full_path)
|
||||
if full_path == "/":
|
||||
return "/api/fs/"
|
||||
return "/api/fs/" + self._encode_path(full_path)
|
||||
|
||||
def _stat_path(self, full_path: str) -> str:
|
||||
full_path = _normalize_fs_path(full_path)
|
||||
if full_path == "/":
|
||||
return "/api/fs/stat/"
|
||||
return "/api/fs/stat/" + self._encode_path(full_path)
|
||||
|
||||
def _file_path(self, full_path: str) -> str:
|
||||
full_path = _normalize_fs_path(full_path)
|
||||
if full_path == "/":
|
||||
return "/api/fs/file/"
|
||||
return "/api/fs/file/" + self._encode_path(full_path)
|
||||
|
||||
def _stream_path(self, full_path: str) -> str:
|
||||
full_path = _normalize_fs_path(full_path)
|
||||
if full_path == "/":
|
||||
return "/api/fs/stream/"
|
||||
return "/api/fs/stream/" + self._encode_path(full_path)
|
||||
|
||||
async def list_dir(
|
||||
self,
|
||||
root: str,
|
||||
rel: str,
|
||||
page_num: int = 1,
|
||||
page_size: int = 50,
|
||||
sort_by: str = "name",
|
||||
sort_order: str = "asc",
|
||||
) -> Tuple[List[Dict], int]:
|
||||
rel = (rel or "").strip("/")
|
||||
full_path = _join_fs_path(root, rel)
|
||||
payload = await self._request_json(
|
||||
"GET",
|
||||
self._browse_path(full_path),
|
||||
params={
|
||||
"page": page_num,
|
||||
"page_size": page_size,
|
||||
"sort_by": sort_by,
|
||||
"sort_order": sort_order,
|
||||
},
|
||||
)
|
||||
data = _unwrap_success(payload, context="list_dir")
|
||||
if not isinstance(data, dict):
|
||||
raise HTTPException(502, detail="Foxel 浏览响应异常")
|
||||
entries = data.get("entries") or []
|
||||
pagination = data.get("pagination") or {}
|
||||
total = pagination.get("total")
|
||||
try:
|
||||
total_int = int(total) if total is not None else len(entries)
|
||||
except Exception:
|
||||
total_int = len(entries)
|
||||
if not isinstance(entries, list):
|
||||
entries = []
|
||||
return entries, total_int
|
||||
|
||||
async def stat_file(self, root: str, rel: str):
|
||||
rel = (rel or "").strip("/")
|
||||
full_path = _join_fs_path(root, rel)
|
||||
payload = await self._request_json("GET", self._stat_path(full_path))
|
||||
data = _unwrap_success(payload, context="stat_file")
|
||||
if not isinstance(data, dict):
|
||||
raise HTTPException(502, detail="Foxel stat 响应异常")
|
||||
return data
|
||||
|
||||
async def exists(self, root: str, rel: str) -> bool:
|
||||
rel = (rel or "").strip("/")
|
||||
full_path = _join_fs_path(root, rel)
|
||||
url = self.base_url + self._stat_path(full_path)
|
||||
for attempt in range(2):
|
||||
token = await self._ensure_token()
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
async with httpx.AsyncClient(timeout=self.timeout, follow_redirects=True) as client:
|
||||
resp = await client.get(url, headers=headers)
|
||||
if resp.status_code == 401 and attempt == 0:
|
||||
self._token = None
|
||||
continue
|
||||
return resp.status_code == 200
|
||||
return False
|
||||
|
||||
async def read_file(self, root: str, rel: str) -> bytes:
|
||||
rel = (rel or "").lstrip("/")
|
||||
full_path = _join_fs_path(root, rel)
|
||||
url = self.base_url + self._file_path(full_path)
|
||||
for attempt in range(2):
|
||||
token = await self._ensure_token()
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
async with httpx.AsyncClient(timeout=self.timeout, follow_redirects=True) as client:
|
||||
resp = await client.get(url, headers=headers)
|
||||
if resp.status_code == 401 and attempt == 0:
|
||||
self._token = None
|
||||
continue
|
||||
if resp.status_code == 404:
|
||||
raise FileNotFoundError(rel)
|
||||
resp.raise_for_status()
|
||||
return resp.content
|
||||
raise HTTPException(502, detail="Foxel 读取失败")
|
||||
|
||||
async def _upload_file_path(self, full_path: str, file_path: Path) -> None:
|
||||
url = self.base_url + self._file_path(full_path)
|
||||
filename = Path(full_path).name or file_path.name
|
||||
for attempt in range(2):
|
||||
token = await self._ensure_token()
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
with file_path.open("rb") as f:
|
||||
files = {"file": (filename, f, "application/octet-stream")}
|
||||
async with httpx.AsyncClient(timeout=self.timeout, follow_redirects=True) as client:
|
||||
resp = await client.post(url, headers=headers, files=files)
|
||||
if resp.status_code == 401 and attempt == 0:
|
||||
self._token = None
|
||||
continue
|
||||
resp.raise_for_status()
|
||||
return
|
||||
raise HTTPException(502, detail="Foxel 上传失败")
|
||||
|
||||
async def write_file(self, root: str, rel: str, data: bytes):
|
||||
rel = (rel or "").lstrip("/")
|
||||
full_path = _join_fs_path(root, rel)
|
||||
url = self.base_url + self._file_path(full_path)
|
||||
filename = Path(rel).name or "file"
|
||||
for attempt in range(2):
|
||||
token = await self._ensure_token()
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
files = {"file": (filename, data, "application/octet-stream")}
|
||||
async with httpx.AsyncClient(timeout=self.timeout, follow_redirects=True) as client:
|
||||
resp = await client.post(url, headers=headers, files=files)
|
||||
if resp.status_code == 401 and attempt == 0:
|
||||
self._token = None
|
||||
continue
|
||||
resp.raise_for_status()
|
||||
return True
|
||||
raise HTTPException(502, detail="Foxel 写入失败")
|
||||
|
||||
async def write_file_stream(self, root: str, rel: str, data_iter: AsyncIterator[bytes]):
|
||||
rel = (rel or "").lstrip("/")
|
||||
full_path = _join_fs_path(root, rel)
|
||||
suffix = Path(rel).suffix
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tf:
|
||||
tmp_path = Path(tf.name)
|
||||
|
||||
size = 0
|
||||
try:
|
||||
with tmp_path.open("wb") as f:
|
||||
async for chunk in data_iter:
|
||||
if not chunk:
|
||||
continue
|
||||
f.write(chunk)
|
||||
size += len(chunk)
|
||||
await self._upload_file_path(full_path, tmp_path)
|
||||
return size
|
||||
finally:
|
||||
try:
|
||||
tmp_path.unlink(missing_ok=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def mkdir(self, root: str, rel: str):
|
||||
rel = (rel or "").strip("/")
|
||||
full_path = _join_fs_path(root, rel)
|
||||
payload = await self._request_json("POST", "/api/fs/mkdir", json={"path": full_path})
|
||||
_unwrap_success(payload, context="mkdir")
|
||||
return True
|
||||
|
||||
async def delete(self, root: str, rel: str):
|
||||
rel = (rel or "").strip("/")
|
||||
full_path = _join_fs_path(root, rel)
|
||||
url = self.base_url + self._browse_path(full_path)
|
||||
for attempt in range(2):
|
||||
token = await self._ensure_token()
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
async with httpx.AsyncClient(timeout=self.timeout, follow_redirects=True) as client:
|
||||
resp = await client.delete(url, headers=headers)
|
||||
if resp.status_code == 401 and attempt == 0:
|
||||
self._token = None
|
||||
continue
|
||||
if resp.status_code == 404:
|
||||
return
|
||||
resp.raise_for_status()
|
||||
return
|
||||
raise HTTPException(502, detail="Foxel 删除失败")
|
||||
|
||||
async def move(self, root: str, src_rel: str, dst_rel: str):
|
||||
src_path = _join_fs_path(root, (src_rel or "").lstrip("/"))
|
||||
dst_path = _join_fs_path(root, (dst_rel or "").lstrip("/"))
|
||||
payload = await self._request_json("POST", "/api/fs/move", json={"src": src_path, "dst": dst_path})
|
||||
_unwrap_success(payload, context="move")
|
||||
return True
|
||||
|
||||
async def rename(self, root: str, src_rel: str, dst_rel: str):
|
||||
src_path = _join_fs_path(root, (src_rel or "").lstrip("/"))
|
||||
dst_path = _join_fs_path(root, (dst_rel or "").lstrip("/"))
|
||||
payload = await self._request_json("POST", "/api/fs/rename", json={"src": src_path, "dst": dst_path})
|
||||
_unwrap_success(payload, context="rename")
|
||||
return True
|
||||
|
||||
async def copy(self, root: str, src_rel: str, dst_rel: str, overwrite: bool = False):
|
||||
src_path = _join_fs_path(root, (src_rel or "").lstrip("/"))
|
||||
dst_path = _join_fs_path(root, (dst_rel or "").lstrip("/"))
|
||||
payload = await self._request_json(
|
||||
"POST",
|
||||
"/api/fs/copy",
|
||||
json={"src": src_path, "dst": dst_path},
|
||||
params={"overwrite": overwrite},
|
||||
)
|
||||
_unwrap_success(payload, context="copy")
|
||||
return True
|
||||
|
||||
async def stream_file(self, root: str, rel: str, range_header: str | None):
|
||||
rel = (rel or "").lstrip("/")
|
||||
full_path = _join_fs_path(root, rel)
|
||||
url = self.base_url + self._stream_path(full_path)
|
||||
|
||||
headers = {}
|
||||
if range_header:
|
||||
headers["Range"] = range_header
|
||||
|
||||
for attempt in range(2):
|
||||
token = await self._ensure_token()
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
client = httpx.AsyncClient(timeout=None, follow_redirects=True)
|
||||
stream_cm = client.stream("GET", url, headers=headers)
|
||||
try:
|
||||
resp = await stream_cm.__aenter__()
|
||||
except Exception:
|
||||
await client.aclose()
|
||||
raise
|
||||
|
||||
if resp.status_code == 401 and attempt == 0:
|
||||
try:
|
||||
await resp.aread()
|
||||
finally:
|
||||
await stream_cm.__aexit__(None, None, None)
|
||||
await client.aclose()
|
||||
self._token = None
|
||||
continue
|
||||
|
||||
if resp.status_code == 404:
|
||||
try:
|
||||
await resp.aread()
|
||||
finally:
|
||||
await stream_cm.__aexit__(None, None, None)
|
||||
await client.aclose()
|
||||
raise FileNotFoundError(rel)
|
||||
|
||||
if resp.status_code >= 400:
|
||||
try:
|
||||
await resp.aread()
|
||||
finally:
|
||||
await stream_cm.__aexit__(None, None, None)
|
||||
await client.aclose()
|
||||
resp.raise_for_status()
|
||||
|
||||
content_type = resp.headers.get("Content-Type") or (
|
||||
mimetypes.guess_type(rel)[0] or "application/octet-stream"
|
||||
)
|
||||
out_headers = {}
|
||||
for key in ("Accept-Ranges", "Content-Range", "Content-Length"):
|
||||
value = resp.headers.get(key)
|
||||
if value:
|
||||
out_headers[key] = value
|
||||
|
||||
async def iterator():
|
||||
try:
|
||||
async for chunk in resp.aiter_bytes():
|
||||
if chunk:
|
||||
yield chunk
|
||||
finally:
|
||||
await stream_cm.__aexit__(None, None, None)
|
||||
await client.aclose()
|
||||
|
||||
return StreamingResponse(
|
||||
iterator(),
|
||||
status_code=resp.status_code,
|
||||
headers=out_headers,
|
||||
media_type=content_type,
|
||||
)
|
||||
|
||||
raise HTTPException(502, detail="Foxel 流式读取失败")
|
||||
|
||||
|
||||
ADAPTER_TYPE = "foxel"
|
||||
CONFIG_SCHEMA = [
|
||||
{"key": "base_url", "label": "节点地址", "type": "string", "required": True, "placeholder": "http://127.0.0.1:8000"},
|
||||
{"key": "username", "label": "用户名", "type": "string", "required": True},
|
||||
{"key": "password", "label": "密码", "type": "password", "required": True},
|
||||
{"key": "root", "label": "远端根目录", "type": "string", "required": False, "default": "/", "placeholder": "/ 或 /drive"},
|
||||
{"key": "timeout", "label": "超时(秒)", "type": "number", "required": False, "default": 60},
|
||||
]
|
||||
|
||||
|
||||
def ADAPTER_FACTORY(rec: StorageAdapter):
|
||||
return FoxelAdapter(rec)
|
||||
@@ -1,11 +1,26 @@
|
||||
from typing import List, Dict, Tuple, AsyncIterator
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import os
|
||||
import struct
|
||||
from models import StorageAdapter
|
||||
from telethon import TelegramClient
|
||||
from telethon.crypto import AuthKey
|
||||
from telethon.sessions import StringSession
|
||||
from telethon.tl import types
|
||||
import socks
|
||||
|
||||
_SESSION_LOCKS: Dict[str, asyncio.Lock] = {}
|
||||
|
||||
|
||||
def _get_session_lock(session_string: str) -> asyncio.Lock:
|
||||
lock = _SESSION_LOCKS.get(session_string)
|
||||
if lock is None:
|
||||
lock = asyncio.Lock()
|
||||
_SESSION_LOCKS[session_string] = lock
|
||||
return lock
|
||||
|
||||
# 适配器类型标识
|
||||
ADAPTER_TYPE = "telegram"
|
||||
|
||||
@@ -54,9 +69,93 @@ class TelegramAdapter:
|
||||
if not all([self.api_id, self.api_hash, self.session_string, self.chat_id]):
|
||||
raise ValueError("Telegram 适配器需要 api_id, api_hash, session_string 和 chat_id")
|
||||
|
||||
@staticmethod
|
||||
def _parse_legacy_session_string(value: str) -> StringSession:
|
||||
"""
|
||||
兼容旧版 session_string 格式:
|
||||
- version(1B char) + base64(data)
|
||||
- data: dc_id(1B) + ip_len(2B) + ip(ASCII, ip_len bytes) + port(2B) + auth_key(256B)
|
||||
"""
|
||||
s = (value or "").strip()
|
||||
if not s:
|
||||
raise ValueError("session_string 为空")
|
||||
|
||||
body = s[1:] if s.startswith("1") else s
|
||||
raw = base64.urlsafe_b64decode(body)
|
||||
if len(raw) < 1 + 2 + 2 + 256:
|
||||
raise ValueError("legacy session 数据长度不足")
|
||||
|
||||
dc_id = raw[0]
|
||||
ip_len = struct.unpack(">H", raw[1:3])[0]
|
||||
expected_len = 1 + 2 + ip_len + 2 + 256
|
||||
if len(raw) != expected_len:
|
||||
raise ValueError("legacy session 数据长度不匹配")
|
||||
|
||||
ip_start = 3
|
||||
ip_end = ip_start + ip_len
|
||||
ip = raw[ip_start:ip_end].decode("utf-8")
|
||||
port = struct.unpack(">H", raw[ip_end : ip_end + 2])[0]
|
||||
key = raw[ip_end + 2 : ip_end + 2 + 256]
|
||||
|
||||
sess = StringSession()
|
||||
sess.set_dc(dc_id, ip, port)
|
||||
sess.auth_key = AuthKey(key)
|
||||
return sess
|
||||
|
||||
@staticmethod
|
||||
def _pick_photo_thumb(thumbs: list | None):
|
||||
if not thumbs:
|
||||
return None
|
||||
|
||||
cached = []
|
||||
others = []
|
||||
for t in thumbs:
|
||||
if isinstance(t, (types.PhotoCachedSize, types.PhotoStrippedSize)):
|
||||
cached.append(t)
|
||||
elif isinstance(t, (types.PhotoSize, types.PhotoSizeProgressive)):
|
||||
if not isinstance(t, types.PhotoSizeEmpty):
|
||||
others.append(t)
|
||||
|
||||
if cached:
|
||||
cached.sort(key=lambda x: len(getattr(x, "bytes", b"") or b""))
|
||||
return cached[-1]
|
||||
|
||||
if others:
|
||||
def _sz(x):
|
||||
if isinstance(x, types.PhotoSizeProgressive):
|
||||
return max(x.sizes or [0])
|
||||
return int(getattr(x, "size", 0) or 0)
|
||||
|
||||
others.sort(key=_sz)
|
||||
return others[-1]
|
||||
|
||||
return None
|
||||
|
||||
def _build_session(self) -> StringSession:
|
||||
s = (self.session_string or "").strip()
|
||||
if not s:
|
||||
raise ValueError("Telegram 适配器 session_string 为空")
|
||||
|
||||
try:
|
||||
return StringSession(s)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 少数工具可能去掉了 version 前缀,这里做一次兼容
|
||||
if not s.startswith("1"):
|
||||
try:
|
||||
return StringSession("1" + s)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
return self._parse_legacy_session_string(s)
|
||||
except Exception as exc:
|
||||
raise ValueError("Telegram session_string 无效,请使用 Telethon StringSession 重新生成") from exc
|
||||
|
||||
def _get_client(self) -> TelegramClient:
|
||||
"""创建一个新的 TelegramClient 实例"""
|
||||
return TelegramClient(StringSession(self.session_string), self.api_id, self.api_hash, proxy=self.proxy)
|
||||
return TelegramClient(self._build_session(), self.api_id, self.api_hash, proxy=self.proxy)
|
||||
|
||||
def get_effective_root(self, sub_path: str | None) -> str:
|
||||
return ""
|
||||
@@ -198,6 +297,41 @@ class TelegramAdapter:
|
||||
async def mkdir(self, root: str, rel: str):
|
||||
raise NotImplementedError("Telegram 适配器不支持创建目录。")
|
||||
|
||||
async def get_thumbnail(self, root: str, rel: str, size: str = "medium"):
|
||||
try:
|
||||
message_id_str, _ = rel.split('_', 1)
|
||||
message_id = int(message_id_str)
|
||||
except (ValueError, IndexError):
|
||||
return None
|
||||
|
||||
client = self._get_client()
|
||||
try:
|
||||
await client.connect()
|
||||
message = await client.get_messages(self.chat_id, ids=message_id)
|
||||
if not message:
|
||||
return None
|
||||
|
||||
doc = message.document or message.video
|
||||
thumbs = None
|
||||
if doc and getattr(doc, "thumbs", None):
|
||||
thumbs = list(doc.thumbs or [])
|
||||
elif message.photo and getattr(message.photo, "sizes", None):
|
||||
thumbs = list(message.photo.sizes or [])
|
||||
|
||||
thumb = self._pick_photo_thumb(thumbs)
|
||||
if not thumb:
|
||||
return None
|
||||
|
||||
result = await client.download_media(message, bytes, thumb=thumb)
|
||||
if isinstance(result, (bytes, bytearray)):
|
||||
return bytes(result)
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
finally:
|
||||
if client.is_connected():
|
||||
await client.disconnect()
|
||||
|
||||
async def delete(self, root: str, rel: str):
|
||||
"""删除一个文件 (即一条消息)"""
|
||||
try:
|
||||
@@ -236,6 +370,8 @@ class TelegramAdapter:
|
||||
raise HTTPException(status_code=400, detail=f"无效的文件路径格式: {rel}")
|
||||
|
||||
client = self._get_client()
|
||||
lock = _get_session_lock(self.session_string)
|
||||
await lock.acquire()
|
||||
|
||||
try:
|
||||
await client.connect()
|
||||
@@ -273,7 +409,6 @@ class TelegramAdapter:
|
||||
headers = {
|
||||
"Accept-Ranges": "bytes",
|
||||
"Content-Type": mime_type,
|
||||
"Content-Length": str(file_size),
|
||||
}
|
||||
|
||||
if range_header:
|
||||
@@ -285,7 +420,6 @@ class TelegramAdapter:
|
||||
if start >= file_size or end >= file_size or start > end:
|
||||
raise HTTPException(status_code=416, detail="Requested Range Not Satisfiable")
|
||||
status = 206
|
||||
headers["Content-Length"] = str(end - start + 1)
|
||||
headers["Content-Range"] = f"bytes {start}-{end}/{file_size}"
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail="Invalid Range header")
|
||||
@@ -304,18 +438,28 @@ class TelegramAdapter:
|
||||
if downloaded >= limit:
|
||||
break
|
||||
finally:
|
||||
if client.is_connected():
|
||||
await client.disconnect()
|
||||
try:
|
||||
if client.is_connected():
|
||||
await client.disconnect()
|
||||
finally:
|
||||
lock.release()
|
||||
|
||||
return StreamingResponse(iterator(), status_code=status, headers=headers)
|
||||
|
||||
except HTTPException:
|
||||
if client.is_connected():
|
||||
await client.disconnect()
|
||||
lock.release()
|
||||
raise
|
||||
except FileNotFoundError as e:
|
||||
if client.is_connected():
|
||||
await client.disconnect()
|
||||
lock.release()
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
if client.is_connected():
|
||||
await client.disconnect()
|
||||
lock.release()
|
||||
raise HTTPException(status_code=500, detail=f"Streaming failed: {str(e)}")
|
||||
|
||||
async def stat_file(self, root: str, rel: str):
|
||||
|
||||
@@ -4,7 +4,7 @@ from importlib import import_module
|
||||
from typing import Callable, Dict
|
||||
|
||||
from models import StorageAdapter
|
||||
from domain.adapters.providers.base import BaseAdapter
|
||||
from .providers.base import BaseAdapter
|
||||
|
||||
AdapterFactory = Callable[[StorageAdapter], BaseAdapter]
|
||||
|
||||
@@ -21,7 +21,7 @@ def normalize_adapter_type(value: str | None) -> str | None:
|
||||
|
||||
def discover_adapters():
|
||||
"""扫描 domain.adapters.providers 包, 自动注册适配器类型、工厂与配置 schema。"""
|
||||
from domain.adapters import providers as adapters_pkg
|
||||
from . import providers as adapters_pkg
|
||||
|
||||
TYPE_MAP.clear()
|
||||
CONFIG_SCHEMAS.clear()
|
||||
|
||||
@@ -2,13 +2,13 @@ from typing import Optional
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from domain.adapters.registry import (
|
||||
from domain.auth import User
|
||||
from .registry import (
|
||||
get_config_schemas,
|
||||
normalize_adapter_type,
|
||||
runtime_registry,
|
||||
)
|
||||
from domain.adapters.types import AdapterCreate, AdapterOut
|
||||
from domain.auth.types import User
|
||||
from .types import AdapterCreate, AdapterOut
|
||||
from models import StorageAdapter
|
||||
|
||||
|
||||
|
||||
9
domain/agent/__init__.py
Normal file
9
domain/agent/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from .service import AgentService
|
||||
from .types import AgentChatContext, AgentChatRequest, PendingToolCall
|
||||
|
||||
__all__ = [
|
||||
"AgentService",
|
||||
"AgentChatContext",
|
||||
"AgentChatRequest",
|
||||
"PendingToolCall",
|
||||
]
|
||||
38
domain/agent/api.py
Normal file
38
domain/agent/api.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from api.response import success
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.auth import User, get_current_active_user
|
||||
from .service import AgentService
|
||||
from .types import AgentChatRequest
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/agent", tags=["agent"])
|
||||
|
||||
|
||||
@router.post("/chat")
|
||||
@audit(action=AuditAction.CREATE, description="Agent 对话", body_fields=["auto_execute"])
|
||||
async def chat(
|
||||
request: Request,
|
||||
payload: AgentChatRequest,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
data = await AgentService.chat(payload, current_user)
|
||||
return success(data)
|
||||
|
||||
|
||||
@router.post("/chat/stream")
|
||||
@audit(action=AuditAction.CREATE, description="Agent 对话(SSE)", body_fields=["auto_execute"])
|
||||
async def chat_stream(
|
||||
request: Request,
|
||||
payload: AgentChatRequest,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
return StreamingResponse(
|
||||
AgentService.chat_stream(payload, current_user),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache"},
|
||||
)
|
||||
470
domain/agent/service.py
Normal file
470
domain/agent/service.py
Normal file
@@ -0,0 +1,470 @@
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import httpx
|
||||
from fastapi import HTTPException
|
||||
|
||||
from domain.ai import AIProviderService, MissingModelError, chat_completion, chat_completion_stream
|
||||
from domain.auth import User
|
||||
from .tools import get_tool, openai_tools, tool_result_to_content
|
||||
from .types import AgentChatRequest, PendingToolCall
|
||||
|
||||
|
||||
def _normalize_path(p: Optional[str]) -> Optional[str]:
|
||||
if not p:
|
||||
return None
|
||||
s = str(p).strip()
|
||||
if not s:
|
||||
return None
|
||||
s = s.replace("\\", "/")
|
||||
if not s.startswith("/"):
|
||||
s = "/" + s
|
||||
s = s.rstrip("/") or "/"
|
||||
return s
|
||||
|
||||
|
||||
def _build_system_prompt(current_path: Optional[str]) -> str:
|
||||
lines = [
|
||||
"你是 Foxel 的 AI 助手。",
|
||||
"你可以通过工具对文件/目录进行查询、读写、移动、复制、删除,以及运行处理器(processor)。",
|
||||
"",
|
||||
"可用工具:",
|
||||
"- vfs_list_dir:浏览目录(列出 entries + pagination)。",
|
||||
"- vfs_stat:查看文件/目录信息。",
|
||||
"- vfs_read_text:读取文本文件内容(不支持二进制)。",
|
||||
"- vfs_search:搜索文件(vector/filename)。",
|
||||
"- vfs_write_text:写入文本文件内容(覆盖)。",
|
||||
"- vfs_mkdir:创建目录。",
|
||||
"- vfs_delete:删除文件或目录。",
|
||||
"- vfs_move:移动路径。",
|
||||
"- vfs_copy:复制路径。",
|
||||
"- vfs_rename:重命名路径。",
|
||||
"- processors_list:获取可用处理器列表(含 type/name/config_schema/produces_file/supports_directory)。",
|
||||
"- processors_run:运行处理器处理文件或目录(会返回 task_id 或 task_ids)。",
|
||||
"",
|
||||
"规则:",
|
||||
"1) 读操作(vfs_list_dir/vfs_stat/vfs_read_text/vfs_search)可直接调用工具。",
|
||||
"2) 写/改/删操作(vfs_write_text/vfs_mkdir/vfs_delete/vfs_move/vfs_copy/vfs_rename/processors_run)默认需要用户确认;只有在开启自动执行时才应直接执行。",
|
||||
"3) 用户未给出明确路径时先追问;若提供了“当前文件管理目录”,可以基于它把相对描述补全为绝对路径(以 / 开头)。",
|
||||
"4) 修改文件内容:先读取(vfs_read_text)→给出改动点→确认后再写入(vfs_write_text)。",
|
||||
"5) processors_run 返回任务 id 后,说明任务已提交,可在任务队列查看进度。",
|
||||
"6) 回答保持简洁中文。",
|
||||
]
|
||||
if current_path:
|
||||
lines.append("")
|
||||
lines.append(f"当前文件管理目录:{current_path}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _ensure_tool_call_ids(message: Dict[str, Any]) -> Dict[str, Any]:
|
||||
tool_calls = message.get("tool_calls")
|
||||
if not isinstance(tool_calls, list):
|
||||
return message
|
||||
|
||||
changed = False
|
||||
for idx, call in enumerate(tool_calls):
|
||||
if not isinstance(call, dict):
|
||||
continue
|
||||
call_id = call.get("id")
|
||||
if isinstance(call_id, str) and call_id.strip():
|
||||
continue
|
||||
call["id"] = f"call_{idx}"
|
||||
changed = True
|
||||
|
||||
if changed:
|
||||
message["tool_calls"] = tool_calls
|
||||
return message
|
||||
|
||||
|
||||
def _extract_pending(tool_call: Dict[str, Any], requires_confirmation: bool) -> PendingToolCall:
|
||||
call_id = str(tool_call.get("id") or "")
|
||||
fn = tool_call.get("function") or {}
|
||||
name = str((fn.get("name") if isinstance(fn, dict) else None) or "")
|
||||
raw_args = fn.get("arguments") if isinstance(fn, dict) else None
|
||||
arguments: Dict[str, Any] = {}
|
||||
if isinstance(raw_args, str) and raw_args.strip():
|
||||
try:
|
||||
parsed = json.loads(raw_args)
|
||||
if isinstance(parsed, dict):
|
||||
arguments = parsed
|
||||
except json.JSONDecodeError:
|
||||
arguments = {}
|
||||
return PendingToolCall(
|
||||
id=call_id,
|
||||
name=name,
|
||||
arguments=arguments,
|
||||
requires_confirmation=requires_confirmation,
|
||||
)
|
||||
|
||||
|
||||
def _find_last_assistant_tool_calls(messages: List[Dict[str, Any]]) -> Tuple[int, Dict[str, Any]]:
|
||||
for idx in range(len(messages) - 1, -1, -1):
|
||||
msg = messages[idx]
|
||||
if not isinstance(msg, dict):
|
||||
continue
|
||||
if msg.get("role") != "assistant":
|
||||
continue
|
||||
tool_calls = msg.get("tool_calls")
|
||||
if isinstance(tool_calls, list) and tool_calls:
|
||||
return idx, msg
|
||||
raise HTTPException(status_code=400, detail="没有可确认的待执行操作")
|
||||
|
||||
|
||||
def _existing_tool_result_ids(messages: List[Dict[str, Any]]) -> set[str]:
|
||||
ids: set[str] = set()
|
||||
for msg in messages:
|
||||
if not isinstance(msg, dict):
|
||||
continue
|
||||
if msg.get("role") != "tool":
|
||||
continue
|
||||
tool_call_id = msg.get("tool_call_id")
|
||||
if isinstance(tool_call_id, str) and tool_call_id.strip():
|
||||
ids.add(tool_call_id)
|
||||
return ids
|
||||
|
||||
|
||||
async def _choose_chat_ability() -> str:
|
||||
tools_model = await AIProviderService.get_default_model("tools")
|
||||
return "tools" if tools_model else "chat"
|
||||
|
||||
|
||||
def _sse(event: str, data: Any) -> bytes:
|
||||
payload = json.dumps(data, ensure_ascii=False, separators=(",", ":"))
|
||||
return f"event: {event}\ndata: {payload}\n\n".encode("utf-8")
|
||||
|
||||
|
||||
def _format_exc(exc: BaseException) -> str:
|
||||
text = str(exc)
|
||||
return text if text else exc.__class__.__name__
|
||||
|
||||
|
||||
class AgentService:
|
||||
@classmethod
|
||||
async def chat(cls, req: AgentChatRequest, user: Optional[User]) -> Dict[str, Any]:
|
||||
history: List[Dict[str, Any]] = list(req.messages or [])
|
||||
current_path = _normalize_path(req.context.current_path if req.context else None)
|
||||
|
||||
system_prompt = _build_system_prompt(current_path)
|
||||
internal_messages: List[Dict[str, Any]] = [{"role": "system", "content": system_prompt}] + history
|
||||
|
||||
new_messages: List[Dict[str, Any]] = []
|
||||
pending: List[PendingToolCall] = []
|
||||
|
||||
approved_ids = {i for i in (req.approved_tool_call_ids or []) if isinstance(i, str) and i.strip()}
|
||||
rejected_ids = {i for i in (req.rejected_tool_call_ids or []) if isinstance(i, str) and i.strip()}
|
||||
|
||||
if approved_ids or rejected_ids:
|
||||
_, last_call_msg = _find_last_assistant_tool_calls(internal_messages)
|
||||
last_call_msg = _ensure_tool_call_ids(last_call_msg)
|
||||
tool_calls = last_call_msg.get("tool_calls") or []
|
||||
call_map: Dict[str, Dict[str, Any]] = {
|
||||
str(c.get("id")): c
|
||||
for c in tool_calls
|
||||
if isinstance(c, dict) and isinstance(c.get("id"), str)
|
||||
}
|
||||
|
||||
existing_ids = _existing_tool_result_ids(internal_messages)
|
||||
for call_id in approved_ids | rejected_ids:
|
||||
if call_id in existing_ids:
|
||||
continue
|
||||
tool_call = call_map.get(call_id)
|
||||
if not tool_call:
|
||||
continue
|
||||
fn = tool_call.get("function") or {}
|
||||
name = fn.get("name") if isinstance(fn, dict) else None
|
||||
args_raw = fn.get("arguments") if isinstance(fn, dict) else None
|
||||
args: Dict[str, Any] = {}
|
||||
if isinstance(args_raw, str) and args_raw.strip():
|
||||
try:
|
||||
parsed = json.loads(args_raw)
|
||||
if isinstance(parsed, dict):
|
||||
args = parsed
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
|
||||
spec = get_tool(str(name or ""))
|
||||
if call_id in rejected_ids:
|
||||
content = tool_result_to_content({"canceled": True, "reason": "user_rejected"})
|
||||
tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content}
|
||||
internal_messages.append(tool_msg)
|
||||
new_messages.append(tool_msg)
|
||||
continue
|
||||
|
||||
if not spec:
|
||||
content = tool_result_to_content({"error": f"unknown_tool: {name}"})
|
||||
tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content}
|
||||
internal_messages.append(tool_msg)
|
||||
new_messages.append(tool_msg)
|
||||
continue
|
||||
|
||||
try:
|
||||
result = await spec.handler(args)
|
||||
content = tool_result_to_content(result)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
content = tool_result_to_content({"error": str(exc)})
|
||||
tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content}
|
||||
internal_messages.append(tool_msg)
|
||||
new_messages.append(tool_msg)
|
||||
|
||||
tools_schema = openai_tools()
|
||||
ability = await _choose_chat_ability()
|
||||
max_loops = 4
|
||||
|
||||
for _ in range(max_loops):
|
||||
try:
|
||||
assistant = await chat_completion(
|
||||
internal_messages,
|
||||
ability=ability,
|
||||
tools=tools_schema,
|
||||
tool_choice="auto",
|
||||
timeout=60.0,
|
||||
)
|
||||
except MissingModelError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
except httpx.HTTPStatusError as exc:
|
||||
raise HTTPException(status_code=502, detail=f"对话请求失败: {exc}") from exc
|
||||
except httpx.RequestError as exc:
|
||||
raise HTTPException(status_code=502, detail=f"对话请求异常: {exc}") from exc
|
||||
|
||||
assistant = _ensure_tool_call_ids(assistant)
|
||||
internal_messages.append(assistant)
|
||||
new_messages.append(assistant)
|
||||
|
||||
tool_calls = assistant.get("tool_calls")
|
||||
if not isinstance(tool_calls, list) or not tool_calls:
|
||||
break
|
||||
|
||||
pending = []
|
||||
for call in tool_calls:
|
||||
if not isinstance(call, dict):
|
||||
continue
|
||||
call_id = str(call.get("id") or "")
|
||||
fn = call.get("function") or {}
|
||||
name = fn.get("name") if isinstance(fn, dict) else None
|
||||
args_raw = fn.get("arguments") if isinstance(fn, dict) else None
|
||||
args: Dict[str, Any] = {}
|
||||
if isinstance(args_raw, str) and args_raw.strip():
|
||||
try:
|
||||
parsed = json.loads(args_raw)
|
||||
if isinstance(parsed, dict):
|
||||
args = parsed
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
|
||||
spec = get_tool(str(name or ""))
|
||||
if not spec:
|
||||
content = tool_result_to_content({"error": f"unknown_tool: {name}"})
|
||||
tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content}
|
||||
internal_messages.append(tool_msg)
|
||||
new_messages.append(tool_msg)
|
||||
continue
|
||||
|
||||
if spec.requires_confirmation and not req.auto_execute:
|
||||
pending.append(_extract_pending(call, True))
|
||||
continue
|
||||
|
||||
try:
|
||||
result = await spec.handler(args)
|
||||
content = tool_result_to_content(result)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
content = tool_result_to_content({"error": str(exc)})
|
||||
tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content}
|
||||
internal_messages.append(tool_msg)
|
||||
new_messages.append(tool_msg)
|
||||
|
||||
if pending:
|
||||
break
|
||||
|
||||
payload: Dict[str, Any] = {"messages": new_messages}
|
||||
if pending:
|
||||
payload["pending_tool_calls"] = [p.model_dump() for p in pending]
|
||||
return payload
|
||||
|
||||
@classmethod
|
||||
async def chat_stream(cls, req: AgentChatRequest, user: Optional[User]):
|
||||
history: List[Dict[str, Any]] = list(req.messages or [])
|
||||
current_path = _normalize_path(req.context.current_path if req.context else None)
|
||||
|
||||
system_prompt = _build_system_prompt(current_path)
|
||||
internal_messages: List[Dict[str, Any]] = [{"role": "system", "content": system_prompt}] + history
|
||||
|
||||
new_messages: List[Dict[str, Any]] = []
|
||||
pending: List[PendingToolCall] = []
|
||||
|
||||
approved_ids = {i for i in (req.approved_tool_call_ids or []) if isinstance(i, str) and i.strip()}
|
||||
rejected_ids = {i for i in (req.rejected_tool_call_ids or []) if isinstance(i, str) and i.strip()}
|
||||
|
||||
try:
|
||||
if approved_ids or rejected_ids:
|
||||
_, last_call_msg = _find_last_assistant_tool_calls(internal_messages)
|
||||
last_call_msg = _ensure_tool_call_ids(last_call_msg)
|
||||
tool_calls = last_call_msg.get("tool_calls") or []
|
||||
call_map: Dict[str, Dict[str, Any]] = {
|
||||
str(c.get("id")): c
|
||||
for c in tool_calls
|
||||
if isinstance(c, dict) and isinstance(c.get("id"), str)
|
||||
}
|
||||
|
||||
existing_ids = _existing_tool_result_ids(internal_messages)
|
||||
for call_id in approved_ids | rejected_ids:
|
||||
if call_id in existing_ids:
|
||||
continue
|
||||
tool_call = call_map.get(call_id)
|
||||
if not tool_call:
|
||||
continue
|
||||
fn = tool_call.get("function") or {}
|
||||
name = fn.get("name") if isinstance(fn, dict) else None
|
||||
args_raw = fn.get("arguments") if isinstance(fn, dict) else None
|
||||
args: Dict[str, Any] = {}
|
||||
if isinstance(args_raw, str) and args_raw.strip():
|
||||
try:
|
||||
parsed = json.loads(args_raw)
|
||||
if isinstance(parsed, dict):
|
||||
args = parsed
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
|
||||
spec = get_tool(str(name or ""))
|
||||
if call_id in rejected_ids:
|
||||
content = tool_result_to_content({"canceled": True, "reason": "user_rejected"})
|
||||
tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content}
|
||||
internal_messages.append(tool_msg)
|
||||
new_messages.append(tool_msg)
|
||||
yield _sse("tool_end", {"tool_call_id": call_id, "name": str(name or ""), "message": tool_msg})
|
||||
continue
|
||||
|
||||
if not spec:
|
||||
content = tool_result_to_content({"error": f"unknown_tool: {name}"})
|
||||
tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content}
|
||||
internal_messages.append(tool_msg)
|
||||
new_messages.append(tool_msg)
|
||||
yield _sse("tool_end", {"tool_call_id": call_id, "name": str(name or ""), "message": tool_msg})
|
||||
continue
|
||||
|
||||
yield _sse("tool_start", {"tool_call_id": call_id, "name": spec.name})
|
||||
try:
|
||||
result = await spec.handler(args)
|
||||
content = tool_result_to_content(result)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
content = tool_result_to_content({"error": str(exc)})
|
||||
tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content}
|
||||
internal_messages.append(tool_msg)
|
||||
new_messages.append(tool_msg)
|
||||
yield _sse("tool_end", {"tool_call_id": call_id, "name": spec.name, "message": tool_msg})
|
||||
|
||||
tools_schema = openai_tools()
|
||||
ability = await _choose_chat_ability()
|
||||
max_loops = 4
|
||||
|
||||
for _ in range(max_loops):
|
||||
assistant_event_id = uuid.uuid4().hex
|
||||
yield _sse("assistant_start", {"id": assistant_event_id})
|
||||
|
||||
assistant_message: Dict[str, Any] | None = None
|
||||
try:
|
||||
async for event in chat_completion_stream(
|
||||
internal_messages,
|
||||
ability=ability,
|
||||
tools=tools_schema,
|
||||
tool_choice="auto",
|
||||
timeout=60.0,
|
||||
):
|
||||
if event.get("type") == "delta":
|
||||
delta = event.get("delta")
|
||||
if isinstance(delta, str) and delta:
|
||||
yield _sse("assistant_delta", {"id": assistant_event_id, "delta": delta})
|
||||
elif event.get("type") == "message":
|
||||
msg = event.get("message")
|
||||
if isinstance(msg, dict):
|
||||
assistant_message = msg
|
||||
except MissingModelError as exc:
|
||||
raise HTTPException(status_code=400, detail=_format_exc(exc)) from exc
|
||||
except httpx.HTTPStatusError as exc:
|
||||
raise HTTPException(status_code=502, detail=f"对话请求失败: {_format_exc(exc)}") from exc
|
||||
except httpx.RequestError as exc:
|
||||
raise HTTPException(status_code=502, detail=f"对话请求异常: {_format_exc(exc)}") from exc
|
||||
|
||||
if not assistant_message:
|
||||
assistant_message = {"role": "assistant", "content": ""}
|
||||
|
||||
assistant_message = _ensure_tool_call_ids(assistant_message)
|
||||
internal_messages.append(assistant_message)
|
||||
new_messages.append(assistant_message)
|
||||
yield _sse("assistant_end", {"id": assistant_event_id, "message": assistant_message})
|
||||
|
||||
tool_calls = assistant_message.get("tool_calls")
|
||||
if not isinstance(tool_calls, list) or not tool_calls:
|
||||
break
|
||||
|
||||
pending = []
|
||||
for call in tool_calls:
|
||||
if not isinstance(call, dict):
|
||||
continue
|
||||
call_id = str(call.get("id") or "")
|
||||
fn = call.get("function") or {}
|
||||
name = fn.get("name") if isinstance(fn, dict) else None
|
||||
args_raw = fn.get("arguments") if isinstance(fn, dict) else None
|
||||
args: Dict[str, Any] = {}
|
||||
if isinstance(args_raw, str) and args_raw.strip():
|
||||
try:
|
||||
parsed = json.loads(args_raw)
|
||||
if isinstance(parsed, dict):
|
||||
args = parsed
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
|
||||
spec = get_tool(str(name or ""))
|
||||
if not spec:
|
||||
content = tool_result_to_content({"error": f"unknown_tool: {name}"})
|
||||
tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content}
|
||||
internal_messages.append(tool_msg)
|
||||
new_messages.append(tool_msg)
|
||||
yield _sse("tool_end", {"tool_call_id": call_id, "name": str(name or ""), "message": tool_msg})
|
||||
continue
|
||||
|
||||
if spec.requires_confirmation and not req.auto_execute:
|
||||
pending.append(_extract_pending(call, True))
|
||||
continue
|
||||
|
||||
yield _sse("tool_start", {"tool_call_id": call_id, "name": spec.name})
|
||||
try:
|
||||
result = await spec.handler(args)
|
||||
content = tool_result_to_content(result)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
content = tool_result_to_content({"error": str(exc)})
|
||||
tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content}
|
||||
internal_messages.append(tool_msg)
|
||||
new_messages.append(tool_msg)
|
||||
yield _sse("tool_end", {"tool_call_id": call_id, "name": spec.name, "message": tool_msg})
|
||||
|
||||
if pending:
|
||||
yield _sse("pending", {"pending_tool_calls": [p.model_dump() for p in pending]})
|
||||
break
|
||||
|
||||
payload: Dict[str, Any] = {"messages": new_messages}
|
||||
if pending:
|
||||
payload["pending_tool_calls"] = [p.model_dump() for p in pending]
|
||||
yield _sse("done", payload)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except HTTPException as exc:
|
||||
detail = exc.detail
|
||||
content = detail if isinstance(detail, str) else str(detail)
|
||||
if not content.strip():
|
||||
content = f"请求失败({exc.status_code})"
|
||||
new_messages.append({"role": "assistant", "content": content})
|
||||
payload: Dict[str, Any] = {"messages": new_messages}
|
||||
if pending:
|
||||
payload["pending_tool_calls"] = [p.model_dump() for p in pending]
|
||||
yield _sse("done", payload)
|
||||
return
|
||||
except Exception as exc: # noqa: BLE001
|
||||
new_messages.append({"role": "assistant", "content": f"服务端异常: {_format_exc(exc)}"})
|
||||
payload: Dict[str, Any] = {"messages": new_messages}
|
||||
if pending:
|
||||
payload["pending_tool_calls"] = [p.model_dump() for p in pending]
|
||||
yield _sse("done", payload)
|
||||
return
|
||||
412
domain/agent/tools.py
Normal file
412
domain/agent/tools.py
Normal file
@@ -0,0 +1,412 @@
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional
|
||||
|
||||
from domain.processors import ProcessDirectoryRequest, ProcessRequest, ProcessorService
|
||||
from domain.virtual_fs import VirtualFSService
|
||||
from domain.virtual_fs.search import VirtualFSSearchService
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ToolSpec:
|
||||
name: str
|
||||
description: str
|
||||
parameters: Dict[str, Any]
|
||||
requires_confirmation: bool
|
||||
handler: Callable[[Dict[str, Any]], Awaitable[Any]]
|
||||
|
||||
|
||||
async def _processors_list(_: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return {"processors": ProcessorService.list_processors()}
|
||||
|
||||
|
||||
async def _processors_run(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
path = str(args.get("path") or "")
|
||||
processor_type = str(args.get("processor_type") or "")
|
||||
config = args.get("config")
|
||||
if not isinstance(config, dict):
|
||||
config = {}
|
||||
|
||||
save_to = args.get("save_to")
|
||||
save_to = str(save_to) if isinstance(save_to, str) and save_to.strip() else None
|
||||
|
||||
max_depth = args.get("max_depth")
|
||||
max_depth_value: Optional[int] = None
|
||||
if max_depth is not None:
|
||||
try:
|
||||
max_depth_value = int(max_depth)
|
||||
except (TypeError, ValueError):
|
||||
max_depth_value = None
|
||||
|
||||
suffix = args.get("suffix")
|
||||
suffix_value = str(suffix) if isinstance(suffix, str) and suffix.strip() else None
|
||||
|
||||
overwrite_value = args.get("overwrite")
|
||||
overwrite = bool(overwrite_value) if overwrite_value is not None else None
|
||||
|
||||
is_dir = await VirtualFSService.path_is_directory(path)
|
||||
if is_dir and (max_depth_value is not None or suffix_value is not None):
|
||||
req = ProcessDirectoryRequest(
|
||||
path=path,
|
||||
processor_type=processor_type,
|
||||
config=config,
|
||||
overwrite=True if overwrite is None else overwrite,
|
||||
max_depth=max_depth_value,
|
||||
suffix=suffix_value,
|
||||
)
|
||||
result = await ProcessorService.process_directory(req)
|
||||
return {"mode": "directory", **result}
|
||||
|
||||
req = ProcessRequest(
|
||||
path=path,
|
||||
processor_type=processor_type,
|
||||
config=config,
|
||||
save_to=save_to,
|
||||
overwrite=False if overwrite is None else overwrite,
|
||||
)
|
||||
result = await ProcessorService.process_file(req)
|
||||
return {"mode": "file", **result}
|
||||
|
||||
|
||||
def _normalize_vfs_path(value: Any) -> str:
|
||||
s = str(value or "").strip().replace("\\", "/")
|
||||
if not s:
|
||||
return ""
|
||||
if not s.startswith("/"):
|
||||
s = "/" + s
|
||||
s = s.rstrip("/") or "/"
|
||||
return s
|
||||
|
||||
|
||||
def _require_vfs_path(value: Any, field: str) -> str:
|
||||
path = _normalize_vfs_path(value)
|
||||
if not path:
|
||||
raise ValueError(f"missing_{field}")
|
||||
return path
|
||||
|
||||
|
||||
async def _vfs_list_dir(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
path = _normalize_vfs_path(args.get("path") or "/") or "/"
|
||||
page = int(args.get("page") or 1)
|
||||
page_size = int(args.get("page_size") or 50)
|
||||
sort_by = str(args.get("sort_by") or "name")
|
||||
sort_order = str(args.get("sort_order") or "asc")
|
||||
return await VirtualFSService.list_directory(path, page, page_size, sort_by, sort_order)
|
||||
|
||||
|
||||
async def _vfs_stat(args: Dict[str, Any]) -> Any:
|
||||
path = _require_vfs_path(args.get("path"), "path")
|
||||
return await VirtualFSService.stat(path)
|
||||
|
||||
|
||||
async def _vfs_read_text(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
path = _require_vfs_path(args.get("path"), "path")
|
||||
encoding = str(args.get("encoding") or "utf-8")
|
||||
max_chars = int(args.get("max_chars") or 8000)
|
||||
|
||||
data = await VirtualFSService.read_file(path)
|
||||
if isinstance(data, (bytes, bytearray)):
|
||||
try:
|
||||
text = bytes(data).decode(encoding)
|
||||
except UnicodeDecodeError:
|
||||
return {"error": "binary_or_invalid_text", "path": path}
|
||||
elif isinstance(data, str):
|
||||
text = data
|
||||
else:
|
||||
text = str(data)
|
||||
|
||||
original_len = len(text)
|
||||
truncated = original_len > max_chars
|
||||
if truncated:
|
||||
text = text[:max_chars]
|
||||
return {
|
||||
"path": path,
|
||||
"encoding": encoding,
|
||||
"content": text,
|
||||
"truncated": truncated,
|
||||
"length": original_len,
|
||||
}
|
||||
|
||||
|
||||
async def _vfs_write_text(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
path = _require_vfs_path(args.get("path"), "path")
|
||||
if path == "/":
|
||||
raise ValueError("invalid_path")
|
||||
encoding = str(args.get("encoding") or "utf-8")
|
||||
content = str(args.get("content") or "")
|
||||
data = content.encode(encoding)
|
||||
await VirtualFSService.write_file(path, data)
|
||||
return {"written": True, "path": path, "encoding": encoding, "bytes": len(data)}
|
||||
|
||||
|
||||
async def _vfs_mkdir(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
path = _require_vfs_path(args.get("path"), "path")
|
||||
return await VirtualFSService.mkdir(path)
|
||||
|
||||
|
||||
async def _vfs_delete(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
path = _require_vfs_path(args.get("path"), "path")
|
||||
return await VirtualFSService.delete(path)
|
||||
|
||||
|
||||
async def _vfs_move(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
src = _require_vfs_path(args.get("src"), "src")
|
||||
dst = _require_vfs_path(args.get("dst"), "dst")
|
||||
if src == "/" or dst == "/":
|
||||
raise ValueError("invalid_path")
|
||||
overwrite = bool(args.get("overwrite") or False)
|
||||
return await VirtualFSService.move(src, dst, overwrite)
|
||||
|
||||
|
||||
async def _vfs_copy(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
src = _require_vfs_path(args.get("src"), "src")
|
||||
dst = _require_vfs_path(args.get("dst"), "dst")
|
||||
if src == "/" or dst == "/":
|
||||
raise ValueError("invalid_path")
|
||||
overwrite = bool(args.get("overwrite") or False)
|
||||
return await VirtualFSService.copy(src, dst, overwrite)
|
||||
|
||||
|
||||
async def _vfs_rename(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
src = _require_vfs_path(args.get("src"), "src")
|
||||
dst = _require_vfs_path(args.get("dst"), "dst")
|
||||
if src == "/" or dst == "/":
|
||||
raise ValueError("invalid_path")
|
||||
overwrite = bool(args.get("overwrite") or False)
|
||||
return await VirtualFSService.rename(src, dst, overwrite)
|
||||
|
||||
|
||||
async def _vfs_search(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
q = str(args.get("q") or "").strip()
|
||||
if not q:
|
||||
raise ValueError("missing_q")
|
||||
mode = str(args.get("mode") or "vector")
|
||||
top_k = int(args.get("top_k") or 10)
|
||||
page = int(args.get("page") or 1)
|
||||
page_size = int(args.get("page_size") or 10)
|
||||
return await VirtualFSSearchService.search(q, top_k, mode, page, page_size)
|
||||
|
||||
|
||||
TOOLS: Dict[str, ToolSpec] = {
|
||||
"processors_list": ToolSpec(
|
||||
name="processors_list",
|
||||
description="获取可用处理器列表(type/name/config_schema 等)。",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"additionalProperties": False,
|
||||
},
|
||||
requires_confirmation=False,
|
||||
handler=_processors_list,
|
||||
),
|
||||
"processors_run": ToolSpec(
|
||||
name="processors_run",
|
||||
description=(
|
||||
"运行处理器处理文件或目录。"
|
||||
" 对目录可选 max_depth/suffix;对文件可选 overwrite/save_to。"
|
||||
" 返回任务 id(去任务队列查看进度)。"
|
||||
),
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "文件或目录路径(绝对路径,如 /foo/bar)"},
|
||||
"processor_type": {"type": "string", "description": "处理器类型(例如 image_watermark)"},
|
||||
"config": {"type": "object", "description": "处理器配置,按 processors_list 返回的 config_schema 填写"},
|
||||
"overwrite": {"type": "boolean", "description": "是否覆盖原文件/目录内文件"},
|
||||
"save_to": {"type": "string", "description": "保存到指定路径(仅文件模式,且 overwrite=false 时使用)"},
|
||||
"max_depth": {"type": "integer", "description": "目录遍历深度(仅目录模式)"},
|
||||
"suffix": {"type": "string", "description": "目录批处理时的输出后缀(仅 produces_file 且 overwrite=false)"},
|
||||
},
|
||||
"required": ["path", "processor_type"],
|
||||
},
|
||||
requires_confirmation=True,
|
||||
handler=_processors_run,
|
||||
),
|
||||
"vfs_list_dir": ToolSpec(
|
||||
name="vfs_list_dir",
|
||||
description="浏览目录(列出 entries + pagination)。",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "目录路径(绝对路径,如 /foo/bar)"},
|
||||
"page": {"type": "integer", "description": "页码(从 1 开始)"},
|
||||
"page_size": {"type": "integer", "description": "每页条数"},
|
||||
"sort_by": {"type": "string", "description": "排序字段:name/size/mtime"},
|
||||
"sort_order": {"type": "string", "description": "排序顺序:asc/desc"},
|
||||
},
|
||||
"required": ["path"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
requires_confirmation=False,
|
||||
handler=_vfs_list_dir,
|
||||
),
|
||||
"vfs_stat": ToolSpec(
|
||||
name="vfs_stat",
|
||||
description="查看文件/目录信息(size/mtime/is_dir/has_thumbnail/vector_index 等)。",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "路径(绝对路径,如 /foo/bar.txt)"},
|
||||
},
|
||||
"required": ["path"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
requires_confirmation=False,
|
||||
handler=_vfs_stat,
|
||||
),
|
||||
"vfs_read_text": ToolSpec(
|
||||
name="vfs_read_text",
|
||||
description="读取文本文件内容(解码失败视为二进制,返回 error)。",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "文件路径(绝对路径,如 /foo/bar.md)"},
|
||||
"encoding": {"type": "string", "description": "文本编码(默认 utf-8)"},
|
||||
"max_chars": {"type": "integer", "description": "最多返回的字符数(默认 8000)"},
|
||||
},
|
||||
"required": ["path"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
requires_confirmation=False,
|
||||
handler=_vfs_read_text,
|
||||
),
|
||||
"vfs_write_text": ToolSpec(
|
||||
name="vfs_write_text",
|
||||
description="写入文本文件内容(会覆盖目标文件)。",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "文件路径(绝对路径,如 /foo/bar.md)"},
|
||||
"content": {"type": "string", "description": "要写入的文本内容"},
|
||||
"encoding": {"type": "string", "description": "文本编码(默认 utf-8)"},
|
||||
},
|
||||
"required": ["path", "content"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
requires_confirmation=True,
|
||||
handler=_vfs_write_text,
|
||||
),
|
||||
"vfs_mkdir": ToolSpec(
|
||||
name="vfs_mkdir",
|
||||
description="创建目录。",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "目录路径(绝对路径,如 /foo/bar)"},
|
||||
},
|
||||
"required": ["path"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
requires_confirmation=True,
|
||||
handler=_vfs_mkdir,
|
||||
),
|
||||
"vfs_delete": ToolSpec(
|
||||
name="vfs_delete",
|
||||
description="删除文件或目录(由底层适配器决定是否递归)。",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "路径(绝对路径,如 /foo/bar 或 /foo/bar.txt)"},
|
||||
},
|
||||
"required": ["path"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
requires_confirmation=True,
|
||||
handler=_vfs_delete,
|
||||
),
|
||||
"vfs_move": ToolSpec(
|
||||
name="vfs_move",
|
||||
description="移动路径(可能进入任务队列)。",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"src": {"type": "string", "description": "源路径(绝对路径)"},
|
||||
"dst": {"type": "string", "description": "目标路径(绝对路径)"},
|
||||
"overwrite": {"type": "boolean", "description": "是否允许覆盖已存在目标(默认 false)"},
|
||||
},
|
||||
"required": ["src", "dst"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
requires_confirmation=True,
|
||||
handler=_vfs_move,
|
||||
),
|
||||
"vfs_copy": ToolSpec(
|
||||
name="vfs_copy",
|
||||
description="复制路径(可能进入任务队列)。",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"src": {"type": "string", "description": "源路径(绝对路径)"},
|
||||
"dst": {"type": "string", "description": "目标路径(绝对路径)"},
|
||||
"overwrite": {"type": "boolean", "description": "是否覆盖已存在目标(默认 false)"},
|
||||
},
|
||||
"required": ["src", "dst"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
requires_confirmation=True,
|
||||
handler=_vfs_copy,
|
||||
),
|
||||
"vfs_rename": ToolSpec(
|
||||
name="vfs_rename",
|
||||
description="重命名路径(本质是同目录 move)。",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"src": {"type": "string", "description": "源路径(绝对路径)"},
|
||||
"dst": {"type": "string", "description": "目标路径(绝对路径)"},
|
||||
"overwrite": {"type": "boolean", "description": "是否允许覆盖已存在目标(默认 false)"},
|
||||
},
|
||||
"required": ["src", "dst"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
requires_confirmation=True,
|
||||
handler=_vfs_rename,
|
||||
),
|
||||
"vfs_search": ToolSpec(
|
||||
name="vfs_search",
|
||||
description="搜索文件(mode=vector 或 filename)。",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"q": {"type": "string", "description": "搜索关键词"},
|
||||
"mode": {"type": "string", "description": "搜索模式:vector/filename(默认 vector)"},
|
||||
"top_k": {"type": "integer", "description": "返回数量(vector 模式使用,默认 10)"},
|
||||
"page": {"type": "integer", "description": "页码(filename 模式使用,默认 1)"},
|
||||
"page_size": {"type": "integer", "description": "分页大小(filename 模式使用,默认 10)"},
|
||||
},
|
||||
"required": ["q"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
requires_confirmation=False,
|
||||
handler=_vfs_search,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def get_tool(name: str) -> Optional[ToolSpec]:
|
||||
return TOOLS.get(name)
|
||||
|
||||
|
||||
def openai_tools() -> List[Dict[str, Any]]:
|
||||
out: List[Dict[str, Any]] = []
|
||||
for spec in TOOLS.values():
|
||||
out.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": spec.name,
|
||||
"description": spec.description,
|
||||
"parameters": spec.parameters,
|
||||
},
|
||||
})
|
||||
return out
|
||||
|
||||
|
||||
def tool_result_to_content(result: Any) -> str:
|
||||
if result is None:
|
||||
return ""
|
||||
if isinstance(result, str):
|
||||
return result
|
||||
try:
|
||||
return json.dumps(result, ensure_ascii=False)
|
||||
except TypeError:
|
||||
return json.dumps({"result": str(result)}, ensure_ascii=False)
|
||||
23
domain/agent/types.py
Normal file
23
domain/agent/types.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class AgentChatContext(BaseModel):
|
||||
current_path: Optional[str] = None
|
||||
|
||||
|
||||
class AgentChatRequest(BaseModel):
|
||||
messages: List[Dict[str, Any]] = Field(default_factory=list)
|
||||
auto_execute: bool = False
|
||||
approved_tool_call_ids: List[str] = Field(default_factory=list)
|
||||
rejected_tool_call_ids: List[str] = Field(default_factory=list)
|
||||
context: Optional[AgentChatContext] = None
|
||||
|
||||
|
||||
class PendingToolCall(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
arguments: Dict[str, Any] = Field(default_factory=dict)
|
||||
requires_confirmation: bool = True
|
||||
|
||||
@@ -1,28 +1,61 @@
|
||||
from .api import router_ai, router_vector_db
|
||||
from .inference import (
|
||||
MissingModelError,
|
||||
chat_completion,
|
||||
chat_completion_stream,
|
||||
describe_image_base64,
|
||||
get_text_embedding,
|
||||
provider_service,
|
||||
rerank_texts,
|
||||
)
|
||||
from .service import (
|
||||
AIProviderService,
|
||||
FILE_COLLECTION_NAME,
|
||||
VECTOR_COLLECTION_NAME,
|
||||
DEFAULT_VECTOR_DIMENSION,
|
||||
VectorDBConfigManager,
|
||||
VectorDBService,
|
||||
DEFAULT_VECTOR_DIMENSION,
|
||||
ABILITIES,
|
||||
normalize_capabilities,
|
||||
)
|
||||
from .types import (
|
||||
ABILITIES,
|
||||
AIDefaultsUpdate,
|
||||
AIModelCreate,
|
||||
AIModelUpdate,
|
||||
AIProviderCreate,
|
||||
AIProviderUpdate,
|
||||
VectorDBConfigPayload,
|
||||
normalize_capabilities,
|
||||
)
|
||||
from .vector_providers import (
|
||||
BaseVectorProvider,
|
||||
MilvusLiteProvider,
|
||||
MilvusServerProvider,
|
||||
QdrantProvider,
|
||||
get_provider_class,
|
||||
get_provider_entry,
|
||||
list_providers,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"router_ai",
|
||||
"router_vector_db",
|
||||
"MissingModelError",
|
||||
"chat_completion",
|
||||
"chat_completion_stream",
|
||||
"describe_image_base64",
|
||||
"get_text_embedding",
|
||||
"provider_service",
|
||||
"rerank_texts",
|
||||
"AIProviderService",
|
||||
"VectorDBService",
|
||||
"VectorDBConfigManager",
|
||||
"DEFAULT_VECTOR_DIMENSION",
|
||||
"VECTOR_COLLECTION_NAME",
|
||||
"FILE_COLLECTION_NAME",
|
||||
"BaseVectorProvider",
|
||||
"MilvusLiteProvider",
|
||||
"MilvusServerProvider",
|
||||
"QdrantProvider",
|
||||
"list_providers",
|
||||
"get_provider_entry",
|
||||
"get_provider_class",
|
||||
"ABILITIES",
|
||||
"normalize_capabilities",
|
||||
"AIDefaultsUpdate",
|
||||
|
||||
@@ -5,8 +5,9 @@ from fastapi import APIRouter, Depends, HTTPException, Path, Request
|
||||
|
||||
from api.response import success
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.ai.service import AIProviderService, VectorDBConfigManager, VectorDBService
|
||||
from domain.ai.types import (
|
||||
from domain.auth import User, get_current_active_user
|
||||
from .service import AIProviderService, VectorDBConfigManager, VectorDBService
|
||||
from .types import (
|
||||
AIDefaultsUpdate,
|
||||
AIModelCreate,
|
||||
AIModelUpdate,
|
||||
@@ -14,9 +15,7 @@ from domain.ai.types import (
|
||||
AIProviderUpdate,
|
||||
VectorDBConfigPayload,
|
||||
)
|
||||
from domain.ai.vector_providers import get_provider_class, get_provider_entry, list_providers
|
||||
from domain.auth.service import get_current_active_user
|
||||
from domain.auth.types import User
|
||||
from .vector_providers import get_provider_class, get_provider_entry, list_providers
|
||||
|
||||
router_ai = APIRouter(prefix="/api/ai", tags=["ai"])
|
||||
router_vector_db = APIRouter(prefix="/api/vector-db", tags=["vector-db"])
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import json
|
||||
|
||||
import httpx
|
||||
from typing import List, Sequence, Tuple
|
||||
from typing import Any, AsyncIterator, Dict, List, Sequence, Tuple
|
||||
|
||||
from models.database import AIModel, AIProvider
|
||||
from domain.ai.service import AIProviderService
|
||||
from .service import AIProviderService
|
||||
|
||||
|
||||
provider_service = AIProviderService
|
||||
@@ -243,3 +245,195 @@ async def _rerank_with_gemini(
|
||||
except (TypeError, ValueError):
|
||||
scores.append(0.0)
|
||||
return scores
|
||||
|
||||
|
||||
async def chat_completion(
|
||||
messages: List[Dict[str, Any]],
|
||||
*,
|
||||
ability: str = "chat",
|
||||
tools: List[Dict[str, Any]] | None = None,
|
||||
tool_choice: Any | None = None,
|
||||
temperature: float | None = None,
|
||||
timeout: float = 60.0,
|
||||
) -> Dict[str, Any]:
|
||||
model, provider = await _require_model(ability)
|
||||
if provider.api_format != "openai":
|
||||
raise MissingModelError("当前仅支持 OpenAI 兼容接口的对话模型。")
|
||||
return await _chat_with_openai(
|
||||
provider,
|
||||
model,
|
||||
messages,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
temperature=temperature,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
|
||||
async def _chat_with_openai(
|
||||
provider: AIProvider,
|
||||
model: AIModel,
|
||||
messages: List[Dict[str, Any]],
|
||||
*,
|
||||
tools: List[Dict[str, Any]] | None,
|
||||
tool_choice: Any | None,
|
||||
temperature: float | None,
|
||||
timeout: float,
|
||||
) -> Dict[str, Any]:
|
||||
url = _openai_endpoint(provider, "/chat/completions")
|
||||
payload: Dict[str, Any] = {
|
||||
"model": model.name,
|
||||
"messages": messages,
|
||||
}
|
||||
if tools:
|
||||
payload["tools"] = tools
|
||||
payload["tool_choice"] = tool_choice or "auto"
|
||||
if temperature is not None:
|
||||
payload["temperature"] = float(temperature)
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
response = await client.post(url, headers=_openai_headers(provider), json=payload)
|
||||
response.raise_for_status()
|
||||
body = response.json()
|
||||
|
||||
choices = body.get("choices") or []
|
||||
if not choices:
|
||||
raise RuntimeError("对话接口返回为空")
|
||||
message = choices[0].get("message")
|
||||
if not isinstance(message, dict):
|
||||
raise RuntimeError("对话接口返回格式异常")
|
||||
return message
|
||||
|
||||
|
||||
async def chat_completion_stream(
|
||||
messages: List[Dict[str, Any]],
|
||||
*,
|
||||
ability: str = "chat",
|
||||
tools: List[Dict[str, Any]] | None = None,
|
||||
tool_choice: Any | None = None,
|
||||
temperature: float | None = None,
|
||||
timeout: float = 60.0,
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
model, provider = await _require_model(ability)
|
||||
if provider.api_format != "openai":
|
||||
raise MissingModelError("当前仅支持 OpenAI 兼容接口的对话模型。")
|
||||
async for event in _chat_stream_with_openai(
|
||||
provider,
|
||||
model,
|
||||
messages,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
temperature=temperature,
|
||||
timeout=timeout,
|
||||
):
|
||||
yield event
|
||||
|
||||
|
||||
async def _chat_stream_with_openai(
|
||||
provider: AIProvider,
|
||||
model: AIModel,
|
||||
messages: List[Dict[str, Any]],
|
||||
*,
|
||||
tools: List[Dict[str, Any]] | None,
|
||||
tool_choice: Any | None,
|
||||
temperature: float | None,
|
||||
timeout: float,
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
url = _openai_endpoint(provider, "/chat/completions")
|
||||
payload: Dict[str, Any] = {
|
||||
"model": model.name,
|
||||
"messages": messages,
|
||||
"stream": True,
|
||||
}
|
||||
if tools:
|
||||
payload["tools"] = tools
|
||||
payload["tool_choice"] = tool_choice or "auto"
|
||||
if temperature is not None:
|
||||
payload["temperature"] = float(temperature)
|
||||
|
||||
content_parts: List[str] = []
|
||||
tool_call_map: Dict[int, Dict[str, Any]] = {}
|
||||
role = "assistant"
|
||||
finish_reason: str | None = None
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
async with client.stream("POST", url, headers=_openai_headers(provider), json=payload) as response:
|
||||
response.raise_for_status()
|
||||
async for line in response.aiter_lines():
|
||||
if not line:
|
||||
continue
|
||||
if not line.startswith("data:"):
|
||||
continue
|
||||
data = line[5:].strip()
|
||||
if not data:
|
||||
continue
|
||||
if data == "[DONE]":
|
||||
break
|
||||
try:
|
||||
chunk = json.loads(data)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
choices = chunk.get("choices") or []
|
||||
if not choices:
|
||||
continue
|
||||
choice = choices[0] if isinstance(choices[0], dict) else {}
|
||||
delta = choice.get("delta") if isinstance(choice, dict) else None
|
||||
delta = delta if isinstance(delta, dict) else {}
|
||||
|
||||
if isinstance(delta.get("role"), str):
|
||||
role = delta["role"]
|
||||
|
||||
delta_content = delta.get("content")
|
||||
if isinstance(delta_content, str) and delta_content:
|
||||
content_parts.append(delta_content)
|
||||
yield {"type": "delta", "delta": delta_content}
|
||||
|
||||
delta_tool_calls = delta.get("tool_calls")
|
||||
if isinstance(delta_tool_calls, list):
|
||||
for item in delta_tool_calls:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
idx = item.get("index")
|
||||
if not isinstance(idx, int):
|
||||
continue
|
||||
entry = tool_call_map.setdefault(
|
||||
idx,
|
||||
{"id": None, "type": None, "function": {"name": None, "arguments": ""}},
|
||||
)
|
||||
if isinstance(item.get("id"), str) and item["id"].strip():
|
||||
entry["id"] = item["id"]
|
||||
if isinstance(item.get("type"), str) and item["type"].strip():
|
||||
entry["type"] = item["type"]
|
||||
fn = item.get("function")
|
||||
if isinstance(fn, dict):
|
||||
if isinstance(fn.get("name"), str) and fn["name"].strip():
|
||||
entry["function"]["name"] = fn["name"]
|
||||
args_part = fn.get("arguments")
|
||||
if isinstance(args_part, str) and args_part:
|
||||
entry["function"]["arguments"] += args_part
|
||||
|
||||
fr = choice.get("finish_reason") if isinstance(choice, dict) else None
|
||||
if isinstance(fr, str) and fr:
|
||||
finish_reason = fr
|
||||
|
||||
content = "".join(content_parts)
|
||||
message: Dict[str, Any] = {"role": role, "content": content}
|
||||
if tool_call_map:
|
||||
tool_calls: List[Dict[str, Any]] = []
|
||||
for idx in sorted(tool_call_map.keys()):
|
||||
item = tool_call_map[idx]
|
||||
fn = item.get("function") if isinstance(item.get("function"), dict) else {}
|
||||
call_id = item.get("id") if isinstance(item.get("id"), str) and item.get("id") else f"call_{idx}"
|
||||
call_type = item.get("type") if isinstance(item.get("type"), str) and item.get("type") else "function"
|
||||
tool_calls.append({
|
||||
"id": call_id,
|
||||
"type": call_type,
|
||||
"function": {
|
||||
"name": fn.get("name") or "",
|
||||
"arguments": fn.get("arguments") or "",
|
||||
},
|
||||
})
|
||||
message["tool_calls"] = tool_calls
|
||||
|
||||
yield {"type": "message", "message": message, "finish_reason": finish_reason}
|
||||
|
||||
@@ -7,7 +7,7 @@ import httpx
|
||||
from tortoise.exceptions import DoesNotExist
|
||||
from tortoise.transactions import in_transaction
|
||||
|
||||
from domain.config.service import ConfigService
|
||||
from domain.config import ConfigService
|
||||
from models.database import AIDefaultModel, AIModel, AIProvider
|
||||
|
||||
from .types import ABILITIES, normalize_capabilities
|
||||
@@ -19,6 +19,8 @@ from .vector_providers import (
|
||||
)
|
||||
|
||||
DEFAULT_VECTOR_DIMENSION = 4096
|
||||
VECTOR_COLLECTION_NAME = "vector_collection"
|
||||
FILE_COLLECTION_NAME = "file_collection"
|
||||
|
||||
OPENAI_EMBEDDING_DIMS = {
|
||||
"text-embedding-3-large": 3072,
|
||||
@@ -138,7 +140,7 @@ def serialize_provider(provider: AIProvider) -> Dict[str, Any]:
|
||||
"provider_type": provider.provider_type,
|
||||
"api_format": provider.api_format,
|
||||
"base_url": provider.base_url,
|
||||
"api_key": provider.api_key,
|
||||
"has_api_key": bool(provider.api_key),
|
||||
"logo_url": provider.logo_url,
|
||||
"extra_config": provider.extra_config or {},
|
||||
"created_at": provider.created_at,
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from domain.audit.decorator import audit
|
||||
from domain.audit.types import AuditAction
|
||||
from domain.audit.api import router
|
||||
from .decorator import audit
|
||||
from .types import AuditAction
|
||||
|
||||
__all__ = ["audit", "AuditAction", "router"]
|
||||
__all__ = ["audit", "AuditAction"]
|
||||
|
||||
@@ -4,10 +4,9 @@ from typing import Annotated, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
|
||||
from api import response
|
||||
from domain.audit.service import AuditService
|
||||
from domain.audit.types import AuditAction
|
||||
from domain.auth.service import get_current_active_user
|
||||
from domain.auth.types import User
|
||||
from domain.auth import User, get_current_active_user
|
||||
from .service import AuditService
|
||||
from .types import AuditAction
|
||||
|
||||
CurrentUser = Annotated[User, Depends(get_current_active_user)]
|
||||
|
||||
@@ -62,7 +61,5 @@ async def clear_audit_logs(
|
||||
):
|
||||
start_dt = _parse_iso(start_time, "start_time")
|
||||
end_dt = _parse_iso(end_time, "end_time")
|
||||
if start_dt is None and end_dt is None:
|
||||
raise HTTPException(status_code=400, detail="start_time 或 end_time 至少提供一个")
|
||||
deleted_count = await AuditService.clear_logs(start_time=start_dt, end_time=end_dt)
|
||||
return response.success({"deleted_count": deleted_count})
|
||||
|
||||
@@ -7,11 +7,11 @@ import jwt
|
||||
from fastapi import Request
|
||||
from jwt.exceptions import InvalidTokenError
|
||||
|
||||
from domain.audit.service import AuditService
|
||||
from domain.audit.types import AuditAction
|
||||
from domain.auth.service import ALGORITHM
|
||||
from domain.config.service import ConfigService
|
||||
from domain.auth import ALGORITHM
|
||||
from domain.config import ConfigService
|
||||
from models.database import UserAccount
|
||||
from .service import AuditService
|
||||
from .types import AuditAction
|
||||
|
||||
|
||||
def _extract_request(bound_args: Mapping[str, Any]) -> Request | None:
|
||||
@@ -98,6 +98,11 @@ def _build_request_params(request: Request | None) -> Dict[str, Any] | None:
|
||||
def _get_client_ip(request: Request | None) -> str | None:
|
||||
if not request:
|
||||
return None
|
||||
cf_connecting_ip = request.headers.get("cf-connecting-ip") or request.headers.get("CF-Connecting-IP")
|
||||
if cf_connecting_ip:
|
||||
ip = cf_connecting_ip.strip()
|
||||
if ip:
|
||||
return ip
|
||||
x_real_ip = request.headers.get("x-real-ip") or request.headers.get("X-Real-IP")
|
||||
if x_real_ip:
|
||||
ip = x_real_ip.strip()
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import Any, Dict, Optional
|
||||
|
||||
from models.database import AuditLog
|
||||
|
||||
from domain.audit.types import AuditAction
|
||||
from .types import AuditAction
|
||||
|
||||
|
||||
class AuditService:
|
||||
|
||||
49
domain/auth/__init__.py
Normal file
49
domain/auth/__init__.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from .service import (
|
||||
ALGORITHM,
|
||||
AuthService,
|
||||
authenticate_user_db,
|
||||
create_access_token,
|
||||
get_current_active_user,
|
||||
get_current_user,
|
||||
get_password_hash,
|
||||
has_users,
|
||||
register_user,
|
||||
request_password_reset,
|
||||
reset_password_with_token,
|
||||
verify_password,
|
||||
verify_password_reset_token,
|
||||
)
|
||||
from .types import (
|
||||
PasswordResetConfirm,
|
||||
PasswordResetRequest,
|
||||
RegisterRequest,
|
||||
Token,
|
||||
TokenData,
|
||||
UpdateMeRequest,
|
||||
User,
|
||||
UserInDB,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ALGORITHM",
|
||||
"AuthService",
|
||||
"authenticate_user_db",
|
||||
"create_access_token",
|
||||
"get_current_active_user",
|
||||
"get_current_user",
|
||||
"get_password_hash",
|
||||
"has_users",
|
||||
"register_user",
|
||||
"request_password_reset",
|
||||
"reset_password_with_token",
|
||||
"verify_password",
|
||||
"verify_password_reset_token",
|
||||
"PasswordResetConfirm",
|
||||
"PasswordResetRequest",
|
||||
"RegisterRequest",
|
||||
"Token",
|
||||
"TokenData",
|
||||
"UpdateMeRequest",
|
||||
"User",
|
||||
"UserInDB",
|
||||
]
|
||||
@@ -5,8 +5,8 @@ from fastapi.security import OAuth2PasswordRequestForm
|
||||
|
||||
from api.response import success
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.auth.service import AuthService, get_current_active_user
|
||||
from domain.auth.types import (
|
||||
from .service import AuthService, get_current_active_user
|
||||
from .types import (
|
||||
PasswordResetConfirm,
|
||||
PasswordResetRequest,
|
||||
RegisterRequest,
|
||||
|
||||
@@ -5,13 +5,15 @@ from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Annotated
|
||||
|
||||
import bcrypt
|
||||
import jwt
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
||||
from jwt.exceptions import InvalidTokenError
|
||||
from passlib.context import CryptContext
|
||||
|
||||
from domain.auth.types import (
|
||||
from domain.config import ConfigService
|
||||
from models.database import UserAccount
|
||||
from .types import (
|
||||
PasswordResetConfirm,
|
||||
PasswordResetRequest,
|
||||
RegisterRequest,
|
||||
@@ -21,8 +23,6 @@ from domain.auth.types import (
|
||||
User,
|
||||
UserInDB,
|
||||
)
|
||||
from models.database import UserAccount
|
||||
from domain.config.service import ConfigService
|
||||
|
||||
ALGORITHM = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 * 365
|
||||
@@ -97,12 +97,15 @@ class PasswordResetStore:
|
||||
|
||||
|
||||
class AuthService:
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login")
|
||||
algorithm = ALGORITHM
|
||||
access_token_expire_minutes = ACCESS_TOKEN_EXPIRE_MINUTES
|
||||
password_reset_token_expire_minutes = PASSWORD_RESET_TOKEN_EXPIRE_MINUTES
|
||||
|
||||
@staticmethod
|
||||
def _to_bytes(value: str) -> bytes:
|
||||
return value.encode("utf-8")
|
||||
|
||||
@classmethod
|
||||
async def get_secret_key(cls) -> str:
|
||||
return await ConfigService.get_secret_key("SECRET_KEY", None)
|
||||
@@ -113,11 +116,17 @@ class AuthService:
|
||||
|
||||
@classmethod
|
||||
def verify_password(cls, plain_password: str, hashed_password: str) -> bool:
|
||||
return cls.pwd_context.verify(plain_password, hashed_password)
|
||||
try:
|
||||
return bcrypt.checkpw(cls._to_bytes(plain_password), hashed_password.encode("utf-8"))
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_password_hash(cls, password: str) -> str:
|
||||
return cls.pwd_context.hash(password)
|
||||
encoded = cls._to_bytes(password)
|
||||
if len(encoded) > 72:
|
||||
raise HTTPException(status_code=400, detail="密码过长")
|
||||
return bcrypt.hashpw(encoded, bcrypt.gensalt()).decode("utf-8")
|
||||
|
||||
@classmethod
|
||||
async def get_user_db(cls, username_or_email: str) -> UserInDB | None:
|
||||
@@ -315,7 +324,7 @@ class AuthService:
|
||||
|
||||
@classmethod
|
||||
async def _send_password_reset_email(cls, user: UserAccount, token: str) -> None:
|
||||
from domain.email.service import EmailService
|
||||
from domain.email import EmailService
|
||||
|
||||
app_domain = await ConfigService.get("APP_DOMAIN", None)
|
||||
base_url = (app_domain or "http://localhost:5173").rstrip("/")
|
||||
|
||||
@@ -1 +1,7 @@
|
||||
from .service import BackupService
|
||||
from .types import BackupData
|
||||
|
||||
__all__ = [
|
||||
"BackupService",
|
||||
"BackupData",
|
||||
]
|
||||
|
||||
@@ -4,8 +4,8 @@ from fastapi import APIRouter, Depends, File, Request, UploadFile
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.auth.service import get_current_active_user
|
||||
from domain.backup.service import BackupService
|
||||
from domain.auth import get_current_active_user
|
||||
from .service import BackupService
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api/backup",
|
||||
|
||||
@@ -4,8 +4,8 @@ from datetime import datetime
|
||||
from fastapi import HTTPException
|
||||
from tortoise.transactions import in_transaction
|
||||
|
||||
from domain.backup.types import BackupData
|
||||
from domain.config.service import VERSION
|
||||
from domain.config import VERSION
|
||||
from .types import BackupData
|
||||
from models.database import (
|
||||
AIDefaultModel,
|
||||
AIModel,
|
||||
|
||||
10
domain/config/__init__.py
Normal file
10
domain/config/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from .service import ConfigService, VERSION
|
||||
from .types import ConfigItem, LatestVersionInfo, SystemStatus
|
||||
|
||||
__all__ = [
|
||||
"ConfigService",
|
||||
"VERSION",
|
||||
"ConfigItem",
|
||||
"LatestVersionInfo",
|
||||
"SystemStatus",
|
||||
]
|
||||
@@ -4,10 +4,9 @@ from fastapi import APIRouter, Depends, Form, Request
|
||||
|
||||
from api.response import success
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.auth.service import get_current_active_user
|
||||
from domain.auth.types import User
|
||||
from domain.config.service import ConfigService
|
||||
from domain.config.types import ConfigItem
|
||||
from domain.auth import User, get_current_active_user
|
||||
from .service import ConfigService
|
||||
from .types import ConfigItem
|
||||
|
||||
router = APIRouter(prefix="/api/config", tags=["config"])
|
||||
|
||||
@@ -29,7 +28,7 @@ async def set_config(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
key: str = Form(...),
|
||||
value: str = Form(...),
|
||||
value: str = Form(""),
|
||||
):
|
||||
await ConfigService.set(key, value)
|
||||
return success(ConfigItem(key=key, value=value).model_dump())
|
||||
|
||||
@@ -5,12 +5,12 @@ from typing import Any, Dict, Optional
|
||||
import httpx
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from domain.config.types import LatestVersionInfo, SystemStatus
|
||||
from .types import LatestVersionInfo, SystemStatus
|
||||
from models.database import Configuration, UserAccount
|
||||
|
||||
load_dotenv(dotenv_path=".env")
|
||||
|
||||
VERSION = "v1.5.0"
|
||||
VERSION = "v1.7.0"
|
||||
|
||||
|
||||
class ConfigService:
|
||||
|
||||
20
domain/email/__init__.py
Normal file
20
domain/email/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from .service import EmailService, EmailTemplateRenderer
|
||||
from .types import (
|
||||
EmailConfig,
|
||||
EmailSecurity,
|
||||
EmailSendPayload,
|
||||
EmailTemplatePreviewPayload,
|
||||
EmailTemplateUpdate,
|
||||
EmailTestRequest,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"EmailService",
|
||||
"EmailTemplateRenderer",
|
||||
"EmailConfig",
|
||||
"EmailSecurity",
|
||||
"EmailSendPayload",
|
||||
"EmailTemplatePreviewPayload",
|
||||
"EmailTemplateUpdate",
|
||||
"EmailTestRequest",
|
||||
]
|
||||
@@ -2,10 +2,9 @@ from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
|
||||
from api.response import success
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.auth.service import get_current_active_user
|
||||
from domain.auth.types import User
|
||||
from domain.email.service import EmailService, EmailTemplateRenderer
|
||||
from domain.email.types import (
|
||||
from domain.auth import User, get_current_active_user
|
||||
from .service import EmailService, EmailTemplateRenderer
|
||||
from .types import (
|
||||
EmailTemplatePreviewPayload,
|
||||
EmailTemplateUpdate,
|
||||
EmailTestRequest,
|
||||
|
||||
@@ -7,8 +7,8 @@ from pathlib import Path
|
||||
from string import Template
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from domain.config.service import ConfigService
|
||||
from domain.email.types import EmailConfig, EmailSecurity, EmailSendPayload
|
||||
from domain.config import ConfigService
|
||||
from .types import EmailConfig, EmailSecurity, EmailSendPayload
|
||||
|
||||
|
||||
class EmailTemplateRenderer:
|
||||
@@ -104,7 +104,7 @@ class EmailService:
|
||||
template: str,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
from domain.tasks.task_queue import TaskProgress, task_queue_service
|
||||
from domain.tasks import TaskProgress, task_queue_service
|
||||
|
||||
payload = EmailSendPayload(
|
||||
recipients=recipients,
|
||||
@@ -126,7 +126,7 @@ class EmailService:
|
||||
|
||||
@classmethod
|
||||
async def send_from_task(cls, task_id: str, data: Dict[str, Any]):
|
||||
from domain.tasks.task_queue import TaskProgress, task_queue_service
|
||||
from domain.tasks import TaskProgress, task_queue_service
|
||||
|
||||
payload = EmailSendPayload(**data)
|
||||
|
||||
|
||||
7
domain/offline_downloads/__init__.py
Normal file
7
domain/offline_downloads/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from .service import OfflineDownloadService
|
||||
from .types import OfflineDownloadCreate
|
||||
|
||||
__all__ = [
|
||||
"OfflineDownloadService",
|
||||
"OfflineDownloadCreate",
|
||||
]
|
||||
@@ -4,10 +4,9 @@ from fastapi import APIRouter, Depends, Request
|
||||
|
||||
from api.response import success
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.auth.service import get_current_active_user
|
||||
from domain.auth.types import User
|
||||
from domain.offline_downloads.service import OfflineDownloadService
|
||||
from domain.offline_downloads.types import OfflineDownloadCreate
|
||||
from domain.auth import User, get_current_active_user
|
||||
from .service import OfflineDownloadService
|
||||
from .types import OfflineDownloadCreate
|
||||
|
||||
CurrentUser = Annotated[User, Depends(get_current_active_user)]
|
||||
|
||||
|
||||
@@ -7,11 +7,10 @@ import aiofiles
|
||||
import aiohttp
|
||||
from fastapi import Depends, HTTPException
|
||||
|
||||
from domain.auth.service import get_current_active_user
|
||||
from domain.auth.types import User
|
||||
from domain.offline_downloads.types import OfflineDownloadCreate
|
||||
from domain.virtual_fs.service import VirtualFSService
|
||||
from domain.tasks.task_queue import Task, TaskProgress, task_queue_service
|
||||
from domain.auth import User, get_current_active_user
|
||||
from domain.tasks import Task, TaskProgress, task_queue_service
|
||||
from domain.virtual_fs import VirtualFSService
|
||||
from .types import OfflineDownloadCreate
|
||||
|
||||
|
||||
class OfflineDownloadService:
|
||||
|
||||
@@ -1 +1,17 @@
|
||||
"""
|
||||
Foxel 插件系统
|
||||
|
||||
提供 .foxpkg 插件包的安装、管理和运行时加载功能。
|
||||
"""
|
||||
|
||||
from .loader import PluginLoadError, PluginLoader
|
||||
from .service import PluginService
|
||||
from .startup import init_plugins, load_installed_plugins
|
||||
|
||||
__all__ = [
|
||||
"PluginLoader",
|
||||
"PluginLoadError",
|
||||
"PluginService",
|
||||
"init_plugins",
|
||||
"load_installed_plugins",
|
||||
]
|
||||
|
||||
@@ -1,76 +1,111 @@
|
||||
"""
|
||||
插件管理 API 路由
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Body, Request
|
||||
from fastapi import APIRouter, File, Request, UploadFile
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.plugins.service import PluginService
|
||||
from domain.plugins.routes import video_player as video_player_routes
|
||||
from domain.plugins.types import PluginCreate, PluginManifestUpdate, PluginOut
|
||||
from .service import PluginService
|
||||
from .types import (
|
||||
PluginInstallResult,
|
||||
PluginOut,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/api/plugins", tags=["plugins"])
|
||||
router.include_router(video_player_routes.router)
|
||||
|
||||
|
||||
@router.post("", response_model=PluginOut)
|
||||
@audit(
|
||||
action=AuditAction.CREATE,
|
||||
description="创建插件",
|
||||
body_fields=["url", "enabled"],
|
||||
)
|
||||
async def create_plugin(request: Request, payload: PluginCreate):
|
||||
return await PluginService.create(payload)
|
||||
# ========== 安装 ==========
|
||||
|
||||
|
||||
@router.post("/install", response_model=PluginInstallResult)
|
||||
@audit(action=AuditAction.CREATE, description="安装插件包")
|
||||
async def install_plugin(request: Request, file: UploadFile = File(...)):
|
||||
"""
|
||||
安装 .foxpkg 插件包
|
||||
|
||||
上传 .foxpkg 文件进行安装。
|
||||
"""
|
||||
content = await file.read()
|
||||
return await PluginService.install_package(content, file.filename or "plugin.foxpkg")
|
||||
|
||||
|
||||
# ========== 插件列表和详情 ==========
|
||||
|
||||
|
||||
@router.get("", response_model=List[PluginOut])
|
||||
@audit(action=AuditAction.READ, description="获取插件列表")
|
||||
async def list_plugins(request: Request):
|
||||
"""获取已安装的插件列表"""
|
||||
return await PluginService.list_plugins()
|
||||
|
||||
|
||||
@router.delete("/{plugin_id}")
|
||||
@audit(action=AuditAction.DELETE, description="删除插件")
|
||||
async def delete_plugin(request: Request, plugin_id: int):
|
||||
await PluginService.delete(plugin_id)
|
||||
@router.get("/{key_or_id}", response_model=PluginOut)
|
||||
@audit(action=AuditAction.READ, description="获取插件详情")
|
||||
async def get_plugin(request: Request, key_or_id: str):
|
||||
"""获取单个插件详情"""
|
||||
return await PluginService.get_plugin(key_or_id)
|
||||
|
||||
|
||||
# ========== 插件管理 ==========
|
||||
|
||||
|
||||
@router.delete("/{key_or_id}")
|
||||
@audit(action=AuditAction.DELETE, description="卸载插件")
|
||||
async def delete_plugin(request: Request, key_or_id: str):
|
||||
"""卸载插件"""
|
||||
await PluginService.delete(key_or_id)
|
||||
return {"code": 0, "msg": "ok"}
|
||||
|
||||
|
||||
@router.put("/{plugin_id}", response_model=PluginOut)
|
||||
@audit(
|
||||
action=AuditAction.UPDATE,
|
||||
description="更新插件",
|
||||
body_fields=["url", "enabled"],
|
||||
)
|
||||
async def update_plugin(request: Request, plugin_id: int, payload: PluginCreate):
|
||||
return await PluginService.update(plugin_id, payload)
|
||||
# ========== 插件资源 ==========
|
||||
|
||||
|
||||
@router.post("/{plugin_id}/metadata", response_model=PluginOut)
|
||||
@audit(
|
||||
action=AuditAction.UPDATE,
|
||||
description="更新插件 manifest",
|
||||
body_fields=[
|
||||
"key",
|
||||
"name",
|
||||
"version",
|
||||
"open_app",
|
||||
"supported_exts",
|
||||
"default_bounds",
|
||||
"default_maximized",
|
||||
"icon",
|
||||
"description",
|
||||
"author",
|
||||
"website",
|
||||
"github",
|
||||
],
|
||||
)
|
||||
async def update_manifest(
|
||||
request: Request, plugin_id: int, manifest: PluginManifestUpdate = Body(...)
|
||||
):
|
||||
return await PluginService.update_manifest(plugin_id, manifest)
|
||||
@router.get("/{key_or_id}/bundle.js")
|
||||
async def get_bundle(request: Request, key_or_id: str):
|
||||
"""获取插件前端 bundle"""
|
||||
path = await PluginService.get_bundle_path(key_or_id)
|
||||
v = (request.query_params.get("v") or "").strip()
|
||||
cache_control = "public, max-age=31536000, immutable" if v else "no-cache"
|
||||
return FileResponse(
|
||||
path,
|
||||
media_type="application/javascript",
|
||||
headers={"Cache-Control": cache_control},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{plugin_id}/bundle.js")
|
||||
async def get_bundle(request: Request, plugin_id: int):
|
||||
path = await PluginService.get_bundle_path(plugin_id)
|
||||
return FileResponse(path, media_type="application/javascript", headers={"Cache-Control": "no-store"})
|
||||
@router.get("/{key}/assets/{asset_path:path}")
|
||||
async def get_asset(request: Request, key: str, asset_path: str):
|
||||
"""获取插件静态资源"""
|
||||
path = await PluginService.get_asset_path(key, asset_path)
|
||||
|
||||
# 根据扩展名确定 MIME 类型
|
||||
ext = path.suffix.lower()
|
||||
media_types = {
|
||||
".js": "application/javascript",
|
||||
".css": "text/css",
|
||||
".json": "application/json",
|
||||
".svg": "image/svg+xml",
|
||||
".png": "image/png",
|
||||
".jpg": "image/jpeg",
|
||||
".jpeg": "image/jpeg",
|
||||
".gif": "image/gif",
|
||||
".webp": "image/webp",
|
||||
".ico": "image/x-icon",
|
||||
".woff": "font/woff",
|
||||
".woff2": "font/woff2",
|
||||
".ttf": "font/ttf",
|
||||
".eot": "application/vnd.ms-fontobject",
|
||||
".html": "text/html",
|
||||
".txt": "text/plain",
|
||||
".md": "text/markdown",
|
||||
}
|
||||
media_type = media_types.get(ext, "application/octet-stream")
|
||||
|
||||
return FileResponse(
|
||||
path,
|
||||
media_type=media_type,
|
||||
headers={"Cache-Control": "public, max-age=3600"},
|
||||
)
|
||||
|
||||
449
domain/plugins/loader.py
Normal file
449
domain/plugins/loader.py
Normal file
@@ -0,0 +1,449 @@
|
||||
"""
|
||||
插件加载器模块
|
||||
|
||||
负责:
|
||||
1. .foxpkg 解包和验证
|
||||
2. 插件文件部署
|
||||
3. 后端路由动态加载
|
||||
4. 处理器动态注册
|
||||
"""
|
||||
|
||||
import io
|
||||
import json
|
||||
import shutil
|
||||
import sys
|
||||
import zipfile
|
||||
from importlib.util import module_from_spec, spec_from_file_location
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from .types import (
|
||||
ManifestProcessorConfig,
|
||||
ManifestRouteConfig,
|
||||
PluginManifest,
|
||||
)
|
||||
|
||||
|
||||
class PluginLoadError(Exception):
|
||||
"""插件加载错误"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class PluginLoader:
|
||||
"""插件加载器"""
|
||||
|
||||
PLUGINS_ROOT = Path("data/plugins")
|
||||
|
||||
# 已加载的插件模块缓存
|
||||
_loaded_modules: Dict[str, ModuleType] = {}
|
||||
# 已挂载的路由追踪
|
||||
_mounted_routers: Dict[str, List[APIRouter]] = {}
|
||||
|
||||
@classmethod
|
||||
def get_plugin_dir(cls, plugin_key: str) -> Path:
|
||||
"""获取插件目录"""
|
||||
return cls.PLUGINS_ROOT / plugin_key
|
||||
|
||||
@classmethod
|
||||
def get_manifest_path(cls, plugin_key: str) -> Path:
|
||||
"""获取插件 manifest.json 路径"""
|
||||
return cls.get_plugin_dir(plugin_key) / "manifest.json"
|
||||
|
||||
@classmethod
|
||||
def get_frontend_bundle_path(cls, plugin_key: str, entry: Optional[str] = None) -> Path:
|
||||
"""获取前端 bundle 路径"""
|
||||
plugin_dir = cls.get_plugin_dir(plugin_key)
|
||||
if entry:
|
||||
return plugin_dir / entry
|
||||
# 默认位置
|
||||
return plugin_dir / "frontend" / "index.js"
|
||||
|
||||
@classmethod
|
||||
def get_asset_path(cls, plugin_key: str, asset_path: str) -> Path:
|
||||
"""获取静态资源路径"""
|
||||
return cls.get_plugin_dir(plugin_key) / asset_path
|
||||
|
||||
# ========== 解包和验证 ==========
|
||||
|
||||
@classmethod
|
||||
def validate_manifest(cls, manifest_data: Dict[str, Any]) -> Tuple[bool, List[str]]:
|
||||
"""验证 manifest 数据"""
|
||||
errors: List[str] = []
|
||||
|
||||
# 必需字段检查
|
||||
if not manifest_data.get("key"):
|
||||
errors.append("manifest 缺少必需字段: key")
|
||||
if not manifest_data.get("name"):
|
||||
errors.append("manifest 缺少必需字段: name")
|
||||
|
||||
# key 格式检查(Java 命名空间格式)
|
||||
key = manifest_data.get("key", "")
|
||||
if key:
|
||||
import re
|
||||
|
||||
# 格式: com.example.plugin (至少两级,每级以小写字母开头,可包含小写字母和数字)
|
||||
if not re.match(r"^[a-z][a-z0-9]*(\.[a-z][a-z0-9]*)+$", key):
|
||||
errors.append(
|
||||
"key 格式无效:必须使用命名空间格式(如 com.example.plugin),"
|
||||
"每个部分以小写字母开头,只能包含小写字母和数字,至少两级"
|
||||
)
|
||||
|
||||
# 版本格式检查(简单检查)
|
||||
version = manifest_data.get("version", "")
|
||||
if version and not isinstance(version, str):
|
||||
errors.append("version 必须是字符串")
|
||||
|
||||
# 验证 frontend 配置
|
||||
frontend = manifest_data.get("frontend")
|
||||
if frontend and isinstance(frontend, dict):
|
||||
if frontend.get("entry") and not isinstance(frontend["entry"], str):
|
||||
errors.append("frontend.entry 必须是字符串")
|
||||
if frontend.get("styles") is not None:
|
||||
if not isinstance(frontend["styles"], list) or not all(
|
||||
isinstance(x, str) for x in frontend["styles"]
|
||||
):
|
||||
errors.append("frontend.styles 必须是字符串数组")
|
||||
supported_exts = frontend.get("supportedExts") or frontend.get("supported_exts")
|
||||
if supported_exts and not isinstance(supported_exts, list):
|
||||
errors.append("frontend.supportedExts 必须是数组")
|
||||
use_system_window = frontend.get("useSystemWindow") or frontend.get("use_system_window")
|
||||
if use_system_window is not None and not isinstance(use_system_window, bool):
|
||||
errors.append("frontend.useSystemWindow 必须是布尔值")
|
||||
|
||||
# 验证 backend 配置
|
||||
backend = manifest_data.get("backend")
|
||||
if backend and isinstance(backend, dict):
|
||||
routes = backend.get("routes", [])
|
||||
if routes:
|
||||
for i, route in enumerate(routes):
|
||||
if not route.get("module"):
|
||||
errors.append(f"backend.routes[{i}] 缺少 module")
|
||||
if not route.get("prefix"):
|
||||
errors.append(f"backend.routes[{i}] 缺少 prefix")
|
||||
|
||||
processors = backend.get("processors", [])
|
||||
if processors:
|
||||
for i, proc in enumerate(processors):
|
||||
if not proc.get("module"):
|
||||
errors.append(f"backend.processors[{i}] 缺少 module")
|
||||
if not proc.get("type"):
|
||||
errors.append(f"backend.processors[{i}] 缺少 type")
|
||||
|
||||
return len(errors) == 0, errors
|
||||
|
||||
@classmethod
|
||||
def unpack_foxpkg(
|
||||
cls, file_content: bytes, target_key: Optional[str] = None
|
||||
) -> Tuple[PluginManifest, Path]:
|
||||
"""
|
||||
解包 .foxpkg 文件
|
||||
|
||||
Args:
|
||||
file_content: .foxpkg 文件内容
|
||||
target_key: 可选,指定安装的插件 key(覆盖 manifest 中的 key)
|
||||
|
||||
Returns:
|
||||
(manifest, plugin_dir) 元组
|
||||
|
||||
Raises:
|
||||
PluginLoadError: 解包或验证失败
|
||||
"""
|
||||
try:
|
||||
with zipfile.ZipFile(io.BytesIO(file_content)) as zf:
|
||||
# 读取 manifest.json
|
||||
try:
|
||||
manifest_bytes = zf.read("manifest.json")
|
||||
except KeyError:
|
||||
raise PluginLoadError("插件包缺少 manifest.json")
|
||||
|
||||
try:
|
||||
manifest_data = json.loads(manifest_bytes.decode("utf-8"))
|
||||
except json.JSONDecodeError as e:
|
||||
raise PluginLoadError(f"manifest.json 解析失败: {e}")
|
||||
|
||||
# 验证 manifest
|
||||
valid, errors = cls.validate_manifest(manifest_data)
|
||||
if not valid:
|
||||
raise PluginLoadError(f"manifest 验证失败: {'; '.join(errors)}")
|
||||
|
||||
# 解析 manifest
|
||||
try:
|
||||
manifest = PluginManifest.model_validate(manifest_data)
|
||||
except Exception as e:
|
||||
raise PluginLoadError(f"manifest 解析失败: {e}")
|
||||
|
||||
# 确定插件 key
|
||||
plugin_key = target_key or manifest.key
|
||||
|
||||
# 验证包内文件
|
||||
cls._validate_package_files(zf, manifest)
|
||||
|
||||
# 部署文件
|
||||
target_dir = cls.PLUGINS_ROOT / plugin_key
|
||||
if target_dir.exists():
|
||||
# 备份旧版本
|
||||
backup_dir = cls.PLUGINS_ROOT / f"{plugin_key}.backup"
|
||||
if backup_dir.exists():
|
||||
shutil.rmtree(backup_dir)
|
||||
shutil.move(str(target_dir), str(backup_dir))
|
||||
|
||||
target_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
zf.extractall(target_dir)
|
||||
except Exception as e:
|
||||
# 恢复备份
|
||||
if (cls.PLUGINS_ROOT / f"{plugin_key}.backup").exists():
|
||||
shutil.rmtree(target_dir, ignore_errors=True)
|
||||
shutil.move(str(cls.PLUGINS_ROOT / f"{plugin_key}.backup"), str(target_dir))
|
||||
raise PluginLoadError(f"文件解压失败: {e}")
|
||||
|
||||
# 清理备份
|
||||
backup_dir = cls.PLUGINS_ROOT / f"{plugin_key}.backup"
|
||||
if backup_dir.exists():
|
||||
shutil.rmtree(backup_dir, ignore_errors=True)
|
||||
|
||||
return manifest, target_dir
|
||||
|
||||
except zipfile.BadZipFile:
|
||||
raise PluginLoadError("无效的插件包格式(非 ZIP 文件)")
|
||||
|
||||
@classmethod
|
||||
def _validate_package_files(cls, zf: zipfile.ZipFile, manifest: PluginManifest) -> None:
|
||||
"""验证包内文件是否完整"""
|
||||
file_list = zf.namelist()
|
||||
|
||||
# 检查前端入口
|
||||
if manifest.frontend and manifest.frontend.entry:
|
||||
if manifest.frontend.entry not in file_list:
|
||||
raise PluginLoadError(f"前端入口文件不存在: {manifest.frontend.entry}")
|
||||
|
||||
# 检查后端模块
|
||||
if manifest.backend:
|
||||
if manifest.backend.routes:
|
||||
for route in manifest.backend.routes:
|
||||
if route.module not in file_list:
|
||||
raise PluginLoadError(f"路由模块不存在: {route.module}")
|
||||
|
||||
if manifest.backend.processors:
|
||||
for proc in manifest.backend.processors:
|
||||
if proc.module not in file_list:
|
||||
raise PluginLoadError(f"处理器模块不存在: {proc.module}")
|
||||
|
||||
# ========== 路由动态加载 ==========
|
||||
|
||||
@classmethod
|
||||
def load_route_module(cls, plugin_key: str, route_config: ManifestRouteConfig) -> APIRouter:
|
||||
"""
|
||||
动态加载插件路由模块
|
||||
|
||||
Args:
|
||||
plugin_key: 插件标识
|
||||
route_config: 路由配置
|
||||
|
||||
Returns:
|
||||
加载的 APIRouter
|
||||
"""
|
||||
module_path = cls.get_plugin_dir(plugin_key) / route_config.module
|
||||
|
||||
if not module_path.exists():
|
||||
raise PluginLoadError(f"路由模块不存在: {module_path}")
|
||||
|
||||
module_name = f"foxel_plugin_{plugin_key}_route_{module_path.stem}"
|
||||
|
||||
try:
|
||||
spec = spec_from_file_location(module_name, module_path)
|
||||
if spec is None or spec.loader is None:
|
||||
raise PluginLoadError(f"无法加载路由模块: {module_path}")
|
||||
|
||||
module = module_from_spec(spec)
|
||||
sys.modules[module_name] = module
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
# 缓存模块
|
||||
cls._loaded_modules[f"{plugin_key}:route:{route_config.module}"] = module
|
||||
|
||||
# 获取 router
|
||||
router = getattr(module, "router", None)
|
||||
if router is None:
|
||||
raise PluginLoadError(f"路由模块缺少 'router' 对象: {module_path}")
|
||||
|
||||
if not isinstance(router, APIRouter):
|
||||
raise PluginLoadError(f"'router' 不是有效的 APIRouter 实例: {module_path}")
|
||||
|
||||
# 创建包装路由器添加前缀
|
||||
wrapper = APIRouter(prefix=route_config.prefix, tags=route_config.tags or [])
|
||||
wrapper.include_router(router)
|
||||
|
||||
return wrapper
|
||||
|
||||
except PluginLoadError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise PluginLoadError(f"加载路由模块失败 [{module_path}]: {e}")
|
||||
|
||||
@classmethod
|
||||
def load_all_routes(cls, plugin_key: str, manifest: PluginManifest) -> List[APIRouter]:
|
||||
"""加载插件的所有路由"""
|
||||
routers: List[APIRouter] = []
|
||||
|
||||
if not manifest.backend or not manifest.backend.routes:
|
||||
return routers
|
||||
|
||||
for route_config in manifest.backend.routes:
|
||||
router = cls.load_route_module(plugin_key, route_config)
|
||||
routers.append(router)
|
||||
|
||||
cls._mounted_routers[plugin_key] = routers
|
||||
return routers
|
||||
|
||||
# ========== 处理器动态注册 ==========
|
||||
|
||||
@classmethod
|
||||
def load_processor_module(
|
||||
cls, plugin_key: str, processor_config: ManifestProcessorConfig
|
||||
) -> None:
|
||||
"""
|
||||
动态加载并注册处理器模块
|
||||
|
||||
Args:
|
||||
plugin_key: 插件标识
|
||||
processor_config: 处理器配置
|
||||
"""
|
||||
module_path = cls.get_plugin_dir(plugin_key) / processor_config.module
|
||||
|
||||
if not module_path.exists():
|
||||
raise PluginLoadError(f"处理器模块不存在: {module_path}")
|
||||
|
||||
module_name = f"foxel_plugin_{plugin_key}_processor_{module_path.stem}"
|
||||
|
||||
try:
|
||||
spec = spec_from_file_location(module_name, module_path)
|
||||
if spec is None or spec.loader is None:
|
||||
raise PluginLoadError(f"无法加载处理器模块: {module_path}")
|
||||
|
||||
module = module_from_spec(spec)
|
||||
sys.modules[module_name] = module
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
# 缓存模块
|
||||
cls._loaded_modules[f"{plugin_key}:processor:{processor_config.module}"] = module
|
||||
|
||||
# 获取处理器工厂
|
||||
factory = getattr(module, "PROCESSOR_FACTORY", None)
|
||||
if factory is None:
|
||||
raise PluginLoadError(f"处理器模块缺少 'PROCESSOR_FACTORY': {module_path}")
|
||||
|
||||
# 获取配置 schema
|
||||
config_schema = getattr(module, "CONFIG_SCHEMA", [])
|
||||
processor_name = getattr(module, "PROCESSOR_NAME", processor_config.name or processor_config.type)
|
||||
supported_exts = getattr(module, "SUPPORTED_EXTS", [])
|
||||
|
||||
# 注册到处理器注册表
|
||||
from domain.processors import CONFIG_SCHEMAS, TYPE_MAP
|
||||
|
||||
processor_type = processor_config.type
|
||||
TYPE_MAP[processor_type] = factory
|
||||
|
||||
# 获取实例以读取属性
|
||||
try:
|
||||
sample = factory()
|
||||
produces_file = getattr(sample, "produces_file", False)
|
||||
supports_directory = getattr(sample, "supports_directory", False)
|
||||
except Exception:
|
||||
produces_file = False
|
||||
supports_directory = False
|
||||
|
||||
CONFIG_SCHEMAS[processor_type] = {
|
||||
"type": processor_type,
|
||||
"name": processor_name,
|
||||
"supported_exts": supported_exts,
|
||||
"config_schema": config_schema,
|
||||
"produces_file": produces_file,
|
||||
"supports_directory": supports_directory,
|
||||
"plugin": plugin_key, # 标记来源插件
|
||||
"module_path": str(module_path),
|
||||
}
|
||||
|
||||
except PluginLoadError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise PluginLoadError(f"加载处理器模块失败 [{module_path}]: {e}")
|
||||
|
||||
@classmethod
|
||||
def load_all_processors(cls, plugin_key: str, manifest: PluginManifest) -> List[str]:
|
||||
"""加载插件的所有处理器,返回处理器类型列表"""
|
||||
processor_types: List[str] = []
|
||||
|
||||
if not manifest.backend or not manifest.backend.processors:
|
||||
return processor_types
|
||||
|
||||
for proc_config in manifest.backend.processors:
|
||||
cls.load_processor_module(plugin_key, proc_config)
|
||||
processor_types.append(proc_config.type)
|
||||
|
||||
return processor_types
|
||||
|
||||
# ========== 卸载 ==========
|
||||
|
||||
@classmethod
|
||||
def unload_plugin(cls, plugin_key: str, manifest: Optional[PluginManifest] = None) -> None:
|
||||
"""
|
||||
卸载插件的后端组件
|
||||
|
||||
Args:
|
||||
plugin_key: 插件标识
|
||||
manifest: 可选的 manifest,用于确定要卸载的组件
|
||||
"""
|
||||
# 卸载处理器
|
||||
if manifest and manifest.backend and manifest.backend.processors:
|
||||
from domain.processors import CONFIG_SCHEMAS, TYPE_MAP
|
||||
|
||||
for proc_config in manifest.backend.processors:
|
||||
proc_type = proc_config.type
|
||||
if proc_type in TYPE_MAP:
|
||||
del TYPE_MAP[proc_type]
|
||||
if proc_type in CONFIG_SCHEMAS:
|
||||
del CONFIG_SCHEMAS[proc_type]
|
||||
|
||||
# 清理缓存的模块
|
||||
keys_to_remove = [k for k in cls._loaded_modules if k.startswith(f"{plugin_key}:")]
|
||||
for key in keys_to_remove:
|
||||
module = cls._loaded_modules.pop(key, None)
|
||||
if module and module.__name__ in sys.modules:
|
||||
del sys.modules[module.__name__]
|
||||
|
||||
# 清理路由追踪(注意:FastAPI 不支持动态移除路由,需要重启应用)
|
||||
cls._mounted_routers.pop(plugin_key, None)
|
||||
|
||||
@classmethod
|
||||
def delete_plugin_files(cls, plugin_key: str) -> None:
|
||||
"""删除插件文件"""
|
||||
plugin_dir = cls.get_plugin_dir(plugin_key)
|
||||
if plugin_dir.exists():
|
||||
shutil.rmtree(plugin_dir)
|
||||
|
||||
# 同时删除备份
|
||||
backup_dir = cls.PLUGINS_ROOT / f"{plugin_key}.backup"
|
||||
if backup_dir.exists():
|
||||
shutil.rmtree(backup_dir)
|
||||
|
||||
# ========== 读取 manifest ==========
|
||||
|
||||
@classmethod
|
||||
def read_manifest(cls, plugin_key: str) -> Optional[PluginManifest]:
|
||||
"""从文件系统读取插件 manifest"""
|
||||
manifest_path = cls.get_manifest_path(plugin_key)
|
||||
if not manifest_path.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(manifest_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
return PluginManifest.model_validate(data)
|
||||
except Exception:
|
||||
return None
|
||||
@@ -1,2 +0,0 @@
|
||||
"""插件专属服务端路由集合。"""
|
||||
|
||||
@@ -1,142 +0,0 @@
|
||||
import json
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
|
||||
from api.response import success
|
||||
from domain.auth.service import get_current_active_user
|
||||
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/video-player",
|
||||
tags=["plugins"],
|
||||
dependencies=[Depends(get_current_active_user)],
|
||||
)
|
||||
|
||||
DATA_ROOT = Path("data/.video")
|
||||
|
||||
|
||||
def _read_json(path: Path) -> Dict[str, Any]:
|
||||
return json.loads(path.read_text(encoding="utf-8"))
|
||||
|
||||
|
||||
def _file_mtime_iso(path: Path) -> str:
|
||||
try:
|
||||
ts = path.stat().st_mtime
|
||||
except FileNotFoundError:
|
||||
return ""
|
||||
return datetime.fromtimestamp(ts, tz=UTC).isoformat()
|
||||
|
||||
|
||||
def _extract_title(payload: Dict[str, Any]) -> str:
|
||||
detail = (payload.get("tmdb") or {}).get("detail") or {}
|
||||
if payload.get("type") == "tv":
|
||||
return str(detail.get("name") or detail.get("original_name") or "")
|
||||
return str(detail.get("title") or detail.get("original_title") or "")
|
||||
|
||||
|
||||
def _extract_year(payload: Dict[str, Any]) -> Optional[str]:
|
||||
detail = (payload.get("tmdb") or {}).get("detail") or {}
|
||||
value = detail.get("first_air_date") if payload.get("type") == "tv" else detail.get("release_date")
|
||||
if not value or not isinstance(value, str):
|
||||
return None
|
||||
return value[:4] if len(value) >= 4 else value
|
||||
|
||||
|
||||
def _extract_genres(payload: Dict[str, Any]) -> List[str]:
|
||||
detail = (payload.get("tmdb") or {}).get("detail") or {}
|
||||
genres = detail.get("genres") or []
|
||||
out: List[str] = []
|
||||
if isinstance(genres, list):
|
||||
for g in genres:
|
||||
if isinstance(g, dict) and g.get("name"):
|
||||
out.append(str(g["name"]))
|
||||
return out
|
||||
|
||||
|
||||
def _summarize(item_id: str, payload: Dict[str, Any], mtime_iso: str) -> Dict[str, Any]:
|
||||
detail = (payload.get("tmdb") or {}).get("detail") or {}
|
||||
media_type = payload.get("type") or "unknown"
|
||||
episodes = payload.get("episodes") or []
|
||||
seasons = {e.get("season") for e in episodes if isinstance(e, dict) and e.get("season") is not None}
|
||||
|
||||
return {
|
||||
"id": item_id,
|
||||
"type": media_type,
|
||||
"title": _extract_title(payload),
|
||||
"year": _extract_year(payload),
|
||||
"overview": detail.get("overview"),
|
||||
"poster_path": detail.get("poster_path"),
|
||||
"backdrop_path": detail.get("backdrop_path"),
|
||||
"genres": _extract_genres(payload),
|
||||
"tmdb_id": (payload.get("tmdb") or {}).get("id"),
|
||||
"source_path": payload.get("source_path"),
|
||||
"scraped_at": payload.get("scraped_at"),
|
||||
"updated_at": mtime_iso,
|
||||
"episodes_count": len(episodes) if isinstance(episodes, list) else 0,
|
||||
"seasons_count": len(seasons),
|
||||
"vote_average": detail.get("vote_average"),
|
||||
"vote_count": detail.get("vote_count"),
|
||||
}
|
||||
|
||||
|
||||
def _iter_library_files() -> List[tuple[str, Path]]:
|
||||
files: List[tuple[str, Path]] = []
|
||||
for sub in ("tv", "movie"):
|
||||
folder = DATA_ROOT / sub
|
||||
if not folder.exists():
|
||||
continue
|
||||
for p in folder.glob("*.json"):
|
||||
if not p.is_file():
|
||||
continue
|
||||
files.append((sub, p))
|
||||
return files
|
||||
|
||||
|
||||
@router.get("/library")
|
||||
async def list_library(
|
||||
q: str | None = Query(None, description="搜索关键字(标题/简介)"),
|
||||
media_type: str | None = Query(None, alias="type", description="tv 或 movie"),
|
||||
):
|
||||
items: List[Dict[str, Any]] = []
|
||||
keyword = (q or "").strip().lower()
|
||||
type_filter = (media_type or "").strip().lower()
|
||||
if type_filter and type_filter not in {"tv", "movie"}:
|
||||
raise HTTPException(status_code=400, detail="type must be tv or movie")
|
||||
|
||||
for _sub, path in _iter_library_files():
|
||||
item_id = path.stem
|
||||
try:
|
||||
payload = _read_json(path)
|
||||
except Exception:
|
||||
continue
|
||||
if type_filter and str(payload.get("type") or "").lower() != type_filter:
|
||||
continue
|
||||
summary = _summarize(item_id, payload, _file_mtime_iso(path))
|
||||
if keyword:
|
||||
haystack = f"{summary.get('title') or ''} {summary.get('overview') or ''}".lower()
|
||||
if keyword not in haystack:
|
||||
continue
|
||||
items.append(summary)
|
||||
|
||||
items.sort(key=lambda x: x.get("updated_at") or "", reverse=True)
|
||||
return success(items)
|
||||
|
||||
|
||||
@router.get("/library/{item_id}")
|
||||
async def get_library_item(item_id: str):
|
||||
candidates = [
|
||||
DATA_ROOT / "tv" / f"{item_id}.json",
|
||||
DATA_ROOT / "movie" / f"{item_id}.json",
|
||||
]
|
||||
path = next((p for p in candidates if p.exists()), None)
|
||||
if not path:
|
||||
raise HTTPException(status_code=404, detail="Item not found")
|
||||
|
||||
payload = _read_json(path)
|
||||
payload["id"] = item_id
|
||||
payload["updated_at"] = _file_mtime_iso(path)
|
||||
return success(payload)
|
||||
|
||||
@@ -1,138 +1,273 @@
|
||||
"""
|
||||
插件服务模块
|
||||
|
||||
负责插件的安装、卸载等管理操作
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import re
|
||||
import logging
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import aiofiles
|
||||
import httpx
|
||||
from fastapi import HTTPException
|
||||
|
||||
from domain.plugins.types import PluginCreate, PluginManifestUpdate, PluginOut
|
||||
from .loader import PluginLoadError, PluginLoader
|
||||
from .types import (
|
||||
PluginInstallResult,
|
||||
PluginManifest,
|
||||
PluginOut,
|
||||
)
|
||||
from models.database import Plugin
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PluginService:
|
||||
"""插件服务"""
|
||||
|
||||
_plugins_root = Path("data/plugins")
|
||||
|
||||
@classmethod
|
||||
def _folder_name(cls, rec: Plugin) -> str:
|
||||
if rec.key:
|
||||
safe = re.sub(r"[^A-Za-z0-9_.-]", "_", rec.key)
|
||||
return safe or str(rec.id)
|
||||
return str(rec.id)
|
||||
# ========== 工具方法 ==========
|
||||
|
||||
@classmethod
|
||||
def _bundle_dir_from_rec(cls, rec: Plugin) -> Path:
|
||||
return cls._plugins_root / cls._folder_name(rec) / "current"
|
||||
def _get_plugin_dir(cls, plugin_key: str) -> Path:
|
||||
"""获取插件目录"""
|
||||
return cls._plugins_root / plugin_key
|
||||
|
||||
@classmethod
|
||||
def _bundle_path_from_rec(cls, rec: Plugin) -> Path:
|
||||
return cls._bundle_dir_from_rec(rec) / "index.js"
|
||||
def _get_bundle_path(cls, rec: Plugin) -> Path:
|
||||
"""获取前端 bundle 路径"""
|
||||
plugin_dir = cls._get_plugin_dir(rec.key)
|
||||
# 从 manifest 读取
|
||||
if rec.manifest:
|
||||
frontend = rec.manifest.get("frontend", {})
|
||||
entry = frontend.get("entry")
|
||||
if entry:
|
||||
return plugin_dir / entry
|
||||
# 默认位置
|
||||
return plugin_dir / "frontend" / "index.js"
|
||||
|
||||
@classmethod
|
||||
async def _download_bundle(cls, rec: Plugin, url: str) -> None:
|
||||
dest_dir = cls._bundle_dir_from_rec(rec)
|
||||
dest_dir.mkdir(parents=True, exist_ok=True)
|
||||
dest_path = cls._bundle_path_from_rec(rec)
|
||||
tmp_path = dest_path.with_suffix(".tmp")
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
|
||||
async with client.stream("GET", url) as resp:
|
||||
resp.raise_for_status()
|
||||
async with aiofiles.open(tmp_path, "wb") as f:
|
||||
async for chunk in resp.aiter_bytes(chunk_size=65536):
|
||||
if not chunk:
|
||||
continue
|
||||
await f.write(chunk)
|
||||
tmp_path.replace(dest_path)
|
||||
except Exception:
|
||||
with contextlib.suppress(Exception):
|
||||
if tmp_path.exists():
|
||||
tmp_path.unlink()
|
||||
raise
|
||||
|
||||
@classmethod
|
||||
async def _ensure_bundle(cls, plugin_id: int) -> Path:
|
||||
rec = await cls._get_or_404(plugin_id)
|
||||
bundle_path = cls._bundle_path_from_rec(rec)
|
||||
if bundle_path.exists():
|
||||
return bundle_path
|
||||
|
||||
legacy = cls._plugins_root / str(rec.id) / "current" / "index.js"
|
||||
if legacy.exists():
|
||||
return legacy
|
||||
|
||||
raise HTTPException(status_code=404, detail="Plugin bundle not found")
|
||||
|
||||
@classmethod
|
||||
async def get_bundle_path(cls, plugin_id: int) -> Path:
|
||||
return await cls._ensure_bundle(plugin_id)
|
||||
|
||||
@classmethod
|
||||
async def create(cls, payload: PluginCreate) -> PluginOut:
|
||||
rec = await Plugin.create(**payload.model_dump())
|
||||
try:
|
||||
await cls._download_bundle(rec, rec.url)
|
||||
except Exception as exc:
|
||||
with contextlib.suppress(Exception):
|
||||
await rec.delete()
|
||||
raise HTTPException(status_code=400, detail=f"Failed to fetch plugin: {exc}")
|
||||
return PluginOut.model_validate(rec)
|
||||
|
||||
@classmethod
|
||||
async def list_plugins(cls) -> list[PluginOut]:
|
||||
rows = await Plugin.all().order_by("-id")
|
||||
return [PluginOut.model_validate(r) for r in rows]
|
||||
|
||||
@classmethod
|
||||
async def _get_or_404(cls, plugin_id: int) -> Plugin:
|
||||
rec = await Plugin.get_or_none(id=plugin_id)
|
||||
async def _get_by_key_or_404(cls, key: str) -> Plugin:
|
||||
"""通过 key 获取插件,不存在则返回 404"""
|
||||
rec = await Plugin.get_or_none(key=key)
|
||||
if not rec:
|
||||
raise HTTPException(status_code=404, detail="Plugin not found")
|
||||
return rec
|
||||
|
||||
@classmethod
|
||||
async def delete(cls, plugin_id: int) -> None:
|
||||
rec = await cls._get_or_404(plugin_id)
|
||||
await rec.delete()
|
||||
with contextlib.suppress(Exception):
|
||||
dirs = {cls._bundle_dir_from_rec(rec).parent, cls._plugins_root / str(rec.id)}
|
||||
for plugin_dir in dirs:
|
||||
if plugin_dir.exists():
|
||||
shutil.rmtree(plugin_dir)
|
||||
async def _get_by_key_or_id(cls, key_or_id: Union[str, int]) -> Plugin:
|
||||
"""通过 key 或 ID 获取插件"""
|
||||
# 尝试作为 ID
|
||||
if isinstance(key_or_id, int) or (isinstance(key_or_id, str) and key_or_id.isdigit()):
|
||||
plugin_id = int(key_or_id)
|
||||
rec = await Plugin.get_or_none(id=plugin_id)
|
||||
if rec:
|
||||
return rec
|
||||
# 尝试作为 key
|
||||
if isinstance(key_or_id, str):
|
||||
rec = await Plugin.get_or_none(key=key_or_id)
|
||||
if rec:
|
||||
return rec
|
||||
raise HTTPException(status_code=404, detail="Plugin not found")
|
||||
|
||||
# ========== 安装 ==========
|
||||
|
||||
@classmethod
|
||||
async def update(cls, plugin_id: int, payload: PluginCreate) -> PluginOut:
|
||||
rec = await cls._get_or_404(plugin_id)
|
||||
url_changed = rec.url != payload.url
|
||||
if url_changed:
|
||||
try:
|
||||
await cls._download_bundle(rec, payload.url)
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=400, detail=f"Failed to fetch plugin: {exc}")
|
||||
rec.url = payload.url
|
||||
rec.enabled = payload.enabled
|
||||
await rec.save()
|
||||
return PluginOut.model_validate(rec)
|
||||
async def install_package(cls, file_content: bytes, filename: str) -> PluginInstallResult:
|
||||
"""
|
||||
安装 .foxpkg 插件包
|
||||
|
||||
Args:
|
||||
file_content: 插件包内容
|
||||
filename: 文件名
|
||||
|
||||
Returns:
|
||||
安装结果
|
||||
"""
|
||||
errors: List[str] = []
|
||||
|
||||
try:
|
||||
# 解包
|
||||
manifest, plugin_dir = PluginLoader.unpack_foxpkg(file_content)
|
||||
plugin_key = manifest.key
|
||||
|
||||
# 检查是否已存在
|
||||
existing = await Plugin.get_or_none(key=plugin_key)
|
||||
if existing:
|
||||
# 更新现有插件
|
||||
logger.info(f"更新插件: {plugin_key}")
|
||||
rec = existing
|
||||
else:
|
||||
# 创建新插件
|
||||
logger.info(f"安装新插件: {plugin_key}")
|
||||
rec = Plugin(key=plugin_key)
|
||||
|
||||
# 更新字段
|
||||
rec.name = manifest.name
|
||||
rec.version = manifest.version
|
||||
rec.description = manifest.description
|
||||
rec.author = manifest.author
|
||||
rec.website = manifest.website
|
||||
rec.github = manifest.github
|
||||
rec.license = manifest.license
|
||||
rec.manifest = manifest.model_dump(mode="json")
|
||||
|
||||
# 从 manifest.frontend 提取前端配置
|
||||
if manifest.frontend:
|
||||
rec.open_app = manifest.frontend.open_app or False
|
||||
rec.supported_exts = manifest.frontend.supported_exts
|
||||
rec.default_bounds = manifest.frontend.default_bounds
|
||||
rec.default_maximized = manifest.frontend.default_maximized
|
||||
rec.icon = manifest.frontend.icon
|
||||
|
||||
@classmethod
|
||||
async def update_manifest(
|
||||
cls, plugin_id: int, manifest: PluginManifestUpdate
|
||||
) -> PluginOut:
|
||||
rec = await cls._get_or_404(plugin_id)
|
||||
old_dir = cls._bundle_dir_from_rec(rec).parent
|
||||
updates = manifest.model_dump(exclude_none=True)
|
||||
if updates:
|
||||
for key, value in updates.items():
|
||||
setattr(rec, key, value)
|
||||
await rec.save()
|
||||
new_dir = cls._bundle_dir_from_rec(rec).parent
|
||||
if rec.key and new_dir != old_dir:
|
||||
candidate_dir = old_dir if old_dir.exists() else (cls._plugins_root / str(rec.id))
|
||||
if candidate_dir.exists():
|
||||
new_dir.parent.mkdir(parents=True, exist_ok=True)
|
||||
with contextlib.suppress(Exception):
|
||||
if new_dir.exists():
|
||||
shutil.rmtree(new_dir)
|
||||
shutil.move(str(candidate_dir), str(new_dir))
|
||||
|
||||
# 加载后端组件(如果有)
|
||||
loaded_routes: List[str] = []
|
||||
loaded_processors: List[str] = []
|
||||
|
||||
if manifest.backend:
|
||||
# 加载路由
|
||||
if manifest.backend.routes:
|
||||
try:
|
||||
from main import app
|
||||
routers = PluginLoader.load_all_routes(plugin_key, manifest)
|
||||
for router in routers:
|
||||
app.include_router(router)
|
||||
loaded_routes.append(router.prefix)
|
||||
except PluginLoadError as e:
|
||||
errors.append(f"路由加载失败: {e}")
|
||||
logger.error(f"插件 {plugin_key} 路由加载失败: {e}")
|
||||
except Exception as e:
|
||||
errors.append(f"路由加载失败: {e}")
|
||||
logger.exception(f"插件 {plugin_key} 路由加载异常")
|
||||
|
||||
# 加载处理器
|
||||
if manifest.backend.processors:
|
||||
try:
|
||||
processor_types = PluginLoader.load_all_processors(plugin_key, manifest)
|
||||
loaded_processors = processor_types
|
||||
except PluginLoadError as e:
|
||||
errors.append(f"处理器加载失败: {e}")
|
||||
logger.error(f"插件 {plugin_key} 处理器加载失败: {e}")
|
||||
except Exception as e:
|
||||
errors.append(f"处理器加载失败: {e}")
|
||||
logger.exception(f"插件 {plugin_key} 处理器加载异常")
|
||||
|
||||
# 更新加载状态
|
||||
rec.loaded_routes = loaded_routes if loaded_routes else None
|
||||
rec.loaded_processors = loaded_processors if loaded_processors else None
|
||||
await rec.save()
|
||||
|
||||
return PluginInstallResult(
|
||||
success=True,
|
||||
plugin=PluginOut.model_validate(rec),
|
||||
message="安装成功" if not errors else "安装完成,但有部分组件加载失败",
|
||||
errors=errors if errors else None,
|
||||
)
|
||||
|
||||
except PluginLoadError as e:
|
||||
logger.error(f"插件安装失败: {e}")
|
||||
return PluginInstallResult(
|
||||
success=False,
|
||||
message=str(e),
|
||||
errors=[str(e)],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("插件安装异常")
|
||||
return PluginInstallResult(
|
||||
success=False,
|
||||
message=f"安装失败: {e}",
|
||||
errors=[str(e)],
|
||||
)
|
||||
|
||||
# ========== 查询 ==========
|
||||
|
||||
@classmethod
|
||||
async def list_plugins(cls) -> List[PluginOut]:
|
||||
"""获取所有插件列表"""
|
||||
rows = await Plugin.all().order_by("-id")
|
||||
for rec in rows:
|
||||
try:
|
||||
manifest = PluginLoader.read_manifest(rec.key)
|
||||
if manifest:
|
||||
rec.manifest = manifest.model_dump(mode="json")
|
||||
except Exception:
|
||||
continue
|
||||
return [PluginOut.model_validate(r) for r in rows]
|
||||
|
||||
@classmethod
|
||||
async def get_plugin(cls, key_or_id: Union[str, int]) -> PluginOut:
|
||||
"""获取单个插件详情"""
|
||||
rec = await cls._get_by_key_or_id(key_or_id)
|
||||
try:
|
||||
manifest = PluginLoader.read_manifest(rec.key)
|
||||
if manifest:
|
||||
rec.manifest = manifest.model_dump(mode="json")
|
||||
except Exception:
|
||||
pass
|
||||
return PluginOut.model_validate(rec)
|
||||
|
||||
@classmethod
|
||||
async def get_bundle_path(cls, key_or_id: Union[str, int]) -> Path:
|
||||
"""获取插件前端 bundle 路径"""
|
||||
rec = await cls._get_by_key_or_id(key_or_id)
|
||||
bundle_path = cls._get_bundle_path(rec)
|
||||
if not bundle_path.exists():
|
||||
raise HTTPException(status_code=404, detail="Plugin bundle not found")
|
||||
return bundle_path
|
||||
|
||||
@classmethod
|
||||
async def get_asset_path(cls, key: str, asset_path: str) -> Path:
|
||||
"""获取插件静态资源路径"""
|
||||
rec = await cls._get_by_key_or_404(key)
|
||||
plugin_dir = cls._get_plugin_dir(rec.key)
|
||||
|
||||
# 安全检查:防止路径遍历
|
||||
asset_path = asset_path.lstrip("/")
|
||||
if ".." in asset_path:
|
||||
raise HTTPException(status_code=400, detail="Invalid asset path")
|
||||
|
||||
full_path = plugin_dir / asset_path
|
||||
if not full_path.exists():
|
||||
raise HTTPException(status_code=404, detail="Asset not found")
|
||||
|
||||
# 确保路径在插件目录内
|
||||
try:
|
||||
full_path.resolve().relative_to(plugin_dir.resolve())
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail="Invalid asset path")
|
||||
|
||||
return full_path
|
||||
|
||||
# ========== 管理操作 ==========
|
||||
|
||||
@classmethod
|
||||
async def delete(cls, key_or_id: Union[str, int]) -> None:
|
||||
"""删除/卸载插件"""
|
||||
rec = await cls._get_by_key_or_id(key_or_id)
|
||||
|
||||
# 获取 manifest 用于卸载组件
|
||||
manifest: Optional[PluginManifest] = None
|
||||
if rec.manifest:
|
||||
try:
|
||||
manifest = PluginManifest.model_validate(rec.manifest)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 卸载后端组件
|
||||
if manifest:
|
||||
PluginLoader.unload_plugin(rec.key, manifest)
|
||||
|
||||
# 删除数据库记录
|
||||
await rec.delete()
|
||||
|
||||
# 删除文件
|
||||
with contextlib.suppress(Exception):
|
||||
plugin_dir = cls._get_plugin_dir(rec.key)
|
||||
if plugin_dir.exists():
|
||||
shutil.rmtree(plugin_dir)
|
||||
|
||||
logger.info(f"插件 {rec.key} 已卸载")
|
||||
|
||||
115
domain/plugins/startup.py
Normal file
115
domain/plugins/startup.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""
|
||||
插件启动加载模块
|
||||
|
||||
负责在应用启动时加载所有已安装的插件
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, List, Tuple
|
||||
|
||||
from .loader import PluginLoadError, PluginLoader
|
||||
from .types import PluginManifest
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastapi import FastAPI
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def load_installed_plugins(app: "FastAPI") -> Tuple[int, List[str]]:
|
||||
"""
|
||||
加载所有已安装的插件
|
||||
|
||||
Args:
|
||||
app: FastAPI 应用实例
|
||||
|
||||
Returns:
|
||||
(成功加载数量, 错误列表)
|
||||
"""
|
||||
from models.database import Plugin
|
||||
|
||||
errors: List[str] = []
|
||||
loaded_count = 0
|
||||
|
||||
try:
|
||||
plugins = await Plugin.all()
|
||||
except Exception as e:
|
||||
logger.error(f"查询插件列表失败: {e}")
|
||||
return 0, [f"查询插件列表失败: {e}"]
|
||||
|
||||
for plugin in plugins:
|
||||
if not plugin.key:
|
||||
continue
|
||||
|
||||
try:
|
||||
# 获取 manifest
|
||||
manifest = None
|
||||
if plugin.manifest:
|
||||
try:
|
||||
manifest = PluginManifest.model_validate(plugin.manifest)
|
||||
except Exception:
|
||||
# 尝试从文件系统读取
|
||||
manifest = PluginLoader.read_manifest(plugin.key)
|
||||
else:
|
||||
manifest = PluginLoader.read_manifest(plugin.key)
|
||||
|
||||
if not manifest:
|
||||
logger.warning(f"插件 {plugin.key} 缺少 manifest,跳过加载")
|
||||
continue
|
||||
|
||||
# 加载后端路由
|
||||
loaded_routes: List[str] = []
|
||||
if manifest.backend and manifest.backend.routes:
|
||||
try:
|
||||
routers = PluginLoader.load_all_routes(plugin.key, manifest)
|
||||
for router in routers:
|
||||
app.include_router(router)
|
||||
loaded_routes.append(router.prefix)
|
||||
logger.info(f"插件 {plugin.key} 加载了 {len(routers)} 个路由")
|
||||
except PluginLoadError as e:
|
||||
errors.append(f"插件 {plugin.key} 路由加载失败: {e}")
|
||||
logger.error(f"插件 {plugin.key} 路由加载失败: {e}")
|
||||
|
||||
# 加载处理器
|
||||
loaded_processors: List[str] = []
|
||||
if manifest.backend and manifest.backend.processors:
|
||||
try:
|
||||
processor_types = PluginLoader.load_all_processors(plugin.key, manifest)
|
||||
loaded_processors = processor_types
|
||||
logger.info(f"插件 {plugin.key} 注册了 {len(processor_types)} 个处理器")
|
||||
except PluginLoadError as e:
|
||||
errors.append(f"插件 {plugin.key} 处理器加载失败: {e}")
|
||||
logger.error(f"插件 {plugin.key} 处理器加载失败: {e}")
|
||||
|
||||
# 更新数据库记录
|
||||
plugin.loaded_routes = loaded_routes if loaded_routes else None
|
||||
plugin.loaded_processors = loaded_processors if loaded_processors else None
|
||||
await plugin.save()
|
||||
|
||||
loaded_count += 1
|
||||
logger.info(f"插件 {plugin.key} 加载完成")
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"插件 {plugin.key} 加载异常: {e}"
|
||||
errors.append(error_msg)
|
||||
logger.exception(error_msg)
|
||||
|
||||
return loaded_count, errors
|
||||
|
||||
|
||||
async def init_plugins(app: "FastAPI") -> None:
|
||||
"""
|
||||
初始化插件系统
|
||||
|
||||
在应用启动时调用
|
||||
"""
|
||||
logger.info("开始加载已安装插件...")
|
||||
|
||||
loaded_count, errors = await load_installed_plugins(app)
|
||||
|
||||
if errors:
|
||||
logger.warning(f"插件加载完成,共 {loaded_count} 个成功,{len(errors)} 个错误")
|
||||
for error in errors:
|
||||
logger.warning(f" - {error}")
|
||||
else:
|
||||
logger.info(f"插件加载完成,共 {loaded_count} 个插件")
|
||||
@@ -1,48 +1,119 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import AliasChoices, BaseModel, ConfigDict, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class PluginCreate(BaseModel):
|
||||
url: str = Field(min_length=1)
|
||||
enabled: bool = True
|
||||
# ========== Manifest 相关类型 ==========
|
||||
|
||||
|
||||
class PluginManifestUpdate(BaseModel):
|
||||
class ManifestFrontend(BaseModel):
|
||||
"""manifest.json 中的 frontend 配置"""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True, extra="ignore")
|
||||
|
||||
key: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
version: Optional[str] = None
|
||||
entry: Optional[str] = Field(default=None, description="前端入口文件路径")
|
||||
styles: Optional[List[str]] = Field(default=None, description="前端样式文件路径列表(相对插件根目录)")
|
||||
open_app: Optional[bool] = Field(
|
||||
default=None,
|
||||
validation_alias=AliasChoices("open_app", "openApp"),
|
||||
alias="openApp",
|
||||
description="是否支持独立打开",
|
||||
)
|
||||
supported_exts: Optional[List[str]] = Field(
|
||||
default=None,
|
||||
validation_alias=AliasChoices("supported_exts", "supportedExts"),
|
||||
alias="supportedExts",
|
||||
description="支持的文件扩展名列表",
|
||||
)
|
||||
default_bounds: Optional[Dict[str, Any]] = Field(
|
||||
default=None,
|
||||
validation_alias=AliasChoices("default_bounds", "defaultBounds"),
|
||||
alias="defaultBounds",
|
||||
description="默认窗口尺寸",
|
||||
)
|
||||
default_maximized: Optional[bool] = Field(
|
||||
default=None,
|
||||
validation_alias=AliasChoices("default_maximized", "defaultMaximized"),
|
||||
alias="defaultMaximized",
|
||||
description="是否默认最大化",
|
||||
)
|
||||
icon: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
author: Optional[str] = None
|
||||
website: Optional[str] = None
|
||||
github: Optional[str] = None
|
||||
icon: Optional[str] = Field(default=None, description="图标路径")
|
||||
use_system_window: Optional[bool] = Field(
|
||||
default=None,
|
||||
alias="useSystemWindow",
|
||||
description="是否使用系统窗口",
|
||||
)
|
||||
|
||||
|
||||
class ManifestRouteConfig(BaseModel):
|
||||
"""manifest.json 中的路由配置"""
|
||||
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
|
||||
module: str = Field(..., description="路由模块路径")
|
||||
prefix: str = Field(..., description="路由前缀")
|
||||
tags: Optional[List[str]] = Field(default=None, description="API 标签")
|
||||
|
||||
|
||||
class ManifestProcessorConfig(BaseModel):
|
||||
"""manifest.json 中的处理器配置"""
|
||||
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
|
||||
module: str = Field(..., description="处理器模块路径")
|
||||
type: str = Field(..., description="处理器类型标识")
|
||||
name: Optional[str] = Field(default=None, description="处理器显示名称")
|
||||
|
||||
|
||||
class ManifestBackend(BaseModel):
|
||||
"""manifest.json 中的 backend 配置"""
|
||||
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
|
||||
routes: Optional[List[ManifestRouteConfig]] = Field(default=None, description="路由列表")
|
||||
processors: Optional[List[ManifestProcessorConfig]] = Field(
|
||||
default=None, description="处理器列表"
|
||||
)
|
||||
|
||||
|
||||
class ManifestDependencies(BaseModel):
|
||||
"""manifest.json 中的依赖配置"""
|
||||
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
|
||||
python: Optional[str] = Field(default=None, description="Python 版本要求")
|
||||
packages: Optional[List[str]] = Field(default=None, description="Python 包依赖列表")
|
||||
|
||||
|
||||
class PluginManifest(BaseModel):
|
||||
"""完整的 manifest.json 结构"""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True, extra="ignore")
|
||||
|
||||
foxpkg: str = Field(default="1.0", description="foxpkg 格式版本")
|
||||
key: str = Field(..., min_length=1, description="插件唯一标识")
|
||||
name: str = Field(..., min_length=1, description="插件名称")
|
||||
version: str = Field(default="1.0.0", description="插件版本")
|
||||
description: Optional[str] = Field(default=None, description="插件描述")
|
||||
i18n: Optional[Dict[str, Dict[str, str]]] = Field(
|
||||
default=None,
|
||||
description="多语言信息(name/description),例如:{'en': {'name': '...', 'description': '...'}}",
|
||||
)
|
||||
author: Optional[str] = Field(default=None, description="作者")
|
||||
website: Optional[str] = Field(default=None, description="网站")
|
||||
github: Optional[str] = Field(default=None, description="GitHub 地址")
|
||||
license: Optional[str] = Field(default=None, description="许可证")
|
||||
|
||||
frontend: Optional[ManifestFrontend] = Field(default=None, description="前端配置")
|
||||
backend: Optional[ManifestBackend] = Field(default=None, description="后端配置")
|
||||
dependencies: Optional[ManifestDependencies] = Field(default=None, description="依赖配置")
|
||||
|
||||
|
||||
# ========== API 请求/响应类型 ==========
|
||||
|
||||
|
||||
class PluginOut(BaseModel):
|
||||
"""插件输出模型"""
|
||||
|
||||
id: int
|
||||
url: str
|
||||
enabled: bool
|
||||
key: str
|
||||
open_app: bool = False
|
||||
key: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
version: Optional[str] = None
|
||||
supported_exts: Optional[List[str]] = None
|
||||
@@ -53,5 +124,20 @@ class PluginOut(BaseModel):
|
||||
author: Optional[str] = None
|
||||
website: Optional[str] = None
|
||||
github: Optional[str] = None
|
||||
license: Optional[str] = None
|
||||
|
||||
# 新增字段
|
||||
manifest: Optional[Dict[str, Any]] = None
|
||||
loaded_routes: Optional[List[str]] = None
|
||||
loaded_processors: Optional[List[str]] = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class PluginInstallResult(BaseModel):
|
||||
"""安装结果"""
|
||||
|
||||
success: bool
|
||||
plugin: Optional[PluginOut] = None
|
||||
message: Optional[str] = None
|
||||
errors: Optional[List[str]] = None
|
||||
|
||||
35
domain/processors/__init__.py
Normal file
35
domain/processors/__init__.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from .base import BaseProcessor
|
||||
from .registry import (
|
||||
CONFIG_SCHEMAS,
|
||||
TYPE_MAP,
|
||||
get_config_schema,
|
||||
get_config_schemas,
|
||||
get_last_discovery_errors,
|
||||
get_module_path,
|
||||
reload_processors,
|
||||
)
|
||||
from .service import (
|
||||
ProcessorService,
|
||||
get_processor,
|
||||
list_processors,
|
||||
reload_processor_modules,
|
||||
)
|
||||
from .types import ProcessDirectoryRequest, ProcessRequest, UpdateSourceRequest
|
||||
|
||||
__all__ = [
|
||||
"BaseProcessor",
|
||||
"CONFIG_SCHEMAS",
|
||||
"TYPE_MAP",
|
||||
"get_config_schema",
|
||||
"get_config_schemas",
|
||||
"get_last_discovery_errors",
|
||||
"get_module_path",
|
||||
"reload_processors",
|
||||
"ProcessorService",
|
||||
"get_processor",
|
||||
"list_processors",
|
||||
"reload_processor_modules",
|
||||
"ProcessDirectoryRequest",
|
||||
"ProcessRequest",
|
||||
"UpdateSourceRequest",
|
||||
]
|
||||
@@ -4,10 +4,9 @@ from fastapi import APIRouter, Body, Depends, Request
|
||||
|
||||
from api.response import success
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.auth.service import get_current_active_user
|
||||
from domain.auth.types import User
|
||||
from domain.processors.service import ProcessorService
|
||||
from domain.processors.types import (
|
||||
from domain.auth import User, get_current_active_user
|
||||
from .service import ProcessorService
|
||||
from .types import (
|
||||
ProcessDirectoryRequest,
|
||||
ProcessRequest,
|
||||
UpdateSourceRequest,
|
||||
|
||||
@@ -8,8 +8,15 @@ from fastapi.responses import Response
|
||||
from PIL import Image
|
||||
|
||||
from ..base import BaseProcessor
|
||||
from domain.ai.inference import describe_image_base64, get_text_embedding, provider_service
|
||||
from domain.ai.service import VectorDBService, DEFAULT_VECTOR_DIMENSION
|
||||
from domain.ai import (
|
||||
DEFAULT_VECTOR_DIMENSION,
|
||||
FILE_COLLECTION_NAME,
|
||||
VECTOR_COLLECTION_NAME,
|
||||
VectorDBService,
|
||||
describe_image_base64,
|
||||
get_text_embedding,
|
||||
provider_service,
|
||||
)
|
||||
|
||||
|
||||
CHUNK_SIZE = 800
|
||||
@@ -112,18 +119,20 @@ class VectorIndexProcessor:
|
||||
action = config.get("action", "create")
|
||||
index_type = config.get("index_type", "vector")
|
||||
vector_db = VectorDBService()
|
||||
collection_name = "vector_collection"
|
||||
vector_collection = VECTOR_COLLECTION_NAME
|
||||
file_collection = FILE_COLLECTION_NAME
|
||||
|
||||
if action == "destroy":
|
||||
await vector_db.delete_vector(collection_name, path)
|
||||
target_collection = file_collection if index_type == "simple" else vector_collection
|
||||
await vector_db.delete_vector(target_collection, path)
|
||||
return Response(content=f"文件 {path} 的 {index_type} 索引已销毁", media_type="text/plain")
|
||||
|
||||
mime_type = _guess_mime(path)
|
||||
|
||||
if index_type == "simple":
|
||||
await vector_db.ensure_collection(collection_name, vector=False)
|
||||
await vector_db.delete_vector(collection_name, path)
|
||||
await vector_db.upsert_vector(collection_name, {
|
||||
await vector_db.ensure_collection(file_collection, vector=False)
|
||||
await vector_db.delete_vector(file_collection, path)
|
||||
await vector_db.upsert_vector(file_collection, {
|
||||
"path": path,
|
||||
"source_path": path,
|
||||
"chunk_id": "filename",
|
||||
@@ -146,8 +155,8 @@ class VectorIndexProcessor:
|
||||
if vector_dim <= 0:
|
||||
vector_dim = DEFAULT_VECTOR_DIMENSION
|
||||
|
||||
await vector_db.ensure_collection(collection_name, vector=True, dim=vector_dim)
|
||||
await vector_db.delete_vector(collection_name, path)
|
||||
await vector_db.ensure_collection(vector_collection, vector=True, dim=vector_dim)
|
||||
await vector_db.delete_vector(vector_collection, path)
|
||||
|
||||
if file_ext in ["jpg", "jpeg", "png", "bmp"]:
|
||||
processed_bytes, compression = _compress_image_for_embedding(input_bytes)
|
||||
@@ -155,7 +164,7 @@ class VectorIndexProcessor:
|
||||
description = await describe_image_base64(base64_image)
|
||||
embedding = await get_text_embedding(description)
|
||||
image_mime = "image/jpeg" if compression else mime_type
|
||||
await vector_db.upsert_vector(collection_name, {
|
||||
await vector_db.upsert_vector(vector_collection, {
|
||||
"path": _chunk_key(path, "image"),
|
||||
"source_path": path,
|
||||
"chunk_id": "image",
|
||||
@@ -177,7 +186,7 @@ class VectorIndexProcessor:
|
||||
|
||||
chunks = _chunk_text(text)
|
||||
if not chunks:
|
||||
await vector_db.upsert_vector(collection_name, {
|
||||
await vector_db.upsert_vector(vector_collection, {
|
||||
"path": _chunk_key(path, "0"),
|
||||
"source_path": path,
|
||||
"chunk_id": "0",
|
||||
@@ -194,7 +203,7 @@ class VectorIndexProcessor:
|
||||
chunk_count = 0
|
||||
for chunk_id, chunk_text, start, end in chunks:
|
||||
embedding = await get_text_embedding(chunk_text)
|
||||
await vector_db.upsert_vector(collection_name, {
|
||||
await vector_db.upsert_vector(vector_collection, {
|
||||
"path": _chunk_key(path, str(chunk_id)),
|
||||
"source_path": path,
|
||||
"chunk_id": str(chunk_id),
|
||||
@@ -213,15 +222,15 @@ class VectorIndexProcessor:
|
||||
return Response(content="文本文件已索引", media_type="text/plain")
|
||||
|
||||
# 其他类型暂未支持向量索引,回退为文件名索引
|
||||
await vector_db.delete_vector(collection_name, path)
|
||||
await vector_db.upsert_vector(collection_name, {
|
||||
"path": _chunk_key(path, "fallback"),
|
||||
await vector_db.ensure_collection(file_collection, vector=False)
|
||||
await vector_db.delete_vector(file_collection, path)
|
||||
await vector_db.upsert_vector(file_collection, {
|
||||
"path": path,
|
||||
"source_path": path,
|
||||
"chunk_id": "filename",
|
||||
"mime": mime_type,
|
||||
"type": "filename",
|
||||
"name": os.path.basename(path),
|
||||
"embedding": [0.0] * vector_dim,
|
||||
})
|
||||
return Response(content="暂不支持该类型的向量索引,已创建文件名索引", media_type="text/plain")
|
||||
|
||||
|
||||
@@ -1,396 +0,0 @@
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import httpx
|
||||
|
||||
from domain.virtual_fs.service import VirtualFSService
|
||||
from domain.virtual_fs.thumbnail import VIDEO_EXT, is_video_filename
|
||||
|
||||
|
||||
DATA_ROOT = Path("data/.video")
|
||||
TMDB_BASE_URL = "https://api.themoviedb.org/3"
|
||||
|
||||
|
||||
def _sha1(text: str) -> str:
|
||||
return hashlib.sha1(text.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def _store_path(media_type: str, source_path: str) -> Path:
|
||||
subdir = "tv" if media_type == "tv" else "movie"
|
||||
return DATA_ROOT / subdir / f"{_sha1(source_path)}.json"
|
||||
|
||||
|
||||
def _write_json(path: Path, payload: dict) -> None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
|
||||
|
||||
_CLEAN_TAGS_RE = re.compile(
|
||||
r"\b("
|
||||
r"2160p|1080p|720p|480p|4k|hdr|dv|dolby|atmos|"
|
||||
r"x264|x265|h264|h265|hevc|av1|aac|dts|flac|"
|
||||
r"bluray|bdrip|web[- ]?dl|webrip|dvdrip|remux|proper|repack"
|
||||
r")\b",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def _clean_query_name(raw: str) -> str:
|
||||
name = raw
|
||||
name = name.replace(".", " ").replace("_", " ")
|
||||
name = re.sub(r"\[[^\]]*\]", " ", name)
|
||||
name = re.sub(r"\([^\)]*\)", " ", name)
|
||||
name = _CLEAN_TAGS_RE.sub(" ", name)
|
||||
name = re.sub(r"\s+", " ", name).strip()
|
||||
return name
|
||||
|
||||
|
||||
def _guess_name_from_path(path: str, is_dir: bool) -> str:
|
||||
norm = path.rstrip("/") if is_dir else path
|
||||
p = Path(norm)
|
||||
raw = p.name if is_dir else p.stem
|
||||
return _clean_query_name(raw)
|
||||
|
||||
|
||||
def _as_bool(value: Any, default: bool) -> bool:
|
||||
if value is None:
|
||||
return default
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if isinstance(value, int):
|
||||
return value != 0
|
||||
if isinstance(value, str):
|
||||
v = value.strip().lower()
|
||||
if v in {"1", "true", "yes", "y", "on"}:
|
||||
return True
|
||||
if v in {"0", "false", "no", "n", "off"}:
|
||||
return False
|
||||
return default
|
||||
|
||||
|
||||
_SXXEYY_RE = re.compile(r"[Ss](\d{1,2})\s*[.\-_ ]*\s*[Ee](\d{1,3})")
|
||||
_X_RE = re.compile(r"(\d{1,2})x(\d{1,3})", re.IGNORECASE)
|
||||
_CN_EP_RE = re.compile(r"第\s*(\d{1,3})\s*[集话]")
|
||||
_CN_SEASON_RE = re.compile(r"第\s*(\d{1,2})\s*季")
|
||||
_SEASON_WORD_RE = re.compile(r"Season\s*(\d{1,2})", re.IGNORECASE)
|
||||
_S_RE = re.compile(r"[Ss](\d{1,2})")
|
||||
|
||||
|
||||
def _parse_season_episode(rel_path: str) -> Tuple[Optional[int], Optional[int]]:
|
||||
stem = Path(rel_path).stem
|
||||
|
||||
m = _SXXEYY_RE.search(stem) or _SXXEYY_RE.search(rel_path)
|
||||
if m:
|
||||
return int(m.group(1)), int(m.group(2))
|
||||
|
||||
m = _X_RE.search(stem)
|
||||
if m:
|
||||
return int(m.group(1)), int(m.group(2))
|
||||
|
||||
m = _CN_EP_RE.search(stem)
|
||||
if m:
|
||||
episode = int(m.group(1))
|
||||
season = None
|
||||
for part in reversed(Path(rel_path).parts[:-1]):
|
||||
sm = _CN_SEASON_RE.search(part) or _SEASON_WORD_RE.search(part) or _S_RE.search(part)
|
||||
if sm:
|
||||
season = int(sm.group(1))
|
||||
break
|
||||
return season or 1, episode
|
||||
|
||||
m = re.match(r"^(\d{1,3})(?!\d)", stem)
|
||||
if m:
|
||||
episode = int(m.group(1))
|
||||
season = None
|
||||
for part in reversed(Path(rel_path).parts[:-1]):
|
||||
sm = _CN_SEASON_RE.search(part) or _SEASON_WORD_RE.search(part) or _S_RE.search(part)
|
||||
if sm:
|
||||
season = int(sm.group(1))
|
||||
break
|
||||
return season or 1, episode
|
||||
|
||||
return None, None
|
||||
|
||||
|
||||
class TMDBClient:
|
||||
def __init__(self, access_token: str | None, api_key: str | None):
|
||||
self._access_token = access_token
|
||||
self._api_key = api_key
|
||||
|
||||
@classmethod
|
||||
def from_env(cls) -> "TMDBClient":
|
||||
access_token = os.getenv("TMDB_ACCESS_TOKEN")
|
||||
api_key = os.getenv("TMDB_API_KEY")
|
||||
if not access_token and not api_key:
|
||||
raise RuntimeError("缺少 TMDB_ACCESS_TOKEN 或 TMDB_API_KEY")
|
||||
return cls(access_token=access_token, api_key=api_key)
|
||||
|
||||
def _headers(self) -> dict:
|
||||
headers = {"Accept": "application/json"}
|
||||
if self._access_token:
|
||||
headers["Authorization"] = f"Bearer {self._access_token}"
|
||||
return headers
|
||||
|
||||
def _merge_params(self, params: dict) -> dict:
|
||||
merged = dict(params or {})
|
||||
if self._api_key:
|
||||
merged.setdefault("api_key", self._api_key)
|
||||
return merged
|
||||
|
||||
async def get(self, path: str, params: dict) -> dict:
|
||||
url = f"{TMDB_BASE_URL}{path}"
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
resp = await client.get(url, headers=self._headers(), params=self._merge_params(params))
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
class VideoLibraryProcessor:
|
||||
name = "影视入库"
|
||||
supported_exts = sorted(VIDEO_EXT)
|
||||
config_schema = [
|
||||
{
|
||||
"key": "name",
|
||||
"label": "手动名称(可选)",
|
||||
"type": "string",
|
||||
"required": False,
|
||||
"placeholder": "留空则从路径提取",
|
||||
},
|
||||
{
|
||||
"key": "language",
|
||||
"label": "语言",
|
||||
"type": "string",
|
||||
"required": False,
|
||||
"default": "zh-CN",
|
||||
},
|
||||
{
|
||||
"key": "include_episodes",
|
||||
"label": "电视剧:保存每集",
|
||||
"type": "select",
|
||||
"required": False,
|
||||
"default": 1,
|
||||
"options": [
|
||||
{"label": "是", "value": 1},
|
||||
{"label": "否", "value": 0},
|
||||
],
|
||||
},
|
||||
]
|
||||
produces_file = False
|
||||
supports_directory = True
|
||||
requires_input_bytes = False
|
||||
|
||||
async def process(self, input_bytes: bytes, path: str, config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
tmdb = TMDBClient.from_env()
|
||||
is_dir = await VirtualFSService.path_is_directory(path)
|
||||
language = str(config.get("language") or "zh-CN")
|
||||
manual_name = str(config.get("name") or "").strip()
|
||||
query_name = manual_name or _guess_name_from_path(path, is_dir=is_dir)
|
||||
scraped_at = datetime.now(UTC).isoformat()
|
||||
|
||||
if is_dir:
|
||||
payload, saved_to = await self._process_tv_dir(tmdb, path, query_name, language, scraped_at, config)
|
||||
return {
|
||||
"ok": True,
|
||||
"type": "tv",
|
||||
"path": path,
|
||||
"tmdb_id": payload.get("tmdb", {}).get("id"),
|
||||
"saved_to": str(saved_to),
|
||||
}
|
||||
|
||||
payload, saved_to = await self._process_movie_file(tmdb, path, query_name, language, scraped_at)
|
||||
return {
|
||||
"ok": True,
|
||||
"type": "movie",
|
||||
"path": path,
|
||||
"tmdb_id": payload.get("tmdb", {}).get("id"),
|
||||
"saved_to": str(saved_to),
|
||||
}
|
||||
|
||||
async def _process_movie_file(
|
||||
self,
|
||||
tmdb: TMDBClient,
|
||||
path: str,
|
||||
query_name: str,
|
||||
language: str,
|
||||
scraped_at: str,
|
||||
) -> Tuple[dict, Path]:
|
||||
search = await tmdb.get("/search/movie", {"query": query_name, "language": language})
|
||||
results = search.get("results") or []
|
||||
if not results:
|
||||
raise RuntimeError(f"未找到电影条目:{query_name}")
|
||||
|
||||
chosen = results[0] or {}
|
||||
movie_id = chosen.get("id")
|
||||
if not movie_id:
|
||||
raise RuntimeError("TMDB 搜索结果缺少 id")
|
||||
|
||||
detail = await tmdb.get(
|
||||
f"/movie/{movie_id}",
|
||||
{
|
||||
"language": language,
|
||||
"append_to_response": "credits,images,external_ids,videos",
|
||||
},
|
||||
)
|
||||
|
||||
payload = {
|
||||
"type": "movie",
|
||||
"source_path": path,
|
||||
"query": {"name": query_name, "language": language},
|
||||
"scraped_at": scraped_at,
|
||||
"tmdb": {
|
||||
"id": movie_id,
|
||||
"search": {"page": search.get("page"), "total_results": search.get("total_results"), "results": results[:5]},
|
||||
"detail": detail,
|
||||
},
|
||||
}
|
||||
saved_to = _store_path("movie", path)
|
||||
_write_json(saved_to, payload)
|
||||
return payload, saved_to
|
||||
|
||||
async def _process_tv_dir(
|
||||
self,
|
||||
tmdb: TMDBClient,
|
||||
path: str,
|
||||
query_name: str,
|
||||
language: str,
|
||||
scraped_at: str,
|
||||
config: Dict[str, Any],
|
||||
) -> Tuple[dict, Path]:
|
||||
search = await tmdb.get("/search/tv", {"query": query_name, "language": language})
|
||||
results = search.get("results") or []
|
||||
if not results:
|
||||
raise RuntimeError(f"未找到电视剧条目:{query_name}")
|
||||
|
||||
chosen = results[0] or {}
|
||||
tv_id = chosen.get("id")
|
||||
if not tv_id:
|
||||
raise RuntimeError("TMDB 搜索结果缺少 id")
|
||||
|
||||
detail = await tmdb.get(
|
||||
f"/tv/{tv_id}",
|
||||
{
|
||||
"language": language,
|
||||
"append_to_response": "credits,images,external_ids,videos",
|
||||
},
|
||||
)
|
||||
|
||||
include_episodes = _as_bool(config.get("include_episodes"), True)
|
||||
episodes: List[dict] = []
|
||||
seasons_detail: Dict[str, Any] = {}
|
||||
if include_episodes:
|
||||
episodes = await self._collect_episode_files(path)
|
||||
seasons = sorted({ep["season"] for ep in episodes if ep.get("season") is not None})
|
||||
for season in seasons:
|
||||
seasons_detail[str(season)] = await tmdb.get(
|
||||
f"/tv/{tv_id}/season/{int(season)}",
|
||||
{"language": language},
|
||||
)
|
||||
self._attach_tmdb_episode_detail(episodes, seasons_detail)
|
||||
|
||||
payload = {
|
||||
"type": "tv",
|
||||
"source_path": path,
|
||||
"query": {"name": query_name, "language": language},
|
||||
"scraped_at": scraped_at,
|
||||
"tmdb": {
|
||||
"id": tv_id,
|
||||
"search": {"page": search.get("page"), "total_results": search.get("total_results"), "results": results[:5]},
|
||||
"detail": detail,
|
||||
"seasons": seasons_detail,
|
||||
},
|
||||
"episodes": episodes,
|
||||
}
|
||||
|
||||
saved_to = _store_path("tv", path)
|
||||
_write_json(saved_to, payload)
|
||||
return payload, saved_to
|
||||
|
||||
async def _collect_episode_files(self, dir_path: str) -> List[dict]:
|
||||
adapter_instance, adapter_model, root, rel = await VirtualFSService.resolve_adapter_and_rel(dir_path)
|
||||
rel = rel.rstrip("/")
|
||||
list_dir = await VirtualFSService._ensure_method(adapter_instance, "list_dir")
|
||||
|
||||
stack: List[str] = [rel]
|
||||
page_size = 200
|
||||
out: List[dict] = []
|
||||
|
||||
while stack:
|
||||
current_rel = stack.pop()
|
||||
page = 1
|
||||
while True:
|
||||
entries, total = await list_dir(root, current_rel, page, page_size, "name", "asc")
|
||||
entries = entries or []
|
||||
if not entries and (total or 0) == 0:
|
||||
break
|
||||
|
||||
for entry in entries:
|
||||
name = entry.get("name")
|
||||
if not name:
|
||||
continue
|
||||
child_rel = VirtualFSService._join_rel(current_rel, name)
|
||||
if entry.get("is_dir"):
|
||||
stack.append(child_rel.rstrip("/"))
|
||||
continue
|
||||
if not is_video_filename(name):
|
||||
continue
|
||||
|
||||
absolute_path = VirtualFSService._build_absolute_path(adapter_model.path, child_rel)
|
||||
rel_in_show = child_rel
|
||||
if rel and child_rel.startswith(rel.rstrip("/") + "/"):
|
||||
rel_in_show = child_rel[len(rel.rstrip("/")) + 1 :]
|
||||
|
||||
season, episode = _parse_season_episode(rel_in_show)
|
||||
out.append(
|
||||
{
|
||||
"path": absolute_path,
|
||||
"rel": rel_in_show,
|
||||
"name": name,
|
||||
"size": entry.get("size"),
|
||||
"mtime": entry.get("mtime"),
|
||||
"season": season,
|
||||
"episode": episode,
|
||||
}
|
||||
)
|
||||
|
||||
if total is None or page * page_size >= total:
|
||||
break
|
||||
page += 1
|
||||
|
||||
return out
|
||||
|
||||
def _attach_tmdb_episode_detail(self, episodes: List[dict], seasons_detail: Dict[str, Any]) -> None:
|
||||
episode_maps: Dict[str, Dict[int, Any]] = {}
|
||||
for season_str, season_payload in (seasons_detail or {}).items():
|
||||
items = (season_payload or {}).get("episodes") or []
|
||||
m: Dict[int, Any] = {}
|
||||
for item in items:
|
||||
try:
|
||||
number = int(item.get("episode_number"))
|
||||
except Exception:
|
||||
continue
|
||||
m[number] = item
|
||||
episode_maps[season_str] = m
|
||||
|
||||
for ep in episodes:
|
||||
season = ep.get("season")
|
||||
episode = ep.get("episode")
|
||||
if season is None or episode is None:
|
||||
continue
|
||||
m = episode_maps.get(str(season))
|
||||
if not m:
|
||||
continue
|
||||
detail = m.get(int(episode))
|
||||
if detail:
|
||||
ep["tmdb_episode"] = detail
|
||||
|
||||
|
||||
PROCESSOR_TYPE = "video_library"
|
||||
PROCESSOR_NAME = VideoLibraryProcessor.name
|
||||
SUPPORTED_EXTS = VideoLibraryProcessor.supported_exts
|
||||
CONFIG_SCHEMA = VideoLibraryProcessor.config_schema
|
||||
PROCESSOR_FACTORY = lambda: VideoLibraryProcessor()
|
||||
@@ -5,7 +5,7 @@ from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
from domain.processors.base import BaseProcessor
|
||||
from .base import BaseProcessor
|
||||
|
||||
ProcessorFactory = Callable[[], BaseProcessor]
|
||||
TYPE_MAP: Dict[str, ProcessorFactory] = {}
|
||||
@@ -16,7 +16,7 @@ LAST_DISCOVERY_ERRORS: list[str] = []
|
||||
|
||||
def discover_processors(force_reload: bool = False) -> list[str]:
|
||||
"""扫描并缓存可用的处理器模块。"""
|
||||
from domain.processors import builtin as processors_pkg
|
||||
from . import builtin as processors_pkg
|
||||
|
||||
TYPE_MAP.clear()
|
||||
CONFIG_SCHEMAS.clear()
|
||||
|
||||
@@ -3,20 +3,20 @@ from typing import List, Tuple
|
||||
|
||||
from fastapi import HTTPException
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from domain.processors.registry import (
|
||||
from domain.tasks import task_queue_service
|
||||
from domain.virtual_fs import VirtualFSService
|
||||
from .registry import (
|
||||
get,
|
||||
get_config_schema,
|
||||
get_config_schemas,
|
||||
get_module_path,
|
||||
reload_processors,
|
||||
)
|
||||
from domain.processors.types import (
|
||||
from .types import (
|
||||
ProcessDirectoryRequest,
|
||||
ProcessRequest,
|
||||
UpdateSourceRequest,
|
||||
)
|
||||
from domain.virtual_fs.service import VirtualFSService
|
||||
from domain.tasks.task_queue import task_queue_service
|
||||
|
||||
|
||||
class ProcessorService:
|
||||
|
||||
1
domain/repositories/__init__.py
Normal file
1
domain/repositories/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
__all__: list[str] = []
|
||||
10
domain/share/__init__.py
Normal file
10
domain/share/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from .service import ShareService
|
||||
from .types import ShareCreate, ShareInfo, ShareInfoWithPassword, SharePassword
|
||||
|
||||
__all__ = [
|
||||
"ShareService",
|
||||
"ShareCreate",
|
||||
"ShareInfo",
|
||||
"ShareInfoWithPassword",
|
||||
"SharePassword",
|
||||
]
|
||||
@@ -4,10 +4,9 @@ from fastapi import APIRouter, Depends, Request
|
||||
|
||||
from api.response import success
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.auth.service import get_current_active_user
|
||||
from domain.auth.types import User
|
||||
from domain.share.service import ShareService
|
||||
from domain.share.types import (
|
||||
from domain.auth import User, get_current_active_user
|
||||
from .service import ShareService
|
||||
from .types import (
|
||||
ShareCreate,
|
||||
ShareInfo,
|
||||
ShareInfoWithPassword,
|
||||
|
||||
@@ -7,7 +7,7 @@ import bcrypt
|
||||
from fastapi import HTTPException, status
|
||||
from fastapi.responses import Response
|
||||
|
||||
from domain.virtual_fs.service import VirtualFSService
|
||||
from domain.virtual_fs import VirtualFSService
|
||||
from models.database import ShareLink, UserAccount
|
||||
|
||||
|
||||
|
||||
24
domain/tasks/__init__.py
Normal file
24
domain/tasks/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from .service import TaskService
|
||||
from .task_queue import Task, TaskProgress, TaskStatus, task_queue_service
|
||||
from .types import (
|
||||
AutomationTaskBase,
|
||||
AutomationTaskCreate,
|
||||
AutomationTaskRead,
|
||||
AutomationTaskUpdate,
|
||||
TaskQueueSettings,
|
||||
TaskQueueSettingsResponse,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"TaskService",
|
||||
"Task",
|
||||
"TaskProgress",
|
||||
"TaskStatus",
|
||||
"task_queue_service",
|
||||
"AutomationTaskBase",
|
||||
"AutomationTaskCreate",
|
||||
"AutomationTaskRead",
|
||||
"AutomationTaskUpdate",
|
||||
"TaskQueueSettings",
|
||||
"TaskQueueSettingsResponse",
|
||||
]
|
||||
@@ -2,9 +2,9 @@ from fastapi import APIRouter, Depends, Request
|
||||
|
||||
from api.response import success
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.auth.service import get_current_active_user
|
||||
from domain.tasks.service import TaskService
|
||||
from domain.tasks.types import (
|
||||
from domain.auth import get_current_active_user
|
||||
from .service import TaskService
|
||||
from .types import (
|
||||
AutomationTaskCreate,
|
||||
AutomationTaskUpdate,
|
||||
TaskQueueSettings,
|
||||
|
||||
@@ -3,17 +3,16 @@ from typing import Annotated, Any, Dict, Optional
|
||||
|
||||
from fastapi import Depends, HTTPException
|
||||
|
||||
from domain.auth.service import get_current_active_user
|
||||
from domain.auth.types import User
|
||||
from domain.config.service import ConfigService
|
||||
from domain.tasks.types import (
|
||||
from domain.auth import User, get_current_active_user
|
||||
from domain.config import ConfigService
|
||||
from .task_queue import task_queue_service
|
||||
from .types import (
|
||||
AutomationTaskCreate,
|
||||
AutomationTaskUpdate,
|
||||
TaskQueueSettings,
|
||||
TaskQueueSettingsResponse,
|
||||
)
|
||||
from models.database import AutomationTask
|
||||
from domain.tasks.task_queue import task_queue_service
|
||||
|
||||
|
||||
class TaskService:
|
||||
|
||||
@@ -74,7 +74,7 @@ class TaskQueueService:
|
||||
|
||||
try:
|
||||
# Local import to avoid circular dependency during module load.
|
||||
from domain.virtual_fs.service import VirtualFSService
|
||||
from domain.virtual_fs import VirtualFSService
|
||||
|
||||
if task.name == "process_file":
|
||||
params = task.task_info
|
||||
@@ -88,7 +88,7 @@ class TaskQueueService:
|
||||
task.result = result
|
||||
elif task.name == "automation_task" or self._is_processor_task(task.name):
|
||||
from models.database import AutomationTask
|
||||
from domain.processors.service import get_processor
|
||||
from domain.processors import get_processor
|
||||
|
||||
params = task.task_info
|
||||
auto_task = await AutomationTask.get(id=params["task_id"])
|
||||
@@ -116,7 +116,7 @@ class TaskQueueService:
|
||||
await VirtualFSService.write_file(save_to, result)
|
||||
task.result = "Automation task completed"
|
||||
elif task.name == "offline_http_download":
|
||||
from domain.offline_downloads.service import OfflineDownloadService
|
||||
from domain.offline_downloads import OfflineDownloadService
|
||||
|
||||
result_path = await OfflineDownloadService.run_http_download(task)
|
||||
task.result = {"path": result_path}
|
||||
@@ -124,7 +124,7 @@ class TaskQueueService:
|
||||
result = await VirtualFSService.run_cross_mount_transfer_task(task)
|
||||
task.result = result
|
||||
elif task.name == "send_email":
|
||||
from domain.email.service import EmailService
|
||||
from domain.email import EmailService
|
||||
await EmailService.send_from_task(task.id, task.task_info)
|
||||
task.result = "Email sent"
|
||||
else:
|
||||
@@ -141,7 +141,7 @@ class TaskQueueService:
|
||||
|
||||
def _is_processor_task(self, task_name: str) -> bool:
|
||||
try:
|
||||
from domain.processors.service import get_processor
|
||||
from domain.processors import get_processor
|
||||
|
||||
return get_processor(task_name) is not None
|
||||
except Exception:
|
||||
@@ -180,7 +180,7 @@ class TaskQueueService:
|
||||
|
||||
async def start_worker(self, concurrency: int | None = None):
|
||||
if concurrency is None:
|
||||
from domain.config.service import ConfigService
|
||||
from domain.config import ConfigService
|
||||
|
||||
stored_value = await ConfigService.get("TASK_QUEUE_CONCURRENCY", self._concurrency)
|
||||
try:
|
||||
|
||||
11
domain/virtual_fs/__init__.py
Normal file
11
domain/virtual_fs/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from .service import VirtualFSService
|
||||
from .types import DirListing, MkdirRequest, MoveRequest, SearchResultItem, VfsEntry
|
||||
|
||||
__all__ = [
|
||||
"VirtualFSService",
|
||||
"DirListing",
|
||||
"MkdirRequest",
|
||||
"MoveRequest",
|
||||
"SearchResultItem",
|
||||
"VfsEntry",
|
||||
]
|
||||
@@ -4,10 +4,9 @@ from fastapi import APIRouter, Depends, File, Query, Request, UploadFile
|
||||
|
||||
from api.response import success
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.auth.service import get_current_active_user
|
||||
from domain.auth.types import User
|
||||
from domain.virtual_fs.service import VirtualFSService
|
||||
from domain.virtual_fs.types import MkdirRequest, MoveRequest
|
||||
from domain.auth import User, get_current_active_user
|
||||
from .service import VirtualFSService
|
||||
from .types import MkdirRequest, MoveRequest
|
||||
|
||||
router = APIRouter(prefix="/api/fs", tags=["virtual-fs"])
|
||||
|
||||
|
||||
@@ -4,8 +4,8 @@ from typing import Any, AsyncIterator, Union
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import Response
|
||||
|
||||
from domain.tasks.service import TaskService
|
||||
from domain.virtual_fs.thumbnail import is_raw_filename
|
||||
from domain.tasks import TaskService
|
||||
from .thumbnail import is_raw_filename, raw_bytes_to_jpeg
|
||||
|
||||
from .listing import VirtualFSListingMixin
|
||||
|
||||
@@ -82,32 +82,9 @@ class VirtualFSFileOpsMixin(VirtualFSListingMixin):
|
||||
if not rel or rel.endswith("/"):
|
||||
raise HTTPException(400, detail="Path is a directory")
|
||||
if is_raw_filename(rel):
|
||||
import io
|
||||
|
||||
import rawpy
|
||||
from PIL import Image
|
||||
|
||||
try:
|
||||
raw_data = await cls.read_file(path)
|
||||
try:
|
||||
with rawpy.imread(io.BytesIO(raw_data)) as raw:
|
||||
try:
|
||||
thumb = raw.extract_thumb()
|
||||
except rawpy.LibRawNoThumbnailError:
|
||||
thumb = None
|
||||
|
||||
if thumb is not None and thumb.format in [rawpy.ThumbFormat.JPEG, rawpy.ThumbFormat.BITMAP]:
|
||||
im = Image.open(io.BytesIO(thumb.data))
|
||||
else:
|
||||
rgb = raw.postprocess(use_camera_wb=False, use_auto_wb=True, output_bps=8)
|
||||
im = Image.fromarray(rgb)
|
||||
except Exception as exc:
|
||||
print(f"rawpy processing failed: {exc}")
|
||||
raise exc
|
||||
|
||||
buf = io.BytesIO()
|
||||
im.save(buf, "JPEG", quality=90)
|
||||
content = buf.getvalue()
|
||||
content = raw_bytes_to_jpeg(raw_data, filename=rel)
|
||||
return Response(content=content, media_type="image/jpeg")
|
||||
except Exception as exc:
|
||||
raise HTTPException(500, detail=f"RAW file processing failed: {exc}")
|
||||
|
||||
@@ -3,9 +3,9 @@ from typing import Any, Dict, List, Tuple
|
||||
from fastapi import HTTPException
|
||||
|
||||
from api.response import page
|
||||
from domain.adapters.registry import runtime_registry
|
||||
from domain.ai.service import VectorDBService
|
||||
from domain.virtual_fs.thumbnail import is_image_filename, is_video_filename
|
||||
from domain.adapters import runtime_registry
|
||||
from domain.ai import FILE_COLLECTION_NAME, VECTOR_COLLECTION_NAME, VectorDBService
|
||||
from .thumbnail import is_image_filename, is_video_filename
|
||||
from models import StorageAdapter
|
||||
|
||||
from .resolver import VirtualFSResolverMixin
|
||||
@@ -161,13 +161,19 @@ class VirtualFSListingMixin(VirtualFSResolverMixin):
|
||||
@classmethod
|
||||
async def _gather_vector_index(cls, full_path: str, limit: int = 20):
|
||||
vector_db = VectorDBService()
|
||||
try:
|
||||
raw_results = await vector_db.search_by_path("vector_collection", full_path, max(limit * 2, 20))
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
matched = []
|
||||
if raw_results:
|
||||
had_success = False
|
||||
fetch_limit = max(limit * 2, 20)
|
||||
for collection_name in (VECTOR_COLLECTION_NAME, FILE_COLLECTION_NAME):
|
||||
try:
|
||||
raw_results = await vector_db.search_by_path(collection_name, full_path, fetch_limit)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if not raw_results:
|
||||
had_success = True
|
||||
continue
|
||||
had_success = True
|
||||
buckets = raw_results if isinstance(raw_results, list) else [raw_results]
|
||||
for bucket in buckets:
|
||||
if not bucket:
|
||||
@@ -193,6 +199,9 @@ class VirtualFSListingMixin(VirtualFSResolverMixin):
|
||||
entry["preview_truncated"] = len(text) > preview_limit
|
||||
matched.append(entry)
|
||||
|
||||
if not had_success:
|
||||
return None
|
||||
|
||||
if not matched:
|
||||
return {"total": 0, "entries": [], "by_type": {}, "has_more": False}
|
||||
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
__all__: list[str] = []
|
||||
|
||||
@@ -2,14 +2,21 @@ import base64
|
||||
import datetime as dt
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import uuid
|
||||
from typing import Dict, Iterable, List, Optional, Tuple
|
||||
import xml.etree.ElementTree as ET
|
||||
from typing import Any, AsyncIterator, Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
import aiofiles
|
||||
from fastapi import APIRouter, Request, Response
|
||||
from fastapi import HTTPException
|
||||
|
||||
from domain.config.service import ConfigService
|
||||
from domain.virtual_fs.service import VirtualFSService
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.config import ConfigService
|
||||
from domain.virtual_fs import VirtualFSService
|
||||
|
||||
|
||||
router = APIRouter(prefix="/s3", tags=["s3"])
|
||||
@@ -18,6 +25,12 @@ router = APIRouter(prefix="/s3", tags=["s3"])
|
||||
FALSEY = {"0", "false", "off", "no"}
|
||||
_XML_NS = "http://s3.amazonaws.com/doc/2006-03-01/"
|
||||
|
||||
_MPU_ROOT = "data/s3_multipart"
|
||||
_MPU_META_NAME = "meta.json"
|
||||
_MPU_PART_DATA_TMPL = "part-{part_number:06d}.bin"
|
||||
_MPU_PART_META_TMPL = "part-{part_number:06d}.json"
|
||||
_MPU_PART_META_RE = re.compile(r"^part-(\d{6})\.json$")
|
||||
|
||||
|
||||
class S3Settings(Dict[str, str]):
|
||||
bucket: str
|
||||
@@ -119,42 +132,136 @@ def _sign(key: bytes, msg: str) -> bytes:
|
||||
|
||||
async def _authorize_sigv4(request: Request, settings: S3Settings) -> Optional[Response]:
|
||||
auth = request.headers.get("authorization")
|
||||
if not auth:
|
||||
return _s3_error("AccessDenied", "Missing Authorization header", status=403)
|
||||
scheme = "AWS4-HMAC-SHA256"
|
||||
if not auth.startswith(scheme + " "):
|
||||
if auth:
|
||||
if not auth.startswith(scheme + " "):
|
||||
return _s3_error("InvalidRequest", "Signature Version 4 is required", status=400)
|
||||
|
||||
parts: Dict[str, str] = {}
|
||||
for segment in auth[len(scheme) + 1 :].split(","):
|
||||
k, _, v = segment.strip().partition("=")
|
||||
parts[k] = v
|
||||
|
||||
credential = parts.get("Credential")
|
||||
signed_headers = parts.get("SignedHeaders")
|
||||
signature = parts.get("Signature")
|
||||
if not credential or not signed_headers or not signature:
|
||||
return _s3_error("InvalidRequest", "Authorization header is malformed", status=400)
|
||||
|
||||
cred_parts = credential.split("/")
|
||||
if len(cred_parts) != 5 or cred_parts[-1] != "aws4_request":
|
||||
return _s3_error("InvalidRequest", "Credential scope is invalid", status=400)
|
||||
|
||||
access_key, datestamp, region, service, _ = cred_parts
|
||||
if access_key != settings["access_key"]:
|
||||
return _s3_error(
|
||||
"InvalidAccessKeyId",
|
||||
"The AWS Access Key Id you provided does not exist in our records.",
|
||||
status=403,
|
||||
)
|
||||
if service != "s3":
|
||||
return _s3_error("InvalidRequest", "Only service 's3' is supported", status=400)
|
||||
if settings.get("region") and region != settings["region"]:
|
||||
return _s3_error("AuthorizationHeaderMalformed", f"Region '{region}' is invalid", status=400)
|
||||
|
||||
amz_date = request.headers.get("x-amz-date")
|
||||
if not amz_date or not amz_date.startswith(datestamp):
|
||||
return _s3_error("AuthorizationHeaderMalformed", "x-amz-date does not match credential scope", status=400)
|
||||
|
||||
payload_hash = request.headers.get("x-amz-content-sha256")
|
||||
if not payload_hash:
|
||||
return _s3_error("AuthorizationHeaderMalformed", "Missing x-amz-content-sha256", status=400)
|
||||
if payload_hash.upper().startswith("STREAMING-AWS4-HMAC-SHA256"):
|
||||
return _s3_error("NotImplemented", "Chunked uploads are not supported", status=400)
|
||||
|
||||
signed_header_names = [h.strip().lower() for h in signed_headers.split(";") if h.strip()]
|
||||
headers = {k.lower(): v for k, v in request.headers.items()}
|
||||
canonical_headers = []
|
||||
for name in signed_header_names:
|
||||
value = headers.get(name)
|
||||
if value is None:
|
||||
return _s3_error("AuthorizationHeaderMalformed", f"Signed header '{name}' missing", status=400)
|
||||
canonical_headers.append(f"{name}:{_normalize_ws(value)}\n")
|
||||
|
||||
canonical_request = "\n".join(
|
||||
[
|
||||
request.method,
|
||||
_canonical_uri(request.url.path),
|
||||
_canonical_query(request.query_params.multi_items()),
|
||||
"".join(canonical_headers),
|
||||
";".join(signed_header_names),
|
||||
payload_hash,
|
||||
]
|
||||
)
|
||||
|
||||
hashed_request = hashlib.sha256(canonical_request.encode("utf-8")).hexdigest()
|
||||
scope = "/".join([datestamp, region, "s3", "aws4_request"])
|
||||
string_to_sign = "\n".join([scheme, amz_date, scope, hashed_request])
|
||||
|
||||
k_date = _sign(("AWS4" + settings["secret_key"]).encode("utf-8"), datestamp)
|
||||
k_region = hmac.new(k_date, region.encode("utf-8"), hashlib.sha256).digest()
|
||||
k_service = hmac.new(k_region, b"s3", hashlib.sha256).digest()
|
||||
k_signing = hmac.new(k_service, b"aws4_request", hashlib.sha256).digest()
|
||||
expected = hmac.new(k_signing, string_to_sign.encode("utf-8"), hashlib.sha256).hexdigest()
|
||||
if expected != signature:
|
||||
return _s3_error(
|
||||
"SignatureDoesNotMatch",
|
||||
"The request signature we calculated does not match the signature you provided.",
|
||||
status=403,
|
||||
)
|
||||
return None
|
||||
|
||||
params = request.query_params
|
||||
q_multi = params.multi_items()
|
||||
q_lower = {k.lower(): v for k, v in q_multi}
|
||||
signature = q_lower.get("x-amz-signature")
|
||||
if not signature:
|
||||
return _s3_error("AccessDenied", "Missing Authorization header", status=403)
|
||||
|
||||
algorithm = q_lower.get("x-amz-algorithm")
|
||||
if not algorithm or algorithm != scheme:
|
||||
return _s3_error("InvalidRequest", "Signature Version 4 is required", status=400)
|
||||
|
||||
parts: Dict[str, str] = {}
|
||||
for segment in auth[len(scheme) + 1 :].split(","):
|
||||
k, _, v = segment.strip().partition("=")
|
||||
parts[k] = v
|
||||
|
||||
credential = parts.get("Credential")
|
||||
signed_headers = parts.get("SignedHeaders")
|
||||
signature = parts.get("Signature")
|
||||
if not credential or not signed_headers or not signature:
|
||||
return _s3_error("InvalidRequest", "Authorization header is malformed", status=400)
|
||||
credential = q_lower.get("x-amz-credential")
|
||||
signed_headers = q_lower.get("x-amz-signedheaders")
|
||||
amz_date = q_lower.get("x-amz-date")
|
||||
expires_raw = q_lower.get("x-amz-expires")
|
||||
if not credential or not signed_headers or not amz_date:
|
||||
return _s3_error("AuthorizationQueryParametersError", "Query-string authentication is malformed", status=400)
|
||||
|
||||
cred_parts = credential.split("/")
|
||||
if len(cred_parts) != 5 or cred_parts[-1] != "aws4_request":
|
||||
return _s3_error("InvalidRequest", "Credential scope is invalid", status=400)
|
||||
return _s3_error("AuthorizationQueryParametersError", "Credential scope is invalid", status=400)
|
||||
|
||||
access_key, datestamp, region, service, _ = cred_parts
|
||||
if access_key != settings["access_key"]:
|
||||
return _s3_error("InvalidAccessKeyId", "The AWS Access Key Id you provided does not exist in our records.", status=403)
|
||||
return _s3_error(
|
||||
"InvalidAccessKeyId",
|
||||
"The AWS Access Key Id you provided does not exist in our records.",
|
||||
status=403,
|
||||
)
|
||||
if service != "s3":
|
||||
return _s3_error("InvalidRequest", "Only service 's3' is supported", status=400)
|
||||
if settings.get("region") and region != settings["region"]:
|
||||
return _s3_error("AuthorizationHeaderMalformed", f"Region '{region}' is invalid", status=400)
|
||||
|
||||
amz_date = request.headers.get("x-amz-date")
|
||||
if not amz_date or not amz_date.startswith(datestamp):
|
||||
return _s3_error("AuthorizationHeaderMalformed", "x-amz-date does not match credential scope", status=400)
|
||||
if not amz_date.startswith(datestamp):
|
||||
return _s3_error("AuthorizationQueryParametersError", "X-Amz-Date does not match credential scope", status=400)
|
||||
|
||||
payload_hash = request.headers.get("x-amz-content-sha256")
|
||||
if not payload_hash:
|
||||
return _s3_error("AuthorizationHeaderMalformed", "Missing x-amz-content-sha256", status=400)
|
||||
if expires_raw:
|
||||
try:
|
||||
expires = int(expires_raw)
|
||||
except ValueError:
|
||||
expires = 0
|
||||
if expires > 0:
|
||||
try:
|
||||
signed_at = dt.datetime.strptime(amz_date, "%Y%m%dT%H%M%SZ")
|
||||
if dt.datetime.utcnow() > signed_at + dt.timedelta(seconds=expires):
|
||||
return _s3_error("AccessDenied", "Request has expired", status=403)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
payload_hash = request.headers.get("x-amz-content-sha256") or "UNSIGNED-PAYLOAD"
|
||||
if payload_hash.upper().startswith("STREAMING-AWS4-HMAC-SHA256"):
|
||||
return _s3_error("NotImplemented", "Chunked uploads are not supported", status=400)
|
||||
|
||||
@@ -164,14 +271,15 @@ async def _authorize_sigv4(request: Request, settings: S3Settings) -> Optional[R
|
||||
for name in signed_header_names:
|
||||
value = headers.get(name)
|
||||
if value is None:
|
||||
return _s3_error("AuthorizationHeaderMalformed", f"Signed header '{name}' missing", status=400)
|
||||
return _s3_error("AuthorizationQueryParametersError", f"Signed header '{name}' missing", status=400)
|
||||
canonical_headers.append(f"{name}:{_normalize_ws(value)}\n")
|
||||
|
||||
canonical_query_items = [(k, v) for k, v in q_multi if k.lower() != "x-amz-signature"]
|
||||
canonical_request = "\n".join(
|
||||
[
|
||||
request.method,
|
||||
_canonical_uri(request.url.path),
|
||||
_canonical_query(request.query_params.multi_items()),
|
||||
_canonical_query(canonical_query_items),
|
||||
"".join(canonical_headers),
|
||||
";".join(signed_header_names),
|
||||
payload_hash,
|
||||
@@ -188,7 +296,11 @@ async def _authorize_sigv4(request: Request, settings: S3Settings) -> Optional[R
|
||||
k_signing = hmac.new(k_service, b"aws4_request", hashlib.sha256).digest()
|
||||
expected = hmac.new(k_signing, string_to_sign.encode("utf-8"), hashlib.sha256).hexdigest()
|
||||
if expected != signature:
|
||||
return _s3_error("SignatureDoesNotMatch", "The request signature we calculated does not match the signature you provided.", status=403)
|
||||
return _s3_error(
|
||||
"SignatureDoesNotMatch",
|
||||
"The request signature we calculated does not match the signature you provided.",
|
||||
status=403,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
@@ -313,7 +425,382 @@ def _resource_path(bucket: str, key: Optional[str] = None) -> str:
|
||||
return f"/s3/{bucket}"
|
||||
|
||||
|
||||
def _safe_upload_id(upload_id: Optional[str]) -> Optional[str]:
|
||||
if not upload_id:
|
||||
return None
|
||||
value = upload_id.strip()
|
||||
if not value:
|
||||
return None
|
||||
if "/" in value or "\\" in value:
|
||||
return None
|
||||
return value
|
||||
|
||||
|
||||
def _mpu_dir(upload_id: str) -> str:
|
||||
return os.path.join(_MPU_ROOT, upload_id)
|
||||
|
||||
|
||||
def _mpu_meta_path(upload_id: str) -> str:
|
||||
return os.path.join(_mpu_dir(upload_id), _MPU_META_NAME)
|
||||
|
||||
|
||||
def _mpu_part_data_path(upload_id: str, part_number: int) -> str:
|
||||
return os.path.join(_mpu_dir(upload_id), _MPU_PART_DATA_TMPL.format(part_number=part_number))
|
||||
|
||||
|
||||
def _mpu_part_meta_path(upload_id: str, part_number: int) -> str:
|
||||
return os.path.join(_mpu_dir(upload_id), _MPU_PART_META_TMPL.format(part_number=part_number))
|
||||
|
||||
|
||||
async def _read_json(path: str) -> Optional[Dict[str, Any]]:
|
||||
try:
|
||||
async with aiofiles.open(path, "r", encoding="utf-8") as f:
|
||||
raw = await f.read()
|
||||
data = json.loads(raw or "{}")
|
||||
return data if isinstance(data, dict) else None
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
async def _write_json(path: str, data: Dict[str, Any]) -> None:
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
async with aiofiles.open(path, "w", encoding="utf-8") as f:
|
||||
await f.write(json.dumps(data, ensure_ascii=False))
|
||||
|
||||
|
||||
async def _load_mpu_meta(bucket: str, key: str, upload_id: Optional[str]) -> Tuple[Optional[Dict[str, Any]], Optional[Response]]:
|
||||
safe_id = _safe_upload_id(upload_id)
|
||||
if not safe_id:
|
||||
return None, _s3_error(
|
||||
"NoSuchUpload",
|
||||
"The specified upload does not exist.",
|
||||
_resource_path(bucket, key),
|
||||
status=404,
|
||||
)
|
||||
meta = await _read_json(_mpu_meta_path(safe_id))
|
||||
if not meta or meta.get("bucket") != bucket or meta.get("key") != key:
|
||||
return None, _s3_error(
|
||||
"NoSuchUpload",
|
||||
"The specified upload does not exist.",
|
||||
_resource_path(bucket, key),
|
||||
status=404,
|
||||
)
|
||||
return meta, None
|
||||
|
||||
|
||||
def _parse_int(value: Optional[str], default: int) -> int:
|
||||
if value is None:
|
||||
return default
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
return default
|
||||
|
||||
|
||||
async def _create_multipart_upload(request: Request, settings: S3Settings, bucket: str, key: str) -> Response:
|
||||
os.makedirs(_MPU_ROOT, exist_ok=True)
|
||||
upload_id = uuid.uuid4().hex
|
||||
dir_path = _mpu_dir(upload_id)
|
||||
while True:
|
||||
try:
|
||||
os.makedirs(dir_path, exist_ok=False)
|
||||
break
|
||||
except FileExistsError:
|
||||
upload_id = uuid.uuid4().hex
|
||||
dir_path = _mpu_dir(upload_id)
|
||||
|
||||
meta = {
|
||||
"bucket": bucket,
|
||||
"key": key,
|
||||
"virtual_path": _virtual_path(settings, key),
|
||||
"initiated": _now_iso(),
|
||||
}
|
||||
await _write_json(_mpu_meta_path(upload_id), meta)
|
||||
|
||||
_, headers = _meta_headers()
|
||||
xml = (
|
||||
f"<?xml version=\"1.0\" encoding=\"UTF-8\"?>"
|
||||
f"<CreateMultipartUploadResult xmlns=\"{_XML_NS}\">"
|
||||
f"<Bucket>{bucket}</Bucket>"
|
||||
f"<Key>{key}</Key>"
|
||||
f"<UploadId>{upload_id}</UploadId>"
|
||||
f"</CreateMultipartUploadResult>"
|
||||
)
|
||||
headers.update({"Content-Type": "application/xml"})
|
||||
return Response(content=xml, media_type="application/xml", headers=headers)
|
||||
|
||||
|
||||
async def _upload_part(request: Request, bucket: str, key: str, upload_id: Optional[str], part_number_raw: Optional[str]) -> Response:
|
||||
part_number = _parse_int(part_number_raw, 0)
|
||||
if part_number <= 0:
|
||||
return _s3_error("InvalidArgument", "partNumber is invalid", _resource_path(bucket, key), status=400)
|
||||
|
||||
meta, err = await _load_mpu_meta(bucket, key, upload_id)
|
||||
if err:
|
||||
return err
|
||||
assert meta
|
||||
safe_id = _safe_upload_id(upload_id)
|
||||
assert safe_id
|
||||
|
||||
part_path = _mpu_part_data_path(safe_id, part_number)
|
||||
tmp_path = part_path + ".tmp"
|
||||
md5 = hashlib.md5()
|
||||
size = 0
|
||||
async with aiofiles.open(tmp_path, "wb") as f:
|
||||
async for chunk in request.stream():
|
||||
if not chunk:
|
||||
continue
|
||||
await f.write(chunk)
|
||||
md5.update(chunk)
|
||||
size += len(chunk)
|
||||
|
||||
etag = '"' + md5.hexdigest() + '"'
|
||||
os.replace(tmp_path, part_path)
|
||||
await _write_json(
|
||||
_mpu_part_meta_path(safe_id, part_number),
|
||||
{"PartNumber": part_number, "ETag": etag, "Size": size, "LastModified": _now_iso()},
|
||||
)
|
||||
|
||||
_, headers = _meta_headers()
|
||||
headers.update({"ETag": etag, "Content-Length": "0"})
|
||||
return Response(status_code=200, headers=headers)
|
||||
|
||||
|
||||
async def _list_parts(request: Request, settings: S3Settings, bucket: str, key: str, upload_id: Optional[str]) -> Response:
|
||||
meta, err = await _load_mpu_meta(bucket, key, upload_id)
|
||||
if err:
|
||||
return err
|
||||
assert meta
|
||||
safe_id = _safe_upload_id(upload_id)
|
||||
assert safe_id
|
||||
|
||||
dir_path = _mpu_dir(safe_id)
|
||||
part_metas: List[Dict[str, Any]] = []
|
||||
try:
|
||||
filenames = os.listdir(dir_path)
|
||||
except FileNotFoundError:
|
||||
filenames = []
|
||||
|
||||
for name in filenames:
|
||||
m = _MPU_PART_META_RE.match(name)
|
||||
if not m:
|
||||
continue
|
||||
pn = int(m.group(1))
|
||||
info = await _read_json(os.path.join(dir_path, name))
|
||||
if not info:
|
||||
continue
|
||||
info.setdefault("PartNumber", pn)
|
||||
part_metas.append(info)
|
||||
|
||||
part_metas.sort(key=lambda item: int(item.get("PartNumber") or 0))
|
||||
max_parts = max(1, min(1000, _parse_int(request.query_params.get("max-parts"), 1000)))
|
||||
marker = max(0, _parse_int(request.query_params.get("part-number-marker"), 0))
|
||||
filtered = [p for p in part_metas if int(p.get("PartNumber") or 0) > marker]
|
||||
is_truncated = len(filtered) > max_parts
|
||||
shown = filtered[:max_parts]
|
||||
next_marker = int(shown[-1]["PartNumber"]) if is_truncated and shown else 0
|
||||
|
||||
_, headers = _meta_headers()
|
||||
body = [f"<?xml version=\"1.0\" encoding=\"UTF-8\"?>", f"<ListPartsResult xmlns=\"{_XML_NS}\">"]
|
||||
body.append(f"<Bucket>{bucket}</Bucket>")
|
||||
body.append(f"<Key>{key}</Key>")
|
||||
body.append(f"<UploadId>{safe_id}</UploadId>")
|
||||
body.append(
|
||||
f"<Initiator><ID>{settings['access_key']}</ID><DisplayName>Foxel</DisplayName></Initiator>"
|
||||
)
|
||||
body.append(
|
||||
f"<Owner><ID>{settings['access_key']}</ID><DisplayName>Foxel</DisplayName></Owner>"
|
||||
)
|
||||
body.append("<StorageClass>STANDARD</StorageClass>")
|
||||
body.append(f"<PartNumberMarker>{marker}</PartNumberMarker>")
|
||||
body.append(f"<NextPartNumberMarker>{next_marker}</NextPartNumberMarker>")
|
||||
body.append(f"<MaxParts>{max_parts}</MaxParts>")
|
||||
body.append(f"<IsTruncated>{str(is_truncated).lower()}</IsTruncated>")
|
||||
for part in shown:
|
||||
pn = int(part.get("PartNumber") or 0)
|
||||
etag = part.get("ETag") or ""
|
||||
size = int(part.get("Size") or 0)
|
||||
last_modified = part.get("LastModified") or _now_iso()
|
||||
body.append(
|
||||
f"<Part><PartNumber>{pn}</PartNumber><LastModified>{last_modified}</LastModified><ETag>{etag}</ETag><Size>{size}</Size></Part>"
|
||||
)
|
||||
body.append("</ListPartsResult>")
|
||||
xml = "".join(body)
|
||||
headers.update({"Content-Type": "application/xml"})
|
||||
return Response(content=xml, media_type="application/xml", headers=headers)
|
||||
|
||||
|
||||
async def _abort_multipart_upload(bucket: str, key: str, upload_id: Optional[str]) -> Response:
|
||||
_, err = await _load_mpu_meta(bucket, key, upload_id)
|
||||
if err:
|
||||
return err
|
||||
safe_id = _safe_upload_id(upload_id)
|
||||
assert safe_id
|
||||
shutil.rmtree(_mpu_dir(safe_id), ignore_errors=True)
|
||||
_, headers = _meta_headers()
|
||||
return Response(status_code=204, headers=headers)
|
||||
|
||||
|
||||
def _parse_complete_parts(body_bytes: bytes) -> List[Tuple[int, str]]:
|
||||
if not body_bytes:
|
||||
return []
|
||||
root = ET.fromstring(body_bytes)
|
||||
parts: List[Tuple[int, str]] = []
|
||||
for part_el in root.findall(".//{*}Part"):
|
||||
pn_el = part_el.find("{*}PartNumber")
|
||||
etag_el = part_el.find("{*}ETag")
|
||||
if pn_el is None or pn_el.text is None:
|
||||
continue
|
||||
pn = _parse_int(pn_el.text.strip(), 0)
|
||||
if pn <= 0:
|
||||
continue
|
||||
etag = (etag_el.text or "").strip() if etag_el is not None else ""
|
||||
parts.append((pn, etag))
|
||||
parts.sort(key=lambda item: item[0])
|
||||
return parts
|
||||
|
||||
|
||||
async def _complete_multipart_upload(request: Request, settings: S3Settings, bucket: str, key: str, upload_id: Optional[str]) -> Response:
|
||||
meta, err = await _load_mpu_meta(bucket, key, upload_id)
|
||||
if err:
|
||||
return err
|
||||
assert meta
|
||||
safe_id = _safe_upload_id(upload_id)
|
||||
assert safe_id
|
||||
|
||||
try:
|
||||
body_bytes = await request.body()
|
||||
except Exception:
|
||||
body_bytes = b""
|
||||
|
||||
try:
|
||||
parts_req = _parse_complete_parts(body_bytes)
|
||||
except Exception:
|
||||
return _s3_error("MalformedXML", "The XML you provided was not well-formed.", _resource_path(bucket, key), status=400)
|
||||
|
||||
if not parts_req:
|
||||
return _s3_error("MalformedXML", "CompleteMultipartUpload parts missing.", _resource_path(bucket, key), status=400)
|
||||
|
||||
part_metas: List[Dict[str, Any]] = []
|
||||
for pn, _etag in parts_req:
|
||||
info = await _read_json(_mpu_part_meta_path(safe_id, pn))
|
||||
if not info:
|
||||
return _s3_error("InvalidPart", "One or more of the specified parts could not be found.", _resource_path(bucket, key), status=400)
|
||||
info.setdefault("PartNumber", pn)
|
||||
part_metas.append(info)
|
||||
|
||||
async def merged_iter() -> AsyncIterator[bytes]:
|
||||
for info in part_metas:
|
||||
pn = int(info.get("PartNumber") or 0)
|
||||
part_path = _mpu_part_data_path(safe_id, pn)
|
||||
async with aiofiles.open(part_path, "rb") as f:
|
||||
while True:
|
||||
chunk = await f.read(1024 * 1024)
|
||||
if not chunk:
|
||||
break
|
||||
yield chunk
|
||||
|
||||
await VirtualFSService.write_file_stream(meta.get("virtual_path") or _virtual_path(settings, key), merged_iter(), overwrite=True)
|
||||
|
||||
etag = ""
|
||||
if len(part_metas) == 1:
|
||||
etag = str(part_metas[0].get("ETag") or "")
|
||||
else:
|
||||
md5_bytes = bytearray()
|
||||
for info in part_metas:
|
||||
raw = str(info.get("ETag") or "").strip().strip('"')
|
||||
try:
|
||||
md5_bytes.extend(bytes.fromhex(raw))
|
||||
except ValueError:
|
||||
pass
|
||||
digest = hashlib.md5(bytes(md5_bytes)).hexdigest() if md5_bytes else hashlib.md5(b"").hexdigest()
|
||||
etag = '"' + f"{digest}-{len(part_metas)}" + '"'
|
||||
|
||||
shutil.rmtree(_mpu_dir(safe_id), ignore_errors=True)
|
||||
|
||||
_, headers = _meta_headers()
|
||||
headers.update({"Content-Type": "application/xml", "ETag": etag})
|
||||
location = str(request.url.replace(query=""))
|
||||
xml = (
|
||||
f"<?xml version=\"1.0\" encoding=\"UTF-8\"?>"
|
||||
f"<CompleteMultipartUploadResult xmlns=\"{_XML_NS}\">"
|
||||
f"<Location>{location}</Location>"
|
||||
f"<Bucket>{bucket}</Bucket>"
|
||||
f"<Key>{key}</Key>"
|
||||
f"<ETag>{etag}</ETag>"
|
||||
f"</CompleteMultipartUploadResult>"
|
||||
)
|
||||
return Response(content=xml, media_type="application/xml", headers=headers)
|
||||
|
||||
|
||||
async def _list_multipart_uploads(request: Request, settings: S3Settings, bucket: str) -> Response:
|
||||
os.makedirs(_MPU_ROOT, exist_ok=True)
|
||||
prefix = request.query_params.get("prefix") or ""
|
||||
max_uploads = max(1, min(1000, _parse_int(request.query_params.get("max-uploads"), 1000)))
|
||||
key_marker = request.query_params.get("key-marker") or ""
|
||||
upload_id_marker = request.query_params.get("upload-id-marker") or ""
|
||||
|
||||
uploads: List[Tuple[str, str, str]] = []
|
||||
try:
|
||||
ids = os.listdir(_MPU_ROOT)
|
||||
except FileNotFoundError:
|
||||
ids = []
|
||||
|
||||
for uid in ids:
|
||||
safe_id = _safe_upload_id(uid)
|
||||
if not safe_id:
|
||||
continue
|
||||
meta = await _read_json(_mpu_meta_path(safe_id))
|
||||
if not meta:
|
||||
continue
|
||||
if meta.get("bucket") != bucket:
|
||||
continue
|
||||
key = str(meta.get("key") or "")
|
||||
if prefix and not key.startswith(prefix):
|
||||
continue
|
||||
initiated = str(meta.get("initiated") or _now_iso())
|
||||
uploads.append((key, safe_id, initiated))
|
||||
|
||||
uploads.sort(key=lambda item: (item[0], item[1]))
|
||||
if key_marker:
|
||||
uploads = [
|
||||
it
|
||||
for it in uploads
|
||||
if (it[0] > key_marker) or (it[0] == key_marker and it[1] > upload_id_marker)
|
||||
]
|
||||
|
||||
is_truncated = len(uploads) > max_uploads
|
||||
shown = uploads[:max_uploads]
|
||||
next_key_marker = shown[-1][0] if is_truncated and shown else ""
|
||||
next_upload_id_marker = shown[-1][1] if is_truncated and shown else ""
|
||||
|
||||
_, headers = _meta_headers()
|
||||
body = [f"<?xml version=\"1.0\" encoding=\"UTF-8\"?>", f"<ListMultipartUploadsResult xmlns=\"{_XML_NS}\">"]
|
||||
body.append(f"<Bucket>{bucket}</Bucket>")
|
||||
body.append(f"<Prefix>{prefix}</Prefix>")
|
||||
body.append(f"<KeyMarker>{key_marker}</KeyMarker>")
|
||||
body.append(f"<UploadIdMarker>{upload_id_marker}</UploadIdMarker>")
|
||||
body.append(f"<NextKeyMarker>{next_key_marker}</NextKeyMarker>")
|
||||
body.append(f"<NextUploadIdMarker>{next_upload_id_marker}</NextUploadIdMarker>")
|
||||
body.append(f"<MaxUploads>{max_uploads}</MaxUploads>")
|
||||
body.append(f"<IsTruncated>{str(is_truncated).lower()}</IsTruncated>")
|
||||
for key, uid, initiated in shown:
|
||||
body.append(
|
||||
f"<Upload><Key>{key}</Key><UploadId>{uid}</UploadId>"
|
||||
f"<Initiator><ID>{settings['access_key']}</ID><DisplayName>Foxel</DisplayName></Initiator>"
|
||||
f"<Owner><ID>{settings['access_key']}</ID><DisplayName>Foxel</DisplayName></Owner>"
|
||||
f"<StorageClass>STANDARD</StorageClass><Initiated>{initiated}</Initiated></Upload>"
|
||||
)
|
||||
body.append("</ListMultipartUploadsResult>")
|
||||
xml = "".join(body)
|
||||
headers.update({"Content-Type": "application/xml"})
|
||||
return Response(content=xml, media_type="application/xml", headers=headers)
|
||||
|
||||
|
||||
@router.get("")
|
||||
@audit(action=AuditAction.READ, description="S3: 列出桶")
|
||||
async def list_buckets(request: Request):
|
||||
if (resp := await _ensure_enabled()) is not None:
|
||||
return resp
|
||||
@@ -336,6 +823,7 @@ async def list_buckets(request: Request):
|
||||
|
||||
|
||||
@router.get("/{bucket}")
|
||||
@audit(action=AuditAction.READ, description="S3: 列出对象")
|
||||
async def list_objects(request: Request, bucket: str):
|
||||
if (resp := await _ensure_enabled()) is not None:
|
||||
return resp
|
||||
@@ -349,6 +837,8 @@ async def list_objects(request: Request, bucket: str):
|
||||
return auth
|
||||
|
||||
params = request.query_params
|
||||
if "uploads" in params:
|
||||
return await _list_multipart_uploads(request, settings, bucket)
|
||||
if params.get("list-type", "2") != "2":
|
||||
return _s3_error("InvalidArgument", "Only ListObjectsV2 (list-type=2) is supported.", _resource_path(bucket), status=400)
|
||||
|
||||
@@ -476,12 +966,18 @@ async def _stat_object(settings: S3Settings, key: str) -> Tuple[Optional[Dict],
|
||||
|
||||
|
||||
@router.api_route("/{bucket}/{object_path:path}", methods=["GET", "HEAD"])
|
||||
@audit(action=AuditAction.DOWNLOAD, description="S3: 获取对象")
|
||||
async def object_get_head(request: Request, bucket: str, object_path: str):
|
||||
settings, error = await _ensure_bucket_and_auth(request, bucket)
|
||||
if error:
|
||||
return error
|
||||
assert settings
|
||||
key = object_path.lstrip("/")
|
||||
upload_id = request.query_params.get("uploadId") or request.query_params.get("uploadid")
|
||||
if upload_id and request.method == "GET":
|
||||
return await _list_parts(request, settings, bucket, key, upload_id)
|
||||
if upload_id and request.method == "HEAD":
|
||||
return _s3_error("MethodNotAllowed", "Method Not Allowed", _resource_path(bucket, key), status=405)
|
||||
meta, err = await _stat_object(settings, key)
|
||||
if err:
|
||||
return err
|
||||
@@ -500,12 +996,17 @@ async def object_get_head(request: Request, bucket: str, object_path: str):
|
||||
|
||||
|
||||
@router.put("/{bucket}/{object_path:path}")
|
||||
@audit(action=AuditAction.UPLOAD, description="S3: 上传对象")
|
||||
async def put_object(request: Request, bucket: str, object_path: str):
|
||||
settings, error = await _ensure_bucket_and_auth(request, bucket)
|
||||
if error:
|
||||
return error
|
||||
assert settings
|
||||
key = object_path.lstrip("/")
|
||||
upload_id = request.query_params.get("uploadId") or request.query_params.get("uploadid")
|
||||
part_number = request.query_params.get("partNumber") or request.query_params.get("partnumber")
|
||||
if upload_id and part_number:
|
||||
return await _upload_part(request, bucket, key, upload_id, part_number)
|
||||
await VirtualFSService.write_file_stream(_virtual_path(settings, key), request.stream(), overwrite=True)
|
||||
meta, err = await _stat_object(settings, key)
|
||||
if err:
|
||||
@@ -519,13 +1020,35 @@ async def put_object(request: Request, bucket: str, object_path: str):
|
||||
return Response(status_code=200, headers=headers)
|
||||
|
||||
|
||||
@router.post("/{bucket}/{object_path:path}")
|
||||
@audit(action=AuditAction.UPLOAD, description="S3: Multipart 上传")
|
||||
async def post_object(request: Request, bucket: str, object_path: str):
|
||||
settings, error = await _ensure_bucket_and_auth(request, bucket)
|
||||
if error:
|
||||
return error
|
||||
assert settings
|
||||
key = object_path.lstrip("/")
|
||||
|
||||
params = request.query_params
|
||||
upload_id = params.get("uploadId") or params.get("uploadid")
|
||||
if "uploads" in params:
|
||||
return await _create_multipart_upload(request, settings, bucket, key)
|
||||
if upload_id:
|
||||
return await _complete_multipart_upload(request, settings, bucket, key, upload_id)
|
||||
return _s3_error("InvalidRequest", "Unsupported POST operation.", _resource_path(bucket, key), status=400)
|
||||
|
||||
|
||||
@router.delete("/{bucket}/{object_path:path}")
|
||||
@audit(action=AuditAction.DELETE, description="S3: 删除对象")
|
||||
async def delete_object(request: Request, bucket: str, object_path: str):
|
||||
settings, error = await _ensure_bucket_and_auth(request, bucket)
|
||||
if error:
|
||||
return error
|
||||
assert settings
|
||||
key = object_path.lstrip("/")
|
||||
upload_id = request.query_params.get("uploadId") or request.query_params.get("uploadid")
|
||||
if upload_id:
|
||||
return await _abort_multipart_upload(bucket, key, upload_id)
|
||||
try:
|
||||
await VirtualFSService.delete_path(_virtual_path(settings, key))
|
||||
except HTTPException as exc:
|
||||
|
||||
@@ -8,10 +8,10 @@ from typing import Optional
|
||||
from fastapi import APIRouter, Request, Response, HTTPException, Depends
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
from domain.auth.service import AuthService
|
||||
from domain.auth.types import User, UserInDB
|
||||
from domain.virtual_fs.service import VirtualFSService
|
||||
from domain.config.service import ConfigService
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.auth import AuthService, User, UserInDB
|
||||
from domain.config import ConfigService
|
||||
from domain.virtual_fs import VirtualFSService
|
||||
|
||||
|
||||
_WEBDAV_ENABLED_KEY = "WEBDAV_MAPPING_ENABLED"
|
||||
@@ -141,11 +141,13 @@ def _normalize_fs_path(path: str) -> str:
|
||||
|
||||
|
||||
@router.options("/{path:path}")
|
||||
async def options_root(path: str = "", _enabled: None = Depends(_ensure_webdav_enabled)):
|
||||
@audit(action=AuditAction.READ, description="WebDAV: OPTIONS", user_kw="user")
|
||||
async def options_root(_request: Request, path: str = "", _enabled: None = Depends(_ensure_webdav_enabled)):
|
||||
return Response(status_code=200, headers=_dav_headers())
|
||||
|
||||
|
||||
@router.api_route("/{path:path}", methods=["PROPFIND"])
|
||||
@audit(action=AuditAction.READ, description="WebDAV: PROPFIND", user_kw="user")
|
||||
async def propfind(
|
||||
request: Request,
|
||||
path: str,
|
||||
@@ -169,12 +171,32 @@ async def propfind(
|
||||
ctype = None if is_dir else (mimetypes.guess_type(name)[0] or "application/octet-stream")
|
||||
responses.append(_build_prop_response(full_path, name, is_dir, size, mtime, ctype))
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(404, detail="Not found")
|
||||
st = None
|
||||
except HTTPException as e:
|
||||
if e.status_code != 404:
|
||||
raise
|
||||
st = None
|
||||
|
||||
if st is None:
|
||||
is_mount_root = False
|
||||
try:
|
||||
_, rel = await VirtualFSService.resolve_adapter_by_path(full_path)
|
||||
is_mount_root = rel == ""
|
||||
except HTTPException:
|
||||
is_mount_root = False
|
||||
|
||||
if not is_mount_root and full_path != "/":
|
||||
listing_probe = await VirtualFSService.list_virtual_dir(full_path, page_num=1, page_size=1)
|
||||
if not (listing_probe.get("items") or []):
|
||||
raise HTTPException(404, detail="Not found")
|
||||
|
||||
name = "/" if full_path == "/" else (full_path.rstrip("/").rsplit("/", 1)[-1] or "/")
|
||||
responses.append(_build_prop_response(full_path, name, True, None, 0, None))
|
||||
|
||||
if depth in ("1", "infinity"):
|
||||
try:
|
||||
listing = await VirtualFSService.list_virtual_dir(full_path, page_num=1, page_size=1000)
|
||||
for ent in listing["items"]:
|
||||
for ent in (listing.get("items") or []):
|
||||
is_dir = bool(ent.get("is_dir"))
|
||||
name = ent.get("name")
|
||||
child_path = full_path.rstrip("/") + "/" + name
|
||||
@@ -193,6 +215,7 @@ async def propfind(
|
||||
|
||||
|
||||
@router.get("/{path:path}")
|
||||
@audit(action=AuditAction.DOWNLOAD, description="WebDAV: GET", user_kw="user")
|
||||
async def dav_get(
|
||||
path: str,
|
||||
request: Request,
|
||||
@@ -205,8 +228,10 @@ async def dav_get(
|
||||
|
||||
|
||||
@router.head("/{path:path}")
|
||||
@audit(action=AuditAction.READ, description="WebDAV: HEAD", user_kw="user")
|
||||
async def dav_head(
|
||||
path: str,
|
||||
_request: Request,
|
||||
_enabled: None = Depends(_ensure_webdav_enabled),
|
||||
user: User = Depends(_get_basic_user),
|
||||
):
|
||||
@@ -231,6 +256,7 @@ async def dav_head(
|
||||
|
||||
|
||||
@router.api_route("/{path:path}", methods=["PUT"])
|
||||
@audit(action=AuditAction.UPLOAD, description="WebDAV: PUT", user_kw="user")
|
||||
async def dav_put(
|
||||
path: str,
|
||||
request: Request,
|
||||
@@ -247,8 +273,10 @@ async def dav_put(
|
||||
|
||||
|
||||
@router.api_route("/{path:path}", methods=["DELETE"])
|
||||
@audit(action=AuditAction.DELETE, description="WebDAV: DELETE", user_kw="user")
|
||||
async def dav_delete(
|
||||
path: str,
|
||||
_request: Request,
|
||||
_enabled: None = Depends(_ensure_webdav_enabled),
|
||||
user: User = Depends(_get_basic_user),
|
||||
):
|
||||
@@ -258,8 +286,10 @@ async def dav_delete(
|
||||
|
||||
|
||||
@router.api_route("/{path:path}", methods=["MKCOL"])
|
||||
@audit(action=AuditAction.CREATE, description="WebDAV: MKCOL", user_kw="user")
|
||||
async def dav_mkcol(
|
||||
path: str,
|
||||
_request: Request,
|
||||
_enabled: None = Depends(_ensure_webdav_enabled),
|
||||
user: User = Depends(_get_basic_user),
|
||||
):
|
||||
@@ -281,7 +311,13 @@ def _parse_destination(dest: str) -> str:
|
||||
|
||||
|
||||
@router.api_route("/{path:path}", methods=["MOVE"])
|
||||
async def dav_move(path: str, request: Request, user: User = Depends(_get_basic_user)):
|
||||
@audit(action=AuditAction.UPDATE, description="WebDAV: MOVE", user_kw="user")
|
||||
async def dav_move(
|
||||
path: str,
|
||||
request: Request,
|
||||
_enabled: None = Depends(_ensure_webdav_enabled),
|
||||
user: User = Depends(_get_basic_user),
|
||||
):
|
||||
full_src = _normalize_fs_path(path)
|
||||
dest_header = request.headers.get("Destination")
|
||||
dst = _parse_destination(dest_header or "")
|
||||
@@ -291,7 +327,13 @@ async def dav_move(path: str, request: Request, user: User = Depends(_get_basic_
|
||||
|
||||
|
||||
@router.api_route("/{path:path}", methods=["COPY"])
|
||||
async def dav_copy(path: str, request: Request, user: User = Depends(_get_basic_user)):
|
||||
@audit(action=AuditAction.CREATE, description="WebDAV: COPY", user_kw="user")
|
||||
async def dav_copy(
|
||||
path: str,
|
||||
request: Request,
|
||||
_enabled: None = Depends(_ensure_webdav_enabled),
|
||||
user: User = Depends(_get_basic_user),
|
||||
):
|
||||
full_src = _normalize_fs_path(path)
|
||||
dest_header = request.headers.get("Destination")
|
||||
dst = _parse_destination(dest_header or "")
|
||||
|
||||
@@ -16,7 +16,7 @@ class VirtualFSProcessingMixin(VirtualFSTransferMixin):
|
||||
save_to: str | None = None,
|
||||
overwrite: bool = False,
|
||||
) -> Any:
|
||||
from domain.processors.service import get_processor
|
||||
from domain.processors import get_processor
|
||||
|
||||
processor = get_processor(processor_type)
|
||||
if not processor:
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Tuple
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import Response
|
||||
|
||||
from domain.adapters.registry import runtime_registry
|
||||
from domain.adapters import runtime_registry
|
||||
from models import StorageAdapter
|
||||
|
||||
from .common import VirtualFSCommonMixin
|
||||
|
||||
@@ -4,8 +4,14 @@ import re
|
||||
from fastapi import HTTPException, UploadFile
|
||||
from fastapi.responses import Response
|
||||
|
||||
from domain.config.service import ConfigService
|
||||
from domain.virtual_fs.thumbnail import get_or_create_thumb, is_image_filename, is_raw_filename, is_video_filename
|
||||
from domain.config import ConfigService
|
||||
from .thumbnail import (
|
||||
get_or_create_thumb,
|
||||
is_image_filename,
|
||||
is_raw_filename,
|
||||
is_video_filename,
|
||||
raw_bytes_to_jpeg,
|
||||
)
|
||||
|
||||
from .temp_link import VirtualFSTempLinkMixin
|
||||
|
||||
@@ -16,19 +22,9 @@ class VirtualFSRouteMixin(VirtualFSTempLinkMixin):
|
||||
full_path = cls._normalize_path(full_path)
|
||||
|
||||
if is_raw_filename(full_path):
|
||||
import io
|
||||
|
||||
import rawpy
|
||||
from PIL import Image
|
||||
|
||||
try:
|
||||
raw_data = await cls.read_file(full_path)
|
||||
with rawpy.imread(io.BytesIO(raw_data)) as raw:
|
||||
rgb = raw.postprocess(use_camera_wb=True, output_bps=8)
|
||||
im = Image.fromarray(rgb)
|
||||
buf = io.BytesIO()
|
||||
im.save(buf, "JPEG", quality=90)
|
||||
content = buf.getvalue()
|
||||
content = raw_bytes_to_jpeg(raw_data, filename=full_path)
|
||||
return Response(content=content, media_type="image/jpeg")
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(404, detail="File not found")
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
from .search_service import VirtualFSSearchService
|
||||
|
||||
__all__ = ["VirtualFSSearchService"]
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
|
||||
from domain.auth.service import get_current_active_user
|
||||
from domain.auth.types import User
|
||||
from domain.virtual_fs.search.search_service import VirtualFSSearchService
|
||||
from api.response import success
|
||||
from domain.auth import User, get_current_active_user
|
||||
from .search_service import VirtualFSSearchService
|
||||
|
||||
router = APIRouter(prefix="/api/fs/search", tags=["search"])
|
||||
|
||||
@@ -17,10 +17,11 @@ async def search_files(
|
||||
user: User = Depends(get_current_active_user),
|
||||
):
|
||||
if not q.strip():
|
||||
return {"items": [], "query": q}
|
||||
return success({"items": [], "query": q, "mode": mode})
|
||||
|
||||
top_k = max(top_k, 1)
|
||||
page = max(page, 1)
|
||||
page_size = max(min(page_size, 100), 1)
|
||||
|
||||
return await VirtualFSSearchService.search(q, top_k, mode, page, page_size)
|
||||
data = await VirtualFSSearchService.search(q, top_k, mode, page, page_size)
|
||||
return success(data)
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from domain.virtual_fs.types import SearchResultItem
|
||||
from domain.ai.inference import get_text_embedding
|
||||
from domain.ai.service import VectorDBService
|
||||
from domain.ai import FILE_COLLECTION_NAME, VECTOR_COLLECTION_NAME, VectorDBService, get_text_embedding
|
||||
from ..types import SearchResultItem
|
||||
|
||||
|
||||
def _normalize_result(raw: Dict[str, Any], source: str, fallback_score: float = 0.0) -> SearchResultItem:
|
||||
@@ -53,7 +52,7 @@ async def _vector_search(query: str, top_k: int) -> List[SearchResultItem]:
|
||||
return []
|
||||
|
||||
try:
|
||||
raw_results = await vector_db.search_vectors("vector_collection", embedding, max(top_k, 10))
|
||||
raw_results = await vector_db.search_vectors(VECTOR_COLLECTION_NAME, embedding, max(top_k, 10))
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
@@ -68,12 +67,15 @@ async def _filename_search(query: str, page: int, page_size: int) -> Tuple[List[
|
||||
vector_db = VectorDBService()
|
||||
limit = max(page * page_size + 1, page_size * (page + 2))
|
||||
limit = min(limit, 2000)
|
||||
try:
|
||||
raw_results = await vector_db.search_by_path("vector_collection", query, limit)
|
||||
except Exception:
|
||||
return [], False
|
||||
records: List[Dict[str, Any]] = []
|
||||
for collection_name in (FILE_COLLECTION_NAME, VECTOR_COLLECTION_NAME):
|
||||
try:
|
||||
raw_results = await vector_db.search_by_path(collection_name, query, limit)
|
||||
except Exception:
|
||||
continue
|
||||
if raw_results:
|
||||
records.extend(raw_results[0] or [])
|
||||
|
||||
records = raw_results[0] if raw_results else []
|
||||
deduped: List[SearchResultItem] = []
|
||||
seen_paths: set[str] = set()
|
||||
for record in records or []:
|
||||
|
||||
@@ -5,7 +5,7 @@ import time
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from domain.config.service import ConfigService
|
||||
from domain.config import ConfigService
|
||||
|
||||
from .processing import VirtualFSProcessingMixin
|
||||
|
||||
|
||||
@@ -2,10 +2,13 @@ import asyncio
|
||||
import inspect
|
||||
import io
|
||||
import hashlib
|
||||
import subprocess
|
||||
import tempfile
|
||||
from contextlib import suppress
|
||||
from pathlib import Path
|
||||
from typing import Tuple
|
||||
|
||||
from PIL import Image
|
||||
from fastapi import HTTPException
|
||||
|
||||
ALLOWED_EXT = {"jpg", "jpeg", "png", "webp", "gif", "bmp",
|
||||
@@ -58,7 +61,6 @@ def _ensure_cache_dir(p: Path):
|
||||
|
||||
|
||||
def _image_to_webp(im, w: int, h: int, fit: str) -> Tuple[bytes, str]:
|
||||
from PIL import Image
|
||||
if im.mode not in ("RGB", "RGBA"):
|
||||
im = im.convert("RGBA" if im.mode in ("P", "LA") else "RGB")
|
||||
if fit == 'cover':
|
||||
@@ -81,30 +83,91 @@ def _image_to_webp(im, w: int, h: int, fit: str) -> Tuple[bytes, str]:
|
||||
return buf.getvalue(), 'image/webp'
|
||||
|
||||
|
||||
def generate_thumb(data: bytes, w: int, h: int, fit: str, is_raw: bool = False) -> Tuple[bytes, str]:
|
||||
from PIL import Image
|
||||
if is_raw:
|
||||
def _load_image_with_pillow(data: bytes):
|
||||
im = Image.open(io.BytesIO(data))
|
||||
im.load()
|
||||
return im
|
||||
|
||||
|
||||
def _load_raw_with_ffmpeg(data: bytes, filename: str | None) -> "Image.Image":
|
||||
src_path: str | None = None
|
||||
dst_path: str | None = None
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(suffix=Path(filename or "").suffix or ".raw", delete=False) as src_tmp:
|
||||
src_tmp.write(data)
|
||||
src_path = src_tmp.name
|
||||
dst_tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
|
||||
dst_path = dst_tmp.name
|
||||
dst_tmp.close()
|
||||
|
||||
cmd = [
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-hide_banner",
|
||||
"-loglevel", "error",
|
||||
"-i", src_path,
|
||||
"-frames:v", "1",
|
||||
dst_path,
|
||||
]
|
||||
try:
|
||||
import rawpy
|
||||
with rawpy.imread(io.BytesIO(data)) as raw:
|
||||
try:
|
||||
thumb = raw.extract_thumb()
|
||||
except rawpy.LibRawNoThumbnailError:
|
||||
thumb = None
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
check=True,
|
||||
)
|
||||
except FileNotFoundError as e:
|
||||
raise RuntimeError("未找到 ffmpeg,可执行文件需要在 PATH 中") from e
|
||||
except subprocess.CalledProcessError as e:
|
||||
stderr = (e.stderr or b"").decode().strip()
|
||||
stdout = (e.stdout or b"").decode().strip()
|
||||
message = stderr or stdout or "ffmpeg 转换 RAW 失败"
|
||||
raise RuntimeError(message) from e
|
||||
|
||||
if thumb is not None and thumb.format in [rawpy.ThumbFormat.JPEG, rawpy.ThumbFormat.BITMAP]:
|
||||
im = Image.open(io.BytesIO(thumb.data))
|
||||
else:
|
||||
rgb = raw.postprocess(
|
||||
use_camera_wb=False, use_auto_wb=True, output_bps=8)
|
||||
im = Image.fromarray(rgb)
|
||||
except Exception as e:
|
||||
print(f"rawpy processing failed: {e}")
|
||||
raise e
|
||||
with open(dst_path, "rb") as f:
|
||||
img_bytes = f.read()
|
||||
im = Image.open(io.BytesIO(img_bytes))
|
||||
im.load()
|
||||
return im
|
||||
finally:
|
||||
if dst_path:
|
||||
with suppress(FileNotFoundError):
|
||||
Path(dst_path).unlink()
|
||||
if src_path:
|
||||
with suppress(FileNotFoundError):
|
||||
Path(src_path).unlink()
|
||||
|
||||
else:
|
||||
im = Image.open(io.BytesIO(data))
|
||||
|
||||
def load_image_from_bytes(data: bytes, *, filename: str | None = None, is_raw: bool = False):
|
||||
if not is_raw:
|
||||
return _load_image_with_pillow(data)
|
||||
|
||||
first_error: Exception | None = None
|
||||
try:
|
||||
return _load_image_with_pillow(data)
|
||||
except Exception as exc:
|
||||
first_error = exc
|
||||
|
||||
try:
|
||||
return _load_raw_with_ffmpeg(data, filename)
|
||||
except Exception as exc:
|
||||
msg = f"RAW 解码失败: ffmpeg 处理异常 {exc}"
|
||||
if first_error:
|
||||
msg = f"RAW 解码失败: Pillow 异常 {first_error}; ffmpeg 异常 {exc}"
|
||||
raise RuntimeError(msg) from exc
|
||||
|
||||
|
||||
def raw_bytes_to_jpeg(data: bytes, filename: str | None = None) -> bytes:
|
||||
im = load_image_from_bytes(data, filename=filename, is_raw=True)
|
||||
if im.mode != "RGB":
|
||||
im = im.convert("RGB")
|
||||
buf = io.BytesIO()
|
||||
im.save(buf, "JPEG", quality=90)
|
||||
return buf.getvalue()
|
||||
|
||||
|
||||
def generate_thumb(data: bytes, w: int, h: int, fit: str, is_raw: bool = False, filename: str | None = None) -> Tuple[bytes, str]:
|
||||
im = load_image_from_bytes(data, filename=filename, is_raw=is_raw)
|
||||
return _image_to_webp(im, w, h, fit)
|
||||
|
||||
|
||||
@@ -434,7 +497,7 @@ async def get_or_create_thumb(adapter, adapter_id: int, root: str, rel: str, w:
|
||||
read_data = await adapter.read_file(root, rel)
|
||||
try:
|
||||
thumb_bytes, mime = generate_thumb(
|
||||
read_data, w, h, fit, is_raw=is_raw_filename(rel))
|
||||
read_data, w, h, fit, is_raw=is_raw_filename(rel), filename=rel)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise HTTPException(
|
||||
|
||||
@@ -273,7 +273,7 @@ class VirtualFSTransferMixin(VirtualFSFileOpsMixin):
|
||||
"overwrite": overwrite,
|
||||
}
|
||||
|
||||
from domain.tasks.task_queue import task_queue_service
|
||||
from domain.tasks import task_queue_service
|
||||
|
||||
task = await task_queue_service.add_task("cross_mount_transfer", payload)
|
||||
return {
|
||||
@@ -286,7 +286,7 @@ class VirtualFSTransferMixin(VirtualFSFileOpsMixin):
|
||||
|
||||
@classmethod
|
||||
async def run_cross_mount_transfer_task(cls, task: "Task") -> Dict[str, Any]:
|
||||
from domain.tasks.task_queue import task_queue_service
|
||||
from domain.tasks import task_queue_service
|
||||
|
||||
params = task.task_info or {}
|
||||
operation = params.get("operation")
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
python migrate/run.py
|
||||
exec gunicorn -k uvicorn.workers.UvicornWorker -w 1 -b 0.0.0.0:80 main:app
|
||||
port="${FOXEL_PORT:-80}"
|
||||
exec gunicorn -k uvicorn.workers.UvicornWorker -w 1 -b "0.0.0.0:${port}" main:app
|
||||
|
||||
59
main.py
59
main.py
@@ -2,15 +2,16 @@ import os
|
||||
from pathlib import Path
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from domain.config.service import ConfigService, VERSION
|
||||
from domain.adapters.registry import runtime_registry
|
||||
from domain.adapters import runtime_registry
|
||||
from domain.config import ConfigService, VERSION
|
||||
from db.session import close_db, init_db
|
||||
from api.routers import include_routers
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||
from middleware.exception_handler import (
|
||||
global_exception_handler,
|
||||
http_exception_handler,
|
||||
@@ -19,43 +20,63 @@ from middleware.exception_handler import (
|
||||
)
|
||||
import httpx
|
||||
from dotenv import load_dotenv
|
||||
from domain.tasks.task_queue import task_queue_service
|
||||
from domain.tasks import task_queue_service
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class SPAStaticFiles(StaticFiles):
|
||||
async def get_response(self, path, scope):
|
||||
response = await super().get_response(path, scope)
|
||||
if response.status_code == 404:
|
||||
return await super().get_response("index.html", scope)
|
||||
try:
|
||||
response = await super().get_response(path, scope)
|
||||
except StarletteHTTPException as exc:
|
||||
if exc.status_code != 404:
|
||||
raise
|
||||
if self._should_spa_fallback(scope):
|
||||
return FileResponse(INDEX_FILE)
|
||||
raise
|
||||
|
||||
if response.status_code == 404 and self._should_spa_fallback(scope):
|
||||
return FileResponse(INDEX_FILE)
|
||||
return response
|
||||
|
||||
@staticmethod
|
||||
def _should_spa_fallback(scope) -> bool:
|
||||
return (
|
||||
scope.get("method") == "GET"
|
||||
and _request_accepts_html(scope)
|
||||
and not (scope.get("path") or "").startswith(SPA_EXCLUDE_PREFIXES)
|
||||
and INDEX_FILE.exists()
|
||||
)
|
||||
|
||||
|
||||
INDEX_FILE = Path("web/dist/index.html")
|
||||
SPA_EXCLUDE_PREFIXES = ("/api", "/docs", "/openapi.json", "/webdav", "/s3")
|
||||
|
||||
|
||||
async def spa_fallback_middleware(request: Request, call_next):
|
||||
response = await call_next(request)
|
||||
if (
|
||||
response.status_code == 404
|
||||
and request.method == "GET"
|
||||
and "text/html" in request.headers.get("accept", "")
|
||||
and not request.url.path.startswith(SPA_EXCLUDE_PREFIXES)
|
||||
and INDEX_FILE.exists()
|
||||
):
|
||||
return FileResponse(INDEX_FILE)
|
||||
return response
|
||||
def _request_accepts_html(scope) -> bool:
|
||||
for k, v in scope.get("headers") or []:
|
||||
if k == b"accept":
|
||||
return "text/html" in v.decode("latin-1")
|
||||
return False
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
os.makedirs("data/db", exist_ok=True)
|
||||
os.makedirs("data/plugins", exist_ok=True)
|
||||
await init_db()
|
||||
await runtime_registry.refresh()
|
||||
await ConfigService.set("APP_VERSION", VERSION)
|
||||
await task_queue_service.start_worker()
|
||||
|
||||
# 加载已安装的插件
|
||||
from domain.plugins import init_plugins
|
||||
await init_plugins(app)
|
||||
|
||||
# 在所有路由加载完成后,挂载静态文件服务(放在最后以避免覆盖 API 路由)
|
||||
app.mount("/", SPAStaticFiles(directory="web/dist", html=True, check_dir=False), name="static")
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
@@ -69,7 +90,6 @@ def create_app() -> FastAPI:
|
||||
description="A highly extensible private cloud storage solution for individuals and teams",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
app.middleware("http")(spa_fallback_middleware)
|
||||
include_routers(app)
|
||||
app.add_exception_handler(HTTPException, http_exception_handler)
|
||||
app.add_exception_handler(RequestValidationError, validation_exception_handler)
|
||||
@@ -86,7 +106,6 @@ app.add_middleware(
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
app.mount("/", SPAStaticFiles(directory="web/dist", html=True, check_dir=False), name="static")
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)
|
||||
|
||||
@@ -168,24 +168,28 @@ class ShareLink(Model):
|
||||
|
||||
class Plugin(Model):
|
||||
id = fields.IntField(pk=True)
|
||||
url = fields.CharField(max_length=2048)
|
||||
enabled = fields.BooleanField(default=True)
|
||||
|
||||
open_app = fields.BooleanField(default=False)
|
||||
|
||||
key = fields.CharField(max_length=100, null=True)
|
||||
key = fields.CharField(max_length=100, unique=True) # 插件唯一标识
|
||||
name = fields.CharField(max_length=255, null=True)
|
||||
version = fields.CharField(max_length=50, null=True)
|
||||
supported_exts = fields.JSONField(null=True)
|
||||
|
||||
default_bounds = fields.JSONField(null=True)
|
||||
default_maximized = fields.BooleanField(null=True)
|
||||
|
||||
icon = fields.CharField(max_length=2048, null=True)
|
||||
description = fields.TextField(null=True)
|
||||
author = fields.CharField(max_length=255, null=True)
|
||||
website = fields.CharField(max_length=2048, null=True)
|
||||
github = fields.CharField(max_length=2048, null=True)
|
||||
license = fields.CharField(max_length=100, null=True)
|
||||
|
||||
# 完整 manifest 存储
|
||||
manifest = fields.JSONField(null=True)
|
||||
|
||||
# 前端相关配置(从 manifest.frontend 提取)
|
||||
open_app = fields.BooleanField(default=False)
|
||||
supported_exts = fields.JSONField(null=True)
|
||||
default_bounds = fields.JSONField(null=True)
|
||||
default_maximized = fields.BooleanField(null=True)
|
||||
icon = fields.CharField(max_length=2048, null=True)
|
||||
|
||||
# 已加载的组件追踪
|
||||
loaded_routes = fields.JSONField(null=True) # ["/api/plugins/xxx", ...]
|
||||
loaded_processors = fields.JSONField(null=True) # ["processor_type", ...]
|
||||
|
||||
created_at = fields.DatetimeField(auto_now_add=True)
|
||||
updated_at = fields.DatetimeField(auto_now=True)
|
||||
|
||||
@@ -3,24 +3,21 @@ name = "foxel"
|
||||
version = "1"
|
||||
description = "foxel.cc"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.13"
|
||||
requires-python = ">=3.14"
|
||||
dependencies = [
|
||||
"aioboto3>=15.2.0",
|
||||
"aiofiles>=25.1.0",
|
||||
"fastapi>=0.116.1",
|
||||
"passlib[bcrypt]>=1.7.4",
|
||||
"bcrypt>=3.2.2,<4.0",
|
||||
"pillow>=11.3.0",
|
||||
"pyjwt>=2.10.1",
|
||||
"pysocks>=1.7.1",
|
||||
"python-dotenv>=1.1.1",
|
||||
"python-multipart>=0.0.20",
|
||||
"qdrant-client>=1.15.1",
|
||||
"rawpy>=0.25.1",
|
||||
"telethon>=1.41.2",
|
||||
"tortoise-orm>=0.25.2",
|
||||
"uvicorn>=0.37.0",
|
||||
"pymilvus[milvus-lite]>=2.6.2",
|
||||
"aioboto3>=15.5.0",
|
||||
"bcrypt>=5.0.0",
|
||||
"fastapi>=0.127.0",
|
||||
"paramiko>=4.0.0",
|
||||
"pydantic[email]>=2.11.7",
|
||||
"pillow>=12.0.0",
|
||||
"pydantic[email]>=2.12.5",
|
||||
"pyjwt>=2.10.1",
|
||||
"pymilvus[milvus-lite]>=2.6.5",
|
||||
"pysocks>=1.7.1",
|
||||
"python-dotenv>=1.2.1",
|
||||
"python-multipart>=0.0.21",
|
||||
"qdrant-client>=1.16.2",
|
||||
"telethon>=1.42.0",
|
||||
"tortoise-orm>=0.25.3",
|
||||
"uvicorn>=0.40.0",
|
||||
]
|
||||
|
||||
@@ -232,7 +232,7 @@ install_new_foxel() {
|
||||
if ss -tuln | grep -q ":${new_port}\b"; then
|
||||
warn "端口 $new_port 已被占用,请换一个。"
|
||||
else
|
||||
sed -i -E "s|\"[0-9]{1,5}:80\"|\"$new_port:80\"|" compose.yaml
|
||||
sed -i -E "s|(FOXEL_HOST_PORT:-)[0-9]{1,5}|\\1$new_port|" compose.yaml
|
||||
info "端口已成功修改为 $new_port。"
|
||||
break
|
||||
fi
|
||||
|
||||
@@ -13,8 +13,8 @@ PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
from domain.auth.service import get_password_hash
|
||||
from domain.config.service import VERSION
|
||||
from domain.config import VERSION
|
||||
from domain.auth import get_password_hash
|
||||
|
||||
|
||||
def _project_root() -> Path:
|
||||
|
||||
635
web/bun.lock
635
web/bun.lock
File diff suppressed because it is too large
Load Diff
@@ -12,11 +12,15 @@ export default tseslint.config([
|
||||
extends: [
|
||||
js.configs.recommended,
|
||||
tseslint.configs.recommended,
|
||||
reactHooks.configs['recommended-latest'],
|
||||
reactRefresh.configs.vite,
|
||||
],
|
||||
plugins: {
|
||||
'react-hooks': reactHooks,
|
||||
'react-refresh': reactRefresh,
|
||||
},
|
||||
rules: {
|
||||
'@typescript-eslint/no-explicit-any': 'off',
|
||||
'react-hooks/rules-of-hooks': 'error',
|
||||
'react-hooks/exhaustive-deps': 'warn',
|
||||
'react-refresh/only-export-components': [
|
||||
'error',
|
||||
{
|
||||
|
||||
@@ -6,13 +6,13 @@
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>Foxel</title>
|
||||
<link rel='stylesheet'
|
||||
href='https://chinese-fonts-cdn.deno.dev/packages/maple-mono-cn/dist/MapleMono-CN-Regular/result.css' />
|
||||
href='https://foxel.cc/fonts/result.css' />
|
||||
</head>
|
||||
|
||||
<body>
|
||||
<style>
|
||||
* {
|
||||
font-family: 'Maple Mono CN';
|
||||
font-family: 'Maple Mono Normal NL NF CN';
|
||||
}
|
||||
</style>
|
||||
<div id="root"></div>
|
||||
|
||||
@@ -10,30 +10,27 @@
|
||||
"preview": "vite preview"
|
||||
},
|
||||
"dependencies": {
|
||||
"@ant-design/icons": "6.x",
|
||||
"@ant-design/v5-patch-for-react-19": "^1.0.3",
|
||||
"@ant-design/icons": "6",
|
||||
"@monaco-editor/react": "^4.7.0",
|
||||
"@uiw/react-md-editor": "^4.0.8",
|
||||
"antd": "^5.27.0",
|
||||
"antd": "6",
|
||||
"artplayer": "^5.3.0",
|
||||
"date-fns": "^4.1.0",
|
||||
"monaco-editor": "^0.53.0",
|
||||
"react": "^19.1.1",
|
||||
"react-dom": "^19.1.1",
|
||||
"react": "^19.2.3",
|
||||
"react-dom": "^19.2.3",
|
||||
"react-markdown": "^10.1.0",
|
||||
"react-router": "^7.8.0"
|
||||
"react-router": "^7.11.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@eslint/js": "^9.33.0",
|
||||
"@types/react": "^19.1.10",
|
||||
"@types/react-dom": "^19.1.7",
|
||||
"@eslint/js": "^9.39.2",
|
||||
"@types/react": "^19.2.7",
|
||||
"@types/react-dom": "^19.2.3",
|
||||
"@vitejs/plugin-react": "^5.1.2",
|
||||
"eslint": "^9.39.2",
|
||||
"eslint-plugin-react-hooks": "^5.2.0",
|
||||
"eslint-plugin-react-hooks": "^7.0.1",
|
||||
"eslint-plugin-react-refresh": "^0.4.26",
|
||||
"globals": "^16.3.0",
|
||||
"typescript": "~5.8.3",
|
||||
"typescript-eslint": "^8.39.1",
|
||||
"vite": "^7.1.2"
|
||||
"globals": "^16.5.0",
|
||||
"typescript": "~5.9.3",
|
||||
"typescript-eslint": "^8.51.0",
|
||||
"vite": "^7.3.0"
|
||||
}
|
||||
}
|
||||
|
||||
28
web/plugin-frame.html
Normal file
28
web/plugin-frame.html
Normal file
@@ -0,0 +1,28 @@
|
||||
<!doctype html>
|
||||
<html lang="zh-CN">
|
||||
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>Foxel Plugin Frame</title>
|
||||
<link rel='stylesheet' href='https://foxel.cc/fonts/result.css' />
|
||||
<style>
|
||||
html,
|
||||
body,
|
||||
#root {
|
||||
height: 100%;
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
* {
|
||||
font-family: 'Maple Mono Normal NL NF CN';
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
|
||||
<body>
|
||||
<div id="root"></div>
|
||||
<script type="module" src="/src/plugin-frame.ts"></script>
|
||||
</body>
|
||||
|
||||
</html>
|
||||
121
web/src/api/agent.ts
Normal file
121
web/src/api/agent.ts
Normal file
@@ -0,0 +1,121 @@
|
||||
import request, { API_BASE_URL } from './client';
|
||||
|
||||
export type AgentChatMessage = Record<string, any>;
|
||||
|
||||
export interface AgentChatContext {
|
||||
current_path?: string | null;
|
||||
}
|
||||
|
||||
export interface AgentChatRequest {
|
||||
messages: AgentChatMessage[];
|
||||
auto_execute?: boolean;
|
||||
approved_tool_call_ids?: string[];
|
||||
rejected_tool_call_ids?: string[];
|
||||
context?: AgentChatContext;
|
||||
}
|
||||
|
||||
export interface PendingToolCall {
|
||||
id: string;
|
||||
name: string;
|
||||
arguments: Record<string, any>;
|
||||
requires_confirmation: boolean;
|
||||
}
|
||||
|
||||
export interface AgentChatResponse {
|
||||
messages: AgentChatMessage[];
|
||||
pending_tool_calls?: PendingToolCall[];
|
||||
}
|
||||
|
||||
export type AgentSseEvent =
|
||||
| { event: 'assistant_start'; data: { id: string } }
|
||||
| { event: 'assistant_delta'; data: { id: string; delta: string } }
|
||||
| { event: 'assistant_end'; data: { id: string; message: AgentChatMessage } }
|
||||
| { event: 'tool_start'; data: { tool_call_id: string; name: string } }
|
||||
| { event: 'tool_end'; data: { tool_call_id: string; name: string; message: AgentChatMessage } }
|
||||
| { event: 'pending'; data: { pending_tool_calls: PendingToolCall[] } }
|
||||
| { event: 'done'; data: AgentChatResponse };
|
||||
|
||||
export const agentApi = {
|
||||
chat: (payload: AgentChatRequest) =>
|
||||
request<AgentChatResponse>('/agent/chat', {
|
||||
method: 'POST',
|
||||
json: payload,
|
||||
}),
|
||||
chatStream: async (
|
||||
payload: AgentChatRequest,
|
||||
onEvent: (evt: AgentSseEvent) => void,
|
||||
options?: { signal?: AbortSignal }
|
||||
) => {
|
||||
const headers: Record<string, string> = {
|
||||
'Content-Type': 'application/json',
|
||||
'Accept': 'text/event-stream',
|
||||
};
|
||||
const token = localStorage.getItem('token');
|
||||
if (token) headers['Authorization'] = `Bearer ${token}`;
|
||||
|
||||
const resp = await fetch(`${API_BASE_URL}/agent/chat/stream`, {
|
||||
method: 'POST',
|
||||
headers,
|
||||
body: JSON.stringify(payload),
|
||||
signal: options?.signal,
|
||||
});
|
||||
|
||||
if (!resp.ok) {
|
||||
let errMsg = resp.statusText;
|
||||
try {
|
||||
const data = await resp.json();
|
||||
if (Array.isArray((data as any)?.detail)) {
|
||||
errMsg = (data as any).detail.map((e: any) => e.msg || JSON.stringify(e)).join('; ');
|
||||
} else {
|
||||
errMsg = (typeof (data as any)?.detail === 'string') ? (data as any).detail : JSON.stringify(data);
|
||||
}
|
||||
} catch {
|
||||
try {
|
||||
errMsg = await resp.text();
|
||||
} catch { void 0; }
|
||||
}
|
||||
throw new Error(errMsg || `Request failed: ${resp.status}`);
|
||||
}
|
||||
|
||||
const reader = resp.body?.getReader();
|
||||
if (!reader) throw new Error('Stream not supported');
|
||||
|
||||
const decoder = new TextDecoder();
|
||||
let buffer = '';
|
||||
|
||||
const flush = (raw: string) => {
|
||||
const lines = raw.split(/\r?\n/);
|
||||
let eventName = 'message';
|
||||
const dataLines: string[] = [];
|
||||
for (const line of lines) {
|
||||
if (line.startsWith('event:')) {
|
||||
eventName = line.slice(6).trim();
|
||||
} else if (line.startsWith('data:')) {
|
||||
dataLines.push(line.slice(5).trimStart());
|
||||
}
|
||||
}
|
||||
const dataStr = dataLines.join('\n').trim();
|
||||
if (!eventName || !dataStr) return;
|
||||
try {
|
||||
const data = JSON.parse(dataStr);
|
||||
onEvent({ event: eventName as any, data } as any);
|
||||
} catch {
|
||||
// ignore parse error
|
||||
}
|
||||
};
|
||||
|
||||
while (true) {
|
||||
const { value, done } = await reader.read();
|
||||
if (done) break;
|
||||
buffer += decoder.decode(value, { stream: true });
|
||||
while (true) {
|
||||
const idx = buffer.indexOf('\n\n');
|
||||
if (idx === -1) break;
|
||||
const chunk = buffer.slice(0, idx);
|
||||
buffer = buffer.slice(idx + 2);
|
||||
if (chunk.trim()) flush(chunk);
|
||||
}
|
||||
}
|
||||
if (buffer.trim()) flush(buffer);
|
||||
},
|
||||
};
|
||||
@@ -13,8 +13,9 @@ export interface AIProviderPayload {
|
||||
extra_config?: Record<string, unknown> | null;
|
||||
}
|
||||
|
||||
export interface AIProvider extends Omit<AIProviderPayload, 'extra_config'> {
|
||||
export interface AIProvider extends Omit<AIProviderPayload, 'extra_config' | 'api_key'> {
|
||||
id: number;
|
||||
has_api_key: boolean;
|
||||
extra_config: Record<string, unknown>;
|
||||
created_at: string;
|
||||
updated_at: string;
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
import request from './client';
|
||||
|
||||
export async function getConfig(key: string) {
|
||||
return request<{ key: string; value: string }>('/config?key=' + encodeURIComponent(key));
|
||||
return request<{ key: string; value: string }>('/config/?key=' + encodeURIComponent(key));
|
||||
}
|
||||
|
||||
export async function setConfig(key: string, value: string) {
|
||||
export async function setConfig(key: string, value?: string | null) {
|
||||
const form = new FormData();
|
||||
form.append('key', key);
|
||||
form.append('value', value);
|
||||
form.append('value', value ?? '');
|
||||
return request('/config/', { method: 'POST', formData: form });
|
||||
}
|
||||
|
||||
|
||||
@@ -28,7 +28,67 @@ export interface RepoQueryParams {
|
||||
pageSize?: number;
|
||||
}
|
||||
|
||||
// foxel-core 应用中心的数据结构
|
||||
export interface FoxelCoreApp {
|
||||
key: string;
|
||||
version: string;
|
||||
name: {
|
||||
zh: string;
|
||||
en: string;
|
||||
};
|
||||
description: {
|
||||
zh: string;
|
||||
en: string;
|
||||
};
|
||||
author: string;
|
||||
website: string;
|
||||
tags: {
|
||||
zh: string[];
|
||||
en: string[];
|
||||
};
|
||||
approvedAt: number;
|
||||
detailUrl: string;
|
||||
downloadUrl: string;
|
||||
}
|
||||
|
||||
export interface FoxelCoreAppsResponse {
|
||||
apps: FoxelCoreApp[];
|
||||
}
|
||||
|
||||
export interface FoxelCoreAppVersion {
|
||||
version: string;
|
||||
name: {
|
||||
zh: string;
|
||||
en: string;
|
||||
};
|
||||
description: {
|
||||
zh: string;
|
||||
en: string;
|
||||
};
|
||||
author: string;
|
||||
website: string;
|
||||
tags: {
|
||||
zh: string[];
|
||||
en: string[];
|
||||
};
|
||||
approvedAt: number;
|
||||
releaseNotesMd: string | null;
|
||||
}
|
||||
|
||||
export interface FoxelCoreAppDetail {
|
||||
key: string;
|
||||
latest: FoxelCoreAppVersion & {
|
||||
downloadUrl: string;
|
||||
};
|
||||
versions: FoxelCoreAppVersion[];
|
||||
}
|
||||
|
||||
export interface FoxelCoreAppDetailResponse {
|
||||
app: FoxelCoreAppDetail;
|
||||
}
|
||||
|
||||
const CENTER_BASE = 'https://center.foxel.cc';
|
||||
const FOXEL_CORE_BASE = 'https://foxel.cc';
|
||||
|
||||
export function buildCenterUrl(path: string) {
|
||||
return new URL(path, CENTER_BASE).href;
|
||||
@@ -50,3 +110,46 @@ export async function fetchRepoList(params: RepoQueryParams = {}): Promise<RepoL
|
||||
return await resp.json();
|
||||
}
|
||||
|
||||
/**
|
||||
* 从 foxel-core 应用中心获取应用列表
|
||||
*/
|
||||
export async function fetchFoxelCoreApps(query?: string): Promise<FoxelCoreApp[]> {
|
||||
const url = new URL('/api/apps', FOXEL_CORE_BASE);
|
||||
const q = query?.trim();
|
||||
if (q) {
|
||||
url.searchParams.set('q', q);
|
||||
}
|
||||
const resp = await fetch(url.href);
|
||||
if (!resp.ok) {
|
||||
throw new Error(`Failed to fetch apps: ${resp.status}`);
|
||||
}
|
||||
const data: FoxelCoreAppsResponse = await resp.json();
|
||||
return data.apps;
|
||||
}
|
||||
|
||||
/**
|
||||
* 从 foxel-core 应用中心获取应用详情(含历史版本)
|
||||
*/
|
||||
export async function fetchFoxelCoreAppDetail(appKey: string): Promise<FoxelCoreAppDetail> {
|
||||
const url = `${FOXEL_CORE_BASE}/api/apps/${encodeURIComponent(appKey)}`;
|
||||
const resp = await fetch(url);
|
||||
if (!resp.ok) {
|
||||
throw new Error(`Failed to fetch app detail: ${resp.status}`);
|
||||
}
|
||||
const data: FoxelCoreAppDetailResponse = await resp.json();
|
||||
return data.app;
|
||||
}
|
||||
|
||||
/**
|
||||
* 从 foxel-core 下载应用包文件
|
||||
*/
|
||||
export async function downloadFoxelCoreApp(app: Pick<FoxelCoreApp, 'key' | 'version' | 'downloadUrl'>): Promise<File> {
|
||||
const url = `${FOXEL_CORE_BASE}${app.downloadUrl}`;
|
||||
const resp = await fetch(url);
|
||||
if (!resp.ok) {
|
||||
throw new Error(`Failed to download app: ${resp.status}`);
|
||||
}
|
||||
const blob = await resp.blob();
|
||||
const filename = `${app.key}-${app.version}.foxpkg`;
|
||||
return new File([blob], filename, { type: 'application/octet-stream' });
|
||||
}
|
||||
|
||||
@@ -2,46 +2,67 @@ import request from './client';
|
||||
|
||||
export interface PluginItem {
|
||||
id: number;
|
||||
url: string;
|
||||
enabled: boolean;
|
||||
key: string;
|
||||
open_app?: boolean | null;
|
||||
key?: string | null;
|
||||
name?: string | null;
|
||||
version?: string | null;
|
||||
supported_exts?: string[] | null;
|
||||
default_bounds?: Record<string, any> | null;
|
||||
default_bounds?: Record<string, number> | null;
|
||||
default_maximized?: boolean | null;
|
||||
icon?: string | null;
|
||||
description?: string | null;
|
||||
author?: string | null;
|
||||
website?: string | null;
|
||||
github?: string | null;
|
||||
license?: string | null;
|
||||
manifest?: Record<string, unknown> | null;
|
||||
loaded_routes?: string[] | null;
|
||||
loaded_processors?: string[] | null;
|
||||
}
|
||||
|
||||
export interface PluginCreate {
|
||||
url: string;
|
||||
enabled?: boolean;
|
||||
}
|
||||
|
||||
export interface PluginManifestUpdate {
|
||||
key?: string;
|
||||
name?: string;
|
||||
version?: string;
|
||||
open_app?: boolean;
|
||||
supported_exts?: string[];
|
||||
default_bounds?: Record<string, any>;
|
||||
default_maximized?: boolean;
|
||||
icon?: string;
|
||||
description?: string;
|
||||
author?: string;
|
||||
website?: string;
|
||||
github?: string;
|
||||
export interface PluginInstallResult {
|
||||
success: boolean;
|
||||
plugin?: PluginItem;
|
||||
message?: string;
|
||||
errors?: string[];
|
||||
}
|
||||
|
||||
export const pluginsApi = {
|
||||
/**
|
||||
* 获取已安装插件列表
|
||||
*/
|
||||
list: () => request<PluginItem[]>(`/plugins`),
|
||||
create: (payload: PluginCreate) => request<PluginItem>(`/plugins`, { method: 'POST', json: payload }),
|
||||
remove: (id: number) => request(`/plugins/${id}`, { method: 'DELETE' }),
|
||||
update: (id: number, payload: PluginCreate) => request<PluginItem>(`/plugins/${id}`, { method: 'PUT', json: payload }),
|
||||
updateManifest: (id: number, payload: PluginManifestUpdate) => request<PluginItem>(`/plugins/${id}/metadata`, { method: 'POST', json: payload }),
|
||||
|
||||
/**
|
||||
* 获取单个插件详情
|
||||
*/
|
||||
get: (key: string) => request<PluginItem>(`/plugins/${key}`),
|
||||
|
||||
/**
|
||||
* 安装插件(上传 .foxpkg)
|
||||
*/
|
||||
install: async (file: File): Promise<PluginInstallResult> => {
|
||||
const formData = new FormData();
|
||||
formData.append('file', file);
|
||||
return request<PluginInstallResult>(`/plugins/install`, {
|
||||
method: 'POST',
|
||||
formData,
|
||||
});
|
||||
},
|
||||
|
||||
/**
|
||||
* 删除/卸载插件
|
||||
*/
|
||||
remove: (key: string) => request(`/plugins/${key}`, { method: 'DELETE' }),
|
||||
|
||||
/**
|
||||
* 获取插件 bundle URL
|
||||
*/
|
||||
getBundleUrl: (key: string) => `/api/plugins/${key}/bundle.js`,
|
||||
|
||||
/**
|
||||
* 获取插件资源 URL
|
||||
*/
|
||||
getAssetUrl: (key: string, assetPath: string) =>
|
||||
`/api/plugins/${key}/assets/${assetPath}`,
|
||||
};
|
||||
|
||||
@@ -1,35 +0,0 @@
|
||||
import request from './client';
|
||||
|
||||
export type VideoLibraryMediaType = 'tv' | 'movie';
|
||||
|
||||
export interface VideoLibraryItem {
|
||||
id: string;
|
||||
type: VideoLibraryMediaType;
|
||||
title: string;
|
||||
year?: string | null;
|
||||
overview?: string | null;
|
||||
poster_path?: string | null;
|
||||
backdrop_path?: string | null;
|
||||
genres?: string[];
|
||||
tmdb_id?: number | null;
|
||||
source_path?: string | null;
|
||||
scraped_at?: string | null;
|
||||
updated_at?: string | null;
|
||||
episodes_count?: number;
|
||||
seasons_count?: number;
|
||||
vote_average?: number | null;
|
||||
vote_count?: number | null;
|
||||
}
|
||||
|
||||
export const videoLibraryApi = {
|
||||
list: (params?: { q?: string; type?: VideoLibraryMediaType }) => {
|
||||
const search = new URLSearchParams();
|
||||
if (params?.q) search.set('q', params.q);
|
||||
if (params?.type) search.set('type', params.type);
|
||||
const suffix = search.toString();
|
||||
return request<VideoLibraryItem[]>(`/plugins/video-player/library${suffix ? `?${suffix}` : ''}`, { method: 'GET' });
|
||||
},
|
||||
get: (id: string) =>
|
||||
request<any>(`/plugins/video-player/library/${encodeURIComponent(id)}`, { method: 'GET' }),
|
||||
};
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user