mirror of
https://github.com/DrizzleTime/Foxel.git
synced 2026-05-08 09:13:23 +08:00
Compare commits
50 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
984b7a74ae | ||
|
|
97a3c58f0f | ||
|
|
451e8555d5 | ||
|
|
f444ec46cc | ||
|
|
103beb7dad | ||
|
|
c5e4b3ef43 | ||
|
|
4014a4dd74 | ||
|
|
d0c6e1882f | ||
|
|
434715fc8b | ||
|
|
a127987a3f | ||
|
|
edf95e897d | ||
|
|
b72f8152b6 | ||
|
|
aacddb1208 | ||
|
|
1d6d793f7a | ||
|
|
d9d2ddf2d1 | ||
|
|
e6ab01ef9d | ||
|
|
4a2e01196d | ||
|
|
f22ca62902 | ||
|
|
a394ffa46b | ||
|
|
d003e53a3a | ||
|
|
060a427fe4 | ||
|
|
f4c18f991f | ||
|
|
58c2cdd440 | ||
|
|
7d861ca5f7 | ||
|
|
52bac11760 | ||
|
|
c441d8776f | ||
|
|
45e0194465 | ||
|
|
540065f195 | ||
|
|
4f86e2da4d | ||
|
|
31d347d24f | ||
|
|
7a9a20509c | ||
|
|
373b6410c2 | ||
|
|
d6eb6e1605 | ||
|
|
1d66fb56c8 | ||
|
|
bb9589fa62 | ||
|
|
ab89451b2d | ||
|
|
3e1b75d81a | ||
|
|
1679b03d3a | ||
|
|
ab6562fc79 | ||
|
|
87770176b6 | ||
|
|
e7cf8dbdb8 | ||
|
|
e7eafdee97 | ||
|
|
051b49d3f6 | ||
|
|
b059b0eb44 | ||
|
|
59ad2cb622 | ||
|
|
6b2ada0b42 | ||
|
|
a727e77341 | ||
|
|
4638356a45 | ||
|
|
e51344b43e | ||
|
|
b7685db0e8 |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -5,7 +5,6 @@ __pycache__/
|
||||
.venv/
|
||||
.vscode/
|
||||
data/
|
||||
migrate/
|
||||
.env
|
||||
AGENTS.md
|
||||
|
||||
|
||||
@@ -11,10 +11,14 @@ from domain.processors import api as processors
|
||||
from domain.share import api as share
|
||||
from domain.tasks import api as tasks
|
||||
from domain.ai import api as ai
|
||||
from domain.agent import api as agent
|
||||
from domain.virtual_fs import api as virtual_fs
|
||||
from domain.virtual_fs.mapping import s3_api, webdav_api
|
||||
from domain.virtual_fs.search import search_api
|
||||
from domain.audit import router as audit
|
||||
from domain.audit import api as audit
|
||||
from domain.permission import api as permission
|
||||
from domain.user import api as user
|
||||
from domain.role import api as role
|
||||
|
||||
|
||||
def include_routers(app: FastAPI):
|
||||
@@ -30,9 +34,13 @@ def include_routers(app: FastAPI):
|
||||
app.include_router(backup.router)
|
||||
app.include_router(ai.router_vector_db)
|
||||
app.include_router(ai.router_ai)
|
||||
app.include_router(agent.router)
|
||||
app.include_router(plugins.router)
|
||||
app.include_router(webdav_api.router)
|
||||
app.include_router(s3_api.router)
|
||||
app.include_router(offline_downloads.router)
|
||||
app.include_router(email.router)
|
||||
app.include_router(audit)
|
||||
app.include_router(audit.router)
|
||||
app.include_router(permission.router)
|
||||
app.include_router(user.router)
|
||||
app.include_router(role.router)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from tortoise import Tortoise
|
||||
|
||||
from domain.adapters.registry import runtime_registry
|
||||
from domain.adapters import runtime_registry
|
||||
|
||||
TORTOISE_ORM = {
|
||||
"connections": {"default": "sqlite://data/db/db.sqlite3"},
|
||||
|
||||
7
domain/__init__.py
Normal file
7
domain/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
domain:业务域层
|
||||
|
||||
约定:跨包只从各子包 `__init__.py` 导入公开 API。
|
||||
"""
|
||||
|
||||
__all__: list[str] = []
|
||||
@@ -1 +1,24 @@
|
||||
from .providers import BaseAdapter
|
||||
from .registry import (
|
||||
RuntimeRegistry,
|
||||
discover_adapters,
|
||||
get_config_schema,
|
||||
get_config_schemas,
|
||||
normalize_adapter_type,
|
||||
runtime_registry,
|
||||
)
|
||||
from .service import AdapterService
|
||||
from .types import AdapterCreate, AdapterOut
|
||||
|
||||
__all__ = [
|
||||
"BaseAdapter",
|
||||
"RuntimeRegistry",
|
||||
"discover_adapters",
|
||||
"get_config_schema",
|
||||
"get_config_schemas",
|
||||
"normalize_adapter_type",
|
||||
"runtime_registry",
|
||||
"AdapterService",
|
||||
"AdapterCreate",
|
||||
"AdapterOut",
|
||||
]
|
||||
|
||||
@@ -4,10 +4,11 @@ from fastapi import APIRouter, Depends, Request
|
||||
|
||||
from api.response import success
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.adapters.service import AdapterService
|
||||
from domain.adapters.types import AdapterCreate
|
||||
from domain.auth.service import get_current_active_user
|
||||
from domain.auth.types import User
|
||||
from domain.auth import User, get_current_active_user
|
||||
from domain.permission import require_system_permission
|
||||
from domain.permission.types import AdapterPermission
|
||||
from .service import AdapterService
|
||||
from .types import AdapterCreate
|
||||
|
||||
router = APIRouter(prefix="/api/adapters", tags=["adapters"])
|
||||
|
||||
@@ -18,6 +19,7 @@ router = APIRouter(prefix="/api/adapters", tags=["adapters"])
|
||||
description="创建存储适配器",
|
||||
body_fields=["name", "type", "path", "sub_path", "enabled"],
|
||||
)
|
||||
@require_system_permission(AdapterPermission.CREATE)
|
||||
async def create_adapter(
|
||||
request: Request,
|
||||
data: AdapterCreate,
|
||||
@@ -29,6 +31,7 @@ async def create_adapter(
|
||||
|
||||
@router.get("")
|
||||
@audit(action=AuditAction.READ, description="获取适配器列表")
|
||||
@require_system_permission(AdapterPermission.LIST)
|
||||
async def list_adapters(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
@@ -39,6 +42,7 @@ async def list_adapters(
|
||||
|
||||
@router.get("/available")
|
||||
@audit(action=AuditAction.READ, description="获取可用适配器类型")
|
||||
@require_system_permission(AdapterPermission.LIST)
|
||||
async def available_adapter_types(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
@@ -49,6 +53,7 @@ async def available_adapter_types(
|
||||
|
||||
@router.get("/{adapter_id}")
|
||||
@audit(action=AuditAction.READ, description="获取适配器详情")
|
||||
@require_system_permission(AdapterPermission.LIST)
|
||||
async def get_adapter(
|
||||
request: Request,
|
||||
adapter_id: int,
|
||||
@@ -64,6 +69,7 @@ async def get_adapter(
|
||||
description="更新存储适配器",
|
||||
body_fields=["name", "type", "path", "sub_path", "enabled"],
|
||||
)
|
||||
@require_system_permission(AdapterPermission.EDIT)
|
||||
async def update_adapter(
|
||||
request: Request,
|
||||
adapter_id: int,
|
||||
@@ -76,6 +82,7 @@ async def update_adapter(
|
||||
|
||||
@router.delete("/{adapter_id}")
|
||||
@audit(action=AuditAction.DELETE, description="删除存储适配器")
|
||||
@require_system_permission(AdapterPermission.DELETE)
|
||||
async def delete_adapter(
|
||||
request: Request,
|
||||
adapter_id: int,
|
||||
|
||||
@@ -81,8 +81,9 @@ class AListApiAdapterBase:
|
||||
raise ValueError(f"{product_name} requires base_url http/https")
|
||||
self.username: str = str(cfg.get("username") or "")
|
||||
self.password: str = str(cfg.get("password") or "")
|
||||
if not self.username or not self.password:
|
||||
raise ValueError(f"{product_name} requires username and password")
|
||||
if (self.username and not self.password) or (self.password and not self.username):
|
||||
raise ValueError(f"{product_name} requires both username and password")
|
||||
self.use_auth: bool = bool(self.username and self.password)
|
||||
|
||||
self.timeout: float = float(cfg.get("timeout", 30))
|
||||
self.root_path: str = _normalize_fs_path(str(cfg.get("root") or "/"))
|
||||
@@ -98,6 +99,8 @@ class AListApiAdapterBase:
|
||||
return base
|
||||
|
||||
async def _ensure_token(self) -> str:
|
||||
if not self.use_auth:
|
||||
return ""
|
||||
if self._token:
|
||||
return self._token
|
||||
async with self._login_lock:
|
||||
@@ -137,12 +140,14 @@ class AListApiAdapterBase:
|
||||
) -> Any:
|
||||
token = await self._ensure_token()
|
||||
url = self.base_url + endpoint
|
||||
req_headers: Dict[str, str] = {"Authorization": token}
|
||||
req_headers: Dict[str, str] = {}
|
||||
if token:
|
||||
req_headers["Authorization"] = token
|
||||
if headers:
|
||||
req_headers.update(headers)
|
||||
async with httpx.AsyncClient(timeout=self.timeout, follow_redirects=True) as client:
|
||||
resp = await client.request(method, url, json=json, headers=req_headers, files=files)
|
||||
if resp.status_code == 401 and retry:
|
||||
if resp.status_code == 401 and retry and self.use_auth:
|
||||
self._token = None
|
||||
return await self._api_json(method, endpoint, json=json, headers=headers, retry=False, files=files)
|
||||
resp.raise_for_status()
|
||||
@@ -153,7 +158,7 @@ class AListApiAdapterBase:
|
||||
code = payload.get("code")
|
||||
if code in (0, 200):
|
||||
return payload.get("data")
|
||||
if code in (401, 403) and retry:
|
||||
if code in (401, 403) and retry and self.use_auth:
|
||||
self._token = None
|
||||
return await self._api_json(method, endpoint, json=json, headers=headers, retry=False, files=files)
|
||||
if code == 404:
|
||||
@@ -349,10 +354,9 @@ class AListApiAdapterBase:
|
||||
|
||||
async def _upload_file(self, full_path: str, file_path: Path) -> Any:
|
||||
token = await self._ensure_token()
|
||||
headers = {
|
||||
"Authorization": token,
|
||||
"File-Path": quote(full_path, safe="/"),
|
||||
}
|
||||
headers = {"File-Path": quote(full_path, safe="/")}
|
||||
if token:
|
||||
headers["Authorization"] = token
|
||||
with file_path.open("rb") as f:
|
||||
files = {"file": (file_path.name, f, "application/octet-stream")}
|
||||
async with httpx.AsyncClient(timeout=self.timeout, follow_redirects=True) as client:
|
||||
@@ -381,6 +385,30 @@ class AListApiAdapterBase:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def write_upload_file(self, root: str, rel: str, file_obj, filename: str | None, file_size: int | None = None, content_type: str | None = None):
|
||||
full_path = _join_fs_path(root, rel)
|
||||
token = await self._ensure_token()
|
||||
headers = {"File-Path": quote(full_path, safe="/")}
|
||||
if token:
|
||||
headers["Authorization"] = token
|
||||
name = filename or Path(rel).name or "file"
|
||||
mime = content_type or "application/octet-stream"
|
||||
files = {"file": (name, file_obj, mime)}
|
||||
async with httpx.AsyncClient(timeout=self.timeout, follow_redirects=True) as client:
|
||||
resp = await client.put(self.base_url + "/api/fs/form", headers=headers, files=files)
|
||||
resp.raise_for_status()
|
||||
payload = resp.json()
|
||||
if not isinstance(payload, dict):
|
||||
raise HTTPException(502, detail=f"{self.product_name} upload: invalid response")
|
||||
code = payload.get("code")
|
||||
if code not in (0, 200):
|
||||
msg = payload.get("message") or payload.get("msg") or ""
|
||||
raise HTTPException(502, detail=f"{self.product_name} upload failed: {msg}")
|
||||
data = payload.get("data")
|
||||
if isinstance(data, dict) and file_size is not None and "size" not in data:
|
||||
data["size"] = file_size
|
||||
return data
|
||||
|
||||
async def write_file_stream(self, root: str, rel: str, data_iter: AsyncIterator[bytes]):
|
||||
full_path = _join_fs_path(root, rel)
|
||||
suffix = Path(rel).suffix
|
||||
@@ -479,8 +507,8 @@ ADAPTER_TYPES = {"alist": AListAdapter, "openlist": OpenListAdapter}
|
||||
|
||||
CONFIG_SCHEMA = [
|
||||
{"key": "base_url", "label": "基础地址", "type": "string", "required": True, "placeholder": "http://127.0.0.1:5244"},
|
||||
{"key": "username", "label": "用户名", "type": "string", "required": True},
|
||||
{"key": "password", "label": "密码", "type": "password", "required": True},
|
||||
{"key": "username", "label": "用户名", "type": "string", "required": False, "placeholder": "留空则匿名访问"},
|
||||
{"key": "password", "label": "密码", "type": "password", "required": False, "placeholder": "留空则匿名访问"},
|
||||
{"key": "root", "label": "根目录", "type": "string", "required": False, "default": "/"},
|
||||
{"key": "timeout", "label": "超时(秒)", "type": "number", "required": False, "default": 30},
|
||||
{"key": "enable_direct_download_307", "label": "启用 307 直链下载", "type": "boolean", "default": False},
|
||||
|
||||
@@ -250,6 +250,30 @@ class FoxelAdapter:
|
||||
return True
|
||||
raise HTTPException(502, detail="Foxel 写入失败")
|
||||
|
||||
async def write_upload_file(self, root: str, rel: str, file_obj, filename: str | None, file_size: int | None = None, content_type: str | None = None):
|
||||
rel = (rel or "").lstrip("/")
|
||||
full_path = _join_fs_path(root, rel)
|
||||
url = self.base_url + self._file_path(full_path)
|
||||
name = filename or Path(rel).name or "file"
|
||||
mime = content_type or "application/octet-stream"
|
||||
for attempt in range(2):
|
||||
try:
|
||||
if callable(getattr(file_obj, "seek", None)):
|
||||
file_obj.seek(0)
|
||||
except Exception:
|
||||
pass
|
||||
token = await self._ensure_token()
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
files = {"file": (name, file_obj, mime)}
|
||||
async with httpx.AsyncClient(timeout=self.timeout, follow_redirects=True) as client:
|
||||
resp = await client.post(url, headers=headers, files=files)
|
||||
if resp.status_code == 401 and attempt == 0:
|
||||
self._token = None
|
||||
continue
|
||||
resp.raise_for_status()
|
||||
return {"size": file_size or 0}
|
||||
raise HTTPException(502, detail="Foxel 上传失败")
|
||||
|
||||
async def write_file_stream(self, root: str, rel: str, data_iter: AsyncIterator[bytes]):
|
||||
rel = (rel or "").lstrip("/")
|
||||
full_path = _join_fs_path(root, rel)
|
||||
|
||||
@@ -238,6 +238,39 @@ class FTPAdapter:
|
||||
|
||||
await asyncio.to_thread(_do_write)
|
||||
|
||||
async def write_upload_file(self, root: str, rel: str, file_obj, filename: str | None, file_size: int | None = None, content_type: str | None = None):
|
||||
path = _join_remote(root, rel)
|
||||
|
||||
def _ensure_dirs(ftp: FTP, dir_path: str):
|
||||
parts = [p for p in dir_path.strip("/").split("/") if p]
|
||||
cur = "/"
|
||||
for p in parts:
|
||||
cur = _join_remote(cur, p)
|
||||
try:
|
||||
ftp.mkd(cur)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _do_upload():
|
||||
ftp = self._connect()
|
||||
try:
|
||||
parent = "/" if "/" not in path.strip("/") else path.rsplit("/", 1)[0]
|
||||
_ensure_dirs(ftp, parent)
|
||||
try:
|
||||
if callable(getattr(file_obj, "seek", None)):
|
||||
file_obj.seek(0)
|
||||
except Exception:
|
||||
pass
|
||||
ftp.storbinary("STOR " + path, file_obj)
|
||||
finally:
|
||||
try:
|
||||
ftp.quit()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await asyncio.to_thread(_do_upload)
|
||||
return {"size": file_size or 0}
|
||||
|
||||
async def write_file_stream(self, root: str, rel: str, data_iter: AsyncIterator[bytes]):
|
||||
# KISS: 聚合后一次性写入
|
||||
buf = bytearray()
|
||||
|
||||
@@ -114,6 +114,32 @@ class LocalAdapter:
|
||||
if not pre_exists:
|
||||
await asyncio.to_thread(_apply_mode, fp, DEFAULT_FILE_MODE)
|
||||
|
||||
async def write_upload_file(self, root: str, rel: str, file_obj, filename: str | None, file_size: int | None = None, content_type: str | None = None):
|
||||
fp = _safe_join(root, rel)
|
||||
pre_exists = fp.exists()
|
||||
await asyncio.to_thread(os.makedirs, fp.parent, mode=DEFAULT_DIR_MODE, exist_ok=True)
|
||||
|
||||
def _copy():
|
||||
try:
|
||||
if callable(getattr(file_obj, "seek", None)):
|
||||
file_obj.seek(0)
|
||||
except Exception:
|
||||
pass
|
||||
with open(fp, "wb") as f:
|
||||
shutil.copyfileobj(file_obj, f)
|
||||
|
||||
await asyncio.to_thread(_copy)
|
||||
if not pre_exists:
|
||||
await asyncio.to_thread(_apply_mode, fp, DEFAULT_FILE_MODE)
|
||||
|
||||
size = file_size
|
||||
if size is None:
|
||||
try:
|
||||
size = fp.stat().st_size
|
||||
except Exception:
|
||||
size = 0
|
||||
return {"size": int(size or 0)}
|
||||
|
||||
async def write_file_stream(self, root: str, rel: str, data_iter: AsyncIterator[bytes]):
|
||||
fp = _safe_join(root, rel)
|
||||
pre_exists = fp.exists()
|
||||
|
||||
@@ -453,6 +453,159 @@ class QuarkAdapter:
|
||||
yield data
|
||||
return await self.write_file_stream(root, rel, gen())
|
||||
|
||||
async def write_upload_file(self, root: str, rel: str, file_obj, filename: str | None, file_size: int | None = None, content_type: str | None = None):
|
||||
if not rel or rel.endswith("/"):
|
||||
raise HTTPException(400, detail="Invalid file path")
|
||||
|
||||
parent = rel.rsplit("/", 1)[0] if "/" in rel else ""
|
||||
name = filename or rel.rsplit("/", 1)[-1]
|
||||
base_fid = root or self.root_fid
|
||||
parent_fid = await self._resolve_dir_fid_from(base_fid, parent)
|
||||
|
||||
md5 = hashlib.md5()
|
||||
sha1 = hashlib.sha1()
|
||||
total = 0
|
||||
try:
|
||||
if callable(getattr(file_obj, "seek", None)):
|
||||
file_obj.seek(0)
|
||||
except Exception:
|
||||
pass
|
||||
while True:
|
||||
chunk = file_obj.read(1024 * 1024)
|
||||
if not chunk:
|
||||
break
|
||||
total += len(chunk)
|
||||
md5.update(chunk)
|
||||
sha1.update(chunk)
|
||||
|
||||
md5_hex = md5.hexdigest()
|
||||
sha1_hex = sha1.hexdigest()
|
||||
|
||||
# 预上传,拿到上传信息
|
||||
pre_resp = await self._upload_pre(name, total, parent_fid)
|
||||
pre_data = pre_resp.get("data", {})
|
||||
|
||||
# hash 秒传
|
||||
hash_body = {"md5": md5_hex, "sha1": sha1_hex, "task_id": pre_data.get("task_id")}
|
||||
hash_resp = await self._request("POST", "/file/update/hash", json=hash_body)
|
||||
if (hash_resp.get("data") or {}).get("finish") is True:
|
||||
self._invalidate_children_cache(parent_fid)
|
||||
return {"size": total}
|
||||
|
||||
# 分片上传
|
||||
part_size = int((pre_resp.get("metadata") or {}).get("part_size") or 0)
|
||||
if part_size <= 0:
|
||||
raise HTTPException(502, detail="Invalid part_size from Quark")
|
||||
|
||||
bucket = pre_data.get("bucket")
|
||||
obj_key = pre_data.get("obj_key")
|
||||
upload_id = pre_data.get("upload_id")
|
||||
upload_url = pre_data.get("upload_url")
|
||||
if not (bucket and obj_key and upload_id and upload_url):
|
||||
raise HTTPException(502, detail="Upload pre missing fields")
|
||||
|
||||
try:
|
||||
upload_host = upload_url.split("://", 1)[1]
|
||||
except Exception:
|
||||
upload_host = upload_url
|
||||
base_url = f"https://{bucket}.{upload_host}/{obj_key}"
|
||||
|
||||
try:
|
||||
if callable(getattr(file_obj, "seek", None)):
|
||||
file_obj.seek(0)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
etags: List[str] = []
|
||||
oss_ua = "aliyun-sdk-js/6.6.1 Chrome 98.0.4758.80 on Windows 10 64-bit"
|
||||
async with httpx.AsyncClient(timeout=None, follow_redirects=True) as client:
|
||||
part_number = 1
|
||||
left = total
|
||||
while left > 0:
|
||||
sz = min(part_size, left)
|
||||
data_bytes = file_obj.read(sz)
|
||||
if len(data_bytes) != sz:
|
||||
raise IOError("Failed to read part bytes")
|
||||
now_str = time.strftime("%a, %d %b %Y %H:%M:%S GMT", time.gmtime())
|
||||
auth_meta = (
|
||||
"PUT\n\n"
|
||||
f"{self._guess_mime(name)}\n"
|
||||
f"{now_str}\n"
|
||||
f"x-oss-date:{now_str}\n"
|
||||
f"x-oss-user-agent:{oss_ua}\n"
|
||||
f"/{bucket}/{obj_key}?partNumber={part_number}&uploadId={upload_id}"
|
||||
)
|
||||
auth_req_body = {"auth_info": pre_data.get("auth_info"), "auth_meta": auth_meta, "task_id": pre_data.get("task_id")}
|
||||
auth_resp = await self._request("POST", "/file/upload/auth", json=auth_req_body)
|
||||
auth_key = (auth_resp.get("data") or {}).get("auth_key")
|
||||
if not auth_key:
|
||||
raise HTTPException(502, detail="upload/auth missing auth_key")
|
||||
|
||||
put_headers = {
|
||||
"Authorization": auth_key,
|
||||
"Content-Type": self._guess_mime(name),
|
||||
"Referer": REFERER + "/",
|
||||
"x-oss-date": now_str,
|
||||
"x-oss-user-agent": oss_ua,
|
||||
}
|
||||
put_url = f"{base_url}?partNumber={part_number}&uploadId={upload_id}"
|
||||
put_resp = await client.put(put_url, headers=put_headers, content=data_bytes)
|
||||
if put_resp.status_code != 200:
|
||||
raise HTTPException(502, detail=f"Upload part failed status={put_resp.status_code} text={put_resp.text}")
|
||||
etag = put_resp.headers.get("Etag", "")
|
||||
etags.append(etag)
|
||||
left -= sz
|
||||
part_number += 1
|
||||
|
||||
parts_xml = [f"<Part>\n<PartNumber>{i+1}</PartNumber>\n<ETag>{etags[i]}</ETag>\n</Part>\n" for i in range(len(etags))]
|
||||
body_xml = "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<CompleteMultipartUpload>\n" + "".join(parts_xml) + "</CompleteMultipartUpload>"
|
||||
content_md5 = base64.b64encode(hashlib.md5(body_xml.encode("utf-8")).digest()).decode("ascii")
|
||||
callback = pre_data.get("callback") or {}
|
||||
try:
|
||||
import json as _json
|
||||
callback_b64 = base64.b64encode(_json.dumps(callback).encode("utf-8")).decode("ascii")
|
||||
except Exception:
|
||||
callback_b64 = ""
|
||||
|
||||
now_str = time.strftime("%a, %d %b %Y %H:%M:%S GMT", time.gmtime())
|
||||
auth_meta_commit = (
|
||||
"POST\n"
|
||||
f"{content_md5}\n"
|
||||
"application/xml\n"
|
||||
f"{now_str}\n"
|
||||
f"x-oss-callback:{callback_b64}\n"
|
||||
f"x-oss-date:{now_str}\n"
|
||||
f"x-oss-user-agent:{oss_ua}\n"
|
||||
f"/{bucket}/{obj_key}?uploadId={upload_id}"
|
||||
)
|
||||
auth_commit_resp = await self._request("POST", "/file/upload/auth", json={"auth_info": pre_data.get("auth_info"), "auth_meta": auth_meta_commit, "task_id": pre_data.get("task_id")})
|
||||
auth_key_commit = (auth_commit_resp.get("data") or {}).get("auth_key")
|
||||
if not auth_key_commit:
|
||||
raise HTTPException(502, detail="upload/auth(commit) missing auth_key")
|
||||
|
||||
async with httpx.AsyncClient(timeout=None, follow_redirects=True) as client:
|
||||
commit_headers = {
|
||||
"Authorization": auth_key_commit,
|
||||
"Content-MD5": content_md5,
|
||||
"Content-Type": "application/xml",
|
||||
"Referer": REFERER + "/",
|
||||
"x-oss-callback": callback_b64,
|
||||
"x-oss-date": now_str,
|
||||
"x-oss-user-agent": oss_ua,
|
||||
}
|
||||
commit_url = f"{base_url}?uploadId={upload_id}"
|
||||
r = await client.post(commit_url, headers=commit_headers, content=body_xml.encode("utf-8"))
|
||||
if r.status_code != 200:
|
||||
raise HTTPException(502, detail=f"Upload commit failed status={r.status_code} text={r.text}")
|
||||
|
||||
await self._request("POST", "/file/upload/finish", json={"obj_key": obj_key, "task_id": pre_data.get("task_id")})
|
||||
try:
|
||||
await asyncio.sleep(1.0)
|
||||
except Exception:
|
||||
pass
|
||||
self._invalidate_children_cache(parent_fid)
|
||||
return {"size": total}
|
||||
|
||||
async def write_file_stream(self, root: str, rel: str, data_iter: AsyncIterator[bytes]):
|
||||
if not rel or rel.endswith("/"):
|
||||
raise HTTPException(400, detail="Invalid file path")
|
||||
|
||||
@@ -157,6 +157,41 @@ class SFTPAdapter:
|
||||
|
||||
await asyncio.to_thread(_do_write)
|
||||
|
||||
async def write_upload_file(self, root: str, rel: str, file_obj, filename: str | None, file_size: int | None = None, content_type: str | None = None):
|
||||
path = _join_remote(root, rel)
|
||||
|
||||
def _ensure_dirs(sftp: paramiko.SFTPClient, dir_path: str):
|
||||
parts = [p for p in dir_path.strip("/").split("/") if p]
|
||||
cur = "/"
|
||||
for p in parts:
|
||||
cur = _join_remote(cur, p)
|
||||
try:
|
||||
sftp.mkdir(cur)
|
||||
except IOError:
|
||||
pass
|
||||
|
||||
def _do_upload():
|
||||
sftp = self._connect()
|
||||
try:
|
||||
parent = "/" if "/" not in path.strip("/") else path.rsplit("/", 1)[0]
|
||||
_ensure_dirs(sftp, parent)
|
||||
try:
|
||||
if callable(getattr(file_obj, "seek", None)):
|
||||
file_obj.seek(0)
|
||||
except Exception:
|
||||
pass
|
||||
with sftp.open(path, "wb") as f:
|
||||
import shutil
|
||||
shutil.copyfileobj(file_obj, f)
|
||||
finally:
|
||||
try:
|
||||
sftp.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await asyncio.to_thread(_do_upload)
|
||||
return {"size": file_size or 0}
|
||||
|
||||
async def write_file_stream(self, root: str, rel: str, data_iter: AsyncIterator[bytes]):
|
||||
buf = bytearray()
|
||||
async for chunk in data_iter:
|
||||
|
||||
@@ -1,11 +1,50 @@
|
||||
from typing import List, Dict, Tuple, AsyncIterator
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import os
|
||||
import struct
|
||||
from models import StorageAdapter
|
||||
from telethon import TelegramClient
|
||||
from telethon.crypto import AuthKey
|
||||
from telethon.sessions import StringSession
|
||||
from telethon.tl import types
|
||||
import socks
|
||||
|
||||
_SESSION_LOCKS: Dict[str, asyncio.Lock] = {}
|
||||
|
||||
|
||||
def _get_session_lock(session_string: str) -> asyncio.Lock:
|
||||
lock = _SESSION_LOCKS.get(session_string)
|
||||
if lock is None:
|
||||
lock = asyncio.Lock()
|
||||
_SESSION_LOCKS[session_string] = lock
|
||||
return lock
|
||||
|
||||
|
||||
class _NamedFile:
|
||||
def __init__(self, file_obj, name: str):
|
||||
self._file = file_obj
|
||||
self.name = name
|
||||
|
||||
def read(self, *args, **kwargs):
|
||||
return self._file.read(*args, **kwargs)
|
||||
|
||||
def seek(self, *args, **kwargs):
|
||||
return self._file.seek(*args, **kwargs)
|
||||
|
||||
def tell(self):
|
||||
return self._file.tell()
|
||||
|
||||
def seekable(self):
|
||||
return self._file.seekable()
|
||||
|
||||
def close(self):
|
||||
return self._file.close()
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._file, name)
|
||||
|
||||
# 适配器类型标识
|
||||
ADAPTER_TYPE = "telegram"
|
||||
|
||||
@@ -54,9 +93,93 @@ class TelegramAdapter:
|
||||
if not all([self.api_id, self.api_hash, self.session_string, self.chat_id]):
|
||||
raise ValueError("Telegram 适配器需要 api_id, api_hash, session_string 和 chat_id")
|
||||
|
||||
@staticmethod
|
||||
def _parse_legacy_session_string(value: str) -> StringSession:
|
||||
"""
|
||||
兼容旧版 session_string 格式:
|
||||
- version(1B char) + base64(data)
|
||||
- data: dc_id(1B) + ip_len(2B) + ip(ASCII, ip_len bytes) + port(2B) + auth_key(256B)
|
||||
"""
|
||||
s = (value or "").strip()
|
||||
if not s:
|
||||
raise ValueError("session_string 为空")
|
||||
|
||||
body = s[1:] if s.startswith("1") else s
|
||||
raw = base64.urlsafe_b64decode(body)
|
||||
if len(raw) < 1 + 2 + 2 + 256:
|
||||
raise ValueError("legacy session 数据长度不足")
|
||||
|
||||
dc_id = raw[0]
|
||||
ip_len = struct.unpack(">H", raw[1:3])[0]
|
||||
expected_len = 1 + 2 + ip_len + 2 + 256
|
||||
if len(raw) != expected_len:
|
||||
raise ValueError("legacy session 数据长度不匹配")
|
||||
|
||||
ip_start = 3
|
||||
ip_end = ip_start + ip_len
|
||||
ip = raw[ip_start:ip_end].decode("utf-8")
|
||||
port = struct.unpack(">H", raw[ip_end : ip_end + 2])[0]
|
||||
key = raw[ip_end + 2 : ip_end + 2 + 256]
|
||||
|
||||
sess = StringSession()
|
||||
sess.set_dc(dc_id, ip, port)
|
||||
sess.auth_key = AuthKey(key)
|
||||
return sess
|
||||
|
||||
@staticmethod
|
||||
def _pick_photo_thumb(thumbs: list | None):
|
||||
if not thumbs:
|
||||
return None
|
||||
|
||||
cached = []
|
||||
others = []
|
||||
for t in thumbs:
|
||||
if isinstance(t, (types.PhotoCachedSize, types.PhotoStrippedSize)):
|
||||
cached.append(t)
|
||||
elif isinstance(t, (types.PhotoSize, types.PhotoSizeProgressive)):
|
||||
if not isinstance(t, types.PhotoSizeEmpty):
|
||||
others.append(t)
|
||||
|
||||
if cached:
|
||||
cached.sort(key=lambda x: len(getattr(x, "bytes", b"") or b""))
|
||||
return cached[-1]
|
||||
|
||||
if others:
|
||||
def _sz(x):
|
||||
if isinstance(x, types.PhotoSizeProgressive):
|
||||
return max(x.sizes or [0])
|
||||
return int(getattr(x, "size", 0) or 0)
|
||||
|
||||
others.sort(key=_sz)
|
||||
return others[-1]
|
||||
|
||||
return None
|
||||
|
||||
def _build_session(self) -> StringSession:
|
||||
s = (self.session_string or "").strip()
|
||||
if not s:
|
||||
raise ValueError("Telegram 适配器 session_string 为空")
|
||||
|
||||
try:
|
||||
return StringSession(s)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 少数工具可能去掉了 version 前缀,这里做一次兼容
|
||||
if not s.startswith("1"):
|
||||
try:
|
||||
return StringSession("1" + s)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
return self._parse_legacy_session_string(s)
|
||||
except Exception as exc:
|
||||
raise ValueError("Telegram session_string 无效,请使用 Telethon StringSession 重新生成") from exc
|
||||
|
||||
def _get_client(self) -> TelegramClient:
|
||||
"""创建一个新的 TelegramClient 实例"""
|
||||
return TelegramClient(StringSession(self.session_string), self.api_id, self.api_hash, proxy=self.proxy)
|
||||
return TelegramClient(self._build_session(), self.api_id, self.api_hash, proxy=self.proxy)
|
||||
|
||||
def get_effective_root(self, sub_path: str | None) -> str:
|
||||
return ""
|
||||
@@ -164,7 +287,48 @@ class TelegramAdapter:
|
||||
|
||||
try:
|
||||
await client.connect()
|
||||
await client.send_file(self.chat_id, file_like, caption=file_like.name)
|
||||
sent = await client.send_file(self.chat_id, file_like, caption=file_like.name)
|
||||
message = sent[0] if isinstance(sent, list) and sent else sent
|
||||
actual_rel = rel
|
||||
if message:
|
||||
stored_name = file_like.name
|
||||
file_meta = getattr(message, "file", None)
|
||||
if file_meta and getattr(file_meta, "name", None):
|
||||
stored_name = file_meta.name
|
||||
if getattr(message, "id", None) is not None:
|
||||
actual_rel = f"{message.id}_{stored_name}"
|
||||
return {"rel": actual_rel, "size": len(data)}
|
||||
finally:
|
||||
if client.is_connected():
|
||||
await client.disconnect()
|
||||
|
||||
async def write_upload_file(self, root: str, rel: str, file_obj, filename: str | None, file_size: int | None = None, content_type: str | None = None):
|
||||
client = self._get_client()
|
||||
name = filename or os.path.basename(rel) or "file"
|
||||
file_like = _NamedFile(file_obj, name)
|
||||
|
||||
try:
|
||||
await client.connect()
|
||||
sent = await client.send_file(
|
||||
self.chat_id,
|
||||
file_like,
|
||||
caption=file_like.name,
|
||||
file_size=file_size,
|
||||
mime_type=content_type,
|
||||
)
|
||||
message = sent[0] if isinstance(sent, list) and sent else sent
|
||||
actual_rel = rel
|
||||
size = file_size or 0
|
||||
if message:
|
||||
stored_name = file_like.name
|
||||
file_meta = getattr(message, "file", None)
|
||||
if file_meta and getattr(file_meta, "name", None):
|
||||
stored_name = file_meta.name
|
||||
if getattr(message, "id", None) is not None:
|
||||
actual_rel = f"{message.id}_{stored_name}"
|
||||
if file_meta and getattr(file_meta, "size", None):
|
||||
size = int(file_meta.size)
|
||||
return {"rel": actual_rel, "size": size}
|
||||
finally:
|
||||
if client.is_connected():
|
||||
await client.disconnect()
|
||||
@@ -174,8 +338,9 @@ class TelegramAdapter:
|
||||
client = self._get_client()
|
||||
filename = os.path.basename(rel) or "file"
|
||||
import tempfile
|
||||
temp_dir = tempfile.gettempdir()
|
||||
temp_path = os.path.join(temp_dir, filename)
|
||||
suffix = os.path.splitext(filename)[1]
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tf:
|
||||
temp_path = tf.name
|
||||
|
||||
total_size = 0
|
||||
try:
|
||||
@@ -186,18 +351,62 @@ class TelegramAdapter:
|
||||
total_size += len(chunk)
|
||||
|
||||
await client.connect()
|
||||
await client.send_file(self.chat_id, temp_path, caption=filename)
|
||||
sent = await client.send_file(self.chat_id, temp_path, caption=filename)
|
||||
message = sent[0] if isinstance(sent, list) and sent else sent
|
||||
actual_rel = rel
|
||||
if message:
|
||||
stored_name = filename
|
||||
file_meta = getattr(message, "file", None)
|
||||
if file_meta and getattr(file_meta, "name", None):
|
||||
stored_name = file_meta.name
|
||||
if getattr(message, "id", None) is not None:
|
||||
actual_rel = f"{message.id}_{stored_name}"
|
||||
|
||||
finally:
|
||||
if os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
if client.is_connected():
|
||||
await client.disconnect()
|
||||
return total_size
|
||||
return {"rel": actual_rel, "size": total_size}
|
||||
|
||||
async def mkdir(self, root: str, rel: str):
|
||||
raise NotImplementedError("Telegram 适配器不支持创建目录。")
|
||||
|
||||
async def get_thumbnail(self, root: str, rel: str, size: str = "medium"):
|
||||
try:
|
||||
message_id_str, _ = rel.split('_', 1)
|
||||
message_id = int(message_id_str)
|
||||
except (ValueError, IndexError):
|
||||
return None
|
||||
|
||||
client = self._get_client()
|
||||
try:
|
||||
await client.connect()
|
||||
message = await client.get_messages(self.chat_id, ids=message_id)
|
||||
if not message:
|
||||
return None
|
||||
|
||||
doc = message.document or message.video
|
||||
thumbs = None
|
||||
if doc and getattr(doc, "thumbs", None):
|
||||
thumbs = list(doc.thumbs or [])
|
||||
elif message.photo and getattr(message.photo, "sizes", None):
|
||||
thumbs = list(message.photo.sizes or [])
|
||||
|
||||
thumb = self._pick_photo_thumb(thumbs)
|
||||
if not thumb:
|
||||
return None
|
||||
|
||||
result = await client.download_media(message, bytes, thumb=thumb)
|
||||
if isinstance(result, (bytes, bytearray)):
|
||||
return bytes(result)
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
finally:
|
||||
if client.is_connected():
|
||||
await client.disconnect()
|
||||
|
||||
async def delete(self, root: str, rel: str):
|
||||
"""删除一个文件 (即一条消息)"""
|
||||
try:
|
||||
@@ -236,6 +445,8 @@ class TelegramAdapter:
|
||||
raise HTTPException(status_code=400, detail=f"无效的文件路径格式: {rel}")
|
||||
|
||||
client = self._get_client()
|
||||
lock = _get_session_lock(self.session_string)
|
||||
await lock.acquire()
|
||||
|
||||
try:
|
||||
await client.connect()
|
||||
@@ -273,7 +484,6 @@ class TelegramAdapter:
|
||||
headers = {
|
||||
"Accept-Ranges": "bytes",
|
||||
"Content-Type": mime_type,
|
||||
"Content-Length": str(file_size),
|
||||
}
|
||||
|
||||
if range_header:
|
||||
@@ -285,7 +495,6 @@ class TelegramAdapter:
|
||||
if start >= file_size or end >= file_size or start > end:
|
||||
raise HTTPException(status_code=416, detail="Requested Range Not Satisfiable")
|
||||
status = 206
|
||||
headers["Content-Length"] = str(end - start + 1)
|
||||
headers["Content-Range"] = f"bytes {start}-{end}/{file_size}"
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail="Invalid Range header")
|
||||
@@ -304,18 +513,28 @@ class TelegramAdapter:
|
||||
if downloaded >= limit:
|
||||
break
|
||||
finally:
|
||||
if client.is_connected():
|
||||
await client.disconnect()
|
||||
try:
|
||||
if client.is_connected():
|
||||
await client.disconnect()
|
||||
finally:
|
||||
lock.release()
|
||||
|
||||
return StreamingResponse(iterator(), status_code=status, headers=headers)
|
||||
|
||||
except HTTPException:
|
||||
if client.is_connected():
|
||||
await client.disconnect()
|
||||
lock.release()
|
||||
raise
|
||||
except FileNotFoundError as e:
|
||||
if client.is_connected():
|
||||
await client.disconnect()
|
||||
lock.release()
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
if client.is_connected():
|
||||
await client.disconnect()
|
||||
lock.release()
|
||||
raise HTTPException(status_code=500, detail=f"Streaming failed: {str(e)}")
|
||||
|
||||
async def stat_file(self, root: str, rel: str):
|
||||
|
||||
@@ -4,7 +4,7 @@ from importlib import import_module
|
||||
from typing import Callable, Dict
|
||||
|
||||
from models import StorageAdapter
|
||||
from domain.adapters.providers.base import BaseAdapter
|
||||
from .providers.base import BaseAdapter
|
||||
|
||||
AdapterFactory = Callable[[StorageAdapter], BaseAdapter]
|
||||
|
||||
@@ -21,7 +21,7 @@ def normalize_adapter_type(value: str | None) -> str | None:
|
||||
|
||||
def discover_adapters():
|
||||
"""扫描 domain.adapters.providers 包, 自动注册适配器类型、工厂与配置 schema。"""
|
||||
from domain.adapters import providers as adapters_pkg
|
||||
from . import providers as adapters_pkg
|
||||
|
||||
TYPE_MAP.clear()
|
||||
CONFIG_SCHEMAS.clear()
|
||||
|
||||
@@ -2,13 +2,13 @@ from typing import Optional
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from domain.adapters.registry import (
|
||||
from domain.auth import User
|
||||
from .registry import (
|
||||
get_config_schemas,
|
||||
normalize_adapter_type,
|
||||
runtime_registry,
|
||||
)
|
||||
from domain.adapters.types import AdapterCreate, AdapterOut
|
||||
from domain.auth.types import User
|
||||
from .types import AdapterCreate, AdapterOut
|
||||
from models import StorageAdapter
|
||||
|
||||
|
||||
@@ -36,6 +36,11 @@ class AdapterService:
|
||||
missing.append(k)
|
||||
if missing:
|
||||
raise HTTPException(400, detail="缺少必填配置字段: " + ", ".join(missing))
|
||||
if adapter_type in ("alist", "openlist"):
|
||||
username = out.get("username")
|
||||
password = out.get("password")
|
||||
if (username and not password) or (password and not username):
|
||||
raise HTTPException(400, detail="用户名和密码必须同时填写或同时留空")
|
||||
return out
|
||||
|
||||
@classmethod
|
||||
|
||||
9
domain/agent/__init__.py
Normal file
9
domain/agent/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from .service import AgentService
|
||||
from .types import AgentChatContext, AgentChatRequest, PendingToolCall
|
||||
|
||||
__all__ = [
|
||||
"AgentService",
|
||||
"AgentChatContext",
|
||||
"AgentChatRequest",
|
||||
"PendingToolCall",
|
||||
]
|
||||
38
domain/agent/api.py
Normal file
38
domain/agent/api.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from api.response import success
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.auth import User, get_current_active_user
|
||||
from .service import AgentService
|
||||
from .types import AgentChatRequest
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/agent", tags=["agent"])
|
||||
|
||||
|
||||
@router.post("/chat")
|
||||
@audit(action=AuditAction.CREATE, description="Agent 对话", body_fields=["auto_execute"])
|
||||
async def chat(
|
||||
request: Request,
|
||||
payload: AgentChatRequest,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
data = await AgentService.chat(payload, current_user)
|
||||
return success(data)
|
||||
|
||||
|
||||
@router.post("/chat/stream")
|
||||
@audit(action=AuditAction.CREATE, description="Agent 对话(SSE)", body_fields=["auto_execute"])
|
||||
async def chat_stream(
|
||||
request: Request,
|
||||
payload: AgentChatRequest,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
return StreamingResponse(
|
||||
AgentService.chat_stream(payload, current_user),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache"},
|
||||
)
|
||||
472
domain/agent/service.py
Normal file
472
domain/agent/service.py
Normal file
@@ -0,0 +1,472 @@
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import httpx
|
||||
from fastapi import HTTPException
|
||||
|
||||
from domain.ai import AIProviderService, MissingModelError, chat_completion, chat_completion_stream
|
||||
from domain.auth import User
|
||||
from .tools import get_tool, openai_tools, tool_result_to_content
|
||||
from .types import AgentChatRequest, PendingToolCall
|
||||
|
||||
|
||||
def _normalize_path(p: Optional[str]) -> Optional[str]:
|
||||
if not p:
|
||||
return None
|
||||
s = str(p).strip()
|
||||
if not s:
|
||||
return None
|
||||
s = s.replace("\\", "/")
|
||||
if not s.startswith("/"):
|
||||
s = "/" + s
|
||||
s = s.rstrip("/") or "/"
|
||||
return s
|
||||
|
||||
|
||||
def _build_system_prompt(current_path: Optional[str]) -> str:
|
||||
lines = [
|
||||
"你是 Foxel 的 AI 助手。",
|
||||
"你可以通过工具对文件/目录进行查询、读写、移动、复制、删除,以及运行处理器(processor)。",
|
||||
"",
|
||||
"可用工具:",
|
||||
"- time:获取服务器当前时间(精确到秒,英文星期),支持 year/month/day/hour/minute/second 偏移。",
|
||||
"- web_fetch:抓取网页(HTTP 请求),支持 GET/POST/PUT/PATCH/DELETE/HEAD/OPTIONS,返回状态/标题/正文/链接等。",
|
||||
"- vfs_list_dir:浏览目录(列出 entries + pagination)。",
|
||||
"- vfs_stat:查看文件/目录信息。",
|
||||
"- vfs_read_text:读取文本文件内容(不支持二进制)。",
|
||||
"- vfs_search:搜索文件(vector/filename)。",
|
||||
"- vfs_write_text:写入文本文件内容(覆盖)。",
|
||||
"- vfs_mkdir:创建目录。",
|
||||
"- vfs_delete:删除文件或目录。",
|
||||
"- vfs_move:移动路径。",
|
||||
"- vfs_copy:复制路径。",
|
||||
"- vfs_rename:重命名路径。",
|
||||
"- processors_list:获取可用处理器列表(含 type/name/config_schema/produces_file/supports_directory)。",
|
||||
"- processors_run:运行处理器处理文件或目录(会返回 task_id 或 task_ids)。",
|
||||
"",
|
||||
"规则:",
|
||||
"1) 读操作(web_fetch/vfs_list_dir/vfs_stat/vfs_read_text/vfs_search)可直接调用工具。",
|
||||
"2) 写/改/删操作(vfs_write_text/vfs_mkdir/vfs_delete/vfs_move/vfs_copy/vfs_rename/processors_run)默认需要用户确认;只有在开启自动执行时才应直接执行。",
|
||||
"3) 用户未给出明确路径时先追问;若提供了“当前文件管理目录”,可以基于它把相对描述补全为绝对路径(以 / 开头)。",
|
||||
"4) 修改文件内容:先读取(vfs_read_text)→给出改动点→确认后再写入(vfs_write_text)。",
|
||||
"5) processors_run 返回任务 id 后,说明任务已提交,可在任务队列查看进度。",
|
||||
"6) 回答语言跟随用户;用户用英文则用英文,用户用中文则用中文。回答尽量简洁。",
|
||||
]
|
||||
if current_path:
|
||||
lines.append("")
|
||||
lines.append(f"当前文件管理目录:{current_path}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _ensure_tool_call_ids(message: Dict[str, Any]) -> Dict[str, Any]:
|
||||
tool_calls = message.get("tool_calls")
|
||||
if not isinstance(tool_calls, list):
|
||||
return message
|
||||
|
||||
changed = False
|
||||
for idx, call in enumerate(tool_calls):
|
||||
if not isinstance(call, dict):
|
||||
continue
|
||||
call_id = call.get("id")
|
||||
if isinstance(call_id, str) and call_id.strip():
|
||||
continue
|
||||
call["id"] = f"call_{idx}"
|
||||
changed = True
|
||||
|
||||
if changed:
|
||||
message["tool_calls"] = tool_calls
|
||||
return message
|
||||
|
||||
|
||||
def _extract_pending(tool_call: Dict[str, Any], requires_confirmation: bool) -> PendingToolCall:
|
||||
call_id = str(tool_call.get("id") or "")
|
||||
fn = tool_call.get("function") or {}
|
||||
name = str((fn.get("name") if isinstance(fn, dict) else None) or "")
|
||||
raw_args = fn.get("arguments") if isinstance(fn, dict) else None
|
||||
arguments: Dict[str, Any] = {}
|
||||
if isinstance(raw_args, str) and raw_args.strip():
|
||||
try:
|
||||
parsed = json.loads(raw_args)
|
||||
if isinstance(parsed, dict):
|
||||
arguments = parsed
|
||||
except json.JSONDecodeError:
|
||||
arguments = {}
|
||||
return PendingToolCall(
|
||||
id=call_id,
|
||||
name=name,
|
||||
arguments=arguments,
|
||||
requires_confirmation=requires_confirmation,
|
||||
)
|
||||
|
||||
|
||||
def _find_last_assistant_tool_calls(messages: List[Dict[str, Any]]) -> Tuple[int, Dict[str, Any]]:
|
||||
for idx in range(len(messages) - 1, -1, -1):
|
||||
msg = messages[idx]
|
||||
if not isinstance(msg, dict):
|
||||
continue
|
||||
if msg.get("role") != "assistant":
|
||||
continue
|
||||
tool_calls = msg.get("tool_calls")
|
||||
if isinstance(tool_calls, list) and tool_calls:
|
||||
return idx, msg
|
||||
raise HTTPException(status_code=400, detail="没有可确认的待执行操作")
|
||||
|
||||
|
||||
def _existing_tool_result_ids(messages: List[Dict[str, Any]]) -> set[str]:
|
||||
ids: set[str] = set()
|
||||
for msg in messages:
|
||||
if not isinstance(msg, dict):
|
||||
continue
|
||||
if msg.get("role") != "tool":
|
||||
continue
|
||||
tool_call_id = msg.get("tool_call_id")
|
||||
if isinstance(tool_call_id, str) and tool_call_id.strip():
|
||||
ids.add(tool_call_id)
|
||||
return ids
|
||||
|
||||
|
||||
async def _choose_chat_ability() -> str:
|
||||
tools_model = await AIProviderService.get_default_model("tools")
|
||||
return "tools" if tools_model else "chat"
|
||||
|
||||
|
||||
def _sse(event: str, data: Any) -> bytes:
|
||||
payload = json.dumps(data, ensure_ascii=False, separators=(",", ":"))
|
||||
return f"event: {event}\ndata: {payload}\n\n".encode("utf-8")
|
||||
|
||||
|
||||
def _format_exc(exc: BaseException) -> str:
|
||||
text = str(exc)
|
||||
return text if text else exc.__class__.__name__
|
||||
|
||||
|
||||
class AgentService:
|
||||
@classmethod
|
||||
async def chat(cls, req: AgentChatRequest, user: Optional[User]) -> Dict[str, Any]:
|
||||
history: List[Dict[str, Any]] = list(req.messages or [])
|
||||
current_path = _normalize_path(req.context.current_path if req.context else None)
|
||||
|
||||
system_prompt = _build_system_prompt(current_path)
|
||||
internal_messages: List[Dict[str, Any]] = [{"role": "system", "content": system_prompt}] + history
|
||||
|
||||
new_messages: List[Dict[str, Any]] = []
|
||||
pending: List[PendingToolCall] = []
|
||||
|
||||
approved_ids = {i for i in (req.approved_tool_call_ids or []) if isinstance(i, str) and i.strip()}
|
||||
rejected_ids = {i for i in (req.rejected_tool_call_ids or []) if isinstance(i, str) and i.strip()}
|
||||
|
||||
if approved_ids or rejected_ids:
|
||||
_, last_call_msg = _find_last_assistant_tool_calls(internal_messages)
|
||||
last_call_msg = _ensure_tool_call_ids(last_call_msg)
|
||||
tool_calls = last_call_msg.get("tool_calls") or []
|
||||
call_map: Dict[str, Dict[str, Any]] = {
|
||||
str(c.get("id")): c
|
||||
for c in tool_calls
|
||||
if isinstance(c, dict) and isinstance(c.get("id"), str)
|
||||
}
|
||||
|
||||
existing_ids = _existing_tool_result_ids(internal_messages)
|
||||
for call_id in approved_ids | rejected_ids:
|
||||
if call_id in existing_ids:
|
||||
continue
|
||||
tool_call = call_map.get(call_id)
|
||||
if not tool_call:
|
||||
continue
|
||||
fn = tool_call.get("function") or {}
|
||||
name = fn.get("name") if isinstance(fn, dict) else None
|
||||
args_raw = fn.get("arguments") if isinstance(fn, dict) else None
|
||||
args: Dict[str, Any] = {}
|
||||
if isinstance(args_raw, str) and args_raw.strip():
|
||||
try:
|
||||
parsed = json.loads(args_raw)
|
||||
if isinstance(parsed, dict):
|
||||
args = parsed
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
|
||||
spec = get_tool(str(name or ""))
|
||||
if call_id in rejected_ids:
|
||||
content = tool_result_to_content({"canceled": True, "reason": "user_rejected"})
|
||||
tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content}
|
||||
internal_messages.append(tool_msg)
|
||||
new_messages.append(tool_msg)
|
||||
continue
|
||||
|
||||
if not spec:
|
||||
content = tool_result_to_content({"error": f"unknown_tool: {name}"})
|
||||
tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content}
|
||||
internal_messages.append(tool_msg)
|
||||
new_messages.append(tool_msg)
|
||||
continue
|
||||
|
||||
try:
|
||||
result = await spec.handler(args)
|
||||
content = tool_result_to_content(result)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
content = tool_result_to_content({"error": str(exc)})
|
||||
tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content}
|
||||
internal_messages.append(tool_msg)
|
||||
new_messages.append(tool_msg)
|
||||
|
||||
tools_schema = openai_tools()
|
||||
ability = await _choose_chat_ability()
|
||||
max_loops = 4
|
||||
|
||||
for _ in range(max_loops):
|
||||
try:
|
||||
assistant = await chat_completion(
|
||||
internal_messages,
|
||||
ability=ability,
|
||||
tools=tools_schema,
|
||||
tool_choice="auto",
|
||||
timeout=60.0,
|
||||
)
|
||||
except MissingModelError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
except httpx.HTTPStatusError as exc:
|
||||
raise HTTPException(status_code=502, detail=f"对话请求失败: {exc}") from exc
|
||||
except httpx.RequestError as exc:
|
||||
raise HTTPException(status_code=502, detail=f"对话请求异常: {exc}") from exc
|
||||
|
||||
assistant = _ensure_tool_call_ids(assistant)
|
||||
internal_messages.append(assistant)
|
||||
new_messages.append(assistant)
|
||||
|
||||
tool_calls = assistant.get("tool_calls")
|
||||
if not isinstance(tool_calls, list) or not tool_calls:
|
||||
break
|
||||
|
||||
pending = []
|
||||
for call in tool_calls:
|
||||
if not isinstance(call, dict):
|
||||
continue
|
||||
call_id = str(call.get("id") or "")
|
||||
fn = call.get("function") or {}
|
||||
name = fn.get("name") if isinstance(fn, dict) else None
|
||||
args_raw = fn.get("arguments") if isinstance(fn, dict) else None
|
||||
args: Dict[str, Any] = {}
|
||||
if isinstance(args_raw, str) and args_raw.strip():
|
||||
try:
|
||||
parsed = json.loads(args_raw)
|
||||
if isinstance(parsed, dict):
|
||||
args = parsed
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
|
||||
spec = get_tool(str(name or ""))
|
||||
if not spec:
|
||||
content = tool_result_to_content({"error": f"unknown_tool: {name}"})
|
||||
tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content}
|
||||
internal_messages.append(tool_msg)
|
||||
new_messages.append(tool_msg)
|
||||
continue
|
||||
|
||||
if spec.requires_confirmation and not req.auto_execute:
|
||||
pending.append(_extract_pending(call, True))
|
||||
continue
|
||||
|
||||
try:
|
||||
result = await spec.handler(args)
|
||||
content = tool_result_to_content(result)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
content = tool_result_to_content({"error": str(exc)})
|
||||
tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content}
|
||||
internal_messages.append(tool_msg)
|
||||
new_messages.append(tool_msg)
|
||||
|
||||
if pending:
|
||||
break
|
||||
|
||||
payload: Dict[str, Any] = {"messages": new_messages}
|
||||
if pending:
|
||||
payload["pending_tool_calls"] = [p.model_dump() for p in pending]
|
||||
return payload
|
||||
|
||||
@classmethod
|
||||
async def chat_stream(cls, req: AgentChatRequest, user: Optional[User]):
|
||||
history: List[Dict[str, Any]] = list(req.messages or [])
|
||||
current_path = _normalize_path(req.context.current_path if req.context else None)
|
||||
|
||||
system_prompt = _build_system_prompt(current_path)
|
||||
internal_messages: List[Dict[str, Any]] = [{"role": "system", "content": system_prompt}] + history
|
||||
|
||||
new_messages: List[Dict[str, Any]] = []
|
||||
pending: List[PendingToolCall] = []
|
||||
|
||||
approved_ids = {i for i in (req.approved_tool_call_ids or []) if isinstance(i, str) and i.strip()}
|
||||
rejected_ids = {i for i in (req.rejected_tool_call_ids or []) if isinstance(i, str) and i.strip()}
|
||||
|
||||
try:
|
||||
if approved_ids or rejected_ids:
|
||||
_, last_call_msg = _find_last_assistant_tool_calls(internal_messages)
|
||||
last_call_msg = _ensure_tool_call_ids(last_call_msg)
|
||||
tool_calls = last_call_msg.get("tool_calls") or []
|
||||
call_map: Dict[str, Dict[str, Any]] = {
|
||||
str(c.get("id")): c
|
||||
for c in tool_calls
|
||||
if isinstance(c, dict) and isinstance(c.get("id"), str)
|
||||
}
|
||||
|
||||
existing_ids = _existing_tool_result_ids(internal_messages)
|
||||
for call_id in approved_ids | rejected_ids:
|
||||
if call_id in existing_ids:
|
||||
continue
|
||||
tool_call = call_map.get(call_id)
|
||||
if not tool_call:
|
||||
continue
|
||||
fn = tool_call.get("function") or {}
|
||||
name = fn.get("name") if isinstance(fn, dict) else None
|
||||
args_raw = fn.get("arguments") if isinstance(fn, dict) else None
|
||||
args: Dict[str, Any] = {}
|
||||
if isinstance(args_raw, str) and args_raw.strip():
|
||||
try:
|
||||
parsed = json.loads(args_raw)
|
||||
if isinstance(parsed, dict):
|
||||
args = parsed
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
|
||||
spec = get_tool(str(name or ""))
|
||||
if call_id in rejected_ids:
|
||||
content = tool_result_to_content({"canceled": True, "reason": "user_rejected"})
|
||||
tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content}
|
||||
internal_messages.append(tool_msg)
|
||||
new_messages.append(tool_msg)
|
||||
yield _sse("tool_end", {"tool_call_id": call_id, "name": str(name or ""), "message": tool_msg})
|
||||
continue
|
||||
|
||||
if not spec:
|
||||
content = tool_result_to_content({"error": f"unknown_tool: {name}"})
|
||||
tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content}
|
||||
internal_messages.append(tool_msg)
|
||||
new_messages.append(tool_msg)
|
||||
yield _sse("tool_end", {"tool_call_id": call_id, "name": str(name or ""), "message": tool_msg})
|
||||
continue
|
||||
|
||||
yield _sse("tool_start", {"tool_call_id": call_id, "name": spec.name})
|
||||
try:
|
||||
result = await spec.handler(args)
|
||||
content = tool_result_to_content(result)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
content = tool_result_to_content({"error": str(exc)})
|
||||
tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content}
|
||||
internal_messages.append(tool_msg)
|
||||
new_messages.append(tool_msg)
|
||||
yield _sse("tool_end", {"tool_call_id": call_id, "name": spec.name, "message": tool_msg})
|
||||
|
||||
tools_schema = openai_tools()
|
||||
ability = await _choose_chat_ability()
|
||||
max_loops = 4
|
||||
|
||||
for _ in range(max_loops):
|
||||
assistant_event_id = uuid.uuid4().hex
|
||||
yield _sse("assistant_start", {"id": assistant_event_id})
|
||||
|
||||
assistant_message: Dict[str, Any] | None = None
|
||||
try:
|
||||
async for event in chat_completion_stream(
|
||||
internal_messages,
|
||||
ability=ability,
|
||||
tools=tools_schema,
|
||||
tool_choice="auto",
|
||||
timeout=60.0,
|
||||
):
|
||||
if event.get("type") == "delta":
|
||||
delta = event.get("delta")
|
||||
if isinstance(delta, str) and delta:
|
||||
yield _sse("assistant_delta", {"id": assistant_event_id, "delta": delta})
|
||||
elif event.get("type") == "message":
|
||||
msg = event.get("message")
|
||||
if isinstance(msg, dict):
|
||||
assistant_message = msg
|
||||
except MissingModelError as exc:
|
||||
raise HTTPException(status_code=400, detail=_format_exc(exc)) from exc
|
||||
except httpx.HTTPStatusError as exc:
|
||||
raise HTTPException(status_code=502, detail=f"对话请求失败: {_format_exc(exc)}") from exc
|
||||
except httpx.RequestError as exc:
|
||||
raise HTTPException(status_code=502, detail=f"对话请求异常: {_format_exc(exc)}") from exc
|
||||
|
||||
if not assistant_message:
|
||||
assistant_message = {"role": "assistant", "content": ""}
|
||||
|
||||
assistant_message = _ensure_tool_call_ids(assistant_message)
|
||||
internal_messages.append(assistant_message)
|
||||
new_messages.append(assistant_message)
|
||||
yield _sse("assistant_end", {"id": assistant_event_id, "message": assistant_message})
|
||||
|
||||
tool_calls = assistant_message.get("tool_calls")
|
||||
if not isinstance(tool_calls, list) or not tool_calls:
|
||||
break
|
||||
|
||||
pending = []
|
||||
for call in tool_calls:
|
||||
if not isinstance(call, dict):
|
||||
continue
|
||||
call_id = str(call.get("id") or "")
|
||||
fn = call.get("function") or {}
|
||||
name = fn.get("name") if isinstance(fn, dict) else None
|
||||
args_raw = fn.get("arguments") if isinstance(fn, dict) else None
|
||||
args: Dict[str, Any] = {}
|
||||
if isinstance(args_raw, str) and args_raw.strip():
|
||||
try:
|
||||
parsed = json.loads(args_raw)
|
||||
if isinstance(parsed, dict):
|
||||
args = parsed
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
|
||||
spec = get_tool(str(name or ""))
|
||||
if not spec:
|
||||
content = tool_result_to_content({"error": f"unknown_tool: {name}"})
|
||||
tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content}
|
||||
internal_messages.append(tool_msg)
|
||||
new_messages.append(tool_msg)
|
||||
yield _sse("tool_end", {"tool_call_id": call_id, "name": str(name or ""), "message": tool_msg})
|
||||
continue
|
||||
|
||||
if spec.requires_confirmation and not req.auto_execute:
|
||||
pending.append(_extract_pending(call, True))
|
||||
continue
|
||||
|
||||
yield _sse("tool_start", {"tool_call_id": call_id, "name": spec.name})
|
||||
try:
|
||||
result = await spec.handler(args)
|
||||
content = tool_result_to_content(result)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
content = tool_result_to_content({"error": str(exc)})
|
||||
tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content}
|
||||
internal_messages.append(tool_msg)
|
||||
new_messages.append(tool_msg)
|
||||
yield _sse("tool_end", {"tool_call_id": call_id, "name": spec.name, "message": tool_msg})
|
||||
|
||||
if pending:
|
||||
yield _sse("pending", {"pending_tool_calls": [p.model_dump() for p in pending]})
|
||||
break
|
||||
|
||||
payload: Dict[str, Any] = {"messages": new_messages}
|
||||
if pending:
|
||||
payload["pending_tool_calls"] = [p.model_dump() for p in pending]
|
||||
yield _sse("done", payload)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except HTTPException as exc:
|
||||
detail = exc.detail
|
||||
content = detail if isinstance(detail, str) else str(detail)
|
||||
if not content.strip():
|
||||
content = f"请求失败({exc.status_code})"
|
||||
new_messages.append({"role": "assistant", "content": content})
|
||||
payload: Dict[str, Any] = {"messages": new_messages}
|
||||
if pending:
|
||||
payload["pending_tool_calls"] = [p.model_dump() for p in pending]
|
||||
yield _sse("done", payload)
|
||||
return
|
||||
except Exception as exc: # noqa: BLE001
|
||||
new_messages.append({"role": "assistant", "content": f"服务端异常: {_format_exc(exc)}"})
|
||||
payload: Dict[str, Any] = {"messages": new_messages}
|
||||
if pending:
|
||||
payload["pending_tool_calls"] = [p.model_dump() for p in pending]
|
||||
yield _sse("done", payload)
|
||||
return
|
||||
37
domain/agent/tools/__init__.py
Normal file
37
domain/agent/tools/__init__.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from .base import ToolSpec, tool_result_to_content
|
||||
from .processors import TOOLS as PROCESSOR_TOOLS
|
||||
from .time import TOOLS as TIME_TOOLS
|
||||
from .vfs import TOOLS as VFS_TOOLS
|
||||
from .web_fetch import TOOLS as WEB_FETCH_TOOLS
|
||||
|
||||
TOOLS: Dict[str, ToolSpec] = {}
|
||||
for group in (TIME_TOOLS, WEB_FETCH_TOOLS, PROCESSOR_TOOLS, VFS_TOOLS):
|
||||
TOOLS.update(group)
|
||||
|
||||
|
||||
def get_tool(name: str) -> Optional[ToolSpec]:
|
||||
return TOOLS.get(name)
|
||||
|
||||
|
||||
def openai_tools() -> List[Dict[str, Any]]:
|
||||
out: List[Dict[str, Any]] = []
|
||||
for spec in TOOLS.values():
|
||||
out.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": spec.name,
|
||||
"description": spec.description,
|
||||
"parameters": spec.parameters,
|
||||
},
|
||||
})
|
||||
return out
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ToolSpec",
|
||||
"get_tool",
|
||||
"openai_tools",
|
||||
"tool_result_to_content",
|
||||
]
|
||||
149
domain/agent/tools/base.py
Normal file
149
domain/agent/tools/base.py
Normal file
@@ -0,0 +1,149 @@
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ToolSpec:
|
||||
name: str
|
||||
description: str
|
||||
parameters: Dict[str, Any]
|
||||
requires_confirmation: bool
|
||||
handler: Callable[[Dict[str, Any]], Awaitable[Any]]
|
||||
|
||||
|
||||
def _stringify_value(value: Any) -> str:
|
||||
if value is None:
|
||||
return ""
|
||||
if isinstance(value, bool):
|
||||
return "true" if value else "false"
|
||||
if isinstance(value, (int, float)):
|
||||
return str(value)
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
try:
|
||||
return json.dumps(value, ensure_ascii=False)
|
||||
except TypeError:
|
||||
return str(value)
|
||||
|
||||
|
||||
def _list_to_view_items(items: List[Any]) -> List[Any]:
|
||||
normalized: List[Any] = []
|
||||
for item in items:
|
||||
if isinstance(item, dict):
|
||||
normalized.append({str(k): _stringify_value(v) for k, v in item.items()})
|
||||
else:
|
||||
normalized.append(_stringify_value(item))
|
||||
return normalized
|
||||
|
||||
|
||||
def _dict_to_kv_items(data: Dict[str, Any]) -> List[Dict[str, str]]:
|
||||
return [{"key": str(k), "value": _stringify_value(v)} for k, v in data.items()]
|
||||
|
||||
|
||||
def _first_list_field(data: Dict[str, Any]) -> tuple[Optional[str], Optional[List[Any]]]:
|
||||
for key, value in data.items():
|
||||
if isinstance(value, list):
|
||||
return str(key), value
|
||||
return None, None
|
||||
|
||||
|
||||
def _build_view(data: Any) -> Dict[str, Any]:
|
||||
if data is None:
|
||||
return {"type": "kv", "items": []}
|
||||
if isinstance(data, str):
|
||||
return {"type": "text", "text": data}
|
||||
if isinstance(data, list):
|
||||
return {"type": "list", "items": _list_to_view_items(data)}
|
||||
if isinstance(data, dict):
|
||||
content = data.get("content")
|
||||
if isinstance(content, str):
|
||||
meta = {k: _stringify_value(v) for k, v in data.items() if k != "content"}
|
||||
view: Dict[str, Any] = {"type": "text", "text": content}
|
||||
if meta:
|
||||
view["meta"] = meta
|
||||
return view
|
||||
list_key, list_val = _first_list_field(data)
|
||||
if list_key and isinstance(list_val, list):
|
||||
meta = {k: _stringify_value(v) for k, v in data.items() if k != list_key}
|
||||
view = {"type": "list", "title": list_key, "items": _list_to_view_items(list_val)}
|
||||
if meta:
|
||||
view["meta"] = meta
|
||||
return view
|
||||
return {"type": "kv", "items": _dict_to_kv_items(data)}
|
||||
return {"type": "text", "text": _stringify_value(data)}
|
||||
|
||||
|
||||
def _build_summary(view: Dict[str, Any]) -> str:
|
||||
view_type = str(view.get("type") or "")
|
||||
if view_type == "text":
|
||||
text = view.get("text")
|
||||
size = len(text) if isinstance(text, str) else 0
|
||||
return f"chars: {size}" if size else "text"
|
||||
if view_type == "list":
|
||||
items = view.get("items")
|
||||
count = len(items) if isinstance(items, list) else 0
|
||||
title = str(view.get("title") or "items")
|
||||
return f"{title}: {count}"
|
||||
if view_type == "kv":
|
||||
items = view.get("items")
|
||||
count = len(items) if isinstance(items, list) else 0
|
||||
return f"fields: {count}"
|
||||
if view_type == "error":
|
||||
return str(view.get("message") or "error")
|
||||
return ""
|
||||
|
||||
|
||||
def _build_error_payload(code: str, message: str, detail: Any = None) -> Dict[str, Any]:
|
||||
summary = "Canceled" if code == "canceled" else message or "error"
|
||||
view = {"type": "error", "message": summary}
|
||||
payload: Dict[str, Any] = {
|
||||
"ok": False,
|
||||
"summary": summary,
|
||||
"view": view,
|
||||
"error": {
|
||||
"code": code,
|
||||
"message": message,
|
||||
},
|
||||
}
|
||||
if detail is not None:
|
||||
payload["error"]["detail"] = detail
|
||||
return payload
|
||||
|
||||
|
||||
def _normalize_tool_result(result: Any) -> Dict[str, Any]:
|
||||
if isinstance(result, dict) and "ok" in result:
|
||||
payload = dict(result)
|
||||
if payload.get("ok") is False:
|
||||
error = payload.get("error")
|
||||
message = _stringify_value(error.get("message") if isinstance(error, dict) else error)
|
||||
payload.setdefault("summary", message or "error")
|
||||
payload.setdefault("view", {"type": "error", "message": payload["summary"]})
|
||||
return payload
|
||||
data = payload.get("data")
|
||||
if payload.get("view") is None:
|
||||
payload["view"] = _build_view(data)
|
||||
if not payload.get("summary"):
|
||||
payload["summary"] = _build_summary(payload["view"])
|
||||
return payload
|
||||
|
||||
if isinstance(result, dict) and result.get("canceled"):
|
||||
reason = _stringify_value(result.get("reason") or "canceled")
|
||||
return _build_error_payload("canceled", reason, detail=result)
|
||||
|
||||
if isinstance(result, dict) and "error" in result:
|
||||
error = result.get("error")
|
||||
message = _stringify_value(error.get("message") if isinstance(error, dict) else error)
|
||||
return _build_error_payload("error", message, detail=error)
|
||||
|
||||
view = _build_view(result)
|
||||
summary = _build_summary(view)
|
||||
return {"ok": True, "summary": summary, "view": view, "data": result}
|
||||
|
||||
|
||||
def tool_result_to_content(result: Any) -> str:
|
||||
payload = _normalize_tool_result(result)
|
||||
try:
|
||||
return json.dumps(payload, ensure_ascii=False, default=str)
|
||||
except TypeError:
|
||||
return json.dumps({"ok": False, "summary": "error", "view": {"type": "error", "message": "error"}}, ensure_ascii=False)
|
||||
96
domain/agent/tools/processors.py
Normal file
96
domain/agent/tools/processors.py
Normal file
@@ -0,0 +1,96 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from domain.processors import ProcessDirectoryRequest, ProcessRequest, ProcessorService
|
||||
from domain.virtual_fs import VirtualFSService
|
||||
|
||||
from .base import ToolSpec
|
||||
|
||||
|
||||
async def _processors_list(_: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return {"processors": ProcessorService.list_processors()}
|
||||
|
||||
|
||||
async def _processors_run(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
path = str(args.get("path") or "")
|
||||
processor_type = str(args.get("processor_type") or "")
|
||||
config = args.get("config")
|
||||
if not isinstance(config, dict):
|
||||
config = {}
|
||||
|
||||
save_to = args.get("save_to")
|
||||
save_to = str(save_to) if isinstance(save_to, str) and save_to.strip() else None
|
||||
|
||||
max_depth = args.get("max_depth")
|
||||
max_depth_value: Optional[int] = None
|
||||
if max_depth is not None:
|
||||
try:
|
||||
max_depth_value = int(max_depth)
|
||||
except (TypeError, ValueError):
|
||||
max_depth_value = None
|
||||
|
||||
suffix = args.get("suffix")
|
||||
suffix_value = str(suffix) if isinstance(suffix, str) and suffix.strip() else None
|
||||
|
||||
overwrite_value = args.get("overwrite")
|
||||
overwrite = bool(overwrite_value) if overwrite_value is not None else None
|
||||
|
||||
is_dir = await VirtualFSService.path_is_directory(path)
|
||||
if is_dir and (max_depth_value is not None or suffix_value is not None):
|
||||
req = ProcessDirectoryRequest(
|
||||
path=path,
|
||||
processor_type=processor_type,
|
||||
config=config,
|
||||
overwrite=True if overwrite is None else overwrite,
|
||||
max_depth=max_depth_value,
|
||||
suffix=suffix_value,
|
||||
)
|
||||
result = await ProcessorService.process_directory(req)
|
||||
return {"mode": "directory", **result}
|
||||
|
||||
req = ProcessRequest(
|
||||
path=path,
|
||||
processor_type=processor_type,
|
||||
config=config,
|
||||
save_to=save_to,
|
||||
overwrite=False if overwrite is None else overwrite,
|
||||
)
|
||||
result = await ProcessorService.process_file(req)
|
||||
return {"mode": "file", **result}
|
||||
|
||||
|
||||
TOOLS: Dict[str, ToolSpec] = {
|
||||
"processors_list": ToolSpec(
|
||||
name="processors_list",
|
||||
description="获取可用处理器列表(type/name/config_schema 等)。",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"additionalProperties": False,
|
||||
},
|
||||
requires_confirmation=False,
|
||||
handler=_processors_list,
|
||||
),
|
||||
"processors_run": ToolSpec(
|
||||
name="processors_run",
|
||||
description=(
|
||||
"运行处理器处理文件或目录。"
|
||||
" 对目录可选 max_depth/suffix;对文件可选 overwrite/save_to。"
|
||||
" 返回任务 id(去任务队列查看进度)。"
|
||||
),
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "文件或目录路径(绝对路径,如 /foo/bar)"},
|
||||
"processor_type": {"type": "string", "description": "处理器类型(例如 image_watermark)"},
|
||||
"config": {"type": "object", "description": "处理器配置,按 processors_list 返回的 config_schema 填写"},
|
||||
"overwrite": {"type": "boolean", "description": "是否覆盖原文件/目录内文件"},
|
||||
"save_to": {"type": "string", "description": "保存到指定路径(仅文件模式,且 overwrite=false 时使用)"},
|
||||
"max_depth": {"type": "integer", "description": "目录遍历深度(仅目录模式)"},
|
||||
"suffix": {"type": "string", "description": "目录批处理时的输出后缀(仅 produces_file 且 overwrite=false)"},
|
||||
},
|
||||
"required": ["path", "processor_type"],
|
||||
},
|
||||
requires_confirmation=True,
|
||||
handler=_processors_run,
|
||||
),
|
||||
}
|
||||
92
domain/agent/tools/time.py
Normal file
92
domain/agent/tools/time.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import calendar
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict
|
||||
|
||||
from .base import ToolSpec
|
||||
|
||||
|
||||
def _parse_offset(args: Dict[str, Any], key: str) -> int:
|
||||
value = args.get(key)
|
||||
if value is None:
|
||||
return 0
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return 0
|
||||
|
||||
|
||||
def _add_months(dt: datetime, months: int) -> datetime:
|
||||
if months == 0:
|
||||
return dt
|
||||
total = dt.year * 12 + (dt.month - 1) + months
|
||||
year = total // 12
|
||||
month = total % 12 + 1
|
||||
last_day = calendar.monthrange(year, month)[1]
|
||||
day = min(dt.day, last_day)
|
||||
return dt.replace(year=year, month=month, day=day)
|
||||
|
||||
|
||||
async def _time(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
now = datetime.now()
|
||||
year_offset = _parse_offset(args, "year")
|
||||
month_offset = _parse_offset(args, "month")
|
||||
day_offset = _parse_offset(args, "day")
|
||||
hour_offset = _parse_offset(args, "hour")
|
||||
minute_offset = _parse_offset(args, "minute")
|
||||
second_offset = _parse_offset(args, "second")
|
||||
|
||||
dt = _add_months(now, year_offset * 12 + month_offset)
|
||||
dt = dt + timedelta(days=day_offset, hours=hour_offset, minutes=minute_offset, seconds=second_offset)
|
||||
|
||||
weekday_names = [
|
||||
"Monday",
|
||||
"Tuesday",
|
||||
"Wednesday",
|
||||
"Thursday",
|
||||
"Friday",
|
||||
"Saturday",
|
||||
"Sunday",
|
||||
]
|
||||
weekday = weekday_names[dt.weekday()]
|
||||
dt_str = dt.strftime("%Y-%m-%d %H:%M:%S")
|
||||
return {
|
||||
"ok": True,
|
||||
"summary": f"{dt_str} · {weekday}",
|
||||
"data": {
|
||||
"datetime": dt_str,
|
||||
"weekday": weekday,
|
||||
"offset": {
|
||||
"year": year_offset,
|
||||
"month": month_offset,
|
||||
"day": day_offset,
|
||||
"hour": hour_offset,
|
||||
"minute": minute_offset,
|
||||
"second": second_offset,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
TOOLS: Dict[str, ToolSpec] = {
|
||||
"time": ToolSpec(
|
||||
name="time",
|
||||
description=(
|
||||
"获取服务器当前时间(精确到秒,含英文星期)。"
|
||||
" 支持 year/month/day/hour/minute/second 偏移(可为负数)。"
|
||||
),
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"year": {"type": "integer", "description": "年偏移(可为负数)"},
|
||||
"month": {"type": "integer", "description": "月偏移(可为负数)"},
|
||||
"day": {"type": "integer", "description": "日偏移(可为负数)"},
|
||||
"hour": {"type": "integer", "description": "时偏移(可为负数)"},
|
||||
"minute": {"type": "integer", "description": "分偏移(可为负数)"},
|
||||
"second": {"type": "integer", "description": "秒偏移(可为负数)"},
|
||||
},
|
||||
"additionalProperties": False,
|
||||
},
|
||||
requires_confirmation=False,
|
||||
handler=_time,
|
||||
),
|
||||
}
|
||||
287
domain/agent/tools/vfs.py
Normal file
287
domain/agent/tools/vfs.py
Normal file
@@ -0,0 +1,287 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from domain.virtual_fs import VirtualFSService
|
||||
from domain.virtual_fs.search import VirtualFSSearchService
|
||||
|
||||
from .base import ToolSpec
|
||||
|
||||
|
||||
def _normalize_vfs_path(value: Any) -> str:
|
||||
s = str(value or "").strip().replace("\\", "/")
|
||||
if not s:
|
||||
return ""
|
||||
if not s.startswith("/"):
|
||||
s = "/" + s
|
||||
s = s.rstrip("/") or "/"
|
||||
return s
|
||||
|
||||
|
||||
def _require_vfs_path(value: Any, field: str) -> str:
|
||||
path = _normalize_vfs_path(value)
|
||||
if not path:
|
||||
raise ValueError(f"missing_{field}")
|
||||
return path
|
||||
|
||||
|
||||
async def _vfs_list_dir(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
path = _normalize_vfs_path(args.get("path") or "/") or "/"
|
||||
page = int(args.get("page") or 1)
|
||||
page_size = int(args.get("page_size") or 50)
|
||||
sort_by = str(args.get("sort_by") or "name")
|
||||
sort_order = str(args.get("sort_order") or "asc")
|
||||
return await VirtualFSService.list_directory(path, page, page_size, sort_by, sort_order)
|
||||
|
||||
|
||||
async def _vfs_stat(args: Dict[str, Any]) -> Any:
|
||||
path = _require_vfs_path(args.get("path"), "path")
|
||||
return await VirtualFSService.stat(path)
|
||||
|
||||
|
||||
async def _vfs_read_text(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
path = _require_vfs_path(args.get("path"), "path")
|
||||
encoding = str(args.get("encoding") or "utf-8")
|
||||
max_chars = int(args.get("max_chars") or 8000)
|
||||
|
||||
data = await VirtualFSService.read_file(path)
|
||||
if isinstance(data, (bytes, bytearray)):
|
||||
try:
|
||||
text = bytes(data).decode(encoding)
|
||||
except UnicodeDecodeError:
|
||||
return {"error": "binary_or_invalid_text", "path": path}
|
||||
elif isinstance(data, str):
|
||||
text = data
|
||||
else:
|
||||
text = str(data)
|
||||
|
||||
original_len = len(text)
|
||||
truncated = original_len > max_chars
|
||||
if truncated:
|
||||
text = text[:max_chars]
|
||||
return {
|
||||
"path": path,
|
||||
"encoding": encoding,
|
||||
"content": text,
|
||||
"truncated": truncated,
|
||||
"length": original_len,
|
||||
}
|
||||
|
||||
|
||||
async def _vfs_write_text(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
path = _require_vfs_path(args.get("path"), "path")
|
||||
if path == "/":
|
||||
raise ValueError("invalid_path")
|
||||
encoding = str(args.get("encoding") or "utf-8")
|
||||
content = str(args.get("content") or "")
|
||||
data = content.encode(encoding)
|
||||
await VirtualFSService.write_file(path, data)
|
||||
return {"written": True, "path": path, "encoding": encoding, "bytes": len(data)}
|
||||
|
||||
|
||||
async def _vfs_mkdir(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
path = _require_vfs_path(args.get("path"), "path")
|
||||
return await VirtualFSService.mkdir(path)
|
||||
|
||||
|
||||
async def _vfs_delete(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
path = _require_vfs_path(args.get("path"), "path")
|
||||
return await VirtualFSService.delete(path)
|
||||
|
||||
|
||||
async def _vfs_move(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
src = _require_vfs_path(args.get("src"), "src")
|
||||
dst = _require_vfs_path(args.get("dst"), "dst")
|
||||
if src == "/" or dst == "/":
|
||||
raise ValueError("invalid_path")
|
||||
overwrite = bool(args.get("overwrite") or False)
|
||||
return await VirtualFSService.move(src, dst, overwrite)
|
||||
|
||||
|
||||
async def _vfs_copy(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
src = _require_vfs_path(args.get("src"), "src")
|
||||
dst = _require_vfs_path(args.get("dst"), "dst")
|
||||
if src == "/" or dst == "/":
|
||||
raise ValueError("invalid_path")
|
||||
overwrite = bool(args.get("overwrite") or False)
|
||||
return await VirtualFSService.copy(src, dst, overwrite)
|
||||
|
||||
|
||||
async def _vfs_rename(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
src = _require_vfs_path(args.get("src"), "src")
|
||||
dst = _require_vfs_path(args.get("dst"), "dst")
|
||||
if src == "/" or dst == "/":
|
||||
raise ValueError("invalid_path")
|
||||
overwrite = bool(args.get("overwrite") or False)
|
||||
return await VirtualFSService.rename(src, dst, overwrite)
|
||||
|
||||
|
||||
async def _vfs_search(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
q = str(args.get("q") or "").strip()
|
||||
if not q:
|
||||
raise ValueError("missing_q")
|
||||
mode = str(args.get("mode") or "vector")
|
||||
top_k = int(args.get("top_k") or 10)
|
||||
page = int(args.get("page") or 1)
|
||||
page_size = int(args.get("page_size") or 10)
|
||||
return await VirtualFSSearchService.search(q, top_k, mode, page, page_size)
|
||||
|
||||
|
||||
TOOLS: Dict[str, ToolSpec] = {
|
||||
"vfs_list_dir": ToolSpec(
|
||||
name="vfs_list_dir",
|
||||
description="浏览目录(列出 entries + pagination)。",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "目录路径(绝对路径,如 /foo/bar)"},
|
||||
"page": {"type": "integer", "description": "页码(从 1 开始)"},
|
||||
"page_size": {"type": "integer", "description": "每页条数"},
|
||||
"sort_by": {"type": "string", "description": "排序字段:name/size/mtime"},
|
||||
"sort_order": {"type": "string", "description": "排序顺序:asc/desc"},
|
||||
},
|
||||
"required": ["path"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
requires_confirmation=False,
|
||||
handler=_vfs_list_dir,
|
||||
),
|
||||
"vfs_stat": ToolSpec(
|
||||
name="vfs_stat",
|
||||
description="查看文件/目录信息(size/mtime/is_dir/has_thumbnail/vector_index 等)。",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "路径(绝对路径,如 /foo/bar.txt)"},
|
||||
},
|
||||
"required": ["path"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
requires_confirmation=False,
|
||||
handler=_vfs_stat,
|
||||
),
|
||||
"vfs_read_text": ToolSpec(
|
||||
name="vfs_read_text",
|
||||
description="读取文本文件内容(解码失败视为二进制,返回 error)。",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "文件路径(绝对路径,如 /foo/bar.md)"},
|
||||
"encoding": {"type": "string", "description": "文本编码(默认 utf-8)"},
|
||||
"max_chars": {"type": "integer", "description": "最多返回的字符数(默认 8000)"},
|
||||
},
|
||||
"required": ["path"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
requires_confirmation=False,
|
||||
handler=_vfs_read_text,
|
||||
),
|
||||
"vfs_write_text": ToolSpec(
|
||||
name="vfs_write_text",
|
||||
description="写入文本文件内容(会覆盖目标文件)。",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "文件路径(绝对路径,如 /foo/bar.md)"},
|
||||
"content": {"type": "string", "description": "要写入的文本内容"},
|
||||
"encoding": {"type": "string", "description": "文本编码(默认 utf-8)"},
|
||||
},
|
||||
"required": ["path", "content"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
requires_confirmation=True,
|
||||
handler=_vfs_write_text,
|
||||
),
|
||||
"vfs_mkdir": ToolSpec(
|
||||
name="vfs_mkdir",
|
||||
description="创建目录。",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "目录路径(绝对路径,如 /foo/bar)"},
|
||||
},
|
||||
"required": ["path"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
requires_confirmation=True,
|
||||
handler=_vfs_mkdir,
|
||||
),
|
||||
"vfs_delete": ToolSpec(
|
||||
name="vfs_delete",
|
||||
description="删除文件或目录(由底层适配器决定是否递归)。",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "路径(绝对路径,如 /foo/bar 或 /foo/bar.txt)"},
|
||||
},
|
||||
"required": ["path"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
requires_confirmation=True,
|
||||
handler=_vfs_delete,
|
||||
),
|
||||
"vfs_move": ToolSpec(
|
||||
name="vfs_move",
|
||||
description="移动路径(可能进入任务队列)。",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"src": {"type": "string", "description": "源路径(绝对路径)"},
|
||||
"dst": {"type": "string", "description": "目标路径(绝对路径)"},
|
||||
"overwrite": {"type": "boolean", "description": "是否允许覆盖已存在目标(默认 false)"},
|
||||
},
|
||||
"required": ["src", "dst"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
requires_confirmation=True,
|
||||
handler=_vfs_move,
|
||||
),
|
||||
"vfs_copy": ToolSpec(
|
||||
name="vfs_copy",
|
||||
description="复制路径(可能进入任务队列)。",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"src": {"type": "string", "description": "源路径(绝对路径)"},
|
||||
"dst": {"type": "string", "description": "目标路径(绝对路径)"},
|
||||
"overwrite": {"type": "boolean", "description": "是否覆盖已存在目标(默认 false)"},
|
||||
},
|
||||
"required": ["src", "dst"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
requires_confirmation=True,
|
||||
handler=_vfs_copy,
|
||||
),
|
||||
"vfs_rename": ToolSpec(
|
||||
name="vfs_rename",
|
||||
description="重命名路径(本质是同目录 move)。",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"src": {"type": "string", "description": "源路径(绝对路径)"},
|
||||
"dst": {"type": "string", "description": "目标路径(绝对路径)"},
|
||||
"overwrite": {"type": "boolean", "description": "是否允许覆盖已存在目标(默认 false)"},
|
||||
},
|
||||
"required": ["src", "dst"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
requires_confirmation=True,
|
||||
handler=_vfs_rename,
|
||||
),
|
||||
"vfs_search": ToolSpec(
|
||||
name="vfs_search",
|
||||
description="搜索文件(mode=vector 或 filename)。",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"q": {"type": "string", "description": "搜索关键词"},
|
||||
"mode": {"type": "string", "description": "搜索模式:vector/filename(默认 vector)"},
|
||||
"top_k": {"type": "integer", "description": "返回数量(vector 模式使用,默认 10)"},
|
||||
"page": {"type": "integer", "description": "页码(filename 模式使用,默认 1)"},
|
||||
"page_size": {"type": "integer", "description": "分页大小(filename 模式使用,默认 10)"},
|
||||
},
|
||||
"required": ["q"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
requires_confirmation=False,
|
||||
handler=_vfs_search,
|
||||
),
|
||||
}
|
||||
182
domain/agent/tools/web_fetch.py
Normal file
182
domain/agent/tools/web_fetch.py
Normal file
@@ -0,0 +1,182 @@
|
||||
from html.parser import HTMLParser
|
||||
from typing import Any, Dict, List
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import httpx
|
||||
|
||||
from .base import ToolSpec
|
||||
|
||||
|
||||
class _HtmlTextExtractor(HTMLParser):
|
||||
def __init__(self, base_url: str):
|
||||
super().__init__()
|
||||
self.base_url = base_url
|
||||
self.links: List[str] = []
|
||||
self._link_set: set[str] = set()
|
||||
self._title_parts: List[str] = []
|
||||
self._text_parts: List[str] = []
|
||||
self._in_title = False
|
||||
self._skip_text = False
|
||||
|
||||
def handle_starttag(self, tag: str, attrs: List[tuple[str, str | None]]):
|
||||
tag = tag.lower()
|
||||
if tag == "title":
|
||||
self._in_title = True
|
||||
if tag in ("script", "style", "noscript"):
|
||||
self._skip_text = True
|
||||
if tag != "a":
|
||||
return
|
||||
href = ""
|
||||
for key, value in attrs:
|
||||
if key.lower() == "href":
|
||||
href = str(value or "").strip()
|
||||
break
|
||||
if not href or href.startswith("#"):
|
||||
return
|
||||
lower = href.lower()
|
||||
if lower.startswith(("javascript:", "mailto:", "tel:", "data:")):
|
||||
return
|
||||
resolved = urljoin(self.base_url, href)
|
||||
if resolved in self._link_set:
|
||||
return
|
||||
self._link_set.add(resolved)
|
||||
self.links.append(resolved)
|
||||
|
||||
def handle_endtag(self, tag: str):
|
||||
tag = tag.lower()
|
||||
if tag == "title":
|
||||
self._in_title = False
|
||||
if tag in ("script", "style", "noscript"):
|
||||
self._skip_text = False
|
||||
|
||||
def handle_data(self, data: str):
|
||||
if not data:
|
||||
return
|
||||
if self._in_title:
|
||||
self._title_parts.append(data)
|
||||
if self._skip_text:
|
||||
return
|
||||
if data.strip():
|
||||
self._text_parts.append(data)
|
||||
|
||||
@property
|
||||
def title(self) -> str:
|
||||
return " ".join(part.strip() for part in self._title_parts if part and part.strip()).strip()
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
if not self._text_parts:
|
||||
return ""
|
||||
text = " ".join(part.strip() for part in self._text_parts if part and part.strip())
|
||||
return " ".join(text.split())
|
||||
|
||||
|
||||
async def _web_fetch(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
url = str(args.get("url") or "").strip()
|
||||
if not url:
|
||||
raise ValueError("missing_url")
|
||||
|
||||
method = str(args.get("method") or "GET").upper()
|
||||
allowed_methods = {"GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"}
|
||||
if method not in allowed_methods:
|
||||
raise ValueError("invalid_method")
|
||||
|
||||
headers_raw = args.get("headers")
|
||||
headers = {str(k): str(v) for k, v in headers_raw.items() if v is not None} if isinstance(headers_raw, dict) else None
|
||||
params_raw = args.get("params")
|
||||
params = {str(k): str(v) for k, v in params_raw.items() if v is not None} if isinstance(params_raw, dict) else None
|
||||
json_body = args.get("json") if "json" in args else None
|
||||
body = args.get("body")
|
||||
|
||||
request_kwargs: Dict[str, Any] = {}
|
||||
if headers:
|
||||
request_kwargs["headers"] = headers
|
||||
if params:
|
||||
request_kwargs["params"] = params
|
||||
if json_body is not None:
|
||||
request_kwargs["json"] = json_body
|
||||
elif body is not None:
|
||||
request_kwargs["content"] = str(body)
|
||||
|
||||
async with httpx.AsyncClient(timeout=20.0, follow_redirects=True) as client:
|
||||
resp = await client.request(method, url, **request_kwargs)
|
||||
|
||||
content_type = resp.headers.get("content-type") or ""
|
||||
text = resp.text or ""
|
||||
is_html = "html" in content_type.lower()
|
||||
if not is_html:
|
||||
probe = text.lstrip()[:200].lower()
|
||||
if "<html" in probe or "<!doctype html" in probe:
|
||||
is_html = True
|
||||
|
||||
html = ""
|
||||
title = ""
|
||||
links: List[str] = []
|
||||
extracted_text = text
|
||||
|
||||
if is_html and text:
|
||||
html = text
|
||||
parser = _HtmlTextExtractor(str(resp.url))
|
||||
parser.feed(text)
|
||||
title = parser.title
|
||||
links = parser.links
|
||||
extracted_text = parser.text
|
||||
|
||||
data = {
|
||||
"url": url,
|
||||
"method": method,
|
||||
"final_url": str(resp.url),
|
||||
"status_code": resp.status_code,
|
||||
"content_type": content_type,
|
||||
"title": title,
|
||||
"html": html,
|
||||
"text": extracted_text,
|
||||
"links": links,
|
||||
}
|
||||
|
||||
summary_parts = [method, str(resp.status_code)]
|
||||
if title:
|
||||
summary_parts.append(title)
|
||||
summary_parts.append(f"{len(links)} links")
|
||||
summary = " · ".join(summary_parts)
|
||||
|
||||
view = {
|
||||
"type": "text",
|
||||
"text": extracted_text,
|
||||
"meta": {
|
||||
"url": url,
|
||||
"final_url": str(resp.url),
|
||||
"status_code": resp.status_code,
|
||||
"content_type": content_type,
|
||||
"title": title,
|
||||
"method": method,
|
||||
"links": len(links),
|
||||
},
|
||||
}
|
||||
return {"ok": True, "summary": summary, "view": view, "data": data}
|
||||
|
||||
|
||||
TOOLS: Dict[str, ToolSpec] = {
|
||||
"web_fetch": ToolSpec(
|
||||
name="web_fetch",
|
||||
description=(
|
||||
"抓取网页内容,返回状态、标题、正文、HTML、链接等信息。"
|
||||
" 支持 GET/POST/PUT/PATCH/DELETE/HEAD/OPTIONS。"
|
||||
),
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {"type": "string", "description": "目标 URL"},
|
||||
"method": {"type": "string", "description": "请求方法(默认 GET)"},
|
||||
"headers": {"type": "object", "description": "请求头", "additionalProperties": {"type": "string"}},
|
||||
"params": {"type": "object", "description": "查询参数", "additionalProperties": {"type": "string"}},
|
||||
"json": {"type": "object", "description": "JSON 请求体"},
|
||||
"body": {"type": "string", "description": "原始请求体"},
|
||||
},
|
||||
"required": ["url"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
requires_confirmation=False,
|
||||
handler=_web_fetch,
|
||||
),
|
||||
}
|
||||
23
domain/agent/types.py
Normal file
23
domain/agent/types.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class AgentChatContext(BaseModel):
|
||||
current_path: Optional[str] = None
|
||||
|
||||
|
||||
class AgentChatRequest(BaseModel):
|
||||
messages: List[Dict[str, Any]] = Field(default_factory=list)
|
||||
auto_execute: bool = False
|
||||
approved_tool_call_ids: List[str] = Field(default_factory=list)
|
||||
rejected_tool_call_ids: List[str] = Field(default_factory=list)
|
||||
context: Optional[AgentChatContext] = None
|
||||
|
||||
|
||||
class PendingToolCall(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
arguments: Dict[str, Any] = Field(default_factory=dict)
|
||||
requires_confirmation: bool = True
|
||||
|
||||
@@ -1,28 +1,61 @@
|
||||
from .api import router_ai, router_vector_db
|
||||
from .inference import (
|
||||
MissingModelError,
|
||||
chat_completion,
|
||||
chat_completion_stream,
|
||||
describe_image_base64,
|
||||
get_text_embedding,
|
||||
provider_service,
|
||||
rerank_texts,
|
||||
)
|
||||
from .service import (
|
||||
AIProviderService,
|
||||
FILE_COLLECTION_NAME,
|
||||
VECTOR_COLLECTION_NAME,
|
||||
DEFAULT_VECTOR_DIMENSION,
|
||||
VectorDBConfigManager,
|
||||
VectorDBService,
|
||||
DEFAULT_VECTOR_DIMENSION,
|
||||
ABILITIES,
|
||||
normalize_capabilities,
|
||||
)
|
||||
from .types import (
|
||||
ABILITIES,
|
||||
AIDefaultsUpdate,
|
||||
AIModelCreate,
|
||||
AIModelUpdate,
|
||||
AIProviderCreate,
|
||||
AIProviderUpdate,
|
||||
VectorDBConfigPayload,
|
||||
normalize_capabilities,
|
||||
)
|
||||
from .vector_providers import (
|
||||
BaseVectorProvider,
|
||||
MilvusLiteProvider,
|
||||
MilvusServerProvider,
|
||||
QdrantProvider,
|
||||
get_provider_class,
|
||||
get_provider_entry,
|
||||
list_providers,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"router_ai",
|
||||
"router_vector_db",
|
||||
"MissingModelError",
|
||||
"chat_completion",
|
||||
"chat_completion_stream",
|
||||
"describe_image_base64",
|
||||
"get_text_embedding",
|
||||
"provider_service",
|
||||
"rerank_texts",
|
||||
"AIProviderService",
|
||||
"VectorDBService",
|
||||
"VectorDBConfigManager",
|
||||
"DEFAULT_VECTOR_DIMENSION",
|
||||
"VECTOR_COLLECTION_NAME",
|
||||
"FILE_COLLECTION_NAME",
|
||||
"BaseVectorProvider",
|
||||
"MilvusLiteProvider",
|
||||
"MilvusServerProvider",
|
||||
"QdrantProvider",
|
||||
"list_providers",
|
||||
"get_provider_entry",
|
||||
"get_provider_class",
|
||||
"ABILITIES",
|
||||
"normalize_capabilities",
|
||||
"AIDefaultsUpdate",
|
||||
|
||||
@@ -5,8 +5,9 @@ from fastapi import APIRouter, Depends, HTTPException, Path, Request
|
||||
|
||||
from api.response import success
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.ai.service import AIProviderService, VectorDBConfigManager, VectorDBService
|
||||
from domain.ai.types import (
|
||||
from domain.auth import User, get_current_active_user
|
||||
from .service import AIProviderService, VectorDBConfigManager, VectorDBService
|
||||
from .types import (
|
||||
AIDefaultsUpdate,
|
||||
AIModelCreate,
|
||||
AIModelUpdate,
|
||||
@@ -14,9 +15,7 @@ from domain.ai.types import (
|
||||
AIProviderUpdate,
|
||||
VectorDBConfigPayload,
|
||||
)
|
||||
from domain.ai.vector_providers import get_provider_class, get_provider_entry, list_providers
|
||||
from domain.auth.service import get_current_active_user
|
||||
from domain.auth.types import User
|
||||
from .vector_providers import get_provider_class, get_provider_entry, list_providers
|
||||
|
||||
router_ai = APIRouter(prefix="/api/ai", tags=["ai"])
|
||||
router_vector_db = APIRouter(prefix="/api/vector-db", tags=["vector-db"])
|
||||
@@ -251,7 +250,7 @@ async def get_vector_db_stats(request: Request, user: User = Depends(get_current
|
||||
|
||||
@audit(action=AuditAction.READ, description="获取向量数据库提供者列表")
|
||||
@router_vector_db.get("/providers", summary="列出可用向量数据库提供者")
|
||||
async def list_vector_providers(request: Request, user: User = Depends(get_current_active_user)):
|
||||
async def list_vector_providers(request: Request):
|
||||
return success(list_providers())
|
||||
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -7,7 +7,7 @@ import httpx
|
||||
from tortoise.exceptions import DoesNotExist
|
||||
from tortoise.transactions import in_transaction
|
||||
|
||||
from domain.config.service import ConfigService
|
||||
from domain.config import ConfigService
|
||||
from models.database import AIDefaultModel, AIModel, AIProvider
|
||||
|
||||
from .types import ABILITIES, normalize_capabilities
|
||||
@@ -140,7 +140,7 @@ def serialize_provider(provider: AIProvider) -> Dict[str, Any]:
|
||||
"provider_type": provider.provider_type,
|
||||
"api_format": provider.api_format,
|
||||
"base_url": provider.base_url,
|
||||
"api_key": provider.api_key,
|
||||
"has_api_key": bool(provider.api_key),
|
||||
"logo_url": provider.logo_url,
|
||||
"extra_config": provider.extra_config or {},
|
||||
"created_at": provider.created_at,
|
||||
|
||||
@@ -30,8 +30,8 @@ class AIProviderBase(BaseModel):
|
||||
@classmethod
|
||||
def normalize_format(cls, value: str) -> str:
|
||||
fmt = value.lower()
|
||||
if fmt not in {"openai", "gemini"}:
|
||||
raise ValueError("api_format must be 'openai' or 'gemini'")
|
||||
if fmt not in {"openai", "gemini", "anthropic", "ollama"}:
|
||||
raise ValueError("api_format must be 'openai', 'gemini', 'anthropic', or 'ollama'")
|
||||
return fmt
|
||||
|
||||
|
||||
@@ -54,8 +54,8 @@ class AIProviderUpdate(BaseModel):
|
||||
if value is None:
|
||||
return value
|
||||
fmt = value.lower()
|
||||
if fmt not in {"openai", "gemini"}:
|
||||
raise ValueError("api_format must be 'openai' or 'gemini'")
|
||||
if fmt not in {"openai", "gemini", "anthropic", "ollama"}:
|
||||
raise ValueError("api_format must be 'openai', 'gemini', 'anthropic', or 'ollama'")
|
||||
return fmt
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from domain.audit.decorator import audit
|
||||
from domain.audit.types import AuditAction
|
||||
from domain.audit.api import router
|
||||
from .decorator import audit
|
||||
from .types import AuditAction
|
||||
|
||||
__all__ = ["audit", "AuditAction", "router"]
|
||||
__all__ = ["audit", "AuditAction"]
|
||||
|
||||
@@ -4,10 +4,11 @@ from typing import Annotated, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
|
||||
from api import response
|
||||
from domain.audit.service import AuditService
|
||||
from domain.audit.types import AuditAction
|
||||
from domain.auth.service import get_current_active_user
|
||||
from domain.auth.types import User
|
||||
from domain.auth import User, get_current_active_user
|
||||
from domain.permission import require_system_permission
|
||||
from domain.permission.types import SystemPermission
|
||||
from .service import AuditService
|
||||
from .types import AuditAction
|
||||
|
||||
CurrentUser = Annotated[User, Depends(get_current_active_user)]
|
||||
|
||||
@@ -28,6 +29,7 @@ def _parse_iso(value: Optional[str], field: str):
|
||||
|
||||
|
||||
@router.get("/logs")
|
||||
@require_system_permission(SystemPermission.AUDIT_VIEW)
|
||||
async def list_audit_logs(
|
||||
current_user: CurrentUser,
|
||||
page_num: int = Query(1, ge=1, alias="page", description="页码"),
|
||||
@@ -55,6 +57,7 @@ async def list_audit_logs(
|
||||
|
||||
|
||||
@router.delete("/logs")
|
||||
@require_system_permission(SystemPermission.AUDIT_VIEW)
|
||||
async def clear_audit_logs(
|
||||
current_user: CurrentUser,
|
||||
start_time: str | None = Query(None, description="开始时间 (ISO 8601)"),
|
||||
|
||||
@@ -7,11 +7,11 @@ import jwt
|
||||
from fastapi import Request
|
||||
from jwt.exceptions import InvalidTokenError
|
||||
|
||||
from domain.audit.service import AuditService
|
||||
from domain.audit.types import AuditAction
|
||||
from domain.auth.service import ALGORITHM
|
||||
from domain.config.service import ConfigService
|
||||
from domain.auth import ALGORITHM
|
||||
from domain.config import ConfigService
|
||||
from models.database import UserAccount
|
||||
from .service import AuditService
|
||||
from .types import AuditAction
|
||||
|
||||
|
||||
def _extract_request(bound_args: Mapping[str, Any]) -> Request | None:
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import Any, Dict, Optional
|
||||
|
||||
from models.database import AuditLog
|
||||
|
||||
from domain.audit.types import AuditAction
|
||||
from .types import AuditAction
|
||||
|
||||
|
||||
class AuditService:
|
||||
|
||||
49
domain/auth/__init__.py
Normal file
49
domain/auth/__init__.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from .service import (
|
||||
ALGORITHM,
|
||||
AuthService,
|
||||
authenticate_user_db,
|
||||
create_access_token,
|
||||
get_current_active_user,
|
||||
get_current_user,
|
||||
get_password_hash,
|
||||
has_users,
|
||||
register_user,
|
||||
request_password_reset,
|
||||
reset_password_with_token,
|
||||
verify_password,
|
||||
verify_password_reset_token,
|
||||
)
|
||||
from .types import (
|
||||
PasswordResetConfirm,
|
||||
PasswordResetRequest,
|
||||
RegisterRequest,
|
||||
Token,
|
||||
TokenData,
|
||||
UpdateMeRequest,
|
||||
User,
|
||||
UserInDB,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ALGORITHM",
|
||||
"AuthService",
|
||||
"authenticate_user_db",
|
||||
"create_access_token",
|
||||
"get_current_active_user",
|
||||
"get_current_user",
|
||||
"get_password_hash",
|
||||
"has_users",
|
||||
"register_user",
|
||||
"request_password_reset",
|
||||
"reset_password_with_token",
|
||||
"verify_password",
|
||||
"verify_password_reset_token",
|
||||
"PasswordResetConfirm",
|
||||
"PasswordResetRequest",
|
||||
"RegisterRequest",
|
||||
"Token",
|
||||
"TokenData",
|
||||
"UpdateMeRequest",
|
||||
"User",
|
||||
"UserInDB",
|
||||
]
|
||||
@@ -5,8 +5,8 @@ from fastapi.security import OAuth2PasswordRequestForm
|
||||
|
||||
from api.response import success
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.auth.service import AuthService, get_current_active_user
|
||||
from domain.auth.types import (
|
||||
from .service import AuthService, get_current_active_user
|
||||
from .types import (
|
||||
PasswordResetConfirm,
|
||||
PasswordResetRequest,
|
||||
RegisterRequest,
|
||||
@@ -18,16 +18,16 @@ from domain.auth.types import (
|
||||
router = APIRouter(prefix="/api/auth", tags=["auth"])
|
||||
|
||||
|
||||
@router.post("/register", summary="注册第一个管理员用户")
|
||||
@router.post("/register", summary="注册用户(首个用户为管理员)")
|
||||
@audit(
|
||||
action=AuditAction.REGISTER,
|
||||
description="注册管理员",
|
||||
description="注册用户",
|
||||
body_fields=["username", "email", "full_name"],
|
||||
redact_fields=["password"],
|
||||
)
|
||||
async def register(request: Request, data: RegisterRequest):
|
||||
user = await AuthService.register_user(data)
|
||||
return success({"username": user.username}, msg="初始用户注册成功")
|
||||
return success({"username": user.username}, msg="注册成功")
|
||||
|
||||
|
||||
@router.post("/login")
|
||||
|
||||
@@ -11,7 +11,9 @@ from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
||||
from jwt.exceptions import InvalidTokenError
|
||||
|
||||
from domain.auth.types import (
|
||||
from domain.config import ConfigService
|
||||
from models.database import Role, UserAccount, UserRole
|
||||
from .types import (
|
||||
PasswordResetConfirm,
|
||||
PasswordResetRequest,
|
||||
RegisterRequest,
|
||||
@@ -21,8 +23,6 @@ from domain.auth.types import (
|
||||
User,
|
||||
UserInDB,
|
||||
)
|
||||
from models.database import UserAccount
|
||||
from domain.config.service import ConfigService
|
||||
|
||||
ALGORITHM = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 * 365
|
||||
@@ -140,6 +140,7 @@ class AuthService:
|
||||
email=user.email,
|
||||
full_name=user.full_name,
|
||||
disabled=user.disabled,
|
||||
is_admin=user.is_admin,
|
||||
hashed_password=user.hashed_password,
|
||||
)
|
||||
return None
|
||||
@@ -160,19 +161,60 @@ class AuthService:
|
||||
|
||||
@classmethod
|
||||
async def register_user(cls, payload: RegisterRequest):
|
||||
if await cls.has_users():
|
||||
raise HTTPException(status_code=403, detail="系统已初始化,不允许注册新用户")
|
||||
has_users = await cls.has_users()
|
||||
normalized_email = cls._normalize_email(payload.email)
|
||||
if not normalized_email:
|
||||
raise HTTPException(status_code=400, detail="邮箱不能为空")
|
||||
|
||||
if has_users:
|
||||
allow_register = str(await ConfigService.get("AUTH_ALLOW_REGISTER", "false") or "").strip().lower()
|
||||
if allow_register not in ("1", "true", "yes", "on"):
|
||||
raise HTTPException(status_code=403, detail="系统未开放注册")
|
||||
|
||||
default_role_id_raw = str(await ConfigService.get("AUTH_DEFAULT_REGISTER_ROLE_ID", "") or "").strip()
|
||||
if not default_role_id_raw:
|
||||
raise HTTPException(status_code=400, detail="未配置默认注册角色")
|
||||
try:
|
||||
default_role_id = int(default_role_id_raw)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail="默认注册角色配置错误") from exc
|
||||
|
||||
role = await Role.get_or_none(id=default_role_id)
|
||||
if not role:
|
||||
raise HTTPException(status_code=400, detail="默认注册角色不存在")
|
||||
|
||||
exists = await UserAccount.get_or_none(username=payload.username)
|
||||
if exists:
|
||||
raise HTTPException(status_code=400, detail="用户名已存在")
|
||||
|
||||
existing_email = await UserAccount.get_or_none(email=normalized_email)
|
||||
if existing_email:
|
||||
raise HTTPException(status_code=400, detail="邮箱已被使用")
|
||||
|
||||
hashed = cls.get_password_hash(payload.password)
|
||||
|
||||
# 第一个用户自动成为超级管理员(不受开放注册开关影响)
|
||||
if not has_users:
|
||||
user = await UserAccount.create(
|
||||
username=payload.username,
|
||||
email=normalized_email,
|
||||
full_name=payload.full_name,
|
||||
hashed_password=hashed,
|
||||
disabled=False,
|
||||
is_admin=True,
|
||||
)
|
||||
return user
|
||||
|
||||
# 系统已初始化:按默认角色创建普通用户
|
||||
user = await UserAccount.create(
|
||||
username=payload.username,
|
||||
email=payload.email,
|
||||
email=normalized_email,
|
||||
full_name=payload.full_name,
|
||||
hashed_password=hashed,
|
||||
disabled=False,
|
||||
is_admin=False,
|
||||
)
|
||||
await UserRole.create(user_id=user.id, role_id=default_role_id)
|
||||
return user
|
||||
|
||||
@classmethod
|
||||
@@ -195,6 +237,13 @@ class AuthService:
|
||||
detail="用户名或密码错误",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
# 更新最后登录时间
|
||||
db_user = await UserAccount.get_or_none(id=user.id)
|
||||
if db_user:
|
||||
db_user.last_login = _now()
|
||||
await db_user.save(update_fields=["last_login"])
|
||||
|
||||
access_token_expires = timedelta(minutes=cls.access_token_expire_minutes)
|
||||
access_token = await cls.create_access_token(
|
||||
data={"sub": user.username}, expires_delta=access_token_expires
|
||||
@@ -212,6 +261,7 @@ class AuthService:
|
||||
"email": getattr(user, "email", None),
|
||||
"full_name": getattr(user, "full_name", None),
|
||||
"gravatar_url": gravatar_url,
|
||||
"is_admin": getattr(user, "is_admin", False),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@@ -324,7 +374,7 @@ class AuthService:
|
||||
|
||||
@classmethod
|
||||
async def _send_password_reset_email(cls, user: UserAccount, token: str) -> None:
|
||||
from domain.email.service import EmailService
|
||||
from domain.email import EmailService
|
||||
|
||||
app_domain = await ConfigService.get("APP_DOMAIN", None)
|
||||
base_url = (app_domain or "http://localhost:5173").rstrip("/")
|
||||
|
||||
@@ -16,6 +16,7 @@ class User(BaseModel):
|
||||
email: str | None = None
|
||||
full_name: str | None = None
|
||||
disabled: bool | None = None
|
||||
is_admin: bool = False
|
||||
|
||||
|
||||
class UserInDB(User):
|
||||
@@ -25,7 +26,7 @@ class UserInDB(User):
|
||||
class RegisterRequest(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
email: str | None = None
|
||||
email: str
|
||||
full_name: str | None = None
|
||||
|
||||
|
||||
|
||||
@@ -1 +1,7 @@
|
||||
from .service import BackupService
|
||||
from .types import BackupData
|
||||
|
||||
__all__ = [
|
||||
"BackupService",
|
||||
"BackupData",
|
||||
]
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
import datetime
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, File, Request, UploadFile
|
||||
from fastapi import APIRouter, Depends, File, Form, Query, Request, UploadFile
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.auth.service import get_current_active_user
|
||||
from domain.backup.service import BackupService
|
||||
from domain.auth import User, get_current_active_user
|
||||
from domain.permission import require_system_permission
|
||||
from domain.permission.types import SystemPermission
|
||||
from .service import BackupService
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api/backup",
|
||||
@@ -16,8 +19,13 @@ router = APIRouter(
|
||||
|
||||
@router.get("/export", summary="导出全站数据")
|
||||
@audit(action=AuditAction.DOWNLOAD, description="导出备份")
|
||||
async def export_backup(request: Request):
|
||||
data = await BackupService.export_data()
|
||||
@require_system_permission(SystemPermission.CONFIG_EDIT)
|
||||
async def export_backup(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
sections: list[str] | None = Query(default=None),
|
||||
):
|
||||
data = await BackupService.export_data(sections=sections)
|
||||
timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
headers = {"Content-Disposition": f"attachment; filename=foxel_backup_{timestamp}.json"}
|
||||
return JSONResponse(content=data.model_dump(), headers=headers)
|
||||
@@ -25,6 +33,12 @@ async def export_backup(request: Request):
|
||||
|
||||
@router.post("/import", summary="导入数据")
|
||||
@audit(action=AuditAction.UPLOAD, description="导入备份")
|
||||
async def import_backup(request: Request, file: UploadFile = File(...)):
|
||||
await BackupService.import_from_bytes(file.filename, await file.read())
|
||||
@require_system_permission(SystemPermission.CONFIG_EDIT)
|
||||
async def import_backup(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
file: UploadFile = File(...),
|
||||
mode: str = Form("replace"),
|
||||
):
|
||||
await BackupService.import_from_bytes(file.filename, await file.read(), mode=mode)
|
||||
return {"message": "数据导入成功。"}
|
||||
|
||||
@@ -4,8 +4,8 @@ from datetime import datetime
|
||||
from fastapi import HTTPException
|
||||
from tortoise.transactions import in_transaction
|
||||
|
||||
from domain.backup.types import BackupData
|
||||
from domain.config.service import VERSION
|
||||
from domain.config import VERSION
|
||||
from .types import BackupData
|
||||
from models.database import (
|
||||
AIDefaultModel,
|
||||
AIModel,
|
||||
@@ -20,18 +20,64 @@ from models.database import (
|
||||
|
||||
|
||||
class BackupService:
|
||||
ALL_SECTIONS = (
|
||||
"storage_adapters",
|
||||
"user_accounts",
|
||||
"automation_tasks",
|
||||
"share_links",
|
||||
"configurations",
|
||||
"ai_providers",
|
||||
"ai_models",
|
||||
"ai_default_models",
|
||||
"plugins",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def export_data(cls) -> BackupData:
|
||||
async def export_data(cls, sections: list[str] | None = None) -> BackupData:
|
||||
sections = cls._normalize_sections(sections)
|
||||
section_set = set(sections)
|
||||
async with in_transaction():
|
||||
adapters = await StorageAdapter.all().values()
|
||||
users = await UserAccount.all().values()
|
||||
tasks = await AutomationTask.all().values()
|
||||
shares = await ShareLink.all().values()
|
||||
configs = await Configuration.all().values()
|
||||
providers = await AIProvider.all().values()
|
||||
models = await AIModel.all().values()
|
||||
default_models = await AIDefaultModel.all().values()
|
||||
plugins = await Plugin.all().values()
|
||||
adapters = (
|
||||
await StorageAdapter.all().values()
|
||||
if "storage_adapters" in section_set
|
||||
else []
|
||||
)
|
||||
users = (
|
||||
await UserAccount.all().values()
|
||||
if "user_accounts" in section_set
|
||||
else []
|
||||
)
|
||||
tasks = (
|
||||
await AutomationTask.all().values()
|
||||
if "automation_tasks" in section_set
|
||||
else []
|
||||
)
|
||||
shares = (
|
||||
await ShareLink.all().values()
|
||||
if "share_links" in section_set
|
||||
else []
|
||||
)
|
||||
configs = (
|
||||
await Configuration.all().values()
|
||||
if "configurations" in section_set
|
||||
else []
|
||||
)
|
||||
providers = (
|
||||
await AIProvider.all().values()
|
||||
if "ai_providers" in section_set
|
||||
else []
|
||||
)
|
||||
models = (
|
||||
await AIModel.all().values() if "ai_models" in section_set else []
|
||||
)
|
||||
default_models = (
|
||||
await AIDefaultModel.all().values()
|
||||
if "ai_default_models" in section_set
|
||||
else []
|
||||
)
|
||||
plugins = (
|
||||
await Plugin.all().values() if "plugins" in section_set else []
|
||||
)
|
||||
|
||||
share_links = cls._serialize_datetime_fields(
|
||||
shares, ["created_at", "expires_at"]
|
||||
@@ -51,6 +97,7 @@ class BackupService:
|
||||
|
||||
return BackupData(
|
||||
version=VERSION,
|
||||
sections=sections,
|
||||
storage_adapters=list(adapters),
|
||||
user_accounts=list(users),
|
||||
automation_tasks=list(tasks),
|
||||
@@ -63,106 +110,195 @@ class BackupService:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def import_from_bytes(cls, filename: str, content: bytes) -> None:
|
||||
async def import_from_bytes(
|
||||
cls, filename: str, content: bytes, mode: str = "replace"
|
||||
) -> None:
|
||||
if not filename.endswith(".json"):
|
||||
raise HTTPException(status_code=400, detail="无效的文件类型, 请上传 .json 文件")
|
||||
try:
|
||||
raw_data = json.loads(content)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=400, detail="无法解析JSON文件")
|
||||
await cls.import_data(BackupData(**raw_data))
|
||||
await cls.import_data(BackupData(**raw_data), mode=mode)
|
||||
|
||||
@classmethod
|
||||
async def import_data(cls, payload: BackupData) -> None:
|
||||
async def import_data(cls, payload: BackupData, mode: str = "replace") -> None:
|
||||
sections = cls._normalize_sections(payload.sections)
|
||||
if mode not in {"replace", "merge"}:
|
||||
raise HTTPException(status_code=400, detail="无效的导入模式")
|
||||
|
||||
share_links = (
|
||||
cls._parse_datetime_fields(payload.share_links, ["created_at", "expires_at"])
|
||||
if payload.share_links
|
||||
else []
|
||||
)
|
||||
ai_providers = (
|
||||
cls._parse_datetime_fields(payload.ai_providers, ["created_at", "updated_at"])
|
||||
if payload.ai_providers
|
||||
else []
|
||||
)
|
||||
ai_models = (
|
||||
cls._parse_datetime_fields(payload.ai_models, ["created_at", "updated_at"])
|
||||
if payload.ai_models
|
||||
else []
|
||||
)
|
||||
ai_default_models = (
|
||||
cls._parse_datetime_fields(
|
||||
payload.ai_default_models, ["created_at", "updated_at"]
|
||||
)
|
||||
if payload.ai_default_models
|
||||
else []
|
||||
)
|
||||
plugins = (
|
||||
cls._parse_datetime_fields(payload.plugins, ["created_at", "updated_at"])
|
||||
if payload.plugins
|
||||
else []
|
||||
)
|
||||
|
||||
async with in_transaction() as conn:
|
||||
await ShareLink.all().using_db(conn).delete()
|
||||
await AutomationTask.all().using_db(conn).delete()
|
||||
await StorageAdapter.all().using_db(conn).delete()
|
||||
await UserAccount.all().using_db(conn).delete()
|
||||
await Configuration.all().using_db(conn).delete()
|
||||
await AIDefaultModel.all().using_db(conn).delete()
|
||||
await AIModel.all().using_db(conn).delete()
|
||||
await AIProvider.all().using_db(conn).delete()
|
||||
await Plugin.all().using_db(conn).delete()
|
||||
if mode == "replace":
|
||||
if "share_links" in sections:
|
||||
await ShareLink.all().using_db(conn).delete()
|
||||
if "automation_tasks" in sections:
|
||||
await AutomationTask.all().using_db(conn).delete()
|
||||
if "storage_adapters" in sections:
|
||||
await StorageAdapter.all().using_db(conn).delete()
|
||||
if "user_accounts" in sections:
|
||||
await UserAccount.all().using_db(conn).delete()
|
||||
if "configurations" in sections:
|
||||
await Configuration.all().using_db(conn).delete()
|
||||
if "ai_default_models" in sections:
|
||||
await AIDefaultModel.all().using_db(conn).delete()
|
||||
if "ai_models" in sections:
|
||||
await AIModel.all().using_db(conn).delete()
|
||||
if "ai_providers" in sections:
|
||||
await AIProvider.all().using_db(conn).delete()
|
||||
if "plugins" in sections:
|
||||
await Plugin.all().using_db(conn).delete()
|
||||
|
||||
if payload.configurations:
|
||||
await Configuration.bulk_create(
|
||||
[Configuration(**config) for config in payload.configurations],
|
||||
using_db=conn,
|
||||
)
|
||||
if "configurations" in sections and payload.configurations:
|
||||
if mode == "merge":
|
||||
await cls._merge_records(
|
||||
Configuration, payload.configurations, conn
|
||||
)
|
||||
else:
|
||||
await Configuration.bulk_create(
|
||||
[Configuration(**config) for config in payload.configurations],
|
||||
using_db=conn,
|
||||
)
|
||||
|
||||
if payload.user_accounts:
|
||||
await UserAccount.bulk_create(
|
||||
[UserAccount(**user) for user in payload.user_accounts],
|
||||
using_db=conn,
|
||||
)
|
||||
if "user_accounts" in sections and payload.user_accounts:
|
||||
if mode == "merge":
|
||||
await cls._merge_records(UserAccount, payload.user_accounts, conn)
|
||||
else:
|
||||
await UserAccount.bulk_create(
|
||||
[UserAccount(**user) for user in payload.user_accounts],
|
||||
using_db=conn,
|
||||
)
|
||||
|
||||
if payload.storage_adapters:
|
||||
await StorageAdapter.bulk_create(
|
||||
[StorageAdapter(**adapter) for adapter in payload.storage_adapters],
|
||||
using_db=conn,
|
||||
)
|
||||
if "storage_adapters" in sections and payload.storage_adapters:
|
||||
if mode == "merge":
|
||||
await cls._merge_records(
|
||||
StorageAdapter, payload.storage_adapters, conn
|
||||
)
|
||||
else:
|
||||
await StorageAdapter.bulk_create(
|
||||
[StorageAdapter(**adapter) for adapter in payload.storage_adapters],
|
||||
using_db=conn,
|
||||
)
|
||||
|
||||
if payload.automation_tasks:
|
||||
await AutomationTask.bulk_create(
|
||||
[AutomationTask(**task) for task in payload.automation_tasks],
|
||||
using_db=conn,
|
||||
)
|
||||
if "automation_tasks" in sections and payload.automation_tasks:
|
||||
if mode == "merge":
|
||||
await cls._merge_records(
|
||||
AutomationTask, payload.automation_tasks, conn
|
||||
)
|
||||
else:
|
||||
await AutomationTask.bulk_create(
|
||||
[AutomationTask(**task) for task in payload.automation_tasks],
|
||||
using_db=conn,
|
||||
)
|
||||
|
||||
if payload.share_links:
|
||||
await ShareLink.bulk_create(
|
||||
[
|
||||
ShareLink(**share)
|
||||
for share in cls._parse_datetime_fields(
|
||||
payload.share_links, ["created_at", "expires_at"]
|
||||
)
|
||||
],
|
||||
using_db=conn,
|
||||
)
|
||||
if "share_links" in sections and share_links:
|
||||
if mode == "merge":
|
||||
await cls._merge_records(ShareLink, share_links, conn)
|
||||
else:
|
||||
await ShareLink.bulk_create(
|
||||
[ShareLink(**share) for share in share_links],
|
||||
using_db=conn,
|
||||
)
|
||||
|
||||
if payload.ai_providers:
|
||||
await AIProvider.bulk_create(
|
||||
[
|
||||
AIProvider(**item)
|
||||
for item in cls._parse_datetime_fields(
|
||||
payload.ai_providers, ["created_at", "updated_at"]
|
||||
)
|
||||
],
|
||||
using_db=conn,
|
||||
)
|
||||
if "ai_providers" in sections and ai_providers:
|
||||
if mode == "merge":
|
||||
await cls._merge_records(AIProvider, ai_providers, conn)
|
||||
else:
|
||||
await AIProvider.bulk_create(
|
||||
[AIProvider(**item) for item in ai_providers],
|
||||
using_db=conn,
|
||||
)
|
||||
|
||||
if payload.ai_models:
|
||||
await AIModel.bulk_create(
|
||||
[
|
||||
AIModel(**item)
|
||||
for item in cls._parse_datetime_fields(
|
||||
payload.ai_models, ["created_at", "updated_at"]
|
||||
)
|
||||
],
|
||||
using_db=conn,
|
||||
)
|
||||
if "ai_models" in sections and ai_models:
|
||||
if mode == "merge":
|
||||
await cls._merge_records(AIModel, ai_models, conn)
|
||||
else:
|
||||
await AIModel.bulk_create(
|
||||
[AIModel(**item) for item in ai_models],
|
||||
using_db=conn,
|
||||
)
|
||||
|
||||
if payload.ai_default_models:
|
||||
await AIDefaultModel.bulk_create(
|
||||
[
|
||||
AIDefaultModel(**item)
|
||||
for item in cls._parse_datetime_fields(
|
||||
payload.ai_default_models, ["created_at", "updated_at"]
|
||||
)
|
||||
],
|
||||
using_db=conn,
|
||||
)
|
||||
if "ai_default_models" in sections and ai_default_models:
|
||||
if mode == "merge":
|
||||
await cls._merge_records(
|
||||
AIDefaultModel, ai_default_models, conn
|
||||
)
|
||||
else:
|
||||
await AIDefaultModel.bulk_create(
|
||||
[AIDefaultModel(**item) for item in ai_default_models],
|
||||
using_db=conn,
|
||||
)
|
||||
|
||||
if payload.plugins:
|
||||
await Plugin.bulk_create(
|
||||
[
|
||||
Plugin(**item)
|
||||
for item in cls._parse_datetime_fields(
|
||||
payload.plugins, ["created_at", "updated_at"]
|
||||
)
|
||||
],
|
||||
using_db=conn,
|
||||
)
|
||||
if "plugins" in sections and plugins:
|
||||
if mode == "merge":
|
||||
await cls._merge_records(Plugin, plugins, conn)
|
||||
else:
|
||||
await Plugin.bulk_create(
|
||||
[Plugin(**item) for item in plugins],
|
||||
using_db=conn,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _normalize_sections(cls, sections: list[str] | None) -> list[str]:
|
||||
if not sections:
|
||||
return list(cls.ALL_SECTIONS)
|
||||
normalized = [item for item in sections if item]
|
||||
invalid = [item for item in normalized if item not in cls.ALL_SECTIONS]
|
||||
if invalid:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"无效的备份分区: {', '.join(invalid)}"
|
||||
)
|
||||
result: list[str] = []
|
||||
seen = set()
|
||||
for item in normalized:
|
||||
if item in seen:
|
||||
continue
|
||||
seen.add(item)
|
||||
result.append(item)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
async def _merge_records(model, records: list[dict], using_db) -> None:
|
||||
for record in records:
|
||||
data = dict(record)
|
||||
record_id = data.pop("id", None)
|
||||
if record_id is None:
|
||||
await model.create(using_db=using_db, **data)
|
||||
continue
|
||||
updated = (
|
||||
await model.filter(id=record_id)
|
||||
.using_db(using_db)
|
||||
.update(**data)
|
||||
)
|
||||
if updated == 0:
|
||||
await model.create(using_db=using_db, id=record_id, **data)
|
||||
|
||||
@staticmethod
|
||||
def _serialize_datetime_fields(
|
||||
|
||||
@@ -5,6 +5,7 @@ from pydantic import BaseModel, Field
|
||||
|
||||
class BackupData(BaseModel):
|
||||
version: str | None = None
|
||||
sections: list[str] = Field(default_factory=list)
|
||||
storage_adapters: list[dict[str, Any]] = Field(default_factory=list)
|
||||
user_accounts: list[dict[str, Any]] = Field(default_factory=list)
|
||||
automation_tasks: list[dict[str, Any]] = Field(default_factory=list)
|
||||
|
||||
10
domain/config/__init__.py
Normal file
10
domain/config/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from .service import ConfigService, VERSION
|
||||
from .types import ConfigItem, LatestVersionInfo, SystemStatus
|
||||
|
||||
__all__ = [
|
||||
"ConfigService",
|
||||
"VERSION",
|
||||
"ConfigItem",
|
||||
"LatestVersionInfo",
|
||||
"SystemStatus",
|
||||
]
|
||||
@@ -4,16 +4,26 @@ from fastapi import APIRouter, Depends, Form, Request
|
||||
|
||||
from api.response import success
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.auth.service import get_current_active_user
|
||||
from domain.auth.types import User
|
||||
from domain.config.service import ConfigService
|
||||
from domain.config.types import ConfigItem
|
||||
from domain.auth import User, get_current_active_user
|
||||
from domain.permission import require_system_permission
|
||||
from domain.permission.types import SystemPermission
|
||||
from .service import ConfigService
|
||||
from .types import ConfigItem
|
||||
|
||||
router = APIRouter(prefix="/api/config", tags=["config"])
|
||||
|
||||
PUBLIC_CONFIG_KEYS = [
|
||||
"THEME_MODE",
|
||||
"THEME_PRIMARY_COLOR",
|
||||
"THEME_BORDER_RADIUS",
|
||||
"THEME_CUSTOM_TOKENS",
|
||||
"THEME_CUSTOM_CSS",
|
||||
]
|
||||
|
||||
|
||||
@router.get("/")
|
||||
@audit(action=AuditAction.READ, description="获取配置")
|
||||
@require_system_permission(SystemPermission.CONFIG_EDIT)
|
||||
async def get_config(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
@@ -25,6 +35,7 @@ async def get_config(
|
||||
|
||||
@router.post("/")
|
||||
@audit(action=AuditAction.UPDATE, description="设置配置", body_fields=["key", "value"])
|
||||
@require_system_permission(SystemPermission.CONFIG_EDIT)
|
||||
async def set_config(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
@@ -37,6 +48,7 @@ async def set_config(
|
||||
|
||||
@router.get("/all")
|
||||
@audit(action=AuditAction.READ, description="获取全部配置")
|
||||
@require_system_permission(SystemPermission.CONFIG_EDIT)
|
||||
async def get_all_config(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
@@ -44,6 +56,18 @@ async def get_all_config(
|
||||
configs = await ConfigService.get_all()
|
||||
return success(configs)
|
||||
|
||||
@router.get("/public")
|
||||
@audit(action=AuditAction.READ, description="获取公开配置")
|
||||
async def get_public_config(
|
||||
request: Request,
|
||||
):
|
||||
data = {}
|
||||
for key in PUBLIC_CONFIG_KEYS:
|
||||
value = await ConfigService.get(key)
|
||||
if value is not None:
|
||||
data[key] = value
|
||||
return success(data)
|
||||
|
||||
|
||||
@router.get("/status")
|
||||
@audit(action=AuditAction.READ, description="获取系统状态")
|
||||
|
||||
@@ -5,12 +5,12 @@ from typing import Any, Dict, Optional
|
||||
import httpx
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from domain.config.types import LatestVersionInfo, SystemStatus
|
||||
from .types import LatestVersionInfo, SystemStatus
|
||||
from models.database import Configuration, UserAccount
|
||||
|
||||
load_dotenv(dotenv_path=".env")
|
||||
|
||||
VERSION = "v1.6.0"
|
||||
VERSION = "v1.7.4"
|
||||
|
||||
|
||||
class ConfigService:
|
||||
|
||||
20
domain/email/__init__.py
Normal file
20
domain/email/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from .service import EmailService, EmailTemplateRenderer
|
||||
from .types import (
|
||||
EmailConfig,
|
||||
EmailSecurity,
|
||||
EmailSendPayload,
|
||||
EmailTemplatePreviewPayload,
|
||||
EmailTemplateUpdate,
|
||||
EmailTestRequest,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"EmailService",
|
||||
"EmailTemplateRenderer",
|
||||
"EmailConfig",
|
||||
"EmailSecurity",
|
||||
"EmailSendPayload",
|
||||
"EmailTemplatePreviewPayload",
|
||||
"EmailTemplateUpdate",
|
||||
"EmailTestRequest",
|
||||
]
|
||||
@@ -2,10 +2,9 @@ from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
|
||||
from api.response import success
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.auth.service import get_current_active_user
|
||||
from domain.auth.types import User
|
||||
from domain.email.service import EmailService, EmailTemplateRenderer
|
||||
from domain.email.types import (
|
||||
from domain.auth import User, get_current_active_user
|
||||
from .service import EmailService, EmailTemplateRenderer
|
||||
from .types import (
|
||||
EmailTemplatePreviewPayload,
|
||||
EmailTemplateUpdate,
|
||||
EmailTestRequest,
|
||||
|
||||
@@ -7,8 +7,8 @@ from pathlib import Path
|
||||
from string import Template
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from domain.config.service import ConfigService
|
||||
from domain.email.types import EmailConfig, EmailSecurity, EmailSendPayload
|
||||
from domain.config import ConfigService
|
||||
from .types import EmailConfig, EmailSecurity, EmailSendPayload
|
||||
|
||||
|
||||
class EmailTemplateRenderer:
|
||||
@@ -104,7 +104,7 @@ class EmailService:
|
||||
template: str,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
from domain.tasks.task_queue import TaskProgress, task_queue_service
|
||||
from domain.tasks import TaskProgress, task_queue_service
|
||||
|
||||
payload = EmailSendPayload(
|
||||
recipients=recipients,
|
||||
@@ -126,7 +126,7 @@ class EmailService:
|
||||
|
||||
@classmethod
|
||||
async def send_from_task(cls, task_id: str, data: Dict[str, Any]):
|
||||
from domain.tasks.task_queue import TaskProgress, task_queue_service
|
||||
from domain.tasks import TaskProgress, task_queue_service
|
||||
|
||||
payload = EmailSendPayload(**data)
|
||||
|
||||
|
||||
7
domain/offline_downloads/__init__.py
Normal file
7
domain/offline_downloads/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from .service import OfflineDownloadService
|
||||
from .types import OfflineDownloadCreate
|
||||
|
||||
__all__ = [
|
||||
"OfflineDownloadService",
|
||||
"OfflineDownloadCreate",
|
||||
]
|
||||
@@ -4,10 +4,11 @@ from fastapi import APIRouter, Depends, Request
|
||||
|
||||
from api.response import success
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.auth.service import get_current_active_user
|
||||
from domain.auth.types import User
|
||||
from domain.offline_downloads.service import OfflineDownloadService
|
||||
from domain.offline_downloads.types import OfflineDownloadCreate
|
||||
from domain.auth import User, get_current_active_user
|
||||
from domain.permission import require_path_permission
|
||||
from domain.permission.types import PathAction
|
||||
from .service import OfflineDownloadService
|
||||
from .types import OfflineDownloadCreate
|
||||
|
||||
CurrentUser = Annotated[User, Depends(get_current_active_user)]
|
||||
|
||||
@@ -23,6 +24,7 @@ router = APIRouter(
|
||||
description="创建离线下载任务",
|
||||
body_fields=["url", "dest_dir", "filename"],
|
||||
)
|
||||
@require_path_permission(PathAction.WRITE, "payload.dest_dir")
|
||||
async def create_offline_download(request: Request, payload: OfflineDownloadCreate, current_user: CurrentUser):
|
||||
data = await OfflineDownloadService.create_download(payload, current_user)
|
||||
return success(data)
|
||||
|
||||
@@ -7,11 +7,10 @@ import aiofiles
|
||||
import aiohttp
|
||||
from fastapi import Depends, HTTPException
|
||||
|
||||
from domain.auth.service import get_current_active_user
|
||||
from domain.auth.types import User
|
||||
from domain.offline_downloads.types import OfflineDownloadCreate
|
||||
from domain.virtual_fs.service import VirtualFSService
|
||||
from domain.tasks.task_queue import Task, TaskProgress, task_queue_service
|
||||
from domain.auth import User, get_current_active_user
|
||||
from domain.tasks import Task, TaskProgress, task_queue_service
|
||||
from domain.virtual_fs import VirtualFSService
|
||||
from .types import OfflineDownloadCreate
|
||||
|
||||
|
||||
class OfflineDownloadService:
|
||||
|
||||
10
domain/permission/__init__.py
Normal file
10
domain/permission/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from .service import PermissionService
|
||||
from .matcher import PathMatcher
|
||||
from .decorator import require_path_permission, require_system_permission
|
||||
|
||||
__all__ = [
|
||||
"PermissionService",
|
||||
"PathMatcher",
|
||||
"require_system_permission",
|
||||
"require_path_permission",
|
||||
]
|
||||
41
domain/permission/api.py
Normal file
41
domain/permission/api.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from typing import Annotated
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from domain.auth.service import get_current_active_user
|
||||
from domain.auth.types import User
|
||||
from .service import PermissionService
|
||||
from .types import (
|
||||
PathPermissionCheck,
|
||||
PathPermissionResult,
|
||||
UserPermissions,
|
||||
PermissionInfo,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/api", tags=["permissions"])
|
||||
|
||||
|
||||
@router.get("/permissions", response_model=list[PermissionInfo])
|
||||
async def get_all_permissions(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
) -> list[PermissionInfo]:
|
||||
"""获取所有权限定义"""
|
||||
return await PermissionService.get_all_permissions()
|
||||
|
||||
|
||||
@router.get("/me/permissions", response_model=UserPermissions)
|
||||
async def get_my_permissions(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
) -> UserPermissions:
|
||||
"""获取当前用户的有效权限"""
|
||||
return await PermissionService.get_user_permissions(current_user.id)
|
||||
|
||||
|
||||
@router.post("/me/check-path", response_model=PathPermissionResult)
|
||||
async def check_path_permission(
|
||||
data: PathPermissionCheck,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
) -> PathPermissionResult:
|
||||
"""检查当前用户对某路径的权限"""
|
||||
return await PermissionService.check_path_permission_detailed(
|
||||
current_user.id, data.path, data.action
|
||||
)
|
||||
103
domain/permission/decorator.py
Normal file
103
domain/permission/decorator.py
Normal file
@@ -0,0 +1,103 @@
|
||||
import inspect
|
||||
from functools import wraps
|
||||
from typing import Any, Iterable, Mapping
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from .service import PermissionService
|
||||
|
||||
|
||||
def _get_user_id(user: Any) -> int | None:
|
||||
if user is None:
|
||||
return None
|
||||
if isinstance(user, Mapping):
|
||||
raw = user.get("id") or user.get("user_id")
|
||||
return int(raw) if isinstance(raw, int) else None
|
||||
value = getattr(user, "id", None) or getattr(user, "user_id", None)
|
||||
return int(value) if isinstance(value, int) else None
|
||||
|
||||
|
||||
def _resolve_expr(bound_args: Mapping[str, Any], expr: str) -> Any:
|
||||
parts = [p for p in (expr or "").split(".") if p]
|
||||
if not parts:
|
||||
return None
|
||||
cur: Any = bound_args.get(parts[0])
|
||||
for part in parts[1:]:
|
||||
if cur is None:
|
||||
return None
|
||||
if isinstance(cur, Mapping):
|
||||
cur = cur.get(part)
|
||||
else:
|
||||
cur = getattr(cur, part, None)
|
||||
return cur
|
||||
|
||||
|
||||
def require_system_permission(permission_code: str, *, user_kw: str = "current_user"):
|
||||
"""
|
||||
在 endpoint 内部执行系统/适配器权限校验。
|
||||
|
||||
设计目标:
|
||||
- 保持和当前“在函数体内手写 require_*”一致的行为:失败会被外层 @audit 捕获记录
|
||||
- 不依赖 FastAPI dependencies(避免权限失败发生在 endpoint 之外)
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
bound = inspect.signature(func).bind_partial(*args, **kwargs)
|
||||
bound.apply_defaults()
|
||||
user_id = _get_user_id(bound.arguments.get(user_kw))
|
||||
if user_id is None:
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
await PermissionService.require_system_permission(user_id, permission_code)
|
||||
|
||||
result = func(*args, **kwargs)
|
||||
if inspect.isawaitable(result):
|
||||
result = await result
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def require_path_permission(action: str, path_expr: str, *, user_kw: str = "current_user"):
|
||||
"""
|
||||
在 endpoint 内部执行路径权限校验。
|
||||
|
||||
path_expr 支持:
|
||||
- "full_path"
|
||||
- "body.src" / "body.dst"
|
||||
- "payload.paths"(list[str] 会逐个检查)
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
bound = inspect.signature(func).bind_partial(*args, **kwargs)
|
||||
bound.apply_defaults()
|
||||
user_id = _get_user_id(bound.arguments.get(user_kw))
|
||||
if user_id is None:
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
|
||||
value = _resolve_expr(bound.arguments, path_expr)
|
||||
paths: Iterable[Any]
|
||||
if isinstance(value, (list, tuple, set)):
|
||||
paths = value
|
||||
else:
|
||||
paths = [value]
|
||||
|
||||
for path in paths:
|
||||
if path is None:
|
||||
raise HTTPException(status_code=400, detail="Missing path")
|
||||
await PermissionService.require_path_permission(user_id, str(path), action)
|
||||
|
||||
result = func(*args, **kwargs)
|
||||
if inspect.isawaitable(result):
|
||||
result = await result
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
158
domain/permission/matcher.py
Normal file
158
domain/permission/matcher.py
Normal file
@@ -0,0 +1,158 @@
|
||||
import re
|
||||
import fnmatch
|
||||
from functools import lru_cache
|
||||
|
||||
|
||||
class PathMatcher:
|
||||
"""路径匹配器,支持精确匹配、通配符匹配和正则匹配"""
|
||||
|
||||
@classmethod
|
||||
def normalize_path(cls, path: str) -> str:
|
||||
"""规范化路径"""
|
||||
if not path:
|
||||
return "/"
|
||||
# 确保以 / 开头
|
||||
if not path.startswith("/"):
|
||||
path = "/" + path
|
||||
# 移除末尾的 /(除了根路径)
|
||||
if path != "/" and path.endswith("/"):
|
||||
path = path.rstrip("/")
|
||||
return path
|
||||
|
||||
@classmethod
|
||||
def get_parent_path(cls, path: str) -> str | None:
|
||||
"""获取父目录路径"""
|
||||
path = cls.normalize_path(path)
|
||||
if path == "/":
|
||||
return None
|
||||
parent = "/".join(path.rsplit("/", 1)[:-1])
|
||||
return parent if parent else "/"
|
||||
|
||||
@classmethod
|
||||
def match_pattern(cls, path: str, pattern: str, is_regex: bool = False) -> bool:
|
||||
"""
|
||||
匹配路径和模式
|
||||
|
||||
Args:
|
||||
path: 要匹配的路径
|
||||
pattern: 匹配模式
|
||||
is_regex: 是否为正则表达式
|
||||
|
||||
Returns:
|
||||
是否匹配
|
||||
"""
|
||||
path = cls.normalize_path(path)
|
||||
pattern = cls.normalize_path(pattern)
|
||||
|
||||
if is_regex:
|
||||
return cls._match_regex(path, pattern)
|
||||
else:
|
||||
return cls._match_glob(path, pattern)
|
||||
|
||||
@classmethod
|
||||
def _match_regex(cls, path: str, pattern: str) -> bool:
|
||||
"""正则表达式匹配"""
|
||||
try:
|
||||
# 限制正则表达式的复杂度,防止 ReDoS 攻击
|
||||
if len(pattern) > 500:
|
||||
return False
|
||||
regex = re.compile(pattern)
|
||||
return bool(regex.match(path))
|
||||
except re.error:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def _match_glob(cls, path: str, pattern: str) -> bool:
|
||||
"""
|
||||
通配符匹配
|
||||
|
||||
支持的语法:
|
||||
- * : 匹配单层目录中的任意字符
|
||||
- ** : 匹配任意层级目录
|
||||
- ? : 匹配单个字符
|
||||
"""
|
||||
# 精确匹配
|
||||
if pattern == path:
|
||||
return True
|
||||
|
||||
# 处理 ** 通配符
|
||||
if "**" in pattern:
|
||||
return cls._match_double_star(path, pattern)
|
||||
|
||||
# 使用 fnmatch 进行标准通配符匹配
|
||||
return fnmatch.fnmatch(path, pattern)
|
||||
|
||||
@classmethod
|
||||
def _match_double_star(cls, path: str, pattern: str) -> bool:
|
||||
"""处理 ** 通配符匹配"""
|
||||
# 将 ** 替换为特殊标记
|
||||
parts = pattern.split("**")
|
||||
|
||||
if len(parts) == 2:
|
||||
prefix, suffix = parts
|
||||
# 移除 prefix 末尾的 / 和 suffix 开头的 /
|
||||
prefix = prefix.rstrip("/") if prefix else ""
|
||||
suffix = suffix.lstrip("/") if suffix else ""
|
||||
|
||||
# 检查前缀匹配
|
||||
if prefix and not path.startswith(prefix):
|
||||
return False
|
||||
|
||||
# 如果没有后缀,只需要前缀匹配
|
||||
if not suffix:
|
||||
return True
|
||||
|
||||
# 检查后缀匹配
|
||||
remaining = path[len(prefix):].lstrip("/") if prefix else path.lstrip("/")
|
||||
|
||||
# 后缀可以出现在任意位置
|
||||
if "*" in suffix or "?" in suffix:
|
||||
# 后缀包含通配符,逐层检查
|
||||
path_parts = remaining.split("/")
|
||||
suffix_parts = suffix.split("/")
|
||||
|
||||
# 简化处理:检查路径的最后几层是否与后缀匹配
|
||||
if len(path_parts) >= len(suffix_parts):
|
||||
tail = "/".join(path_parts[-len(suffix_parts):])
|
||||
return fnmatch.fnmatch(tail, suffix)
|
||||
return False
|
||||
else:
|
||||
# 后缀是精确字符串
|
||||
return remaining.endswith(suffix) or ("/" + suffix) in remaining or remaining == suffix
|
||||
|
||||
# 多个 ** 的情况,使用简化匹配
|
||||
regex_pattern = pattern.replace("**", ".*").replace("*", "[^/]*").replace("?", ".")
|
||||
try:
|
||||
return bool(re.match(f"^{regex_pattern}$", path))
|
||||
except re.error:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_pattern_specificity(cls, pattern: str, is_regex: bool = False) -> int:
|
||||
"""
|
||||
计算模式的具体程度(用于优先级排序)
|
||||
|
||||
返回值越大表示模式越具体
|
||||
"""
|
||||
pattern = cls.normalize_path(pattern)
|
||||
|
||||
if is_regex:
|
||||
# 正则表达式具体程度较低
|
||||
return len(pattern) // 2
|
||||
|
||||
# 精确路径最具体
|
||||
if "*" not in pattern and "?" not in pattern:
|
||||
return len(pattern) * 10
|
||||
|
||||
# 计算非通配符部分的长度
|
||||
specificity = 0
|
||||
parts = pattern.split("/")
|
||||
for part in parts:
|
||||
if part == "**":
|
||||
specificity += 1
|
||||
elif "*" in part or "?" in part:
|
||||
specificity += 5
|
||||
else:
|
||||
specificity += 10
|
||||
|
||||
return specificity
|
||||
340
domain/permission/service.py
Normal file
340
domain/permission/service.py
Normal file
@@ -0,0 +1,340 @@
|
||||
from typing import List, Optional
|
||||
from fastapi import HTTPException
|
||||
|
||||
from models.database import (
|
||||
UserAccount,
|
||||
UserRole,
|
||||
RolePermission,
|
||||
PathRule,
|
||||
)
|
||||
from .matcher import PathMatcher
|
||||
from .types import (
|
||||
PathAction,
|
||||
PathRuleInfo,
|
||||
PathPermissionResult,
|
||||
UserPermissions,
|
||||
PermissionInfo,
|
||||
PERMISSION_DEFINITIONS,
|
||||
)
|
||||
|
||||
|
||||
class PermissionService:
|
||||
"""权限检查服务"""
|
||||
|
||||
# 权限检查结果缓存(简单的内存缓存)
|
||||
_cache: dict[str, tuple[bool, float]] = {}
|
||||
_cache_ttl = 300 # 5分钟缓存
|
||||
|
||||
@classmethod
|
||||
async def check_path_permission(
|
||||
cls, user_id: int, path: str, action: str
|
||||
) -> bool:
|
||||
"""
|
||||
检查用户对路径的操作权限
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
path: 要检查的路径
|
||||
action: 操作类型 (read/write/delete/share)
|
||||
|
||||
Returns:
|
||||
是否有权限
|
||||
"""
|
||||
import time
|
||||
|
||||
# 检查缓存
|
||||
cache_key = f"{user_id}:{path}:{action}"
|
||||
if cache_key in cls._cache:
|
||||
result, timestamp = cls._cache[cache_key]
|
||||
if time.time() - timestamp < cls._cache_ttl:
|
||||
return result
|
||||
|
||||
# 获取用户
|
||||
user = await UserAccount.get_or_none(id=user_id)
|
||||
if not user:
|
||||
return False
|
||||
|
||||
# 超级管理员直接放行
|
||||
if user.is_admin:
|
||||
cls._cache[cache_key] = (True, time.time())
|
||||
return True
|
||||
|
||||
# 获取用户所有角色
|
||||
user_roles = await UserRole.filter(user_id=user_id).prefetch_related("role")
|
||||
role_ids = [ur.role_id for ur in user_roles]
|
||||
|
||||
if not role_ids:
|
||||
cls._cache[cache_key] = (False, time.time())
|
||||
return False
|
||||
|
||||
# 获取所有角色的路径规则
|
||||
path_rules = await PathRule.filter(role_id__in=role_ids).order_by("-priority")
|
||||
|
||||
# 规范化路径
|
||||
normalized_path = PathMatcher.normalize_path(path)
|
||||
|
||||
# 按优先级和具体程度匹配
|
||||
result = cls._match_path_rules(normalized_path, action, list(path_rules))
|
||||
|
||||
# 如果没有匹配到规则,检查父目录(继承)
|
||||
if result is None:
|
||||
parent_path = PathMatcher.get_parent_path(normalized_path)
|
||||
if parent_path:
|
||||
result = await cls.check_path_permission(user_id, parent_path, action)
|
||||
else:
|
||||
result = False # 默认拒绝
|
||||
|
||||
cls._cache[cache_key] = (result, time.time())
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def _match_path_rules(
|
||||
cls, path: str, action: str, rules: List[PathRule]
|
||||
) -> Optional[bool]:
|
||||
"""
|
||||
匹配路径规则
|
||||
|
||||
Returns:
|
||||
True/False 表示明确的权限结果,None 表示没有匹配到规则
|
||||
"""
|
||||
# 按优先级和具体程度排序
|
||||
sorted_rules = sorted(
|
||||
rules,
|
||||
key=lambda r: (
|
||||
r.priority,
|
||||
PathMatcher.get_pattern_specificity(r.path_pattern, r.is_regex),
|
||||
),
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
for rule in sorted_rules:
|
||||
if PathMatcher.match_pattern(path, rule.path_pattern, rule.is_regex):
|
||||
# 匹配到规则,检查具体操作权限
|
||||
if action == PathAction.READ:
|
||||
return rule.can_read
|
||||
elif action == PathAction.WRITE:
|
||||
return rule.can_write
|
||||
elif action == PathAction.DELETE:
|
||||
return rule.can_delete
|
||||
elif action == PathAction.SHARE:
|
||||
return rule.can_share
|
||||
else:
|
||||
return False
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
async def check_system_permission(cls, user_id: int, permission_code: str) -> bool:
|
||||
"""检查用户的系统/适配器权限"""
|
||||
# 获取用户
|
||||
user = await UserAccount.get_or_none(id=user_id)
|
||||
if not user:
|
||||
return False
|
||||
|
||||
# 超级管理员直接放行
|
||||
if user.is_admin:
|
||||
return True
|
||||
|
||||
# 获取用户所有角色
|
||||
user_roles = await UserRole.filter(user_id=user_id)
|
||||
role_ids = [ur.role_id for ur in user_roles]
|
||||
|
||||
if not role_ids:
|
||||
return False
|
||||
|
||||
role_permission = await RolePermission.filter(
|
||||
role_id__in=role_ids, permission_code=permission_code
|
||||
).first()
|
||||
|
||||
return role_permission is not None
|
||||
|
||||
@classmethod
|
||||
async def require_path_permission(
|
||||
cls, user_id: int, path: str, action: str
|
||||
) -> None:
|
||||
"""要求用户具有路径权限,否则抛出 403"""
|
||||
if not await cls.check_path_permission(user_id, path, action):
|
||||
raise HTTPException(403, detail=f"没有权限执行此操作: {action}")
|
||||
|
||||
@classmethod
|
||||
async def require_system_permission(
|
||||
cls, user_id: int, permission_code: str
|
||||
) -> None:
|
||||
"""要求用户具有系统权限,否则抛出 403"""
|
||||
if not await cls.check_system_permission(user_id, permission_code):
|
||||
raise HTTPException(403, detail=f"没有权限: {permission_code}")
|
||||
|
||||
@classmethod
|
||||
async def get_user_permissions(cls, user_id: int) -> UserPermissions:
|
||||
"""获取用户的所有权限"""
|
||||
user = await UserAccount.get_or_none(id=user_id)
|
||||
if not user:
|
||||
raise HTTPException(404, detail="用户不存在")
|
||||
|
||||
# 超级管理员拥有所有权限
|
||||
if user.is_admin:
|
||||
all_permission_codes = [item["code"] for item in PERMISSION_DEFINITIONS]
|
||||
all_path_rules = await PathRule.all()
|
||||
return UserPermissions(
|
||||
user_id=user_id,
|
||||
is_admin=True,
|
||||
permissions=all_permission_codes,
|
||||
path_rules=[
|
||||
PathRuleInfo(
|
||||
id=r.id,
|
||||
role_id=r.role_id,
|
||||
path_pattern=r.path_pattern,
|
||||
is_regex=r.is_regex,
|
||||
can_read=r.can_read,
|
||||
can_write=r.can_write,
|
||||
can_delete=r.can_delete,
|
||||
can_share=r.can_share,
|
||||
priority=r.priority,
|
||||
created_at=r.created_at,
|
||||
)
|
||||
for r in all_path_rules
|
||||
],
|
||||
)
|
||||
|
||||
# 获取用户角色
|
||||
user_roles = await UserRole.filter(user_id=user_id)
|
||||
role_ids = [ur.role_id for ur in user_roles]
|
||||
|
||||
# 获取权限
|
||||
permissions = []
|
||||
if role_ids:
|
||||
role_permissions = await RolePermission.filter(role_id__in=role_ids)
|
||||
permissions = sorted(set(rp.permission_code for rp in role_permissions))
|
||||
|
||||
# 获取路径规则
|
||||
path_rules = []
|
||||
if role_ids:
|
||||
rules = await PathRule.filter(role_id__in=role_ids)
|
||||
path_rules = [
|
||||
PathRuleInfo(
|
||||
id=r.id,
|
||||
role_id=r.role_id,
|
||||
path_pattern=r.path_pattern,
|
||||
is_regex=r.is_regex,
|
||||
can_read=r.can_read,
|
||||
can_write=r.can_write,
|
||||
can_delete=r.can_delete,
|
||||
can_share=r.can_share,
|
||||
priority=r.priority,
|
||||
created_at=r.created_at,
|
||||
)
|
||||
for r in rules
|
||||
]
|
||||
|
||||
return UserPermissions(
|
||||
user_id=user_id,
|
||||
is_admin=False,
|
||||
permissions=permissions,
|
||||
path_rules=path_rules,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def get_all_permissions(cls) -> List[PermissionInfo]:
|
||||
"""获取所有权限定义"""
|
||||
return [
|
||||
PermissionInfo(
|
||||
code=item["code"],
|
||||
name=item["name"],
|
||||
category=item["category"],
|
||||
description=item.get("description"),
|
||||
)
|
||||
for item in PERMISSION_DEFINITIONS
|
||||
]
|
||||
|
||||
@classmethod
|
||||
async def check_path_permission_detailed(
|
||||
cls, user_id: int, path: str, action: str
|
||||
) -> PathPermissionResult:
|
||||
"""检查路径权限并返回详细结果"""
|
||||
user = await UserAccount.get_or_none(id=user_id)
|
||||
if not user:
|
||||
return PathPermissionResult(path=path, action=action, allowed=False)
|
||||
|
||||
# 超级管理员
|
||||
if user.is_admin:
|
||||
return PathPermissionResult(path=path, action=action, allowed=True)
|
||||
|
||||
# 获取用户角色
|
||||
user_roles = await UserRole.filter(user_id=user_id)
|
||||
role_ids = [ur.role_id for ur in user_roles]
|
||||
|
||||
if not role_ids:
|
||||
return PathPermissionResult(path=path, action=action, allowed=False)
|
||||
|
||||
# 获取路径规则
|
||||
path_rules = await PathRule.filter(role_id__in=role_ids).order_by("-priority")
|
||||
normalized_path = PathMatcher.normalize_path(path)
|
||||
|
||||
# 查找匹配的规则
|
||||
matched_rule = None
|
||||
for rule in sorted(
|
||||
path_rules,
|
||||
key=lambda r: (
|
||||
r.priority,
|
||||
PathMatcher.get_pattern_specificity(r.path_pattern, r.is_regex),
|
||||
),
|
||||
reverse=True,
|
||||
):
|
||||
if PathMatcher.match_pattern(
|
||||
normalized_path, rule.path_pattern, rule.is_regex
|
||||
):
|
||||
matched_rule = rule
|
||||
break
|
||||
|
||||
# 检查权限
|
||||
allowed = False
|
||||
if matched_rule:
|
||||
if action == PathAction.READ:
|
||||
allowed = matched_rule.can_read
|
||||
elif action == PathAction.WRITE:
|
||||
allowed = matched_rule.can_write
|
||||
elif action == PathAction.DELETE:
|
||||
allowed = matched_rule.can_delete
|
||||
elif action == PathAction.SHARE:
|
||||
allowed = matched_rule.can_share
|
||||
|
||||
rule_info = None
|
||||
if matched_rule:
|
||||
rule_info = PathRuleInfo(
|
||||
id=matched_rule.id,
|
||||
role_id=matched_rule.role_id,
|
||||
path_pattern=matched_rule.path_pattern,
|
||||
is_regex=matched_rule.is_regex,
|
||||
can_read=matched_rule.can_read,
|
||||
can_write=matched_rule.can_write,
|
||||
can_delete=matched_rule.can_delete,
|
||||
can_share=matched_rule.can_share,
|
||||
priority=matched_rule.priority,
|
||||
created_at=matched_rule.created_at,
|
||||
)
|
||||
|
||||
return PathPermissionResult(
|
||||
path=path, action=action, allowed=allowed, matched_rule=rule_info
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def clear_cache(cls, user_id: int | None = None) -> None:
|
||||
"""清除权限缓存"""
|
||||
if user_id is None:
|
||||
cls._cache.clear()
|
||||
else:
|
||||
# 清除特定用户的缓存
|
||||
keys_to_delete = [k for k in cls._cache if k.startswith(f"{user_id}:")]
|
||||
for k in keys_to_delete:
|
||||
del cls._cache[k]
|
||||
|
||||
@classmethod
|
||||
async def filter_paths_by_permission(
|
||||
cls, user_id: int, paths: List[str], action: str
|
||||
) -> List[str]:
|
||||
"""过滤出用户有权限的路径列表"""
|
||||
result = []
|
||||
for path in paths:
|
||||
if await cls.check_path_permission(user_id, path, action):
|
||||
result.append(path)
|
||||
return result
|
||||
107
domain/permission/types.py
Normal file
107
domain/permission/types.py
Normal file
@@ -0,0 +1,107 @@
|
||||
from pydantic import BaseModel
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
# 权限操作类型
|
||||
class PathAction:
|
||||
READ = "read"
|
||||
WRITE = "write"
|
||||
DELETE = "delete"
|
||||
SHARE = "share"
|
||||
|
||||
|
||||
# 系统权限代码
|
||||
class SystemPermission:
|
||||
USER_CREATE = "system.user.create"
|
||||
USER_EDIT = "system.user.edit"
|
||||
USER_DELETE = "system.user.delete"
|
||||
USER_LIST = "system.user.list"
|
||||
ROLE_MANAGE = "system.role.manage"
|
||||
CONFIG_EDIT = "system.config.edit"
|
||||
AUDIT_VIEW = "system.audit.view"
|
||||
|
||||
|
||||
# 适配器权限代码
|
||||
class AdapterPermission:
|
||||
CREATE = "adapter.create"
|
||||
EDIT = "adapter.edit"
|
||||
DELETE = "adapter.delete"
|
||||
LIST = "adapter.list"
|
||||
|
||||
|
||||
# 所有权限定义
|
||||
PERMISSION_DEFINITIONS = [
|
||||
# 系统权限
|
||||
{"code": SystemPermission.USER_CREATE, "name": "创建用户", "category": "system", "description": "允许创建新用户"},
|
||||
{"code": SystemPermission.USER_EDIT, "name": "编辑用户", "category": "system", "description": "允许编辑用户信息"},
|
||||
{"code": SystemPermission.USER_DELETE, "name": "删除用户", "category": "system", "description": "允许删除用户"},
|
||||
{"code": SystemPermission.USER_LIST, "name": "查看用户列表", "category": "system", "description": "允许查看用户列表"},
|
||||
{"code": SystemPermission.ROLE_MANAGE, "name": "管理角色和权限", "category": "system", "description": "允许管理角色和权限配置"},
|
||||
{"code": SystemPermission.CONFIG_EDIT, "name": "修改系统配置", "category": "system", "description": "允许修改系统配置"},
|
||||
{"code": SystemPermission.AUDIT_VIEW, "name": "查看审计日志", "category": "system", "description": "允许查看审计日志"},
|
||||
# 适配器权限
|
||||
{"code": AdapterPermission.CREATE, "name": "创建存储适配器", "category": "adapter", "description": "允许创建存储适配器"},
|
||||
{"code": AdapterPermission.EDIT, "name": "编辑存储适配器", "category": "adapter", "description": "允许编辑存储适配器"},
|
||||
{"code": AdapterPermission.DELETE, "name": "删除存储适配器", "category": "adapter", "description": "允许删除存储适配器"},
|
||||
{"code": AdapterPermission.LIST, "name": "查看存储适配器列表", "category": "adapter", "description": "允许查看存储适配器列表"},
|
||||
]
|
||||
|
||||
|
||||
# Pydantic 模型
|
||||
class PermissionInfo(BaseModel):
|
||||
code: str
|
||||
name: str
|
||||
category: str
|
||||
description: str | None = None
|
||||
|
||||
|
||||
class PathRuleInfo(BaseModel):
|
||||
id: int
|
||||
role_id: int
|
||||
path_pattern: str
|
||||
is_regex: bool
|
||||
can_read: bool
|
||||
can_write: bool
|
||||
can_delete: bool
|
||||
can_share: bool
|
||||
priority: int
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class PathRuleCreate(BaseModel):
|
||||
path_pattern: str
|
||||
is_regex: bool = False
|
||||
can_read: bool = True
|
||||
can_write: bool = False
|
||||
can_delete: bool = False
|
||||
can_share: bool = False
|
||||
priority: int = 0
|
||||
|
||||
|
||||
class PathRuleUpdate(BaseModel):
|
||||
path_pattern: str | None = None
|
||||
is_regex: bool | None = None
|
||||
can_read: bool | None = None
|
||||
can_write: bool | None = None
|
||||
can_delete: bool | None = None
|
||||
can_share: bool | None = None
|
||||
priority: int | None = None
|
||||
|
||||
|
||||
class PathPermissionCheck(BaseModel):
|
||||
path: str
|
||||
action: str
|
||||
|
||||
|
||||
class PathPermissionResult(BaseModel):
|
||||
path: str
|
||||
action: str
|
||||
allowed: bool
|
||||
matched_rule: PathRuleInfo | None = None
|
||||
|
||||
|
||||
class UserPermissions(BaseModel):
|
||||
user_id: int
|
||||
is_admin: bool
|
||||
permissions: list[str] # 系统/适配器权限代码列表
|
||||
path_rules: list[PathRuleInfo] # 路径权限规则
|
||||
@@ -4,9 +4,9 @@ Foxel 插件系统
|
||||
提供 .foxpkg 插件包的安装、管理和运行时加载功能。
|
||||
"""
|
||||
|
||||
from domain.plugins.loader import PluginLoader, PluginLoadError
|
||||
from domain.plugins.service import PluginService
|
||||
from domain.plugins.startup import init_plugins, load_installed_plugins
|
||||
from .loader import PluginLoadError, PluginLoader
|
||||
from .service import PluginService
|
||||
from .startup import init_plugins, load_installed_plugins
|
||||
|
||||
__all__ = [
|
||||
"PluginLoader",
|
||||
|
||||
@@ -2,14 +2,17 @@
|
||||
插件管理 API 路由
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
from typing import Annotated, List
|
||||
|
||||
from fastapi import APIRouter, File, Request, UploadFile
|
||||
from fastapi import APIRouter, Depends, File, Request, UploadFile
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.plugins.service import PluginService
|
||||
from domain.plugins.types import (
|
||||
from domain.auth import User, get_current_active_user
|
||||
from domain.permission import require_system_permission
|
||||
from domain.permission.types import SystemPermission
|
||||
from .service import PluginService
|
||||
from .types import (
|
||||
PluginInstallResult,
|
||||
PluginOut,
|
||||
)
|
||||
@@ -22,7 +25,12 @@ router = APIRouter(prefix="/api/plugins", tags=["plugins"])
|
||||
|
||||
@router.post("/install", response_model=PluginInstallResult)
|
||||
@audit(action=AuditAction.CREATE, description="安装插件包")
|
||||
async def install_plugin(request: Request, file: UploadFile = File(...)):
|
||||
@require_system_permission(SystemPermission.ROLE_MANAGE)
|
||||
async def install_plugin(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
file: UploadFile = File(...),
|
||||
):
|
||||
"""
|
||||
安装 .foxpkg 插件包
|
||||
|
||||
@@ -37,14 +45,21 @@ async def install_plugin(request: Request, file: UploadFile = File(...)):
|
||||
|
||||
@router.get("", response_model=List[PluginOut])
|
||||
@audit(action=AuditAction.READ, description="获取插件列表")
|
||||
async def list_plugins(request: Request):
|
||||
async def list_plugins(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
"""获取已安装的插件列表"""
|
||||
return await PluginService.list_plugins()
|
||||
|
||||
|
||||
@router.get("/{key_or_id}", response_model=PluginOut)
|
||||
@audit(action=AuditAction.READ, description="获取插件详情")
|
||||
async def get_plugin(request: Request, key_or_id: str):
|
||||
async def get_plugin(
|
||||
request: Request,
|
||||
key_or_id: str,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
"""获取单个插件详情"""
|
||||
return await PluginService.get_plugin(key_or_id)
|
||||
|
||||
@@ -54,7 +69,12 @@ async def get_plugin(request: Request, key_or_id: str):
|
||||
|
||||
@router.delete("/{key_or_id}")
|
||||
@audit(action=AuditAction.DELETE, description="卸载插件")
|
||||
async def delete_plugin(request: Request, key_or_id: str):
|
||||
@require_system_permission(SystemPermission.ROLE_MANAGE)
|
||||
async def delete_plugin(
|
||||
request: Request,
|
||||
key_or_id: str,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
"""卸载插件"""
|
||||
await PluginService.delete(key_or_id)
|
||||
return {"code": 0, "msg": "ok"}
|
||||
@@ -67,10 +87,12 @@ async def delete_plugin(request: Request, key_or_id: str):
|
||||
async def get_bundle(request: Request, key_or_id: str):
|
||||
"""获取插件前端 bundle"""
|
||||
path = await PluginService.get_bundle_path(key_or_id)
|
||||
v = (request.query_params.get("v") or "").strip()
|
||||
cache_control = "public, max-age=31536000, immutable" if v else "no-cache"
|
||||
return FileResponse(
|
||||
path,
|
||||
media_type="application/javascript",
|
||||
headers={"Cache-Control": "no-store"},
|
||||
headers={"Cache-Control": cache_control},
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from domain.plugins.types import (
|
||||
from .types import (
|
||||
ManifestProcessorConfig,
|
||||
ManifestRouteConfig,
|
||||
PluginManifest,
|
||||
@@ -344,7 +344,7 @@ class PluginLoader:
|
||||
supported_exts = getattr(module, "SUPPORTED_EXTS", [])
|
||||
|
||||
# 注册到处理器注册表
|
||||
from domain.processors.registry import CONFIG_SCHEMAS, TYPE_MAP
|
||||
from domain.processors import CONFIG_SCHEMAS, TYPE_MAP
|
||||
|
||||
processor_type = processor_config.type
|
||||
TYPE_MAP[processor_type] = factory
|
||||
@@ -401,7 +401,7 @@ class PluginLoader:
|
||||
"""
|
||||
# 卸载处理器
|
||||
if manifest and manifest.backend and manifest.backend.processors:
|
||||
from domain.processors.registry import CONFIG_SCHEMAS, TYPE_MAP
|
||||
from domain.processors import CONFIG_SCHEMAS, TYPE_MAP
|
||||
|
||||
for proc_config in manifest.backend.processors:
|
||||
proc_type = proc_config.type
|
||||
|
||||
@@ -12,8 +12,8 @@ from typing import List, Optional, Union
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from domain.plugins.loader import PluginLoadError, PluginLoader
|
||||
from domain.plugins.types import (
|
||||
from .loader import PluginLoadError, PluginLoader
|
||||
from .types import (
|
||||
PluginInstallResult,
|
||||
PluginManifest,
|
||||
PluginOut,
|
||||
|
||||
@@ -7,8 +7,8 @@
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, List, Tuple
|
||||
|
||||
from domain.plugins.loader import PluginLoadError, PluginLoader
|
||||
from domain.plugins.types import PluginManifest
|
||||
from .loader import PluginLoadError, PluginLoader
|
||||
from .types import PluginManifest
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastapi import FastAPI
|
||||
@@ -113,4 +113,3 @@ async def init_plugins(app: "FastAPI") -> None:
|
||||
logger.warning(f" - {error}")
|
||||
else:
|
||||
logger.info(f"插件加载完成,共 {loaded_count} 个插件")
|
||||
|
||||
|
||||
35
domain/processors/__init__.py
Normal file
35
domain/processors/__init__.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from .base import BaseProcessor
|
||||
from .registry import (
|
||||
CONFIG_SCHEMAS,
|
||||
TYPE_MAP,
|
||||
get_config_schema,
|
||||
get_config_schemas,
|
||||
get_last_discovery_errors,
|
||||
get_module_path,
|
||||
reload_processors,
|
||||
)
|
||||
from .service import (
|
||||
ProcessorService,
|
||||
get_processor,
|
||||
list_processors,
|
||||
reload_processor_modules,
|
||||
)
|
||||
from .types import ProcessDirectoryRequest, ProcessRequest, UpdateSourceRequest
|
||||
|
||||
__all__ = [
|
||||
"BaseProcessor",
|
||||
"CONFIG_SCHEMAS",
|
||||
"TYPE_MAP",
|
||||
"get_config_schema",
|
||||
"get_config_schemas",
|
||||
"get_last_discovery_errors",
|
||||
"get_module_path",
|
||||
"reload_processors",
|
||||
"ProcessorService",
|
||||
"get_processor",
|
||||
"list_processors",
|
||||
"reload_processor_modules",
|
||||
"ProcessDirectoryRequest",
|
||||
"ProcessRequest",
|
||||
"UpdateSourceRequest",
|
||||
]
|
||||
@@ -4,10 +4,15 @@ from fastapi import APIRouter, Body, Depends, Request
|
||||
|
||||
from api.response import success
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.auth.service import get_current_active_user
|
||||
from domain.auth.types import User
|
||||
from domain.processors.service import ProcessorService
|
||||
from domain.processors.types import (
|
||||
from domain.auth import User, get_current_active_user
|
||||
from domain.permission import require_path_permission
|
||||
from domain.permission import require_system_permission
|
||||
from domain.permission.service import PermissionService
|
||||
from domain.permission.types import PathAction
|
||||
from domain.permission.types import SystemPermission
|
||||
from domain.processors.registry import get_config_schema
|
||||
from .service import ProcessorService
|
||||
from .types import (
|
||||
ProcessDirectoryRequest,
|
||||
ProcessRequest,
|
||||
UpdateSourceRequest,
|
||||
@@ -32,11 +37,18 @@ async def list_processors(
|
||||
description="处理单个文件",
|
||||
body_fields=["path", "processor_type", "save_to", "overwrite"],
|
||||
)
|
||||
@require_path_permission(PathAction.READ, "req.path")
|
||||
async def process_file_with_processor(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
req: ProcessRequest = Body(...),
|
||||
):
|
||||
meta = get_config_schema(req.processor_type) or {}
|
||||
if meta.get("produces_file"):
|
||||
if req.overwrite:
|
||||
await PermissionService.require_path_permission(current_user.id, req.path, PathAction.WRITE)
|
||||
elif req.save_to:
|
||||
await PermissionService.require_path_permission(current_user.id, req.save_to, PathAction.WRITE)
|
||||
data = await ProcessorService.process_file(req)
|
||||
return success(data)
|
||||
|
||||
@@ -47,17 +59,22 @@ async def process_file_with_processor(
|
||||
description="批量处理目录",
|
||||
body_fields=["path", "processor_type", "overwrite", "max_depth", "suffix"],
|
||||
)
|
||||
@require_path_permission(PathAction.READ, "req.path")
|
||||
async def process_directory_with_processor(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
req: ProcessDirectoryRequest = Body(...),
|
||||
):
|
||||
meta = get_config_schema(req.processor_type) or {}
|
||||
if meta.get("produces_file"):
|
||||
await PermissionService.require_path_permission(current_user.id, req.path, PathAction.WRITE)
|
||||
data = await ProcessorService.process_directory(req)
|
||||
return success(data)
|
||||
|
||||
|
||||
@router.get("/source/{processor_type}")
|
||||
@audit(action=AuditAction.READ, description="获取处理器源码")
|
||||
@require_system_permission(SystemPermission.ROLE_MANAGE)
|
||||
async def get_processor_source(
|
||||
request: Request,
|
||||
processor_type: str,
|
||||
@@ -69,6 +86,7 @@ async def get_processor_source(
|
||||
|
||||
@router.put("/source/{processor_type}")
|
||||
@audit(action=AuditAction.UPDATE, description="更新处理器源码")
|
||||
@require_system_permission(SystemPermission.ROLE_MANAGE)
|
||||
async def update_processor_source(
|
||||
request: Request,
|
||||
processor_type: str,
|
||||
@@ -81,6 +99,7 @@ async def update_processor_source(
|
||||
|
||||
@router.post("/reload")
|
||||
@audit(action=AuditAction.UPDATE, description="重载处理器模块")
|
||||
@require_system_permission(SystemPermission.ROLE_MANAGE)
|
||||
async def reload_processor_modules(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
|
||||
@@ -8,12 +8,14 @@ from fastapi.responses import Response
|
||||
from PIL import Image
|
||||
|
||||
from ..base import BaseProcessor
|
||||
from domain.ai.inference import describe_image_base64, get_text_embedding, provider_service
|
||||
from domain.ai.service import (
|
||||
VectorDBService,
|
||||
from domain.ai import (
|
||||
DEFAULT_VECTOR_DIMENSION,
|
||||
VECTOR_COLLECTION_NAME,
|
||||
FILE_COLLECTION_NAME,
|
||||
VECTOR_COLLECTION_NAME,
|
||||
VectorDBService,
|
||||
describe_image_base64,
|
||||
get_text_embedding,
|
||||
provider_service,
|
||||
)
|
||||
|
||||
|
||||
@@ -112,8 +114,15 @@ class VectorIndexProcessor:
|
||||
}
|
||||
]
|
||||
produces_file = False
|
||||
requires_input_bytes = False
|
||||
|
||||
async def process(self, input_bytes: bytes, path: str, config: Dict[str, Any]) -> Response:
|
||||
async def ensure_input_bytes() -> bytes:
|
||||
if input_bytes:
|
||||
return input_bytes
|
||||
from domain.virtual_fs import VirtualFSService
|
||||
return await VirtualFSService.read_file(path)
|
||||
|
||||
action = config.get("action", "create")
|
||||
index_type = config.get("index_type", "vector")
|
||||
vector_db = VectorDBService()
|
||||
@@ -157,7 +166,8 @@ class VectorIndexProcessor:
|
||||
await vector_db.delete_vector(vector_collection, path)
|
||||
|
||||
if file_ext in ["jpg", "jpeg", "png", "bmp"]:
|
||||
processed_bytes, compression = _compress_image_for_embedding(input_bytes)
|
||||
file_bytes = await ensure_input_bytes()
|
||||
processed_bytes, compression = _compress_image_for_embedding(file_bytes)
|
||||
base64_image = base64.b64encode(processed_bytes).decode("utf-8")
|
||||
description = await describe_image_base64(base64_image)
|
||||
embedding = await get_text_embedding(description)
|
||||
@@ -178,7 +188,8 @@ class VectorIndexProcessor:
|
||||
|
||||
if file_ext in ["txt", "md"]:
|
||||
try:
|
||||
text = input_bytes.decode("utf-8")
|
||||
file_bytes = await ensure_input_bytes()
|
||||
text = file_bytes.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
return Response(content="文本文件解码失败", status_code=400)
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
from domain.processors.base import BaseProcessor
|
||||
from .base import BaseProcessor
|
||||
|
||||
ProcessorFactory = Callable[[], BaseProcessor]
|
||||
TYPE_MAP: Dict[str, ProcessorFactory] = {}
|
||||
@@ -16,7 +16,7 @@ LAST_DISCOVERY_ERRORS: list[str] = []
|
||||
|
||||
def discover_processors(force_reload: bool = False) -> list[str]:
|
||||
"""扫描并缓存可用的处理器模块。"""
|
||||
from domain.processors import builtin as processors_pkg
|
||||
from . import builtin as processors_pkg
|
||||
|
||||
TYPE_MAP.clear()
|
||||
CONFIG_SCHEMAS.clear()
|
||||
|
||||
@@ -3,20 +3,20 @@ from typing import List, Tuple
|
||||
|
||||
from fastapi import HTTPException
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from domain.processors.registry import (
|
||||
from domain.tasks import task_queue_service
|
||||
from domain.virtual_fs import VirtualFSService
|
||||
from .registry import (
|
||||
get,
|
||||
get_config_schema,
|
||||
get_config_schemas,
|
||||
get_module_path,
|
||||
reload_processors,
|
||||
)
|
||||
from domain.processors.types import (
|
||||
from .types import (
|
||||
ProcessDirectoryRequest,
|
||||
ProcessRequest,
|
||||
UpdateSourceRequest,
|
||||
)
|
||||
from domain.virtual_fs.service import VirtualFSService
|
||||
from domain.tasks.task_queue import task_queue_service
|
||||
|
||||
|
||||
class ProcessorService:
|
||||
@@ -85,6 +85,44 @@ class ProcessorService:
|
||||
suffix = raw_suffix
|
||||
overwrite = req.overwrite
|
||||
|
||||
if produces_file:
|
||||
if not overwrite and not suffix:
|
||||
raise HTTPException(400, detail="Suffix is required when not overwriting files")
|
||||
else:
|
||||
overwrite = False
|
||||
suffix = None
|
||||
payload = {
|
||||
"path": req.path,
|
||||
"processor_type": req.processor_type,
|
||||
"config": req.config,
|
||||
"overwrite": overwrite,
|
||||
"max_depth": req.max_depth,
|
||||
"suffix": suffix,
|
||||
}
|
||||
task = await task_queue_service.add_task("process_directory_scan", payload)
|
||||
return {"task_id": task.id}
|
||||
|
||||
@classmethod
|
||||
async def scan_directory(cls, req: ProcessDirectoryRequest):
|
||||
if req.max_depth is not None and req.max_depth < 0:
|
||||
raise HTTPException(400, detail="max_depth must be >= 0")
|
||||
|
||||
is_dir = await VirtualFSService.path_is_directory(req.path)
|
||||
if not is_dir:
|
||||
raise HTTPException(400, detail="Path must be a directory")
|
||||
|
||||
schema = get_config_schema(req.processor_type)
|
||||
_processor = get(req.processor_type)
|
||||
if not schema or not _processor:
|
||||
raise HTTPException(404, detail="Processor not found")
|
||||
|
||||
produces_file = bool(schema.get("produces_file"))
|
||||
raw_suffix = req.suffix if req.suffix is not None else None
|
||||
if raw_suffix is not None and raw_suffix.strip() == "":
|
||||
raw_suffix = None
|
||||
suffix = raw_suffix
|
||||
overwrite = req.overwrite
|
||||
|
||||
if produces_file:
|
||||
if not overwrite and not suffix:
|
||||
raise HTTPException(400, detail="Suffix is required when not overwriting files")
|
||||
@@ -133,7 +171,7 @@ class ProcessorService:
|
||||
new_name = f"{name}{suffix_str}"
|
||||
return str(path_obj.with_name(new_name))
|
||||
|
||||
scheduled_tasks: List[str] = []
|
||||
scheduled_count = 0
|
||||
stack: List[Tuple[str, int]] = [(rel, 0)]
|
||||
page_size = 200
|
||||
|
||||
@@ -161,7 +199,7 @@ class ProcessorService:
|
||||
save_to = None
|
||||
if produces_file and not overwrite and suffix:
|
||||
save_to = apply_suffix(absolute_path, suffix)
|
||||
task = await task_queue_service.add_task(
|
||||
await task_queue_service.add_task(
|
||||
"process_file",
|
||||
{
|
||||
"path": absolute_path,
|
||||
@@ -171,16 +209,13 @@ class ProcessorService:
|
||||
"overwrite": overwrite,
|
||||
},
|
||||
)
|
||||
scheduled_tasks.append(task.id)
|
||||
scheduled_count += 1
|
||||
|
||||
if total is None or page * page_size >= total:
|
||||
break
|
||||
page += 1
|
||||
|
||||
return {
|
||||
"task_ids": scheduled_tasks,
|
||||
"scheduled": len(scheduled_tasks),
|
||||
}
|
||||
return {"scheduled": scheduled_count}
|
||||
|
||||
@classmethod
|
||||
async def get_source(cls, processor_type: str):
|
||||
|
||||
1
domain/repositories/__init__.py
Normal file
1
domain/repositories/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
__all__: list[str] = []
|
||||
3
domain/role/__init__.py
Normal file
3
domain/role/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .service import RoleService
|
||||
|
||||
__all__ = ["RoleService"]
|
||||
119
domain/role/api.py
Normal file
119
domain/role/api.py
Normal file
@@ -0,0 +1,119 @@
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from domain.auth.service import get_current_active_user
|
||||
from domain.auth.types import User
|
||||
from domain.permission import require_system_permission
|
||||
from domain.permission.types import PathRuleCreate, PathRuleInfo, SystemPermission
|
||||
from domain.user.service import UserService
|
||||
from domain.user.types import UserInfo
|
||||
|
||||
from .service import RoleService
|
||||
from .types import RoleCreate, RoleDetail, RoleInfo, RolePermissionsUpdate, RoleUpdate
|
||||
|
||||
router = APIRouter(prefix="/api", tags=["role"])
|
||||
|
||||
|
||||
@router.get("/roles", response_model=list[RoleInfo])
|
||||
@require_system_permission(SystemPermission.ROLE_MANAGE)
|
||||
async def list_roles(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
) -> list[RoleInfo]:
|
||||
return await RoleService.get_all_roles()
|
||||
|
||||
|
||||
@router.get("/roles/{role_id}", response_model=RoleDetail)
|
||||
@require_system_permission(SystemPermission.ROLE_MANAGE)
|
||||
async def get_role(
|
||||
role_id: int,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
) -> RoleDetail:
|
||||
return await RoleService.get_role(role_id)
|
||||
|
||||
|
||||
@router.get("/roles/{role_id}/users", response_model=list[UserInfo])
|
||||
@require_system_permission(SystemPermission.ROLE_MANAGE)
|
||||
async def list_role_users(
|
||||
role_id: int,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
) -> list[UserInfo]:
|
||||
return await UserService.get_users_by_role(role_id)
|
||||
|
||||
|
||||
@router.post("/roles", response_model=RoleInfo)
|
||||
@require_system_permission(SystemPermission.ROLE_MANAGE)
|
||||
async def create_role(
|
||||
data: RoleCreate,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
) -> RoleInfo:
|
||||
return await RoleService.create_role(data)
|
||||
|
||||
|
||||
@router.put("/roles/{role_id}", response_model=RoleInfo)
|
||||
@require_system_permission(SystemPermission.ROLE_MANAGE)
|
||||
async def update_role(
|
||||
role_id: int,
|
||||
data: RoleUpdate,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
) -> RoleInfo:
|
||||
return await RoleService.update_role(role_id, data)
|
||||
|
||||
|
||||
@router.delete("/roles/{role_id}")
|
||||
@require_system_permission(SystemPermission.ROLE_MANAGE)
|
||||
async def delete_role(
|
||||
role_id: int,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
) -> dict:
|
||||
await RoleService.delete_role(role_id)
|
||||
return {"success": True}
|
||||
|
||||
|
||||
@router.post("/roles/{role_id}/permissions", response_model=list[str])
|
||||
@require_system_permission(SystemPermission.ROLE_MANAGE)
|
||||
async def set_role_permissions(
|
||||
role_id: int,
|
||||
data: RolePermissionsUpdate,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
) -> list[str]:
|
||||
return await RoleService.set_role_permissions(role_id, data.permission_codes)
|
||||
|
||||
|
||||
@router.get("/roles/{role_id}/path-rules", response_model=list[PathRuleInfo])
|
||||
@require_system_permission(SystemPermission.ROLE_MANAGE)
|
||||
async def get_role_path_rules(
|
||||
role_id: int,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
) -> list[PathRuleInfo]:
|
||||
return await RoleService.get_role_path_rules(role_id)
|
||||
|
||||
|
||||
@router.post("/roles/{role_id}/path-rules", response_model=PathRuleInfo)
|
||||
@require_system_permission(SystemPermission.ROLE_MANAGE)
|
||||
async def add_path_rule(
|
||||
role_id: int,
|
||||
data: PathRuleCreate,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
) -> PathRuleInfo:
|
||||
return await RoleService.add_path_rule(role_id, data)
|
||||
|
||||
|
||||
@router.put("/path-rules/{rule_id}", response_model=PathRuleInfo)
|
||||
@require_system_permission(SystemPermission.ROLE_MANAGE)
|
||||
async def update_path_rule(
|
||||
rule_id: int,
|
||||
data: PathRuleCreate,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
) -> PathRuleInfo:
|
||||
return await RoleService.update_path_rule(rule_id, data)
|
||||
|
||||
|
||||
@router.delete("/path-rules/{rule_id}")
|
||||
@require_system_permission(SystemPermission.ROLE_MANAGE)
|
||||
async def delete_path_rule(
|
||||
rule_id: int,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
) -> dict:
|
||||
await RoleService.delete_path_rule(rule_id)
|
||||
return {"success": True}
|
||||
288
domain/role/service.py
Normal file
288
domain/role/service.py
Normal file
@@ -0,0 +1,288 @@
|
||||
from typing import List
|
||||
from fastapi import HTTPException
|
||||
|
||||
from models.database import Role, RolePermission, PathRule, UserRole
|
||||
from domain.permission.service import PermissionService
|
||||
from domain.permission.types import PathRuleCreate, PathRuleInfo, PERMISSION_DEFINITIONS
|
||||
from .types import RoleInfo, RoleDetail, RoleCreate, RoleUpdate, SystemRoles
|
||||
|
||||
|
||||
class RoleService:
|
||||
"""角色管理服务"""
|
||||
|
||||
@classmethod
|
||||
async def get_all_roles(cls) -> List[RoleInfo]:
|
||||
"""获取所有角色"""
|
||||
roles = await Role.all().order_by("id")
|
||||
return [
|
||||
RoleInfo(
|
||||
id=r.id,
|
||||
name=r.name,
|
||||
description=r.description,
|
||||
is_system=r.is_system,
|
||||
created_at=r.created_at,
|
||||
)
|
||||
for r in roles
|
||||
]
|
||||
|
||||
@classmethod
|
||||
async def get_role(cls, role_id: int) -> RoleDetail:
|
||||
"""获取角色详情"""
|
||||
role = await Role.get_or_none(id=role_id)
|
||||
if not role:
|
||||
raise HTTPException(404, detail="角色不存在")
|
||||
|
||||
# 获取权限
|
||||
role_permissions = await RolePermission.filter(role_id=role_id)
|
||||
permissions = sorted(set(rp.permission_code for rp in role_permissions))
|
||||
|
||||
# 获取路径规则数量
|
||||
path_rules_count = await PathRule.filter(role_id=role_id).count()
|
||||
|
||||
return RoleDetail(
|
||||
id=role.id,
|
||||
name=role.name,
|
||||
description=role.description,
|
||||
is_system=role.is_system,
|
||||
created_at=role.created_at,
|
||||
permissions=permissions,
|
||||
path_rules_count=path_rules_count,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def create_role(cls, data: RoleCreate) -> RoleInfo:
|
||||
"""创建角色"""
|
||||
# 检查名称是否已存在
|
||||
existing = await Role.get_or_none(name=data.name)
|
||||
if existing:
|
||||
raise HTTPException(400, detail="角色名称已存在")
|
||||
|
||||
role = await Role.create(
|
||||
name=data.name,
|
||||
description=data.description,
|
||||
is_system=False,
|
||||
)
|
||||
|
||||
return RoleInfo(
|
||||
id=role.id,
|
||||
name=role.name,
|
||||
description=role.description,
|
||||
is_system=role.is_system,
|
||||
created_at=role.created_at,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def update_role(cls, role_id: int, data: RoleUpdate) -> RoleInfo:
|
||||
"""更新角色"""
|
||||
role = await Role.get_or_none(id=role_id)
|
||||
if not role:
|
||||
raise HTTPException(404, detail="角色不存在")
|
||||
|
||||
if data.name is not None:
|
||||
# 检查名称是否与其他角色冲突
|
||||
existing = await Role.filter(name=data.name).exclude(id=role_id).first()
|
||||
if existing:
|
||||
raise HTTPException(400, detail="角色名称已存在")
|
||||
role.name = data.name
|
||||
|
||||
if data.description is not None:
|
||||
role.description = data.description
|
||||
|
||||
await role.save()
|
||||
|
||||
return RoleInfo(
|
||||
id=role.id,
|
||||
name=role.name,
|
||||
description=role.description,
|
||||
is_system=role.is_system,
|
||||
created_at=role.created_at,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def delete_role(cls, role_id: int) -> None:
|
||||
"""删除角色"""
|
||||
role = await Role.get_or_none(id=role_id)
|
||||
if not role:
|
||||
raise HTTPException(404, detail="角色不存在")
|
||||
|
||||
if role.is_system:
|
||||
raise HTTPException(400, detail="系统内置角色不可删除")
|
||||
|
||||
# 检查是否有用户使用此角色
|
||||
user_count = await UserRole.filter(role_id=role_id).count()
|
||||
if user_count > 0:
|
||||
raise HTTPException(400, detail=f"有 {user_count} 个用户正在使用此角色,无法删除")
|
||||
|
||||
await role.delete()
|
||||
# 清除权限缓存
|
||||
PermissionService.clear_cache()
|
||||
|
||||
@classmethod
|
||||
async def set_role_permissions(cls, role_id: int, permission_codes: List[str]) -> List[str]:
|
||||
"""设置角色的权限"""
|
||||
role = await Role.get_or_none(id=role_id)
|
||||
if not role:
|
||||
raise HTTPException(404, detail="角色不存在")
|
||||
|
||||
all_permission_codes = {item["code"] for item in PERMISSION_DEFINITIONS}
|
||||
invalid_codes = set(permission_codes) - all_permission_codes
|
||||
if invalid_codes:
|
||||
raise HTTPException(400, detail=f"无效的权限代码: {', '.join(invalid_codes)}")
|
||||
|
||||
# 删除现有权限
|
||||
await RolePermission.filter(role_id=role_id).delete()
|
||||
|
||||
# 添加新权限
|
||||
for code in permission_codes:
|
||||
await RolePermission.create(
|
||||
role_id=role_id,
|
||||
permission_code=code,
|
||||
)
|
||||
|
||||
# 清除权限缓存
|
||||
PermissionService.clear_cache()
|
||||
|
||||
return list(permission_codes)
|
||||
|
||||
@classmethod
|
||||
async def get_role_path_rules(cls, role_id: int) -> List[PathRuleInfo]:
|
||||
"""获取角色的路径规则"""
|
||||
role = await Role.get_or_none(id=role_id)
|
||||
if not role:
|
||||
raise HTTPException(404, detail="角色不存在")
|
||||
|
||||
rules = await PathRule.filter(role_id=role_id).order_by("-priority", "id")
|
||||
return [
|
||||
PathRuleInfo(
|
||||
id=r.id,
|
||||
role_id=r.role_id,
|
||||
path_pattern=r.path_pattern,
|
||||
is_regex=r.is_regex,
|
||||
can_read=r.can_read,
|
||||
can_write=r.can_write,
|
||||
can_delete=r.can_delete,
|
||||
can_share=r.can_share,
|
||||
priority=r.priority,
|
||||
created_at=r.created_at,
|
||||
)
|
||||
for r in rules
|
||||
]
|
||||
|
||||
@classmethod
|
||||
async def add_path_rule(cls, role_id: int, data: PathRuleCreate) -> PathRuleInfo:
|
||||
"""添加路径规则"""
|
||||
role = await Role.get_or_none(id=role_id)
|
||||
if not role:
|
||||
raise HTTPException(404, detail="角色不存在")
|
||||
|
||||
# 验证路径模式
|
||||
if data.is_regex:
|
||||
import re
|
||||
try:
|
||||
re.compile(data.path_pattern)
|
||||
except re.error as e:
|
||||
raise HTTPException(400, detail=f"无效的正则表达式: {e}")
|
||||
|
||||
rule = await PathRule.create(
|
||||
role_id=role_id,
|
||||
path_pattern=data.path_pattern,
|
||||
is_regex=data.is_regex,
|
||||
can_read=data.can_read,
|
||||
can_write=data.can_write,
|
||||
can_delete=data.can_delete,
|
||||
can_share=data.can_share,
|
||||
priority=data.priority,
|
||||
)
|
||||
|
||||
# 清除权限缓存
|
||||
PermissionService.clear_cache()
|
||||
|
||||
return PathRuleInfo(
|
||||
id=rule.id,
|
||||
role_id=rule.role_id,
|
||||
path_pattern=rule.path_pattern,
|
||||
is_regex=rule.is_regex,
|
||||
can_read=rule.can_read,
|
||||
can_write=rule.can_write,
|
||||
can_delete=rule.can_delete,
|
||||
can_share=rule.can_share,
|
||||
priority=rule.priority,
|
||||
created_at=rule.created_at,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def update_path_rule(cls, rule_id: int, data: PathRuleCreate) -> PathRuleInfo:
|
||||
"""更新路径规则"""
|
||||
rule = await PathRule.get_or_none(id=rule_id)
|
||||
if not rule:
|
||||
raise HTTPException(404, detail="路径规则不存在")
|
||||
|
||||
# 验证路径模式
|
||||
if data.is_regex:
|
||||
import re
|
||||
try:
|
||||
re.compile(data.path_pattern)
|
||||
except re.error as e:
|
||||
raise HTTPException(400, detail=f"无效的正则表达式: {e}")
|
||||
|
||||
rule.path_pattern = data.path_pattern
|
||||
rule.is_regex = data.is_regex
|
||||
rule.can_read = data.can_read
|
||||
rule.can_write = data.can_write
|
||||
rule.can_delete = data.can_delete
|
||||
rule.can_share = data.can_share
|
||||
rule.priority = data.priority
|
||||
await rule.save()
|
||||
|
||||
# 清除权限缓存
|
||||
PermissionService.clear_cache()
|
||||
|
||||
return PathRuleInfo(
|
||||
id=rule.id,
|
||||
role_id=rule.role_id,
|
||||
path_pattern=rule.path_pattern,
|
||||
is_regex=rule.is_regex,
|
||||
can_read=rule.can_read,
|
||||
can_write=rule.can_write,
|
||||
can_delete=rule.can_delete,
|
||||
can_share=rule.can_share,
|
||||
priority=rule.priority,
|
||||
created_at=rule.created_at,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def delete_path_rule(cls, rule_id: int) -> None:
|
||||
"""删除路径规则"""
|
||||
rule = await PathRule.get_or_none(id=rule_id)
|
||||
if not rule:
|
||||
raise HTTPException(404, detail="路径规则不存在")
|
||||
|
||||
await rule.delete()
|
||||
# 清除权限缓存
|
||||
PermissionService.clear_cache()
|
||||
|
||||
@classmethod
|
||||
async def ensure_system_roles(cls) -> None:
|
||||
"""确保系统内置角色存在"""
|
||||
system_roles = [
|
||||
{
|
||||
"name": SystemRoles.ADMIN,
|
||||
"description": "管理员角色,拥有所有系统和适配器权限",
|
||||
"is_system": True,
|
||||
},
|
||||
{
|
||||
"name": SystemRoles.USER,
|
||||
"description": "普通用户角色,需要管理员配置路径权限",
|
||||
"is_system": True,
|
||||
},
|
||||
{
|
||||
"name": SystemRoles.VIEWER,
|
||||
"description": "只读用户角色,仅可查看文件",
|
||||
"is_system": True,
|
||||
},
|
||||
]
|
||||
|
||||
for role_data in system_roles:
|
||||
existing = await Role.get_or_none(name=role_data["name"])
|
||||
if not existing:
|
||||
await Role.create(**role_data)
|
||||
36
domain/role/types.py
Normal file
36
domain/role/types.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from pydantic import BaseModel
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class RoleInfo(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
description: str | None = None
|
||||
is_system: bool
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class RoleDetail(RoleInfo):
|
||||
permissions: list[str] # 权限代码列表
|
||||
path_rules_count: int
|
||||
|
||||
|
||||
class RoleCreate(BaseModel):
|
||||
name: str
|
||||
description: str | None = None
|
||||
|
||||
|
||||
class RoleUpdate(BaseModel):
|
||||
name: str | None = None
|
||||
description: str | None = None
|
||||
|
||||
|
||||
class RolePermissionsUpdate(BaseModel):
|
||||
permission_codes: list[str]
|
||||
|
||||
|
||||
# 预置角色名称
|
||||
class SystemRoles:
|
||||
ADMIN = "Admin"
|
||||
USER = "User"
|
||||
VIEWER = "Viewer"
|
||||
10
domain/share/__init__.py
Normal file
10
domain/share/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from .service import ShareService
|
||||
from .types import ShareCreate, ShareInfo, ShareInfoWithPassword, SharePassword
|
||||
|
||||
__all__ = [
|
||||
"ShareService",
|
||||
"ShareCreate",
|
||||
"ShareInfo",
|
||||
"ShareInfoWithPassword",
|
||||
"SharePassword",
|
||||
]
|
||||
@@ -4,10 +4,11 @@ from fastapi import APIRouter, Depends, Request
|
||||
|
||||
from api.response import success
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.auth.service import get_current_active_user
|
||||
from domain.auth.types import User
|
||||
from domain.share.service import ShareService
|
||||
from domain.share.types import (
|
||||
from domain.auth import User, get_current_active_user
|
||||
from domain.permission import require_path_permission
|
||||
from domain.permission.types import PathAction
|
||||
from .service import ShareService
|
||||
from .types import (
|
||||
ShareCreate,
|
||||
ShareInfo,
|
||||
ShareInfoWithPassword,
|
||||
@@ -25,6 +26,7 @@ router = APIRouter(prefix="/api/shares", tags=["Share - Management"])
|
||||
description="创建分享链接",
|
||||
body_fields=["name", "paths", "expires_in_days", "access_type"],
|
||||
)
|
||||
@require_path_permission(PathAction.SHARE, "payload.paths")
|
||||
async def create_share(
|
||||
request: Request,
|
||||
payload: ShareCreate,
|
||||
|
||||
@@ -7,7 +7,7 @@ import bcrypt
|
||||
from fastapi import HTTPException, status
|
||||
from fastapi.responses import Response
|
||||
|
||||
from domain.virtual_fs.service import VirtualFSService
|
||||
from domain.virtual_fs import VirtualFSService
|
||||
from models.database import ShareLink, UserAccount
|
||||
|
||||
|
||||
|
||||
26
domain/tasks/__init__.py
Normal file
26
domain/tasks/__init__.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from .service import TaskService
|
||||
from .scheduler import task_scheduler
|
||||
from .task_queue import Task, TaskProgress, TaskStatus, task_queue_service
|
||||
from .types import (
|
||||
AutomationTaskBase,
|
||||
AutomationTaskCreate,
|
||||
AutomationTaskRead,
|
||||
AutomationTaskUpdate,
|
||||
TaskQueueSettings,
|
||||
TaskQueueSettingsResponse,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"TaskService",
|
||||
"Task",
|
||||
"TaskProgress",
|
||||
"TaskStatus",
|
||||
"task_queue_service",
|
||||
"task_scheduler",
|
||||
"AutomationTaskBase",
|
||||
"AutomationTaskCreate",
|
||||
"AutomationTaskRead",
|
||||
"AutomationTaskUpdate",
|
||||
"TaskQueueSettings",
|
||||
"TaskQueueSettingsResponse",
|
||||
]
|
||||
@@ -2,9 +2,9 @@ from fastapi import APIRouter, Depends, Request
|
||||
|
||||
from api.response import success
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.auth.service import get_current_active_user
|
||||
from domain.tasks.service import TaskService
|
||||
from domain.tasks.types import (
|
||||
from domain.auth import get_current_active_user
|
||||
from .service import TaskService
|
||||
from .types import (
|
||||
AutomationTaskCreate,
|
||||
AutomationTaskUpdate,
|
||||
TaskQueueSettings,
|
||||
@@ -59,8 +59,7 @@ async def get_task_status(task_id: str, request: Request, current_user: CurrentU
|
||||
body_fields=[
|
||||
"name",
|
||||
"event",
|
||||
"path_pattern",
|
||||
"filename_regex",
|
||||
"trigger_config",
|
||||
"processor_type",
|
||||
"processor_config",
|
||||
"enabled",
|
||||
@@ -93,8 +92,7 @@ async def list_tasks(request: Request, current_user: CurrentUser):
|
||||
body_fields=[
|
||||
"name",
|
||||
"event",
|
||||
"path_pattern",
|
||||
"filename_regex",
|
||||
"trigger_config",
|
||||
"processor_type",
|
||||
"processor_config",
|
||||
"enabled",
|
||||
|
||||
102
domain/tasks/scheduler.py
Normal file
102
domain/tasks/scheduler.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
|
||||
from croniter import croniter
|
||||
|
||||
from models.database import AutomationTask
|
||||
from .task_queue import task_queue_service
|
||||
|
||||
|
||||
@dataclass
|
||||
class CronTaskItem:
|
||||
task_id: int
|
||||
processor_type: str
|
||||
path: str
|
||||
cron: croniter
|
||||
next_run: datetime
|
||||
|
||||
|
||||
class AutomationTaskScheduler:
|
||||
def __init__(self):
|
||||
self._items: list[CronTaskItem] = []
|
||||
self._worker: asyncio.Task | None = None
|
||||
self._reload_event = asyncio.Event()
|
||||
self._stop_event = asyncio.Event()
|
||||
|
||||
async def start(self) -> None:
|
||||
if self._worker and not self._worker.done():
|
||||
return
|
||||
self._stop_event.clear()
|
||||
await self._load_tasks()
|
||||
self._worker = asyncio.create_task(self._run_loop())
|
||||
|
||||
async def stop(self) -> None:
|
||||
if not self._worker:
|
||||
return
|
||||
self._stop_event.set()
|
||||
self._reload_event.set()
|
||||
await self._worker
|
||||
self._worker = None
|
||||
|
||||
def refresh(self) -> None:
|
||||
if self._worker and not self._worker.done():
|
||||
self._reload_event.set()
|
||||
|
||||
async def _load_tasks(self) -> None:
|
||||
tasks = await AutomationTask.filter(event="cron", enabled=True)
|
||||
items: list[CronTaskItem] = []
|
||||
now = datetime.now()
|
||||
for task in tasks:
|
||||
trigger = task.trigger_config or {}
|
||||
if not isinstance(trigger, dict):
|
||||
continue
|
||||
cron_expr = trigger.get("cron_expr")
|
||||
path = trigger.get("path")
|
||||
if not cron_expr or not path:
|
||||
continue
|
||||
cron = self._build_cron(cron_expr, now)
|
||||
if not cron:
|
||||
continue
|
||||
next_run = cron.get_next(datetime)
|
||||
items.append(
|
||||
CronTaskItem(
|
||||
task_id=task.id,
|
||||
processor_type=task.processor_type,
|
||||
path=path,
|
||||
cron=cron,
|
||||
next_run=next_run,
|
||||
)
|
||||
)
|
||||
self._items = items
|
||||
|
||||
def _build_cron(self, expr: str, base_time: datetime) -> croniter | None:
|
||||
expr = str(expr or "").strip()
|
||||
if not expr:
|
||||
return None
|
||||
parts = [p for p in expr.split() if p]
|
||||
if len(parts) not in (5, 6):
|
||||
return None
|
||||
second_at_beginning = len(parts) == 6
|
||||
try:
|
||||
return croniter(expr, base_time, second_at_beginning=second_at_beginning)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def _run_loop(self) -> None:
|
||||
while not self._stop_event.is_set():
|
||||
if self._reload_event.is_set():
|
||||
self._reload_event.clear()
|
||||
await self._load_tasks()
|
||||
now = datetime.now()
|
||||
for item in list(self._items):
|
||||
if item.next_run <= now:
|
||||
await task_queue_service.add_task(
|
||||
item.processor_type,
|
||||
{"task_id": item.task_id, "path": item.path},
|
||||
)
|
||||
item.next_run = item.cron.get_next(datetime)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
|
||||
task_scheduler = AutomationTaskScheduler()
|
||||
@@ -3,17 +3,17 @@ from typing import Annotated, Any, Dict, Optional
|
||||
|
||||
from fastapi import Depends, HTTPException
|
||||
|
||||
from domain.auth.service import get_current_active_user
|
||||
from domain.auth.types import User
|
||||
from domain.config.service import ConfigService
|
||||
from domain.tasks.types import (
|
||||
from domain.auth import User, get_current_active_user
|
||||
from domain.config import ConfigService
|
||||
from .scheduler import task_scheduler
|
||||
from .task_queue import task_queue_service
|
||||
from .types import (
|
||||
AutomationTaskCreate,
|
||||
AutomationTaskUpdate,
|
||||
TaskQueueSettings,
|
||||
TaskQueueSettingsResponse,
|
||||
)
|
||||
from models.database import AutomationTask
|
||||
from domain.tasks.task_queue import task_queue_service
|
||||
|
||||
|
||||
class TaskService:
|
||||
@@ -47,6 +47,7 @@ class TaskService:
|
||||
@classmethod
|
||||
async def create_task(cls, payload: AutomationTaskCreate, user: Optional[User]) -> AutomationTask:
|
||||
task = await AutomationTask.create(**payload.model_dump())
|
||||
task_scheduler.refresh()
|
||||
return task
|
||||
|
||||
@classmethod
|
||||
@@ -70,6 +71,7 @@ class TaskService:
|
||||
for key, value in update_data.items():
|
||||
setattr(task, key, value)
|
||||
await task.save()
|
||||
task_scheduler.refresh()
|
||||
return task
|
||||
|
||||
@classmethod
|
||||
@@ -77,6 +79,7 @@ class TaskService:
|
||||
deleted_count = await AutomationTask.filter(id=task_id).delete()
|
||||
if not deleted_count:
|
||||
raise HTTPException(status_code=404, detail=f"Task {task_id} not found")
|
||||
task_scheduler.refresh()
|
||||
|
||||
@classmethod
|
||||
async def trigger_tasks(cls, event: str, path: str):
|
||||
@@ -87,11 +90,16 @@ class TaskService:
|
||||
|
||||
@classmethod
|
||||
def match(cls, task: AutomationTask, path: str) -> bool:
|
||||
if task.path_pattern and not path.startswith(task.path_pattern):
|
||||
trigger_config = task.trigger_config or {}
|
||||
if not isinstance(trigger_config, dict):
|
||||
trigger_config = {}
|
||||
path_prefix = trigger_config.get("path_prefix")
|
||||
filename_regex = trigger_config.get("filename_regex")
|
||||
if path_prefix and not path.startswith(path_prefix):
|
||||
return False
|
||||
if task.filename_regex:
|
||||
if filename_regex:
|
||||
filename = path.split("/")[-1]
|
||||
if not re.match(task.filename_regex, filename):
|
||||
if not re.match(filename_regex, filename):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@@ -74,7 +74,7 @@ class TaskQueueService:
|
||||
|
||||
try:
|
||||
# Local import to avoid circular dependency during module load.
|
||||
from domain.virtual_fs.service import VirtualFSService
|
||||
from domain.virtual_fs import VirtualFSService
|
||||
|
||||
if task.name == "process_file":
|
||||
params = task.task_info
|
||||
@@ -86,37 +86,38 @@ class TaskQueueService:
|
||||
overwrite=params.get("overwrite", False),
|
||||
)
|
||||
task.result = result
|
||||
elif task.name == "process_directory_scan":
|
||||
from domain.processors import ProcessDirectoryRequest, ProcessorService
|
||||
|
||||
params = task.task_info or {}
|
||||
req = ProcessDirectoryRequest(**params)
|
||||
task.result = await ProcessorService.scan_directory(req)
|
||||
elif task.name == "automation_task" or self._is_processor_task(task.name):
|
||||
from models.database import AutomationTask
|
||||
from domain.processors.service import get_processor
|
||||
|
||||
params = task.task_info
|
||||
auto_task = await AutomationTask.get(id=params["task_id"])
|
||||
path = params["path"]
|
||||
|
||||
processor_type = auto_task.processor_type if task.name == "automation_task" else task.name
|
||||
processor = get_processor(processor_type)
|
||||
if not processor:
|
||||
raise ValueError(f"Processor {processor_type} not found for task {auto_task.id}")
|
||||
|
||||
if processor_type != auto_task.processor_type:
|
||||
processor_type = auto_task.processor_type
|
||||
processor = get_processor(processor_type)
|
||||
if not processor:
|
||||
raise ValueError(f"Processor {processor_type} not found for task {auto_task.id}")
|
||||
|
||||
requires_input_bytes = bool(getattr(processor, "requires_input_bytes", True))
|
||||
file_content = b""
|
||||
if requires_input_bytes:
|
||||
file_content = await VirtualFSService.read_file(path)
|
||||
result = await processor.process(file_content, path, auto_task.processor_config)
|
||||
|
||||
save_to = auto_task.processor_config.get("save_to")
|
||||
if save_to and getattr(processor, "produces_file", False):
|
||||
await VirtualFSService.write_file(save_to, result)
|
||||
processor_type = auto_task.processor_type
|
||||
config = auto_task.processor_config or {}
|
||||
save_to = config.get("save_to") if isinstance(config, dict) else None
|
||||
overwrite = bool(config.get("overwrite")) if isinstance(config, dict) else False
|
||||
try:
|
||||
if await VirtualFSService.path_is_directory(path):
|
||||
overwrite = True
|
||||
except Exception:
|
||||
pass
|
||||
await VirtualFSService.process_file(
|
||||
path=path,
|
||||
processor_type=processor_type,
|
||||
config=config if isinstance(config, dict) else {},
|
||||
save_to=save_to,
|
||||
overwrite=overwrite,
|
||||
)
|
||||
task.result = "Automation task completed"
|
||||
elif task.name == "offline_http_download":
|
||||
from domain.offline_downloads.service import OfflineDownloadService
|
||||
from domain.offline_downloads import OfflineDownloadService
|
||||
|
||||
result_path = await OfflineDownloadService.run_http_download(task)
|
||||
task.result = {"path": result_path}
|
||||
@@ -124,12 +125,11 @@ class TaskQueueService:
|
||||
result = await VirtualFSService.run_cross_mount_transfer_task(task)
|
||||
task.result = result
|
||||
elif task.name == "send_email":
|
||||
from domain.email.service import EmailService
|
||||
from domain.email import EmailService
|
||||
await EmailService.send_from_task(task.id, task.task_info)
|
||||
task.result = "Email sent"
|
||||
else:
|
||||
raise ValueError(f"Unknown task name: {task.name}")
|
||||
|
||||
task.status = TaskStatus.SUCCESS
|
||||
|
||||
except Exception as e:
|
||||
@@ -141,7 +141,7 @@ class TaskQueueService:
|
||||
|
||||
def _is_processor_task(self, task_name: str) -> bool:
|
||||
try:
|
||||
from domain.processors.service import get_processor
|
||||
from domain.processors import get_processor
|
||||
|
||||
return get_processor(task_name) is not None
|
||||
except Exception:
|
||||
@@ -180,7 +180,7 @@ class TaskQueueService:
|
||||
|
||||
async def start_worker(self, concurrency: int | None = None):
|
||||
if concurrency is None:
|
||||
from domain.config.service import ConfigService
|
||||
from domain.config import ConfigService
|
||||
|
||||
stored_value = await ConfigService.get("TASK_QUEUE_CONCURRENCY", self._concurrency)
|
||||
try:
|
||||
|
||||
@@ -6,8 +6,7 @@ from pydantic import BaseModel, Field
|
||||
class AutomationTaskBase(BaseModel):
|
||||
name: str
|
||||
event: str
|
||||
path_pattern: Optional[str] = None
|
||||
filename_regex: Optional[str] = None
|
||||
trigger_config: Dict[str, Any] = {}
|
||||
processor_type: str
|
||||
processor_config: Dict[str, Any] = {}
|
||||
enabled: bool = True
|
||||
@@ -22,6 +21,7 @@ class AutomationTaskUpdate(AutomationTaskBase):
|
||||
event: Optional[str] = None
|
||||
processor_type: Optional[str] = None
|
||||
processor_config: Optional[Dict[str, Any]] = None
|
||||
trigger_config: Optional[Dict[str, Any]] = None
|
||||
enabled: Optional[bool] = None
|
||||
|
||||
|
||||
|
||||
4
domain/user/__init__.py
Normal file
4
domain/user/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .service import UserService
|
||||
|
||||
__all__ = ["UserService"]
|
||||
|
||||
79
domain/user/api.py
Normal file
79
domain/user/api.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from domain.auth.service import get_current_active_user
|
||||
from domain.auth.types import User
|
||||
from domain.permission import require_system_permission
|
||||
from domain.permission.types import SystemPermission
|
||||
|
||||
from .service import UserService
|
||||
from .types import UserCreate, UserDetail, UserInfo, UserRoleAssign, UserUpdate
|
||||
|
||||
router = APIRouter(prefix="/api", tags=["user"])
|
||||
|
||||
|
||||
@router.get("/users", response_model=list[UserInfo])
|
||||
@require_system_permission(SystemPermission.USER_LIST)
|
||||
async def list_users(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
) -> list[UserInfo]:
|
||||
return await UserService.get_all_users()
|
||||
|
||||
|
||||
@router.get("/users/{user_id}", response_model=UserDetail)
|
||||
@require_system_permission(SystemPermission.USER_LIST)
|
||||
async def get_user(
|
||||
user_id: int,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
) -> UserDetail:
|
||||
return await UserService.get_user(user_id)
|
||||
|
||||
|
||||
@router.post("/users", response_model=UserDetail)
|
||||
@require_system_permission(SystemPermission.USER_CREATE)
|
||||
async def create_user(
|
||||
data: UserCreate,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
) -> UserDetail:
|
||||
return await UserService.create_user(data, current_user.id)
|
||||
|
||||
|
||||
@router.put("/users/{user_id}", response_model=UserDetail)
|
||||
@require_system_permission(SystemPermission.USER_EDIT)
|
||||
async def update_user(
|
||||
user_id: int,
|
||||
data: UserUpdate,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
) -> UserDetail:
|
||||
return await UserService.update_user(user_id, data, current_user.id)
|
||||
|
||||
|
||||
@router.delete("/users/{user_id}")
|
||||
@require_system_permission(SystemPermission.USER_DELETE)
|
||||
async def delete_user(
|
||||
user_id: int,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
) -> dict:
|
||||
await UserService.delete_user(user_id, current_user.id)
|
||||
return {"success": True}
|
||||
|
||||
|
||||
@router.post("/users/{user_id}/roles", response_model=list[str])
|
||||
@require_system_permission(SystemPermission.USER_EDIT)
|
||||
async def set_user_roles(
|
||||
user_id: int,
|
||||
data: UserRoleAssign,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
) -> list[str]:
|
||||
return await UserService.set_user_roles(user_id, data.role_ids)
|
||||
|
||||
|
||||
@router.delete("/users/{user_id}/roles/{role_id}", response_model=list[str])
|
||||
@require_system_permission(SystemPermission.USER_EDIT)
|
||||
async def remove_user_role(
|
||||
user_id: int,
|
||||
role_id: int,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
) -> list[str]:
|
||||
return await UserService.remove_user_role(user_id, role_id)
|
||||
190
domain/user/service.py
Normal file
190
domain/user/service.py
Normal file
@@ -0,0 +1,190 @@
|
||||
from typing import List
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from domain.auth.service import AuthService
|
||||
from domain.permission.service import PermissionService
|
||||
from models.database import Role, UserAccount, UserRole
|
||||
|
||||
from .types import UserCreate, UserDetail, UserInfo, UserUpdate
|
||||
|
||||
|
||||
class UserService:
|
||||
"""用户管理服务"""
|
||||
|
||||
@classmethod
|
||||
async def get_all_users(cls) -> List[UserInfo]:
|
||||
users = await UserAccount.all().order_by("id")
|
||||
return [
|
||||
UserInfo(
|
||||
id=u.id,
|
||||
username=u.username,
|
||||
email=u.email,
|
||||
full_name=u.full_name,
|
||||
disabled=u.disabled,
|
||||
is_admin=u.is_admin,
|
||||
created_at=u.created_at,
|
||||
last_login=u.last_login,
|
||||
)
|
||||
for u in users
|
||||
]
|
||||
|
||||
@classmethod
|
||||
async def get_user(cls, user_id: int) -> UserDetail:
|
||||
user = await UserAccount.get_or_none(id=user_id).prefetch_related("created_by")
|
||||
if not user:
|
||||
raise HTTPException(404, detail="用户不存在")
|
||||
|
||||
user_roles = await UserRole.filter(user_id=user_id).prefetch_related("role")
|
||||
roles = [ur.role.name for ur in user_roles]
|
||||
|
||||
created_by_username = None
|
||||
if user.created_by_id:
|
||||
creator = await UserAccount.get_or_none(id=user.created_by_id)
|
||||
if creator:
|
||||
created_by_username = creator.username
|
||||
|
||||
return UserDetail(
|
||||
id=user.id,
|
||||
username=user.username,
|
||||
email=user.email,
|
||||
full_name=user.full_name,
|
||||
disabled=user.disabled,
|
||||
is_admin=user.is_admin,
|
||||
created_at=user.created_at,
|
||||
last_login=user.last_login,
|
||||
roles=roles,
|
||||
created_by_username=created_by_username,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def get_users_by_role(cls, role_id: int) -> List[UserInfo]:
|
||||
role = await Role.get_or_none(id=role_id)
|
||||
if not role:
|
||||
raise HTTPException(404, detail="角色不存在")
|
||||
|
||||
user_roles = await UserRole.filter(role_id=role_id).prefetch_related("user")
|
||||
users = [ur.user for ur in user_roles if ur.user]
|
||||
users.sort(key=lambda u: u.id)
|
||||
return [
|
||||
UserInfo(
|
||||
id=u.id,
|
||||
username=u.username,
|
||||
email=u.email,
|
||||
full_name=u.full_name,
|
||||
disabled=u.disabled,
|
||||
is_admin=u.is_admin,
|
||||
created_at=u.created_at,
|
||||
last_login=u.last_login,
|
||||
)
|
||||
for u in users
|
||||
]
|
||||
|
||||
@classmethod
|
||||
async def create_user(cls, data: UserCreate, creator_id: int) -> UserDetail:
|
||||
existing = await UserAccount.get_or_none(username=data.username)
|
||||
if existing:
|
||||
raise HTTPException(400, detail="用户名已存在")
|
||||
|
||||
if data.email:
|
||||
existing_email = await UserAccount.get_or_none(email=data.email)
|
||||
if existing_email:
|
||||
raise HTTPException(400, detail="邮箱已被使用")
|
||||
|
||||
hashed_password = AuthService.get_password_hash(data.password)
|
||||
user = await UserAccount.create(
|
||||
username=data.username,
|
||||
email=data.email,
|
||||
full_name=data.full_name,
|
||||
hashed_password=hashed_password,
|
||||
disabled=data.disabled,
|
||||
is_admin=data.is_admin,
|
||||
created_by_id=creator_id,
|
||||
)
|
||||
|
||||
if data.role_ids:
|
||||
for role_id in data.role_ids:
|
||||
role = await Role.get_or_none(id=role_id)
|
||||
if role:
|
||||
await UserRole.create(user_id=user.id, role_id=role_id)
|
||||
|
||||
return await cls.get_user(user.id)
|
||||
|
||||
@classmethod
|
||||
async def update_user(cls, user_id: int, data: UserUpdate, operator_id: int) -> UserDetail:
|
||||
user = await UserAccount.get_or_none(id=user_id)
|
||||
if not user:
|
||||
raise HTTPException(404, detail="用户不存在")
|
||||
|
||||
if data.is_admin is not None and user_id == operator_id:
|
||||
raise HTTPException(400, detail="不能修改自己的管理员状态")
|
||||
|
||||
if data.email is not None:
|
||||
existing = await UserAccount.filter(email=data.email).exclude(id=user_id).first()
|
||||
if existing:
|
||||
raise HTTPException(400, detail="邮箱已被使用")
|
||||
user.email = data.email
|
||||
|
||||
if data.full_name is not None:
|
||||
user.full_name = data.full_name
|
||||
|
||||
if data.password is not None:
|
||||
user.hashed_password = AuthService.get_password_hash(data.password)
|
||||
|
||||
if data.is_admin is not None:
|
||||
user.is_admin = data.is_admin
|
||||
|
||||
if data.disabled is not None:
|
||||
if user_id == operator_id and data.disabled:
|
||||
raise HTTPException(400, detail="不能禁用自己")
|
||||
user.disabled = data.disabled
|
||||
|
||||
await user.save()
|
||||
|
||||
PermissionService.clear_cache(user_id)
|
||||
return await cls.get_user(user_id)
|
||||
|
||||
@classmethod
|
||||
async def delete_user(cls, user_id: int, operator_id: int) -> None:
|
||||
if user_id == operator_id:
|
||||
raise HTTPException(400, detail="不能删除自己")
|
||||
|
||||
user = await UserAccount.get_or_none(id=user_id)
|
||||
if not user:
|
||||
raise HTTPException(404, detail="用户不存在")
|
||||
|
||||
await UserRole.filter(user_id=user_id).delete()
|
||||
await user.delete()
|
||||
PermissionService.clear_cache(user_id)
|
||||
|
||||
@classmethod
|
||||
async def set_user_roles(cls, user_id: int, role_ids: List[int]) -> List[str]:
|
||||
user = await UserAccount.get_or_none(id=user_id)
|
||||
if not user:
|
||||
raise HTTPException(404, detail="用户不存在")
|
||||
|
||||
roles = await Role.filter(id__in=role_ids)
|
||||
valid_role_ids = {r.id for r in roles}
|
||||
invalid_ids = set(role_ids) - valid_role_ids
|
||||
if invalid_ids:
|
||||
raise HTTPException(400, detail=f"无效的角色ID: {invalid_ids}")
|
||||
|
||||
await UserRole.filter(user_id=user_id).delete()
|
||||
for role_id in role_ids:
|
||||
await UserRole.create(user_id=user_id, role_id=role_id)
|
||||
|
||||
PermissionService.clear_cache(user_id)
|
||||
return [r.name for r in roles if r.id in role_ids]
|
||||
|
||||
@classmethod
|
||||
async def remove_user_role(cls, user_id: int, role_id: int) -> List[str]:
|
||||
user = await UserAccount.get_or_none(id=user_id)
|
||||
if not user:
|
||||
raise HTTPException(404, detail="用户不存在")
|
||||
|
||||
await UserRole.filter(user_id=user_id, role_id=role_id).delete()
|
||||
PermissionService.clear_cache(user_id)
|
||||
|
||||
user_roles = await UserRole.filter(user_id=user_id).prefetch_related("role")
|
||||
return [ur.role.name for ur in user_roles]
|
||||
|
||||
42
domain/user/types.py
Normal file
42
domain/user/types.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class UserInfo(BaseModel):
|
||||
id: int
|
||||
username: str
|
||||
email: str | None = None
|
||||
full_name: str | None = None
|
||||
disabled: bool
|
||||
is_admin: bool
|
||||
created_at: datetime
|
||||
last_login: datetime | None = None
|
||||
|
||||
|
||||
class UserDetail(UserInfo):
|
||||
roles: list[str]
|
||||
created_by_username: str | None = None
|
||||
|
||||
|
||||
class UserCreate(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
email: str | None = None
|
||||
full_name: str | None = None
|
||||
is_admin: bool = False
|
||||
disabled: bool = False
|
||||
role_ids: list[int] = []
|
||||
|
||||
|
||||
class UserUpdate(BaseModel):
|
||||
email: str | None = None
|
||||
full_name: str | None = None
|
||||
password: str | None = None
|
||||
is_admin: bool | None = None
|
||||
disabled: bool | None = None
|
||||
|
||||
|
||||
class UserRoleAssign(BaseModel):
|
||||
role_ids: list[int]
|
||||
|
||||
11
domain/virtual_fs/__init__.py
Normal file
11
domain/virtual_fs/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from .service import VirtualFSService
|
||||
from .types import DirListing, MkdirRequest, MoveRequest, SearchResultItem, VfsEntry
|
||||
|
||||
__all__ = [
|
||||
"VirtualFSService",
|
||||
"DirListing",
|
||||
"MkdirRequest",
|
||||
"MoveRequest",
|
||||
"SearchResultItem",
|
||||
"VfsEntry",
|
||||
]
|
||||
@@ -4,16 +4,18 @@ from fastapi import APIRouter, Depends, File, Query, Request, UploadFile
|
||||
|
||||
from api.response import success
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.auth.service import get_current_active_user
|
||||
from domain.auth.types import User
|
||||
from domain.virtual_fs.service import VirtualFSService
|
||||
from domain.virtual_fs.types import MkdirRequest, MoveRequest
|
||||
from domain.auth import User, get_current_active_user
|
||||
from domain.permission import require_path_permission
|
||||
from domain.permission.types import PathAction
|
||||
from .service import VirtualFSService
|
||||
from .types import MkdirRequest, MoveRequest
|
||||
|
||||
router = APIRouter(prefix="/api/fs", tags=["virtual-fs"])
|
||||
|
||||
|
||||
@router.get("/file/{full_path:path}")
|
||||
@audit(action=AuditAction.DOWNLOAD, description="获取文件")
|
||||
@require_path_permission(PathAction.READ, "full_path")
|
||||
async def get_file(
|
||||
full_path: str,
|
||||
request: Request,
|
||||
@@ -45,6 +47,7 @@ async def stream_endpoint(
|
||||
|
||||
@router.get("/temp-link/{full_path:path}")
|
||||
@audit(action=AuditAction.SHARE, description="创建临时链接")
|
||||
@require_path_permission(PathAction.READ, "full_path")
|
||||
async def get_temp_link(
|
||||
full_path: str,
|
||||
request: Request,
|
||||
@@ -64,8 +67,19 @@ async def access_public_file(
|
||||
return await VirtualFSService.access_public_file(token, request.headers.get("Range"))
|
||||
|
||||
|
||||
@router.get("/public/{token}/{filename}")
|
||||
@audit(action=AuditAction.DOWNLOAD, description="访问临时链接文件")
|
||||
async def access_public_file_with_name(
|
||||
token: str,
|
||||
filename: str,
|
||||
request: Request,
|
||||
):
|
||||
return await VirtualFSService.access_public_file(token, request.headers.get("Range"))
|
||||
|
||||
|
||||
@router.get("/stat/{full_path:path}")
|
||||
@audit(action=AuditAction.READ, description="查看文件信息")
|
||||
@require_path_permission(PathAction.READ, "full_path")
|
||||
async def get_file_stat(
|
||||
full_path: str,
|
||||
request: Request,
|
||||
@@ -77,6 +91,7 @@ async def get_file_stat(
|
||||
|
||||
@router.post("/file/{full_path:path}")
|
||||
@audit(action=AuditAction.UPLOAD, description="上传文件")
|
||||
@require_path_permission(PathAction.WRITE, "full_path")
|
||||
async def put_file(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
@@ -90,6 +105,7 @@ async def put_file(
|
||||
|
||||
@router.post("/mkdir")
|
||||
@audit(action=AuditAction.CREATE, description="创建目录", body_fields=["path"])
|
||||
@require_path_permission(PathAction.WRITE, "body.path")
|
||||
async def api_mkdir(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
@@ -101,6 +117,8 @@ async def api_mkdir(
|
||||
|
||||
@router.post("/move")
|
||||
@audit(action=AuditAction.UPDATE, description="移动路径", body_fields=["src", "dst"])
|
||||
@require_path_permission(PathAction.WRITE, "body.dst")
|
||||
@require_path_permission(PathAction.DELETE, "body.src")
|
||||
async def api_move(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
@@ -113,6 +131,7 @@ async def api_move(
|
||||
|
||||
@router.post("/rename")
|
||||
@audit(action=AuditAction.UPDATE, description="重命名路径", body_fields=["src", "dst"])
|
||||
@require_path_permission(PathAction.WRITE, "body.src")
|
||||
async def api_rename(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
@@ -125,6 +144,8 @@ async def api_rename(
|
||||
|
||||
@router.post("/copy")
|
||||
@audit(action=AuditAction.CREATE, description="复制路径", body_fields=["src", "dst"])
|
||||
@require_path_permission(PathAction.WRITE, "body.dst")
|
||||
@require_path_permission(PathAction.READ, "body.src")
|
||||
async def api_copy(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
@@ -137,6 +158,7 @@ async def api_copy(
|
||||
|
||||
@router.post("/upload/{full_path:path}")
|
||||
@audit(action=AuditAction.UPLOAD, description="流式上传文件")
|
||||
@require_path_permission(PathAction.WRITE, "full_path")
|
||||
async def upload_stream(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
@@ -151,6 +173,7 @@ async def upload_stream(
|
||||
|
||||
@router.get("/{full_path:path}")
|
||||
@audit(action=AuditAction.READ, description="浏览目录")
|
||||
@require_path_permission(PathAction.READ, "full_path")
|
||||
async def browse_fs(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
@@ -160,12 +183,15 @@ async def browse_fs(
|
||||
sort_by: str = Query("name", description="按字段排序: name, size, mtime"),
|
||||
sort_order: str = Query("asc", description="排序顺序: asc, desc"),
|
||||
):
|
||||
data = await VirtualFSService.list_directory(full_path, page_num, page_size, sort_by, sort_order)
|
||||
data = await VirtualFSService.list_directory_with_permission(
|
||||
full_path, current_user.id, page_num, page_size, sort_by, sort_order
|
||||
)
|
||||
return success(data)
|
||||
|
||||
|
||||
@router.delete("/{full_path:path}")
|
||||
@audit(action=AuditAction.DELETE, description="删除路径")
|
||||
@require_path_permission(PathAction.DELETE, "full_path")
|
||||
async def api_delete(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
@@ -185,5 +211,8 @@ async def root_listing(
|
||||
sort_by: str = Query("name", description="按字段排序: name, size, mtime"),
|
||||
sort_order: str = Query("asc", description="排序顺序: asc, desc"),
|
||||
):
|
||||
data = await VirtualFSService.list_directory("/", page_num, page_size, sort_by, sort_order)
|
||||
# 根目录不需要权限检查,但需要过滤无权限的子目录
|
||||
data = await VirtualFSService.list_directory_with_permission(
|
||||
"/", current_user.id, page_num, page_size, sort_by, sort_order
|
||||
)
|
||||
return success(data)
|
||||
|
||||
@@ -4,13 +4,36 @@ from typing import Any, AsyncIterator, Union
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import Response
|
||||
|
||||
from domain.tasks.service import TaskService
|
||||
from domain.virtual_fs.thumbnail import is_raw_filename, raw_bytes_to_jpeg
|
||||
from domain.tasks import TaskService
|
||||
from .thumbnail import is_raw_filename, raw_bytes_to_jpeg
|
||||
|
||||
from .listing import VirtualFSListingMixin
|
||||
|
||||
|
||||
class VirtualFSFileOpsMixin(VirtualFSListingMixin):
|
||||
@classmethod
|
||||
def _normalize_written_result(
|
||||
cls,
|
||||
original_path: str,
|
||||
adapter_model: Any,
|
||||
result: Any,
|
||||
size_hint: int,
|
||||
) -> tuple[str, int]:
|
||||
final_path = original_path
|
||||
size = size_hint
|
||||
if isinstance(result, dict):
|
||||
rel_override = result.get("rel")
|
||||
if isinstance(rel_override, str) and rel_override:
|
||||
final_path = cls._build_absolute_path(adapter_model.path, rel_override)
|
||||
else:
|
||||
path_override = result.get("path")
|
||||
if isinstance(path_override, str) and path_override:
|
||||
final_path = cls._normalize_path(path_override)
|
||||
size_val = result.get("size")
|
||||
if isinstance(size_val, int):
|
||||
size = size_val
|
||||
return final_path, size
|
||||
|
||||
@classmethod
|
||||
async def read_file(cls, path: str) -> Union[bytes, Any]:
|
||||
adapter_instance, _, root, rel = await cls.resolve_adapter_and_rel(path)
|
||||
@@ -21,16 +44,18 @@ class VirtualFSFileOpsMixin(VirtualFSListingMixin):
|
||||
|
||||
@classmethod
|
||||
async def write_file(cls, path: str, data: bytes):
|
||||
adapter_instance, _, root, rel = await cls.resolve_adapter_and_rel(path)
|
||||
adapter_instance, adapter_model, root, rel = await cls.resolve_adapter_and_rel(path)
|
||||
if rel.endswith("/"):
|
||||
raise HTTPException(400, detail="Invalid file path")
|
||||
write_func = await cls._ensure_method(adapter_instance, "write_file")
|
||||
await write_func(root, rel, data)
|
||||
await TaskService.trigger_tasks("file_written", path)
|
||||
result = await write_func(root, rel, data)
|
||||
final_path, size = cls._normalize_written_result(path, adapter_model, result, len(data))
|
||||
await TaskService.trigger_tasks("file_written", final_path)
|
||||
return {"path": final_path, "size": size}
|
||||
|
||||
@classmethod
|
||||
async def write_file_stream(cls, path: str, data_iter: AsyncIterator[bytes], overwrite: bool = True):
|
||||
adapter_instance, _, root, rel = await cls.resolve_adapter_and_rel(path)
|
||||
adapter_instance, adapter_model, root, rel = await cls.resolve_adapter_and_rel(path)
|
||||
if rel.endswith("/"):
|
||||
raise HTTPException(400, detail="Invalid file path")
|
||||
exists_func = getattr(adapter_instance, "exists", None)
|
||||
@@ -46,18 +71,23 @@ class VirtualFSFileOpsMixin(VirtualFSListingMixin):
|
||||
size = 0
|
||||
stream_func = getattr(adapter_instance, "write_file_stream", None)
|
||||
if callable(stream_func):
|
||||
size = await stream_func(root, rel, data_iter)
|
||||
result = await stream_func(root, rel, data_iter)
|
||||
if isinstance(result, dict):
|
||||
size = int(result.get("size") or 0)
|
||||
else:
|
||||
size = int(result or 0)
|
||||
else:
|
||||
buf = bytearray()
|
||||
async for chunk in data_iter:
|
||||
if chunk:
|
||||
buf.extend(chunk)
|
||||
write_func = await cls._ensure_method(adapter_instance, "write_file")
|
||||
await write_func(root, rel, bytes(buf))
|
||||
result = await write_func(root, rel, bytes(buf))
|
||||
size = len(buf)
|
||||
|
||||
await TaskService.trigger_tasks("file_written", path)
|
||||
return size
|
||||
final_path, size = cls._normalize_written_result(path, adapter_model, result, size)
|
||||
await TaskService.trigger_tasks("file_written", final_path)
|
||||
return {"path": final_path, "size": size}
|
||||
|
||||
@classmethod
|
||||
async def make_dir(cls, path: str):
|
||||
|
||||
@@ -3,9 +3,11 @@ from typing import Any, Dict, List, Tuple
|
||||
from fastapi import HTTPException
|
||||
|
||||
from api.response import page
|
||||
from domain.adapters.registry import runtime_registry
|
||||
from domain.ai.service import VectorDBService, VECTOR_COLLECTION_NAME, FILE_COLLECTION_NAME
|
||||
from domain.virtual_fs.thumbnail import is_image_filename, is_video_filename
|
||||
from domain.adapters import runtime_registry
|
||||
from domain.ai import FILE_COLLECTION_NAME, VECTOR_COLLECTION_NAME, VectorDBService
|
||||
from domain.permission.service import PermissionService
|
||||
from domain.permission.types import PathAction
|
||||
from .thumbnail import is_image_filename, is_video_filename
|
||||
from models import StorageAdapter
|
||||
|
||||
from .resolver import VirtualFSResolverMixin
|
||||
@@ -225,7 +227,10 @@ class VirtualFSListingMixin(VirtualFSResolverMixin):
|
||||
stat_func = getattr(adapter_instance, "stat_file", None)
|
||||
if not callable(stat_func):
|
||||
raise HTTPException(501, detail="Adapter does not implement stat_file")
|
||||
info = await stat_func(root, rel)
|
||||
try:
|
||||
info = await stat_func(root, rel)
|
||||
except FileNotFoundError as exc:
|
||||
raise HTTPException(404, detail=str(exc))
|
||||
|
||||
if isinstance(info, dict):
|
||||
info.setdefault("path", path)
|
||||
@@ -242,3 +247,54 @@ class VirtualFSListingMixin(VirtualFSResolverMixin):
|
||||
info["vector_index"] = vector_index
|
||||
|
||||
return info
|
||||
|
||||
@classmethod
|
||||
async def list_virtual_dir_with_permission(
|
||||
cls,
|
||||
path: str,
|
||||
user_id: int,
|
||||
page_num: int = 1,
|
||||
page_size: int = 50,
|
||||
sort_by: str = "name",
|
||||
sort_order: str = "asc",
|
||||
) -> Dict:
|
||||
"""
|
||||
带权限过滤的目录列表
|
||||
|
||||
过滤掉用户没有读取权限的条目
|
||||
"""
|
||||
# 首先获取完整的目录列表
|
||||
result = await cls.list_virtual_dir(path, page_num, page_size, sort_by, sort_order)
|
||||
|
||||
# 检查用户是否是管理员(管理员可以看到所有内容)
|
||||
from models.database import UserAccount
|
||||
user = await UserAccount.get_or_none(id=user_id)
|
||||
if user and user.is_admin:
|
||||
return result
|
||||
|
||||
# 过滤无权限的条目
|
||||
items = result.get("items", [])
|
||||
if not items:
|
||||
return result
|
||||
|
||||
norm = cls._normalize_path(path).rstrip("/") or "/"
|
||||
filtered_items = []
|
||||
|
||||
for item in items:
|
||||
item_name = item.get("name", "")
|
||||
if norm == "/":
|
||||
item_path = f"/{item_name}"
|
||||
else:
|
||||
item_path = f"{norm}/{item_name}"
|
||||
|
||||
# 检查用户是否有读取权限
|
||||
has_permission = await PermissionService.check_path_permission(
|
||||
user_id, item_path, PathAction.READ
|
||||
)
|
||||
if has_permission:
|
||||
filtered_items.append(item)
|
||||
|
||||
# 更新结果
|
||||
result["items"] = filtered_items
|
||||
|
||||
return result
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
__all__: list[str] = []
|
||||
|
||||
@@ -15,8 +15,8 @@ from fastapi import APIRouter, Request, Response
|
||||
from fastapi import HTTPException
|
||||
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.config.service import ConfigService
|
||||
from domain.virtual_fs.service import VirtualFSService
|
||||
from domain.config import ConfigService
|
||||
from domain.virtual_fs import VirtualFSService
|
||||
|
||||
|
||||
router = APIRouter(prefix="/s3", tags=["s3"])
|
||||
|
||||
@@ -9,10 +9,11 @@ from fastapi import APIRouter, Request, Response, HTTPException, Depends
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.auth.service import AuthService
|
||||
from domain.auth.types import User, UserInDB
|
||||
from domain.virtual_fs.service import VirtualFSService
|
||||
from domain.config.service import ConfigService
|
||||
from domain.auth import AuthService, User, UserInDB
|
||||
from domain.config import ConfigService
|
||||
from domain.permission.service import PermissionService
|
||||
from domain.permission.types import PathAction
|
||||
from domain.virtual_fs import VirtualFSService
|
||||
|
||||
|
||||
_WEBDAV_ENABLED_KEY = "WEBDAV_MAPPING_ENABLED"
|
||||
@@ -66,11 +67,26 @@ async def _get_basic_user(request: Request) -> User:
|
||||
if not user_or_false:
|
||||
raise HTTPException(401, detail="Invalid credentials", headers={"WWW-Authenticate": "Basic realm=webdav"})
|
||||
u: UserInDB = user_or_false
|
||||
return User(id=u.id, username=u.username, email=u.email, full_name=u.full_name, disabled=u.disabled)
|
||||
return User(
|
||||
id=u.id,
|
||||
username=u.username,
|
||||
email=u.email,
|
||||
full_name=u.full_name,
|
||||
disabled=u.disabled,
|
||||
is_admin=u.is_admin,
|
||||
)
|
||||
elif scheme_lower == "bearer":
|
||||
if not param:
|
||||
raise HTTPException(401, detail="Invalid Bearer token")
|
||||
return User(id=0, username="bearer", email=None, full_name=None, disabled=False)
|
||||
u = await AuthService.get_current_user(param)
|
||||
return User(
|
||||
id=u.id,
|
||||
username=u.username,
|
||||
email=u.email,
|
||||
full_name=u.full_name,
|
||||
disabled=u.disabled,
|
||||
is_admin=u.is_admin,
|
||||
)
|
||||
else:
|
||||
raise HTTPException(401, detail="Unsupported auth", headers={"WWW-Authenticate": "Basic realm=webdav"})
|
||||
|
||||
@@ -156,6 +172,8 @@ async def propfind(
|
||||
user: User = Depends(_get_basic_user),
|
||||
):
|
||||
full_path = _normalize_fs_path(path)
|
||||
if full_path != "/":
|
||||
await PermissionService.require_path_permission(user.id, full_path, PathAction.READ)
|
||||
depth = request.headers.get("Depth", "1").lower()
|
||||
if depth not in ("0", "1", "infinity"):
|
||||
depth = "1"
|
||||
@@ -172,12 +190,34 @@ async def propfind(
|
||||
ctype = None if is_dir else (mimetypes.guess_type(name)[0] or "application/octet-stream")
|
||||
responses.append(_build_prop_response(full_path, name, is_dir, size, mtime, ctype))
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(404, detail="Not found")
|
||||
st = None
|
||||
except HTTPException as e:
|
||||
if e.status_code != 404:
|
||||
raise
|
||||
st = None
|
||||
|
||||
if st is None:
|
||||
is_mount_root = False
|
||||
try:
|
||||
_, rel = await VirtualFSService.resolve_adapter_by_path(full_path)
|
||||
is_mount_root = rel == ""
|
||||
except HTTPException:
|
||||
is_mount_root = False
|
||||
|
||||
if not is_mount_root and full_path != "/":
|
||||
listing_probe = await VirtualFSService.list_virtual_dir(full_path, page_num=1, page_size=1)
|
||||
if not (listing_probe.get("items") or []):
|
||||
raise HTTPException(404, detail="Not found")
|
||||
|
||||
name = "/" if full_path == "/" else (full_path.rstrip("/").rsplit("/", 1)[-1] or "/")
|
||||
responses.append(_build_prop_response(full_path, name, True, None, 0, None))
|
||||
|
||||
if depth in ("1", "infinity"):
|
||||
try:
|
||||
listing = await VirtualFSService.list_virtual_dir(full_path, page_num=1, page_size=1000)
|
||||
for ent in listing["items"]:
|
||||
listing = await VirtualFSService.list_virtual_dir_with_permission(
|
||||
full_path, user.id, page_num=1, page_size=1000
|
||||
)
|
||||
for ent in (listing.get("items") or []):
|
||||
is_dir = bool(ent.get("is_dir"))
|
||||
name = ent.get("name")
|
||||
child_path = full_path.rstrip("/") + "/" + name
|
||||
@@ -204,6 +244,8 @@ async def dav_get(
|
||||
user: User = Depends(_get_basic_user),
|
||||
):
|
||||
full_path = _normalize_fs_path(path)
|
||||
if full_path != "/":
|
||||
await PermissionService.require_path_permission(user.id, full_path, PathAction.READ)
|
||||
range_header = request.headers.get("Range")
|
||||
return await VirtualFSService.stream_file(full_path, range_header)
|
||||
|
||||
@@ -217,6 +259,8 @@ async def dav_head(
|
||||
user: User = Depends(_get_basic_user),
|
||||
):
|
||||
full_path = _normalize_fs_path(path)
|
||||
if full_path != "/":
|
||||
await PermissionService.require_path_permission(user.id, full_path, PathAction.READ)
|
||||
try:
|
||||
st = await VirtualFSService.stat_file(full_path)
|
||||
except FileNotFoundError:
|
||||
@@ -245,6 +289,7 @@ async def dav_put(
|
||||
user: User = Depends(_get_basic_user),
|
||||
):
|
||||
full_path = _normalize_fs_path(path)
|
||||
await PermissionService.require_path_permission(user.id, full_path, PathAction.WRITE)
|
||||
async def body_iter():
|
||||
async for chunk in request.stream():
|
||||
if chunk:
|
||||
@@ -262,6 +307,7 @@ async def dav_delete(
|
||||
user: User = Depends(_get_basic_user),
|
||||
):
|
||||
full_path = _normalize_fs_path(path)
|
||||
await PermissionService.require_path_permission(user.id, full_path, PathAction.DELETE)
|
||||
await VirtualFSService.delete_path(full_path)
|
||||
return Response(status_code=204, headers=_dav_headers())
|
||||
|
||||
@@ -275,6 +321,7 @@ async def dav_mkcol(
|
||||
user: User = Depends(_get_basic_user),
|
||||
):
|
||||
full_path = _normalize_fs_path(path)
|
||||
await PermissionService.require_path_permission(user.id, full_path, PathAction.WRITE)
|
||||
await VirtualFSService.make_dir(full_path)
|
||||
return Response(status_code=201, headers=_dav_headers())
|
||||
|
||||
@@ -303,6 +350,8 @@ async def dav_move(
|
||||
dest_header = request.headers.get("Destination")
|
||||
dst = _parse_destination(dest_header or "")
|
||||
overwrite = request.headers.get("Overwrite", "T").upper() != "F"
|
||||
await PermissionService.require_path_permission(user.id, full_src, PathAction.DELETE)
|
||||
await PermissionService.require_path_permission(user.id, dst, PathAction.WRITE)
|
||||
await VirtualFSService.move_path(full_src, dst, overwrite=overwrite)
|
||||
return Response(status_code=204, headers=_dav_headers())
|
||||
|
||||
@@ -319,5 +368,7 @@ async def dav_copy(
|
||||
dest_header = request.headers.get("Destination")
|
||||
dst = _parse_destination(dest_header or "")
|
||||
overwrite = request.headers.get("Overwrite", "T").upper() != "F"
|
||||
await PermissionService.require_path_permission(user.id, full_src, PathAction.READ)
|
||||
await PermissionService.require_path_permission(user.id, dst, PathAction.WRITE)
|
||||
await VirtualFSService.copy_path(full_src, dst, overwrite=overwrite)
|
||||
return Response(status_code=201 if not overwrite else 204, headers=_dav_headers())
|
||||
|
||||
@@ -16,7 +16,7 @@ class VirtualFSProcessingMixin(VirtualFSTransferMixin):
|
||||
save_to: str | None = None,
|
||||
overwrite: bool = False,
|
||||
) -> Any:
|
||||
from domain.processors.service import get_processor
|
||||
from domain.processors import get_processor
|
||||
|
||||
processor = get_processor(processor_type)
|
||||
if not processor:
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Tuple
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import Response
|
||||
|
||||
from domain.adapters.registry import runtime_registry
|
||||
from domain.adapters import runtime_registry
|
||||
from models import StorageAdapter
|
||||
|
||||
from .common import VirtualFSCommonMixin
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
import mimetypes
|
||||
import re
|
||||
from urllib.parse import quote
|
||||
|
||||
from fastapi import HTTPException, UploadFile
|
||||
from fastapi.responses import Response
|
||||
|
||||
from domain.config.service import ConfigService
|
||||
from domain.virtual_fs.thumbnail import (
|
||||
from domain.config import ConfigService
|
||||
from domain.tasks import TaskService
|
||||
from .thumbnail import (
|
||||
get_or_create_thumb,
|
||||
is_image_filename,
|
||||
is_raw_filename,
|
||||
@@ -112,12 +114,14 @@ class VirtualFSRouteMixin(VirtualFSTempLinkMixin):
|
||||
async def create_temp_link(cls, full_path: str, expires_in: int):
|
||||
full_path = cls._normalize_path(full_path)
|
||||
token = await cls.generate_temp_link_token(full_path, expires_in=expires_in)
|
||||
filename = full_path.rstrip("/").split("/")[-1]
|
||||
filename_part = f"/{quote(filename, safe='')}" if filename else ""
|
||||
file_domain = await ConfigService.get("FILE_DOMAIN")
|
||||
if file_domain:
|
||||
file_domain = file_domain.rstrip("/")
|
||||
url = f"{file_domain}/api/fs/public/{token}"
|
||||
url = f"{file_domain}/api/fs/public/{token}{filename_part}"
|
||||
else:
|
||||
url = f"/api/fs/public/{token}"
|
||||
url = f"/api/fs/public/{token}{filename_part}"
|
||||
return {"token": token, "path": full_path, "url": url}
|
||||
|
||||
@classmethod
|
||||
@@ -128,12 +132,17 @@ class VirtualFSRouteMixin(VirtualFSTempLinkMixin):
|
||||
raise exc
|
||||
|
||||
try:
|
||||
return await cls.stream_file(path, range_header)
|
||||
response = await cls.stream_file(path, range_header)
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(404, detail="File not found via token")
|
||||
except Exception as exc:
|
||||
raise HTTPException(500, detail=f"File access error: {exc}")
|
||||
|
||||
filename = path.rstrip("/").split("/")[-1]
|
||||
if filename and not response.headers.get("Content-Disposition"):
|
||||
response.headers["Content-Disposition"] = f"inline; filename*=UTF-8''{quote(filename, safe='')}"
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
async def stat(cls, full_path: str):
|
||||
full_path = cls._normalize_path(full_path)
|
||||
@@ -142,8 +151,15 @@ class VirtualFSRouteMixin(VirtualFSTempLinkMixin):
|
||||
@classmethod
|
||||
async def write_uploaded_file(cls, full_path: str, data: bytes):
|
||||
full_path = cls._normalize_path(full_path)
|
||||
await cls.write_file(full_path, data)
|
||||
return {"written": True, "path": full_path, "size": len(data)}
|
||||
result = await cls.write_file(full_path, data)
|
||||
path = full_path
|
||||
size = len(data)
|
||||
if isinstance(result, dict):
|
||||
path = result.get("path") or path
|
||||
size_val = result.get("size")
|
||||
if isinstance(size_val, int):
|
||||
size = size_val
|
||||
return {"written": True, "path": path, "size": size}
|
||||
|
||||
@classmethod
|
||||
async def mkdir(cls, path: str):
|
||||
@@ -201,7 +217,7 @@ class VirtualFSRouteMixin(VirtualFSTempLinkMixin):
|
||||
full_path = cls._normalize_path(full_path)
|
||||
if full_path.endswith("/"):
|
||||
raise HTTPException(400, detail="Path must be a file")
|
||||
adapter, _m, root, rel = await cls.resolve_adapter_and_rel(full_path)
|
||||
adapter, adapter_model, root, rel = await cls.resolve_adapter_and_rel(full_path)
|
||||
exists_func = getattr(adapter, "exists", None)
|
||||
if not overwrite and callable(exists_func):
|
||||
try:
|
||||
@@ -212,6 +228,21 @@ class VirtualFSRouteMixin(VirtualFSTempLinkMixin):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
upload_func = getattr(adapter, "write_upload_file", None)
|
||||
if callable(upload_func):
|
||||
try:
|
||||
await file.seek(0)
|
||||
except Exception:
|
||||
pass
|
||||
size_hint = getattr(file, "size", None)
|
||||
if not isinstance(size_hint, int):
|
||||
size_hint = None
|
||||
filename = file.filename or (rel.rsplit("/", 1)[-1] if rel else "file")
|
||||
result = await upload_func(root, rel, file.file, filename, size_hint, file.content_type)
|
||||
final_path, size = cls._normalize_written_result(full_path, adapter_model, result, size_hint or 0)
|
||||
await TaskService.trigger_tasks("file_written", final_path)
|
||||
return {"uploaded": True, "path": final_path, "size": size, "overwrite": overwrite}
|
||||
|
||||
async def gen():
|
||||
while True:
|
||||
chunk = await file.read(chunk_size)
|
||||
@@ -219,8 +250,17 @@ class VirtualFSRouteMixin(VirtualFSTempLinkMixin):
|
||||
break
|
||||
yield chunk
|
||||
|
||||
size = await cls.write_file_stream(full_path, gen(), overwrite=overwrite)
|
||||
return {"uploaded": True, "path": full_path, "size": size, "overwrite": overwrite}
|
||||
result = await cls.write_file_stream(full_path, gen(), overwrite=overwrite)
|
||||
path = full_path
|
||||
size = 0
|
||||
if isinstance(result, dict):
|
||||
path = result.get("path") or path
|
||||
size_val = result.get("size")
|
||||
if isinstance(size_val, int):
|
||||
size = size_val
|
||||
else:
|
||||
size = int(result or 0)
|
||||
return {"uploaded": True, "path": path, "size": size, "overwrite": overwrite}
|
||||
|
||||
@classmethod
|
||||
async def list_directory(cls, full_path: str, page_num: int, page_size: int, sort_by: str, sort_order: str):
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
from .search_service import VirtualFSSearchService
|
||||
|
||||
__all__ = ["VirtualFSSearchService"]
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
|
||||
from api.response import success
|
||||
from domain.auth.service import get_current_active_user
|
||||
from domain.auth.types import User
|
||||
from domain.virtual_fs.search.search_service import VirtualFSSearchService
|
||||
from domain.auth import User, get_current_active_user
|
||||
from domain.permission.service import PermissionService
|
||||
from domain.permission.types import PathAction
|
||||
from .search_service import VirtualFSSearchService
|
||||
|
||||
router = APIRouter(prefix="/api/fs/search", tags=["search"])
|
||||
|
||||
@@ -25,4 +26,14 @@ async def search_files(
|
||||
page_size = max(min(page_size, 100), 1)
|
||||
|
||||
data = await VirtualFSSearchService.search(q, top_k, mode, page, page_size)
|
||||
items = data.get("items") if isinstance(data, dict) else None
|
||||
if isinstance(items, list) and items:
|
||||
filtered = []
|
||||
for item in items:
|
||||
path = getattr(item, "path", None)
|
||||
if not path:
|
||||
continue
|
||||
if await PermissionService.check_path_permission(user.id, str(path), PathAction.READ):
|
||||
filtered.append(item)
|
||||
data["items"] = filtered
|
||||
return success(data)
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from domain.virtual_fs.types import SearchResultItem
|
||||
from domain.ai.inference import get_text_embedding
|
||||
from domain.ai.service import VectorDBService, VECTOR_COLLECTION_NAME, FILE_COLLECTION_NAME
|
||||
from domain.ai import FILE_COLLECTION_NAME, VECTOR_COLLECTION_NAME, VectorDBService, get_text_embedding
|
||||
from ..types import SearchResultItem
|
||||
|
||||
|
||||
def _normalize_result(raw: Dict[str, Any], source: str, fallback_score: float = 0.0) -> SearchResultItem:
|
||||
|
||||
@@ -18,4 +18,40 @@ class VirtualFSService(
|
||||
VirtualFSResolverMixin,
|
||||
VirtualFSCommonMixin,
|
||||
):
|
||||
pass
|
||||
@classmethod
|
||||
async def list_directory(
|
||||
cls,
|
||||
path: str,
|
||||
page_num: int = 1,
|
||||
page_size: int = 50,
|
||||
sort_by: str = "name",
|
||||
sort_order: str = "asc",
|
||||
):
|
||||
"""列出目录内容"""
|
||||
return await cls.list_virtual_dir(path, page_num, page_size, sort_by, sort_order)
|
||||
|
||||
@classmethod
|
||||
async def list_directory_with_permission(
|
||||
cls,
|
||||
path: str,
|
||||
user_id: int,
|
||||
page_num: int = 1,
|
||||
page_size: int = 50,
|
||||
sort_by: str = "name",
|
||||
sort_order: str = "asc",
|
||||
):
|
||||
"""列出目录内容(带权限过滤)"""
|
||||
full_path = cls._normalize_path(path).rstrip("/") or "/"
|
||||
result = await cls.list_virtual_dir_with_permission(
|
||||
full_path, user_id, page_num, page_size, sort_by, sort_order
|
||||
)
|
||||
return {
|
||||
"path": full_path,
|
||||
"entries": result.get("items", []) if isinstance(result, dict) else [],
|
||||
"pagination": {
|
||||
"total": result.get("total", 0) if isinstance(result, dict) else 0,
|
||||
"page": result.get("page", page_num) if isinstance(result, dict) else page_num,
|
||||
"page_size": result.get("page_size", page_size) if isinstance(result, dict) else page_size,
|
||||
"pages": result.get("pages", 0) if isinstance(result, dict) else 0,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@ import time
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from domain.config.service import ConfigService
|
||||
from domain.config import ConfigService
|
||||
|
||||
from .processing import VirtualFSProcessingMixin
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user