Compare commits

...

54 Commits
v1.5.5 ... v1

Author SHA1 Message Date
shiyu
984b7a74ae fix: update Divider component in AppSettingsTab to use titlePlacement prop 2026-02-09 13:00:43 +08:00
shiyu
97a3c58f0f feat: update SystemSettingsPage to remove AuthSettingsTab and enhance AppSettingsTab with registration settings 2026-02-09 12:46:46 +08:00
shiyu
451e8555d5 feat: add permission decorator to enhance API access control 2026-02-09 12:32:25 +08:00
shiyu
f444ec46cc chore: remove migrate directory from .gitignore 2026-02-09 11:18:19 +08:00
shiyu
103beb7dad refactor: remove Permission model and update related code to use permission codes 2026-02-09 11:15:01 +08:00
shiyu
c5e4b3ef43 feat: add user and role management pages with authentication settings
- Implemented AuthSettingsTab for managing authentication settings including user registration and default roles.
- Created UsersPage for managing users and roles, including user creation, editing, and deletion functionalities.
- Added components for user and role management: UserEditorDrawer, RoleEditorDrawer, UsersTable, RolesTable, and PathRuleEditorDrawer.
- Introduced QuickCreateRoleModal for quick role creation within user management.
- Implemented permission management within roles, including path rules and user assignments.
- Enhanced user experience with loading states and error handling in API interactions.
2026-02-01 19:25:17 +08:00
dependabot[bot]
4014a4dd74 chore(deps): bump antd from 6.1.3 to 6.2.2 in /web (#102)
Bumps [antd](https://github.com/ant-design/ant-design) from 6.1.3 to 6.2.2.
- [Release notes](https://github.com/ant-design/ant-design/releases)
- [Changelog](https://github.com/ant-design/ant-design/blob/master/CHANGELOG.en-US.md)
- [Commits](https://github.com/ant-design/ant-design/compare/6.1.3...6.2.2)

---
updated-dependencies:
- dependency-name: antd
  dependency-version: 6.2.2
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-02-01 12:50:50 +08:00
dependabot[bot]
d0c6e1882f chore(deps): bump react-dom from 19.2.3 to 19.2.4 in /web (#98)
Bumps [react-dom](https://github.com/facebook/react/tree/HEAD/packages/react-dom) from 19.2.3 to 19.2.4.
- [Release notes](https://github.com/facebook/react/releases)
- [Changelog](https://github.com/facebook/react/blob/main/CHANGELOG.md)
- [Commits](https://github.com/facebook/react/commits/v19.2.4/packages/react-dom)

---
updated-dependencies:
- dependency-name: react-dom
  dependency-version: 19.2.4
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-02-01 12:50:39 +08:00
dependabot[bot]
434715fc8b chore(deps): bump pillow from 12.0.0 to 12.1.0 (#99)
Bumps [pillow](https://github.com/python-pillow/Pillow) from 12.0.0 to 12.1.0.
- [Release notes](https://github.com/python-pillow/Pillow/releases)
- [Changelog](https://github.com/python-pillow/Pillow/blob/main/CHANGES.rst)
- [Commits](https://github.com/python-pillow/Pillow/compare/12.0.0...12.1.0)

---
updated-dependencies:
- dependency-name: pillow
  dependency-version: 12.1.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-02-01 11:21:31 +08:00
dependabot[bot]
a127987a3f chore(deps): bump react and @types/react in /web (#100)
Bumps [react](https://github.com/facebook/react/tree/HEAD/packages/react) and [@types/react](https://github.com/DefinitelyTyped/DefinitelyTyped/tree/HEAD/types/react). These dependencies needed to be updated together.

Updates `react` from 19.2.3 to 19.2.4
- [Release notes](https://github.com/facebook/react/releases)
- [Changelog](https://github.com/facebook/react/blob/main/CHANGELOG.md)
- [Commits](https://github.com/facebook/react/commits/v19.2.4/packages/react)

Updates `@types/react` from 19.2.7 to 19.2.10
- [Release notes](https://github.com/DefinitelyTyped/DefinitelyTyped/releases)
- [Commits](https://github.com/DefinitelyTyped/DefinitelyTyped/commits/HEAD/types/react)

---
updated-dependencies:
- dependency-name: react
  dependency-version: 19.2.4
  dependency-type: direct:production
  update-type: version-update:semver-patch
- dependency-name: "@types/react"
  dependency-version: 19.2.10
  dependency-type: direct:development
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-02-01 11:21:19 +08:00
dependabot[bot]
edf95e897d chore(deps): bump react-router from 7.11.0 to 7.13.0 in /web (#101)
Bumps [react-router](https://github.com/remix-run/react-router/tree/HEAD/packages/react-router) from 7.11.0 to 7.13.0.
- [Release notes](https://github.com/remix-run/react-router/releases)
- [Changelog](https://github.com/remix-run/react-router/blob/main/packages/react-router/CHANGELOG.md)
- [Commits](https://github.com/remix-run/react-router/commits/react-router@7.13.0/packages/react-router)

---
updated-dependencies:
- dependency-name: react-router
  dependency-version: 7.13.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-02-01 11:21:13 +08:00
dependabot[bot]
b72f8152b6 chore(deps-dev): bump typescript-eslint from 8.51.0 to 8.54.0 in /web (#103)
Bumps [typescript-eslint](https://github.com/typescript-eslint/typescript-eslint/tree/HEAD/packages/typescript-eslint) from 8.51.0 to 8.54.0.
- [Release notes](https://github.com/typescript-eslint/typescript-eslint/releases)
- [Changelog](https://github.com/typescript-eslint/typescript-eslint/blob/main/packages/typescript-eslint/CHANGELOG.md)
- [Commits](https://github.com/typescript-eslint/typescript-eslint/commits/v8.54.0/packages/typescript-eslint)

---
updated-dependencies:
- dependency-name: typescript-eslint
  dependency-version: 8.54.0
  dependency-type: direct:development
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-02-01 11:20:55 +08:00
dependabot[bot]
aacddb1208 chore(deps): bump pyjwt from 2.10.1 to 2.11.0 (#97)
Bumps [pyjwt](https://github.com/jpadilla/pyjwt) from 2.10.1 to 2.11.0.
- [Release notes](https://github.com/jpadilla/pyjwt/releases)
- [Changelog](https://github.com/jpadilla/pyjwt/blob/master/CHANGELOG.rst)
- [Commits](https://github.com/jpadilla/pyjwt/compare/2.10.1...2.11.0)

---
updated-dependencies:
- dependency-name: pyjwt
  dependency-version: 2.11.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-02-01 11:20:34 +08:00
shiyu
1d6d793f7a feat: add destroyOnHidden property to SearchDialog modal #96 2026-01-31 23:34:08 +08:00
shiyu
d9d2ddf2d1 feat: update vector DB provider handling and improve setup page configuration 2026-01-31 21:59:18 +08:00
shiyu
e6ab01ef9d feat: add user and role management pages with API integration
- Implemented user management functionality in UsersPage including user creation, editing, deletion, and role assignment.
- Added role management functionality in RolesPage with role creation, editing, deletion, and path rule management.
- Created users API for handling user-related operations.
- Created roles API for handling role-related operations.
- Integrated permissions handling in both user and role management.
- Enhanced UI with Ant Design components for better user experience.
2026-01-30 15:59:22 +08:00
dependabot[bot]
4a2e01196d chore(deps): bump the uv group across 1 directory with 2 updates (#95)
Bumps the uv group with 2 updates in the / directory: [python-multipart](https://github.com/Kludex/python-multipart) and [protobuf](https://github.com/protocolbuffers/protobuf).


Updates `python-multipart` from 0.0.21 to 0.0.22
- [Release notes](https://github.com/Kludex/python-multipart/releases)
- [Changelog](https://github.com/Kludex/python-multipart/blob/master/CHANGELOG.md)
- [Commits](https://github.com/Kludex/python-multipart/compare/0.0.21...0.0.22)

Updates `protobuf` from 6.33.2 to 7.34.0rc1
- [Release notes](https://github.com/protocolbuffers/protobuf/releases)
- [Commits](https://github.com/protocolbuffers/protobuf/commits)

---
updated-dependencies:
- dependency-name: python-multipart
  dependency-version: 0.0.22
  dependency-type: direct:production
  dependency-group: uv
- dependency-name: protobuf
  dependency-version: 7.34.0rc1
  dependency-type: indirect
  dependency-group: uv
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-01-27 10:58:19 +08:00
shiyu
f22ca62902 feat: enhance directory processing with task queuing and input handling 2026-01-20 11:34:09 +08:00
shiyu
a394ffa46b feat: implement double-click navigation and click timer for breadcrumb items 2026-01-20 10:34:14 +08:00
shiyu
d003e53a3a feat: add tools for web fetching 2026-01-20 10:17:39 +08:00
dependabot[bot]
060a427fe4 chore(deps): bump the uv group across 1 directory with 2 updates (#93)
---
updated-dependencies:
- dependency-name: aiohttp
  dependency-version: 3.13.3
  dependency-type: indirect
  dependency-group: uv
- dependency-name: urllib3
  dependency-version: 2.6.3
  dependency-type: indirect
  dependency-group: uv
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-01-19 19:59:46 +08:00
shiyu
f4c18f991f chore: update version to v1.7.4 2026-01-19 16:50:36 +08:00
shiyu
58c2cdd440 feat: enforce simultaneous username and password requirement for alist and openlist adapters 2026-01-19 15:58:12 +08:00
dependabot[bot]
7d861ca5f7 chore(deps): bump pyasn1 in the uv group across 1 directory (#92)
Bumps the uv group with 1 update in the / directory: [pyasn1](https://github.com/pyasn1/pyasn1).


Updates `pyasn1` from 0.6.1 to 0.6.2
- [Release notes](https://github.com/pyasn1/pyasn1/releases)
- [Changelog](https://github.com/pyasn1/pyasn1/blob/main/CHANGES.rst)
- [Commits](https://github.com/pyasn1/pyasn1/compare/v0.6.1...v0.6.2)

---
updated-dependencies:
- dependency-name: pyasn1
  dependency-version: 0.6.2
  dependency-type: indirect
  dependency-group: uv
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-01-19 10:06:53 +08:00
shiyu
52bac11760 feat: add create file functionality with modal and context menu integration 2026-01-18 21:31:01 +08:00
shiyu
c441d8776f feat: enhance backup functionality with section selection and import mode options 2026-01-18 21:01:59 +08:00
shiyu
45e0194465 chore: update version to v1.7.3 2026-01-18 18:18:51 +08:00
shiyu
540065f195 feat: implement write_upload_file method for various adapters to handle file uploads 2026-01-18 18:14:04 +08:00
shiyu
4f86e2da4d feat: enhance file upload handling and response normalization in virtual file system 2026-01-18 15:14:25 +08:00
shiyu
31d347d24f feat: add support for filename in public file access and update temp link generation 2026-01-16 20:55:03 +08:00
shiyu
7a9a20509c feat: update system prompt to adjust response language based on user input 2026-01-16 16:29:16 +08:00
shiyu
373b6410c2 feat: add time tool with offset support and update localization for weekday 2026-01-16 15:46:42 +08:00
shiyu
d6eb6e1605 feat: replace Drawer with Modal in AiAgentWidget and enhance styles for better UI 2026-01-16 15:05:53 +08:00
shiyu
1d66fb56c8 feat: update logo.svg 2026-01-16 14:52:53 +08:00
shiyu
bb9589fa62 chore: update version to v1.7.2 in service configuration 2026-01-16 11:20:19 +08:00
shiyu
ab89451b2d feat: implement cron-based automation task scheduling and update task configuration 2026-01-15 15:04:10 +08:00
shiyu
3e1b75d81a feat: add notices feature with modal and API integration 2026-01-14 22:01:29 +08:00
shiyu
1679b03d3a chore: update version to v1.7.1 in service configuration 2026-01-12 10:24:18 +08:00
shiyu
ab6562fc79 feat: add new AI models and improve UI layout in settings 2026-01-11 23:18:54 +08:00
shiyu
87770176b6 feat: expand AI provider support and update descriptions
- Updated AIProviderBase and AIProviderUpdate to support new API formats: 'anthropic' and 'ollama'.
- Added SVG icons for Anthropic, Azure, Ollama, and Z.ai providers.
- Updated AI provider payload interface to include new formats.
- Enhanced English and Chinese localization for new providers and updated descriptions for OpenAI and Anthropic.
- Added new provider templates for Azure OpenAI, Anthropic, Z.ai, and Ollama in the settings tab.
- Updated the API format selection in the settings tab to include new options.
2026-01-11 22:29:22 +08:00
shiyu
e7cf8dbdb8 chore: update version to v1.7.0 in service configuration 2026-01-11 14:09:29 +08:00
shiyu
e7eafdee97 feat: add session locking mechanism in Telegram adapter and improve SPA fallback handling 2026-01-11 14:08:52 +08:00
shiyu
051b49d3f6 feat: improve error handling in propfind function and enhance directory listing logic 2026-01-11 13:32:48 +08:00
shiyu
b059b0eb44 feat: enhance Telegram adapter to support parsing legacy session_string and fetching thumbnails 2026-01-11 11:20:10 +08:00
shiyu
59ad2cb622 feat: update AIProvider structure to include has_api_key and adjust API key handling in settings 2026-01-10 13:22:07 +08:00
shiyu
6b2ada0b42 refactor: imports and reorganize domain structure
- Updated import statements across multiple modules to use relative imports for better encapsulation.
- Consolidated and organized the `__init__.py` files in various domain packages to expose necessary classes and functions.
- Improved code readability and maintainability by grouping related imports and removing unused ones.
- Ensured consistent import patterns across the domain, enhancing the overall structure of the codebase.
2026-01-09 17:28:10 +08:00
时雨
a727e77341 feat: Implement AI Agent with enhanced tool processing capabilities (#89)
* feat: Implement AI Agent with tool processing capabilities

- Added tools for listing and running processors in the agent.
- Created data models for agent chat requests and tool calls.
- Developed API integration for agent chat and streaming responses.
- Built the AI Agent widget with a user interface for interaction.
- Styled the agent components for better user experience.

* feat: 增强 AI 助手工具功能,添加文件操作和搜索功能,更新界面显示

* feat: 更新 AI 助手组件

* feat: 更新 AiAgentWidget 组件样式,调整背景和边距以提升界面一致性
2026-01-09 16:19:20 +08:00
shiyu
4638356a45 chore: update version to v1.6.1 2026-01-08 12:20:26 +08:00
shiyu
e51344b43e feat: enhance plugin frame URL building and improve query handling for plugin styles and entry 2026-01-08 11:34:38 +08:00
shiyu
b7685db0e8 feat: add versioning support for plugin assets and improve loading status handling 2026-01-08 10:13:09 +08:00
shiyu
4e16de973c feat: add search functionality to fetchFoxelCoreApps and enhance PluginsPage with query handling 2026-01-06 21:18:26 +08:00
shiyu
4dd0a4b1d6 chore: update version to v1.6.0 2026-01-06 18:02:01 +08:00
shiyu
5703825c31 fix: adjust grid column size for better layout in PluginsPage and ai-settings 2026-01-06 17:07:30 +08:00
时雨
24255744df feat: enhance plugin functionality 2026-01-06 16:54:49 +08:00
201 changed files with 13395 additions and 4486 deletions

1
.gitignore vendored
View File

@@ -5,7 +5,6 @@ __pycache__/
.venv/
.vscode/
data/
migrate/
.env
AGENTS.md

View File

@@ -11,10 +11,14 @@ from domain.processors import api as processors
from domain.share import api as share
from domain.tasks import api as tasks
from domain.ai import api as ai
from domain.agent import api as agent
from domain.virtual_fs import api as virtual_fs
from domain.virtual_fs.mapping import s3_api, webdav_api
from domain.virtual_fs.search import search_api
from domain.audit import router as audit
from domain.audit import api as audit
from domain.permission import api as permission
from domain.user import api as user
from domain.role import api as role
def include_routers(app: FastAPI):
@@ -30,9 +34,13 @@ def include_routers(app: FastAPI):
app.include_router(backup.router)
app.include_router(ai.router_vector_db)
app.include_router(ai.router_ai)
app.include_router(agent.router)
app.include_router(plugins.router)
app.include_router(webdav_api.router)
app.include_router(s3_api.router)
app.include_router(offline_downloads.router)
app.include_router(email.router)
app.include_router(audit)
app.include_router(audit.router)
app.include_router(permission.router)
app.include_router(user.router)
app.include_router(role.router)

View File

@@ -1,6 +1,6 @@
from tortoise import Tortoise
from domain.adapters.registry import runtime_registry
from domain.adapters import runtime_registry
TORTOISE_ORM = {
"connections": {"default": "sqlite://data/db/db.sqlite3"},

7
domain/__init__.py Normal file
View File

@@ -0,0 +1,7 @@
"""
domain业务域层
约定:跨包只从各子包 `__init__.py` 导入公开 API。
"""
__all__: list[str] = []

View File

@@ -1 +1,24 @@
from .providers import BaseAdapter
from .registry import (
RuntimeRegistry,
discover_adapters,
get_config_schema,
get_config_schemas,
normalize_adapter_type,
runtime_registry,
)
from .service import AdapterService
from .types import AdapterCreate, AdapterOut
__all__ = [
"BaseAdapter",
"RuntimeRegistry",
"discover_adapters",
"get_config_schema",
"get_config_schemas",
"normalize_adapter_type",
"runtime_registry",
"AdapterService",
"AdapterCreate",
"AdapterOut",
]

View File

@@ -4,10 +4,11 @@ from fastapi import APIRouter, Depends, Request
from api.response import success
from domain.audit import AuditAction, audit
from domain.adapters.service import AdapterService
from domain.adapters.types import AdapterCreate
from domain.auth.service import get_current_active_user
from domain.auth.types import User
from domain.auth import User, get_current_active_user
from domain.permission import require_system_permission
from domain.permission.types import AdapterPermission
from .service import AdapterService
from .types import AdapterCreate
router = APIRouter(prefix="/api/adapters", tags=["adapters"])
@@ -18,6 +19,7 @@ router = APIRouter(prefix="/api/adapters", tags=["adapters"])
description="创建存储适配器",
body_fields=["name", "type", "path", "sub_path", "enabled"],
)
@require_system_permission(AdapterPermission.CREATE)
async def create_adapter(
request: Request,
data: AdapterCreate,
@@ -29,6 +31,7 @@ async def create_adapter(
@router.get("")
@audit(action=AuditAction.READ, description="获取适配器列表")
@require_system_permission(AdapterPermission.LIST)
async def list_adapters(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)]
@@ -39,6 +42,7 @@ async def list_adapters(
@router.get("/available")
@audit(action=AuditAction.READ, description="获取可用适配器类型")
@require_system_permission(AdapterPermission.LIST)
async def available_adapter_types(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)]
@@ -49,6 +53,7 @@ async def available_adapter_types(
@router.get("/{adapter_id}")
@audit(action=AuditAction.READ, description="获取适配器详情")
@require_system_permission(AdapterPermission.LIST)
async def get_adapter(
request: Request,
adapter_id: int,
@@ -64,6 +69,7 @@ async def get_adapter(
description="更新存储适配器",
body_fields=["name", "type", "path", "sub_path", "enabled"],
)
@require_system_permission(AdapterPermission.EDIT)
async def update_adapter(
request: Request,
adapter_id: int,
@@ -76,6 +82,7 @@ async def update_adapter(
@router.delete("/{adapter_id}")
@audit(action=AuditAction.DELETE, description="删除存储适配器")
@require_system_permission(AdapterPermission.DELETE)
async def delete_adapter(
request: Request,
adapter_id: int,

View File

@@ -81,8 +81,9 @@ class AListApiAdapterBase:
raise ValueError(f"{product_name} requires base_url http/https")
self.username: str = str(cfg.get("username") or "")
self.password: str = str(cfg.get("password") or "")
if not self.username or not self.password:
raise ValueError(f"{product_name} requires username and password")
if (self.username and not self.password) or (self.password and not self.username):
raise ValueError(f"{product_name} requires both username and password")
self.use_auth: bool = bool(self.username and self.password)
self.timeout: float = float(cfg.get("timeout", 30))
self.root_path: str = _normalize_fs_path(str(cfg.get("root") or "/"))
@@ -98,6 +99,8 @@ class AListApiAdapterBase:
return base
async def _ensure_token(self) -> str:
if not self.use_auth:
return ""
if self._token:
return self._token
async with self._login_lock:
@@ -137,12 +140,14 @@ class AListApiAdapterBase:
) -> Any:
token = await self._ensure_token()
url = self.base_url + endpoint
req_headers: Dict[str, str] = {"Authorization": token}
req_headers: Dict[str, str] = {}
if token:
req_headers["Authorization"] = token
if headers:
req_headers.update(headers)
async with httpx.AsyncClient(timeout=self.timeout, follow_redirects=True) as client:
resp = await client.request(method, url, json=json, headers=req_headers, files=files)
if resp.status_code == 401 and retry:
if resp.status_code == 401 and retry and self.use_auth:
self._token = None
return await self._api_json(method, endpoint, json=json, headers=headers, retry=False, files=files)
resp.raise_for_status()
@@ -153,7 +158,7 @@ class AListApiAdapterBase:
code = payload.get("code")
if code in (0, 200):
return payload.get("data")
if code in (401, 403) and retry:
if code in (401, 403) and retry and self.use_auth:
self._token = None
return await self._api_json(method, endpoint, json=json, headers=headers, retry=False, files=files)
if code == 404:
@@ -349,10 +354,9 @@ class AListApiAdapterBase:
async def _upload_file(self, full_path: str, file_path: Path) -> Any:
token = await self._ensure_token()
headers = {
"Authorization": token,
"File-Path": quote(full_path, safe="/"),
}
headers = {"File-Path": quote(full_path, safe="/")}
if token:
headers["Authorization"] = token
with file_path.open("rb") as f:
files = {"file": (file_path.name, f, "application/octet-stream")}
async with httpx.AsyncClient(timeout=self.timeout, follow_redirects=True) as client:
@@ -381,6 +385,30 @@ class AListApiAdapterBase:
except Exception:
pass
async def write_upload_file(self, root: str, rel: str, file_obj, filename: str | None, file_size: int | None = None, content_type: str | None = None):
full_path = _join_fs_path(root, rel)
token = await self._ensure_token()
headers = {"File-Path": quote(full_path, safe="/")}
if token:
headers["Authorization"] = token
name = filename or Path(rel).name or "file"
mime = content_type or "application/octet-stream"
files = {"file": (name, file_obj, mime)}
async with httpx.AsyncClient(timeout=self.timeout, follow_redirects=True) as client:
resp = await client.put(self.base_url + "/api/fs/form", headers=headers, files=files)
resp.raise_for_status()
payload = resp.json()
if not isinstance(payload, dict):
raise HTTPException(502, detail=f"{self.product_name} upload: invalid response")
code = payload.get("code")
if code not in (0, 200):
msg = payload.get("message") or payload.get("msg") or ""
raise HTTPException(502, detail=f"{self.product_name} upload failed: {msg}")
data = payload.get("data")
if isinstance(data, dict) and file_size is not None and "size" not in data:
data["size"] = file_size
return data
async def write_file_stream(self, root: str, rel: str, data_iter: AsyncIterator[bytes]):
full_path = _join_fs_path(root, rel)
suffix = Path(rel).suffix
@@ -479,8 +507,8 @@ ADAPTER_TYPES = {"alist": AListAdapter, "openlist": OpenListAdapter}
CONFIG_SCHEMA = [
{"key": "base_url", "label": "基础地址", "type": "string", "required": True, "placeholder": "http://127.0.0.1:5244"},
{"key": "username", "label": "用户名", "type": "string", "required": True},
{"key": "password", "label": "密码", "type": "password", "required": True},
{"key": "username", "label": "用户名", "type": "string", "required": False, "placeholder": "留空则匿名访问"},
{"key": "password", "label": "密码", "type": "password", "required": False, "placeholder": "留空则匿名访问"},
{"key": "root", "label": "根目录", "type": "string", "required": False, "default": "/"},
{"key": "timeout", "label": "超时(秒)", "type": "number", "required": False, "default": 30},
{"key": "enable_direct_download_307", "label": "启用 307 直链下载", "type": "boolean", "default": False},

View File

@@ -250,6 +250,30 @@ class FoxelAdapter:
return True
raise HTTPException(502, detail="Foxel 写入失败")
async def write_upload_file(self, root: str, rel: str, file_obj, filename: str | None, file_size: int | None = None, content_type: str | None = None):
rel = (rel or "").lstrip("/")
full_path = _join_fs_path(root, rel)
url = self.base_url + self._file_path(full_path)
name = filename or Path(rel).name or "file"
mime = content_type or "application/octet-stream"
for attempt in range(2):
try:
if callable(getattr(file_obj, "seek", None)):
file_obj.seek(0)
except Exception:
pass
token = await self._ensure_token()
headers = {"Authorization": f"Bearer {token}"}
files = {"file": (name, file_obj, mime)}
async with httpx.AsyncClient(timeout=self.timeout, follow_redirects=True) as client:
resp = await client.post(url, headers=headers, files=files)
if resp.status_code == 401 and attempt == 0:
self._token = None
continue
resp.raise_for_status()
return {"size": file_size or 0}
raise HTTPException(502, detail="Foxel 上传失败")
async def write_file_stream(self, root: str, rel: str, data_iter: AsyncIterator[bytes]):
rel = (rel or "").lstrip("/")
full_path = _join_fs_path(root, rel)

View File

@@ -238,6 +238,39 @@ class FTPAdapter:
await asyncio.to_thread(_do_write)
async def write_upload_file(self, root: str, rel: str, file_obj, filename: str | None, file_size: int | None = None, content_type: str | None = None):
path = _join_remote(root, rel)
def _ensure_dirs(ftp: FTP, dir_path: str):
parts = [p for p in dir_path.strip("/").split("/") if p]
cur = "/"
for p in parts:
cur = _join_remote(cur, p)
try:
ftp.mkd(cur)
except Exception:
pass
def _do_upload():
ftp = self._connect()
try:
parent = "/" if "/" not in path.strip("/") else path.rsplit("/", 1)[0]
_ensure_dirs(ftp, parent)
try:
if callable(getattr(file_obj, "seek", None)):
file_obj.seek(0)
except Exception:
pass
ftp.storbinary("STOR " + path, file_obj)
finally:
try:
ftp.quit()
except Exception:
pass
await asyncio.to_thread(_do_upload)
return {"size": file_size or 0}
async def write_file_stream(self, root: str, rel: str, data_iter: AsyncIterator[bytes]):
# KISS: 聚合后一次性写入
buf = bytearray()

View File

@@ -114,6 +114,32 @@ class LocalAdapter:
if not pre_exists:
await asyncio.to_thread(_apply_mode, fp, DEFAULT_FILE_MODE)
async def write_upload_file(self, root: str, rel: str, file_obj, filename: str | None, file_size: int | None = None, content_type: str | None = None):
fp = _safe_join(root, rel)
pre_exists = fp.exists()
await asyncio.to_thread(os.makedirs, fp.parent, mode=DEFAULT_DIR_MODE, exist_ok=True)
def _copy():
try:
if callable(getattr(file_obj, "seek", None)):
file_obj.seek(0)
except Exception:
pass
with open(fp, "wb") as f:
shutil.copyfileobj(file_obj, f)
await asyncio.to_thread(_copy)
if not pre_exists:
await asyncio.to_thread(_apply_mode, fp, DEFAULT_FILE_MODE)
size = file_size
if size is None:
try:
size = fp.stat().st_size
except Exception:
size = 0
return {"size": int(size or 0)}
async def write_file_stream(self, root: str, rel: str, data_iter: AsyncIterator[bytes]):
fp = _safe_join(root, rel)
pre_exists = fp.exists()

View File

@@ -453,6 +453,159 @@ class QuarkAdapter:
yield data
return await self.write_file_stream(root, rel, gen())
async def write_upload_file(self, root: str, rel: str, file_obj, filename: str | None, file_size: int | None = None, content_type: str | None = None):
if not rel or rel.endswith("/"):
raise HTTPException(400, detail="Invalid file path")
parent = rel.rsplit("/", 1)[0] if "/" in rel else ""
name = filename or rel.rsplit("/", 1)[-1]
base_fid = root or self.root_fid
parent_fid = await self._resolve_dir_fid_from(base_fid, parent)
md5 = hashlib.md5()
sha1 = hashlib.sha1()
total = 0
try:
if callable(getattr(file_obj, "seek", None)):
file_obj.seek(0)
except Exception:
pass
while True:
chunk = file_obj.read(1024 * 1024)
if not chunk:
break
total += len(chunk)
md5.update(chunk)
sha1.update(chunk)
md5_hex = md5.hexdigest()
sha1_hex = sha1.hexdigest()
# 预上传,拿到上传信息
pre_resp = await self._upload_pre(name, total, parent_fid)
pre_data = pre_resp.get("data", {})
# hash 秒传
hash_body = {"md5": md5_hex, "sha1": sha1_hex, "task_id": pre_data.get("task_id")}
hash_resp = await self._request("POST", "/file/update/hash", json=hash_body)
if (hash_resp.get("data") or {}).get("finish") is True:
self._invalidate_children_cache(parent_fid)
return {"size": total}
# 分片上传
part_size = int((pre_resp.get("metadata") or {}).get("part_size") or 0)
if part_size <= 0:
raise HTTPException(502, detail="Invalid part_size from Quark")
bucket = pre_data.get("bucket")
obj_key = pre_data.get("obj_key")
upload_id = pre_data.get("upload_id")
upload_url = pre_data.get("upload_url")
if not (bucket and obj_key and upload_id and upload_url):
raise HTTPException(502, detail="Upload pre missing fields")
try:
upload_host = upload_url.split("://", 1)[1]
except Exception:
upload_host = upload_url
base_url = f"https://{bucket}.{upload_host}/{obj_key}"
try:
if callable(getattr(file_obj, "seek", None)):
file_obj.seek(0)
except Exception:
pass
etags: List[str] = []
oss_ua = "aliyun-sdk-js/6.6.1 Chrome 98.0.4758.80 on Windows 10 64-bit"
async with httpx.AsyncClient(timeout=None, follow_redirects=True) as client:
part_number = 1
left = total
while left > 0:
sz = min(part_size, left)
data_bytes = file_obj.read(sz)
if len(data_bytes) != sz:
raise IOError("Failed to read part bytes")
now_str = time.strftime("%a, %d %b %Y %H:%M:%S GMT", time.gmtime())
auth_meta = (
"PUT\n\n"
f"{self._guess_mime(name)}\n"
f"{now_str}\n"
f"x-oss-date:{now_str}\n"
f"x-oss-user-agent:{oss_ua}\n"
f"/{bucket}/{obj_key}?partNumber={part_number}&uploadId={upload_id}"
)
auth_req_body = {"auth_info": pre_data.get("auth_info"), "auth_meta": auth_meta, "task_id": pre_data.get("task_id")}
auth_resp = await self._request("POST", "/file/upload/auth", json=auth_req_body)
auth_key = (auth_resp.get("data") or {}).get("auth_key")
if not auth_key:
raise HTTPException(502, detail="upload/auth missing auth_key")
put_headers = {
"Authorization": auth_key,
"Content-Type": self._guess_mime(name),
"Referer": REFERER + "/",
"x-oss-date": now_str,
"x-oss-user-agent": oss_ua,
}
put_url = f"{base_url}?partNumber={part_number}&uploadId={upload_id}"
put_resp = await client.put(put_url, headers=put_headers, content=data_bytes)
if put_resp.status_code != 200:
raise HTTPException(502, detail=f"Upload part failed status={put_resp.status_code} text={put_resp.text}")
etag = put_resp.headers.get("Etag", "")
etags.append(etag)
left -= sz
part_number += 1
parts_xml = [f"<Part>\n<PartNumber>{i+1}</PartNumber>\n<ETag>{etags[i]}</ETag>\n</Part>\n" for i in range(len(etags))]
body_xml = "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<CompleteMultipartUpload>\n" + "".join(parts_xml) + "</CompleteMultipartUpload>"
content_md5 = base64.b64encode(hashlib.md5(body_xml.encode("utf-8")).digest()).decode("ascii")
callback = pre_data.get("callback") or {}
try:
import json as _json
callback_b64 = base64.b64encode(_json.dumps(callback).encode("utf-8")).decode("ascii")
except Exception:
callback_b64 = ""
now_str = time.strftime("%a, %d %b %Y %H:%M:%S GMT", time.gmtime())
auth_meta_commit = (
"POST\n"
f"{content_md5}\n"
"application/xml\n"
f"{now_str}\n"
f"x-oss-callback:{callback_b64}\n"
f"x-oss-date:{now_str}\n"
f"x-oss-user-agent:{oss_ua}\n"
f"/{bucket}/{obj_key}?uploadId={upload_id}"
)
auth_commit_resp = await self._request("POST", "/file/upload/auth", json={"auth_info": pre_data.get("auth_info"), "auth_meta": auth_meta_commit, "task_id": pre_data.get("task_id")})
auth_key_commit = (auth_commit_resp.get("data") or {}).get("auth_key")
if not auth_key_commit:
raise HTTPException(502, detail="upload/auth(commit) missing auth_key")
async with httpx.AsyncClient(timeout=None, follow_redirects=True) as client:
commit_headers = {
"Authorization": auth_key_commit,
"Content-MD5": content_md5,
"Content-Type": "application/xml",
"Referer": REFERER + "/",
"x-oss-callback": callback_b64,
"x-oss-date": now_str,
"x-oss-user-agent": oss_ua,
}
commit_url = f"{base_url}?uploadId={upload_id}"
r = await client.post(commit_url, headers=commit_headers, content=body_xml.encode("utf-8"))
if r.status_code != 200:
raise HTTPException(502, detail=f"Upload commit failed status={r.status_code} text={r.text}")
await self._request("POST", "/file/upload/finish", json={"obj_key": obj_key, "task_id": pre_data.get("task_id")})
try:
await asyncio.sleep(1.0)
except Exception:
pass
self._invalidate_children_cache(parent_fid)
return {"size": total}
async def write_file_stream(self, root: str, rel: str, data_iter: AsyncIterator[bytes]):
if not rel or rel.endswith("/"):
raise HTTPException(400, detail="Invalid file path")

View File

@@ -157,6 +157,41 @@ class SFTPAdapter:
await asyncio.to_thread(_do_write)
async def write_upload_file(self, root: str, rel: str, file_obj, filename: str | None, file_size: int | None = None, content_type: str | None = None):
path = _join_remote(root, rel)
def _ensure_dirs(sftp: paramiko.SFTPClient, dir_path: str):
parts = [p for p in dir_path.strip("/").split("/") if p]
cur = "/"
for p in parts:
cur = _join_remote(cur, p)
try:
sftp.mkdir(cur)
except IOError:
pass
def _do_upload():
sftp = self._connect()
try:
parent = "/" if "/" not in path.strip("/") else path.rsplit("/", 1)[0]
_ensure_dirs(sftp, parent)
try:
if callable(getattr(file_obj, "seek", None)):
file_obj.seek(0)
except Exception:
pass
with sftp.open(path, "wb") as f:
import shutil
shutil.copyfileobj(file_obj, f)
finally:
try:
sftp.close()
except Exception:
pass
await asyncio.to_thread(_do_upload)
return {"size": file_size or 0}
async def write_file_stream(self, root: str, rel: str, data_iter: AsyncIterator[bytes]):
buf = bytearray()
async for chunk in data_iter:

View File

@@ -1,11 +1,50 @@
from typing import List, Dict, Tuple, AsyncIterator
import asyncio
import base64
import io
import os
import struct
from models import StorageAdapter
from telethon import TelegramClient
from telethon.crypto import AuthKey
from telethon.sessions import StringSession
from telethon.tl import types
import socks
_SESSION_LOCKS: Dict[str, asyncio.Lock] = {}
def _get_session_lock(session_string: str) -> asyncio.Lock:
lock = _SESSION_LOCKS.get(session_string)
if lock is None:
lock = asyncio.Lock()
_SESSION_LOCKS[session_string] = lock
return lock
class _NamedFile:
def __init__(self, file_obj, name: str):
self._file = file_obj
self.name = name
def read(self, *args, **kwargs):
return self._file.read(*args, **kwargs)
def seek(self, *args, **kwargs):
return self._file.seek(*args, **kwargs)
def tell(self):
return self._file.tell()
def seekable(self):
return self._file.seekable()
def close(self):
return self._file.close()
def __getattr__(self, name):
return getattr(self._file, name)
# 适配器类型标识
ADAPTER_TYPE = "telegram"
@@ -54,9 +93,93 @@ class TelegramAdapter:
if not all([self.api_id, self.api_hash, self.session_string, self.chat_id]):
raise ValueError("Telegram 适配器需要 api_id, api_hash, session_string 和 chat_id")
@staticmethod
def _parse_legacy_session_string(value: str) -> StringSession:
"""
兼容旧版 session_string 格式:
- version(1B char) + base64(data)
- data: dc_id(1B) + ip_len(2B) + ip(ASCII, ip_len bytes) + port(2B) + auth_key(256B)
"""
s = (value or "").strip()
if not s:
raise ValueError("session_string 为空")
body = s[1:] if s.startswith("1") else s
raw = base64.urlsafe_b64decode(body)
if len(raw) < 1 + 2 + 2 + 256:
raise ValueError("legacy session 数据长度不足")
dc_id = raw[0]
ip_len = struct.unpack(">H", raw[1:3])[0]
expected_len = 1 + 2 + ip_len + 2 + 256
if len(raw) != expected_len:
raise ValueError("legacy session 数据长度不匹配")
ip_start = 3
ip_end = ip_start + ip_len
ip = raw[ip_start:ip_end].decode("utf-8")
port = struct.unpack(">H", raw[ip_end : ip_end + 2])[0]
key = raw[ip_end + 2 : ip_end + 2 + 256]
sess = StringSession()
sess.set_dc(dc_id, ip, port)
sess.auth_key = AuthKey(key)
return sess
@staticmethod
def _pick_photo_thumb(thumbs: list | None):
if not thumbs:
return None
cached = []
others = []
for t in thumbs:
if isinstance(t, (types.PhotoCachedSize, types.PhotoStrippedSize)):
cached.append(t)
elif isinstance(t, (types.PhotoSize, types.PhotoSizeProgressive)):
if not isinstance(t, types.PhotoSizeEmpty):
others.append(t)
if cached:
cached.sort(key=lambda x: len(getattr(x, "bytes", b"") or b""))
return cached[-1]
if others:
def _sz(x):
if isinstance(x, types.PhotoSizeProgressive):
return max(x.sizes or [0])
return int(getattr(x, "size", 0) or 0)
others.sort(key=_sz)
return others[-1]
return None
def _build_session(self) -> StringSession:
s = (self.session_string or "").strip()
if not s:
raise ValueError("Telegram 适配器 session_string 为空")
try:
return StringSession(s)
except Exception:
pass
# 少数工具可能去掉了 version 前缀,这里做一次兼容
if not s.startswith("1"):
try:
return StringSession("1" + s)
except Exception:
pass
try:
return self._parse_legacy_session_string(s)
except Exception as exc:
raise ValueError("Telegram session_string 无效,请使用 Telethon StringSession 重新生成") from exc
def _get_client(self) -> TelegramClient:
"""创建一个新的 TelegramClient 实例"""
return TelegramClient(StringSession(self.session_string), self.api_id, self.api_hash, proxy=self.proxy)
return TelegramClient(self._build_session(), self.api_id, self.api_hash, proxy=self.proxy)
def get_effective_root(self, sub_path: str | None) -> str:
return ""
@@ -164,7 +287,48 @@ class TelegramAdapter:
try:
await client.connect()
await client.send_file(self.chat_id, file_like, caption=file_like.name)
sent = await client.send_file(self.chat_id, file_like, caption=file_like.name)
message = sent[0] if isinstance(sent, list) and sent else sent
actual_rel = rel
if message:
stored_name = file_like.name
file_meta = getattr(message, "file", None)
if file_meta and getattr(file_meta, "name", None):
stored_name = file_meta.name
if getattr(message, "id", None) is not None:
actual_rel = f"{message.id}_{stored_name}"
return {"rel": actual_rel, "size": len(data)}
finally:
if client.is_connected():
await client.disconnect()
async def write_upload_file(self, root: str, rel: str, file_obj, filename: str | None, file_size: int | None = None, content_type: str | None = None):
client = self._get_client()
name = filename or os.path.basename(rel) or "file"
file_like = _NamedFile(file_obj, name)
try:
await client.connect()
sent = await client.send_file(
self.chat_id,
file_like,
caption=file_like.name,
file_size=file_size,
mime_type=content_type,
)
message = sent[0] if isinstance(sent, list) and sent else sent
actual_rel = rel
size = file_size or 0
if message:
stored_name = file_like.name
file_meta = getattr(message, "file", None)
if file_meta and getattr(file_meta, "name", None):
stored_name = file_meta.name
if getattr(message, "id", None) is not None:
actual_rel = f"{message.id}_{stored_name}"
if file_meta and getattr(file_meta, "size", None):
size = int(file_meta.size)
return {"rel": actual_rel, "size": size}
finally:
if client.is_connected():
await client.disconnect()
@@ -174,8 +338,9 @@ class TelegramAdapter:
client = self._get_client()
filename = os.path.basename(rel) or "file"
import tempfile
temp_dir = tempfile.gettempdir()
temp_path = os.path.join(temp_dir, filename)
suffix = os.path.splitext(filename)[1]
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tf:
temp_path = tf.name
total_size = 0
try:
@@ -186,18 +351,62 @@ class TelegramAdapter:
total_size += len(chunk)
await client.connect()
await client.send_file(self.chat_id, temp_path, caption=filename)
sent = await client.send_file(self.chat_id, temp_path, caption=filename)
message = sent[0] if isinstance(sent, list) and sent else sent
actual_rel = rel
if message:
stored_name = filename
file_meta = getattr(message, "file", None)
if file_meta and getattr(file_meta, "name", None):
stored_name = file_meta.name
if getattr(message, "id", None) is not None:
actual_rel = f"{message.id}_{stored_name}"
finally:
if os.path.exists(temp_path):
os.remove(temp_path)
if client.is_connected():
await client.disconnect()
return total_size
return {"rel": actual_rel, "size": total_size}
async def mkdir(self, root: str, rel: str):
raise NotImplementedError("Telegram 适配器不支持创建目录。")
async def get_thumbnail(self, root: str, rel: str, size: str = "medium"):
try:
message_id_str, _ = rel.split('_', 1)
message_id = int(message_id_str)
except (ValueError, IndexError):
return None
client = self._get_client()
try:
await client.connect()
message = await client.get_messages(self.chat_id, ids=message_id)
if not message:
return None
doc = message.document or message.video
thumbs = None
if doc and getattr(doc, "thumbs", None):
thumbs = list(doc.thumbs or [])
elif message.photo and getattr(message.photo, "sizes", None):
thumbs = list(message.photo.sizes or [])
thumb = self._pick_photo_thumb(thumbs)
if not thumb:
return None
result = await client.download_media(message, bytes, thumb=thumb)
if isinstance(result, (bytes, bytearray)):
return bytes(result)
return None
except Exception:
return None
finally:
if client.is_connected():
await client.disconnect()
async def delete(self, root: str, rel: str):
"""删除一个文件 (即一条消息)"""
try:
@@ -236,6 +445,8 @@ class TelegramAdapter:
raise HTTPException(status_code=400, detail=f"无效的文件路径格式: {rel}")
client = self._get_client()
lock = _get_session_lock(self.session_string)
await lock.acquire()
try:
await client.connect()
@@ -273,7 +484,6 @@ class TelegramAdapter:
headers = {
"Accept-Ranges": "bytes",
"Content-Type": mime_type,
"Content-Length": str(file_size),
}
if range_header:
@@ -285,7 +495,6 @@ class TelegramAdapter:
if start >= file_size or end >= file_size or start > end:
raise HTTPException(status_code=416, detail="Requested Range Not Satisfiable")
status = 206
headers["Content-Length"] = str(end - start + 1)
headers["Content-Range"] = f"bytes {start}-{end}/{file_size}"
except ValueError:
raise HTTPException(status_code=400, detail="Invalid Range header")
@@ -304,18 +513,28 @@ class TelegramAdapter:
if downloaded >= limit:
break
finally:
if client.is_connected():
await client.disconnect()
try:
if client.is_connected():
await client.disconnect()
finally:
lock.release()
return StreamingResponse(iterator(), status_code=status, headers=headers)
except HTTPException:
if client.is_connected():
await client.disconnect()
lock.release()
raise
except FileNotFoundError as e:
if client.is_connected():
await client.disconnect()
lock.release()
raise HTTPException(status_code=404, detail=str(e))
except Exception as e:
if client.is_connected():
await client.disconnect()
lock.release()
raise HTTPException(status_code=500, detail=f"Streaming failed: {str(e)}")
async def stat_file(self, root: str, rel: str):

View File

@@ -4,7 +4,7 @@ from importlib import import_module
from typing import Callable, Dict
from models import StorageAdapter
from domain.adapters.providers.base import BaseAdapter
from .providers.base import BaseAdapter
AdapterFactory = Callable[[StorageAdapter], BaseAdapter]
@@ -21,7 +21,7 @@ def normalize_adapter_type(value: str | None) -> str | None:
def discover_adapters():
"""扫描 domain.adapters.providers 包, 自动注册适配器类型、工厂与配置 schema。"""
from domain.adapters import providers as adapters_pkg
from . import providers as adapters_pkg
TYPE_MAP.clear()
CONFIG_SCHEMAS.clear()

View File

@@ -2,13 +2,13 @@ from typing import Optional
from fastapi import HTTPException
from domain.adapters.registry import (
from domain.auth import User
from .registry import (
get_config_schemas,
normalize_adapter_type,
runtime_registry,
)
from domain.adapters.types import AdapterCreate, AdapterOut
from domain.auth.types import User
from .types import AdapterCreate, AdapterOut
from models import StorageAdapter
@@ -36,6 +36,11 @@ class AdapterService:
missing.append(k)
if missing:
raise HTTPException(400, detail="缺少必填配置字段: " + ", ".join(missing))
if adapter_type in ("alist", "openlist"):
username = out.get("username")
password = out.get("password")
if (username and not password) or (password and not username):
raise HTTPException(400, detail="用户名和密码必须同时填写或同时留空")
return out
@classmethod

9
domain/agent/__init__.py Normal file
View File

@@ -0,0 +1,9 @@
from .service import AgentService
from .types import AgentChatContext, AgentChatRequest, PendingToolCall
__all__ = [
"AgentService",
"AgentChatContext",
"AgentChatRequest",
"PendingToolCall",
]

38
domain/agent/api.py Normal file
View File

@@ -0,0 +1,38 @@
from typing import Annotated
from fastapi import APIRouter, Depends, Request
from fastapi.responses import StreamingResponse
from api.response import success
from domain.audit import AuditAction, audit
from domain.auth import User, get_current_active_user
from .service import AgentService
from .types import AgentChatRequest
router = APIRouter(prefix="/api/agent", tags=["agent"])
@router.post("/chat")
@audit(action=AuditAction.CREATE, description="Agent 对话", body_fields=["auto_execute"])
async def chat(
request: Request,
payload: AgentChatRequest,
current_user: Annotated[User, Depends(get_current_active_user)],
):
data = await AgentService.chat(payload, current_user)
return success(data)
@router.post("/chat/stream")
@audit(action=AuditAction.CREATE, description="Agent 对话SSE", body_fields=["auto_execute"])
async def chat_stream(
request: Request,
payload: AgentChatRequest,
current_user: Annotated[User, Depends(get_current_active_user)],
):
return StreamingResponse(
AgentService.chat_stream(payload, current_user),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache"},
)

472
domain/agent/service.py Normal file
View File

@@ -0,0 +1,472 @@
import asyncio
import json
import uuid
from typing import Any, Dict, List, Optional, Tuple
import httpx
from fastapi import HTTPException
from domain.ai import AIProviderService, MissingModelError, chat_completion, chat_completion_stream
from domain.auth import User
from .tools import get_tool, openai_tools, tool_result_to_content
from .types import AgentChatRequest, PendingToolCall
def _normalize_path(p: Optional[str]) -> Optional[str]:
if not p:
return None
s = str(p).strip()
if not s:
return None
s = s.replace("\\", "/")
if not s.startswith("/"):
s = "/" + s
s = s.rstrip("/") or "/"
return s
def _build_system_prompt(current_path: Optional[str]) -> str:
lines = [
"你是 Foxel 的 AI 助手。",
"你可以通过工具对文件/目录进行查询、读写、移动、复制、删除以及运行处理器processor",
"",
"可用工具:",
"- time获取服务器当前时间精确到秒英文星期支持 year/month/day/hour/minute/second 偏移。",
"- web_fetch抓取网页HTTP 请求),支持 GET/POST/PUT/PATCH/DELETE/HEAD/OPTIONS返回状态/标题/正文/链接等。",
"- vfs_list_dir浏览目录列出 entries + pagination",
"- vfs_stat查看文件/目录信息。",
"- vfs_read_text读取文本文件内容不支持二进制",
"- vfs_search搜索文件vector/filename",
"- vfs_write_text写入文本文件内容覆盖",
"- vfs_mkdir创建目录。",
"- vfs_delete删除文件或目录。",
"- vfs_move移动路径。",
"- vfs_copy复制路径。",
"- vfs_rename重命名路径。",
"- processors_list获取可用处理器列表含 type/name/config_schema/produces_file/supports_directory",
"- processors_run运行处理器处理文件或目录会返回 task_id 或 task_ids",
"",
"规则:",
"1) 读操作web_fetch/vfs_list_dir/vfs_stat/vfs_read_text/vfs_search可直接调用工具。",
"2) 写/改/删操作vfs_write_text/vfs_mkdir/vfs_delete/vfs_move/vfs_copy/vfs_rename/processors_run默认需要用户确认只有在开启自动执行时才应直接执行。",
"3) 用户未给出明确路径时先追问;若提供了“当前文件管理目录”,可以基于它把相对描述补全为绝对路径(以 / 开头)。",
"4) 修改文件内容先读取vfs_read_text→给出改动点→确认后再写入vfs_write_text",
"5) processors_run 返回任务 id 后,说明任务已提交,可在任务队列查看进度。",
"6) 回答语言跟随用户;用户用英文则用英文,用户用中文则用中文。回答尽量简洁。",
]
if current_path:
lines.append("")
lines.append(f"当前文件管理目录:{current_path}")
return "\n".join(lines)
def _ensure_tool_call_ids(message: Dict[str, Any]) -> Dict[str, Any]:
tool_calls = message.get("tool_calls")
if not isinstance(tool_calls, list):
return message
changed = False
for idx, call in enumerate(tool_calls):
if not isinstance(call, dict):
continue
call_id = call.get("id")
if isinstance(call_id, str) and call_id.strip():
continue
call["id"] = f"call_{idx}"
changed = True
if changed:
message["tool_calls"] = tool_calls
return message
def _extract_pending(tool_call: Dict[str, Any], requires_confirmation: bool) -> PendingToolCall:
call_id = str(tool_call.get("id") or "")
fn = tool_call.get("function") or {}
name = str((fn.get("name") if isinstance(fn, dict) else None) or "")
raw_args = fn.get("arguments") if isinstance(fn, dict) else None
arguments: Dict[str, Any] = {}
if isinstance(raw_args, str) and raw_args.strip():
try:
parsed = json.loads(raw_args)
if isinstance(parsed, dict):
arguments = parsed
except json.JSONDecodeError:
arguments = {}
return PendingToolCall(
id=call_id,
name=name,
arguments=arguments,
requires_confirmation=requires_confirmation,
)
def _find_last_assistant_tool_calls(messages: List[Dict[str, Any]]) -> Tuple[int, Dict[str, Any]]:
for idx in range(len(messages) - 1, -1, -1):
msg = messages[idx]
if not isinstance(msg, dict):
continue
if msg.get("role") != "assistant":
continue
tool_calls = msg.get("tool_calls")
if isinstance(tool_calls, list) and tool_calls:
return idx, msg
raise HTTPException(status_code=400, detail="没有可确认的待执行操作")
def _existing_tool_result_ids(messages: List[Dict[str, Any]]) -> set[str]:
ids: set[str] = set()
for msg in messages:
if not isinstance(msg, dict):
continue
if msg.get("role") != "tool":
continue
tool_call_id = msg.get("tool_call_id")
if isinstance(tool_call_id, str) and tool_call_id.strip():
ids.add(tool_call_id)
return ids
async def _choose_chat_ability() -> str:
tools_model = await AIProviderService.get_default_model("tools")
return "tools" if tools_model else "chat"
def _sse(event: str, data: Any) -> bytes:
payload = json.dumps(data, ensure_ascii=False, separators=(",", ":"))
return f"event: {event}\ndata: {payload}\n\n".encode("utf-8")
def _format_exc(exc: BaseException) -> str:
text = str(exc)
return text if text else exc.__class__.__name__
class AgentService:
@classmethod
async def chat(cls, req: AgentChatRequest, user: Optional[User]) -> Dict[str, Any]:
history: List[Dict[str, Any]] = list(req.messages or [])
current_path = _normalize_path(req.context.current_path if req.context else None)
system_prompt = _build_system_prompt(current_path)
internal_messages: List[Dict[str, Any]] = [{"role": "system", "content": system_prompt}] + history
new_messages: List[Dict[str, Any]] = []
pending: List[PendingToolCall] = []
approved_ids = {i for i in (req.approved_tool_call_ids or []) if isinstance(i, str) and i.strip()}
rejected_ids = {i for i in (req.rejected_tool_call_ids or []) if isinstance(i, str) and i.strip()}
if approved_ids or rejected_ids:
_, last_call_msg = _find_last_assistant_tool_calls(internal_messages)
last_call_msg = _ensure_tool_call_ids(last_call_msg)
tool_calls = last_call_msg.get("tool_calls") or []
call_map: Dict[str, Dict[str, Any]] = {
str(c.get("id")): c
for c in tool_calls
if isinstance(c, dict) and isinstance(c.get("id"), str)
}
existing_ids = _existing_tool_result_ids(internal_messages)
for call_id in approved_ids | rejected_ids:
if call_id in existing_ids:
continue
tool_call = call_map.get(call_id)
if not tool_call:
continue
fn = tool_call.get("function") or {}
name = fn.get("name") if isinstance(fn, dict) else None
args_raw = fn.get("arguments") if isinstance(fn, dict) else None
args: Dict[str, Any] = {}
if isinstance(args_raw, str) and args_raw.strip():
try:
parsed = json.loads(args_raw)
if isinstance(parsed, dict):
args = parsed
except json.JSONDecodeError:
args = {}
spec = get_tool(str(name or ""))
if call_id in rejected_ids:
content = tool_result_to_content({"canceled": True, "reason": "user_rejected"})
tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content}
internal_messages.append(tool_msg)
new_messages.append(tool_msg)
continue
if not spec:
content = tool_result_to_content({"error": f"unknown_tool: {name}"})
tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content}
internal_messages.append(tool_msg)
new_messages.append(tool_msg)
continue
try:
result = await spec.handler(args)
content = tool_result_to_content(result)
except Exception as exc: # noqa: BLE001
content = tool_result_to_content({"error": str(exc)})
tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content}
internal_messages.append(tool_msg)
new_messages.append(tool_msg)
tools_schema = openai_tools()
ability = await _choose_chat_ability()
max_loops = 4
for _ in range(max_loops):
try:
assistant = await chat_completion(
internal_messages,
ability=ability,
tools=tools_schema,
tool_choice="auto",
timeout=60.0,
)
except MissingModelError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
except httpx.HTTPStatusError as exc:
raise HTTPException(status_code=502, detail=f"对话请求失败: {exc}") from exc
except httpx.RequestError as exc:
raise HTTPException(status_code=502, detail=f"对话请求异常: {exc}") from exc
assistant = _ensure_tool_call_ids(assistant)
internal_messages.append(assistant)
new_messages.append(assistant)
tool_calls = assistant.get("tool_calls")
if not isinstance(tool_calls, list) or not tool_calls:
break
pending = []
for call in tool_calls:
if not isinstance(call, dict):
continue
call_id = str(call.get("id") or "")
fn = call.get("function") or {}
name = fn.get("name") if isinstance(fn, dict) else None
args_raw = fn.get("arguments") if isinstance(fn, dict) else None
args: Dict[str, Any] = {}
if isinstance(args_raw, str) and args_raw.strip():
try:
parsed = json.loads(args_raw)
if isinstance(parsed, dict):
args = parsed
except json.JSONDecodeError:
args = {}
spec = get_tool(str(name or ""))
if not spec:
content = tool_result_to_content({"error": f"unknown_tool: {name}"})
tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content}
internal_messages.append(tool_msg)
new_messages.append(tool_msg)
continue
if spec.requires_confirmation and not req.auto_execute:
pending.append(_extract_pending(call, True))
continue
try:
result = await spec.handler(args)
content = tool_result_to_content(result)
except Exception as exc: # noqa: BLE001
content = tool_result_to_content({"error": str(exc)})
tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content}
internal_messages.append(tool_msg)
new_messages.append(tool_msg)
if pending:
break
payload: Dict[str, Any] = {"messages": new_messages}
if pending:
payload["pending_tool_calls"] = [p.model_dump() for p in pending]
return payload
@classmethod
async def chat_stream(cls, req: AgentChatRequest, user: Optional[User]):
history: List[Dict[str, Any]] = list(req.messages or [])
current_path = _normalize_path(req.context.current_path if req.context else None)
system_prompt = _build_system_prompt(current_path)
internal_messages: List[Dict[str, Any]] = [{"role": "system", "content": system_prompt}] + history
new_messages: List[Dict[str, Any]] = []
pending: List[PendingToolCall] = []
approved_ids = {i for i in (req.approved_tool_call_ids or []) if isinstance(i, str) and i.strip()}
rejected_ids = {i for i in (req.rejected_tool_call_ids or []) if isinstance(i, str) and i.strip()}
try:
if approved_ids or rejected_ids:
_, last_call_msg = _find_last_assistant_tool_calls(internal_messages)
last_call_msg = _ensure_tool_call_ids(last_call_msg)
tool_calls = last_call_msg.get("tool_calls") or []
call_map: Dict[str, Dict[str, Any]] = {
str(c.get("id")): c
for c in tool_calls
if isinstance(c, dict) and isinstance(c.get("id"), str)
}
existing_ids = _existing_tool_result_ids(internal_messages)
for call_id in approved_ids | rejected_ids:
if call_id in existing_ids:
continue
tool_call = call_map.get(call_id)
if not tool_call:
continue
fn = tool_call.get("function") or {}
name = fn.get("name") if isinstance(fn, dict) else None
args_raw = fn.get("arguments") if isinstance(fn, dict) else None
args: Dict[str, Any] = {}
if isinstance(args_raw, str) and args_raw.strip():
try:
parsed = json.loads(args_raw)
if isinstance(parsed, dict):
args = parsed
except json.JSONDecodeError:
args = {}
spec = get_tool(str(name or ""))
if call_id in rejected_ids:
content = tool_result_to_content({"canceled": True, "reason": "user_rejected"})
tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content}
internal_messages.append(tool_msg)
new_messages.append(tool_msg)
yield _sse("tool_end", {"tool_call_id": call_id, "name": str(name or ""), "message": tool_msg})
continue
if not spec:
content = tool_result_to_content({"error": f"unknown_tool: {name}"})
tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content}
internal_messages.append(tool_msg)
new_messages.append(tool_msg)
yield _sse("tool_end", {"tool_call_id": call_id, "name": str(name or ""), "message": tool_msg})
continue
yield _sse("tool_start", {"tool_call_id": call_id, "name": spec.name})
try:
result = await spec.handler(args)
content = tool_result_to_content(result)
except Exception as exc: # noqa: BLE001
content = tool_result_to_content({"error": str(exc)})
tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content}
internal_messages.append(tool_msg)
new_messages.append(tool_msg)
yield _sse("tool_end", {"tool_call_id": call_id, "name": spec.name, "message": tool_msg})
tools_schema = openai_tools()
ability = await _choose_chat_ability()
max_loops = 4
for _ in range(max_loops):
assistant_event_id = uuid.uuid4().hex
yield _sse("assistant_start", {"id": assistant_event_id})
assistant_message: Dict[str, Any] | None = None
try:
async for event in chat_completion_stream(
internal_messages,
ability=ability,
tools=tools_schema,
tool_choice="auto",
timeout=60.0,
):
if event.get("type") == "delta":
delta = event.get("delta")
if isinstance(delta, str) and delta:
yield _sse("assistant_delta", {"id": assistant_event_id, "delta": delta})
elif event.get("type") == "message":
msg = event.get("message")
if isinstance(msg, dict):
assistant_message = msg
except MissingModelError as exc:
raise HTTPException(status_code=400, detail=_format_exc(exc)) from exc
except httpx.HTTPStatusError as exc:
raise HTTPException(status_code=502, detail=f"对话请求失败: {_format_exc(exc)}") from exc
except httpx.RequestError as exc:
raise HTTPException(status_code=502, detail=f"对话请求异常: {_format_exc(exc)}") from exc
if not assistant_message:
assistant_message = {"role": "assistant", "content": ""}
assistant_message = _ensure_tool_call_ids(assistant_message)
internal_messages.append(assistant_message)
new_messages.append(assistant_message)
yield _sse("assistant_end", {"id": assistant_event_id, "message": assistant_message})
tool_calls = assistant_message.get("tool_calls")
if not isinstance(tool_calls, list) or not tool_calls:
break
pending = []
for call in tool_calls:
if not isinstance(call, dict):
continue
call_id = str(call.get("id") or "")
fn = call.get("function") or {}
name = fn.get("name") if isinstance(fn, dict) else None
args_raw = fn.get("arguments") if isinstance(fn, dict) else None
args: Dict[str, Any] = {}
if isinstance(args_raw, str) and args_raw.strip():
try:
parsed = json.loads(args_raw)
if isinstance(parsed, dict):
args = parsed
except json.JSONDecodeError:
args = {}
spec = get_tool(str(name or ""))
if not spec:
content = tool_result_to_content({"error": f"unknown_tool: {name}"})
tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content}
internal_messages.append(tool_msg)
new_messages.append(tool_msg)
yield _sse("tool_end", {"tool_call_id": call_id, "name": str(name or ""), "message": tool_msg})
continue
if spec.requires_confirmation and not req.auto_execute:
pending.append(_extract_pending(call, True))
continue
yield _sse("tool_start", {"tool_call_id": call_id, "name": spec.name})
try:
result = await spec.handler(args)
content = tool_result_to_content(result)
except Exception as exc: # noqa: BLE001
content = tool_result_to_content({"error": str(exc)})
tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content}
internal_messages.append(tool_msg)
new_messages.append(tool_msg)
yield _sse("tool_end", {"tool_call_id": call_id, "name": spec.name, "message": tool_msg})
if pending:
yield _sse("pending", {"pending_tool_calls": [p.model_dump() for p in pending]})
break
payload: Dict[str, Any] = {"messages": new_messages}
if pending:
payload["pending_tool_calls"] = [p.model_dump() for p in pending]
yield _sse("done", payload)
except asyncio.CancelledError:
return
except HTTPException as exc:
detail = exc.detail
content = detail if isinstance(detail, str) else str(detail)
if not content.strip():
content = f"请求失败({exc.status_code})"
new_messages.append({"role": "assistant", "content": content})
payload: Dict[str, Any] = {"messages": new_messages}
if pending:
payload["pending_tool_calls"] = [p.model_dump() for p in pending]
yield _sse("done", payload)
return
except Exception as exc: # noqa: BLE001
new_messages.append({"role": "assistant", "content": f"服务端异常: {_format_exc(exc)}"})
payload: Dict[str, Any] = {"messages": new_messages}
if pending:
payload["pending_tool_calls"] = [p.model_dump() for p in pending]
yield _sse("done", payload)
return

View File

@@ -0,0 +1,37 @@
from typing import Any, Dict, List, Optional
from .base import ToolSpec, tool_result_to_content
from .processors import TOOLS as PROCESSOR_TOOLS
from .time import TOOLS as TIME_TOOLS
from .vfs import TOOLS as VFS_TOOLS
from .web_fetch import TOOLS as WEB_FETCH_TOOLS
TOOLS: Dict[str, ToolSpec] = {}
for group in (TIME_TOOLS, WEB_FETCH_TOOLS, PROCESSOR_TOOLS, VFS_TOOLS):
TOOLS.update(group)
def get_tool(name: str) -> Optional[ToolSpec]:
return TOOLS.get(name)
def openai_tools() -> List[Dict[str, Any]]:
out: List[Dict[str, Any]] = []
for spec in TOOLS.values():
out.append({
"type": "function",
"function": {
"name": spec.name,
"description": spec.description,
"parameters": spec.parameters,
},
})
return out
__all__ = [
"ToolSpec",
"get_tool",
"openai_tools",
"tool_result_to_content",
]

149
domain/agent/tools/base.py Normal file
View File

@@ -0,0 +1,149 @@
import json
from dataclasses import dataclass
from typing import Any, Awaitable, Callable, Dict, List, Optional
@dataclass(frozen=True)
class ToolSpec:
name: str
description: str
parameters: Dict[str, Any]
requires_confirmation: bool
handler: Callable[[Dict[str, Any]], Awaitable[Any]]
def _stringify_value(value: Any) -> str:
if value is None:
return ""
if isinstance(value, bool):
return "true" if value else "false"
if isinstance(value, (int, float)):
return str(value)
if isinstance(value, str):
return value
try:
return json.dumps(value, ensure_ascii=False)
except TypeError:
return str(value)
def _list_to_view_items(items: List[Any]) -> List[Any]:
normalized: List[Any] = []
for item in items:
if isinstance(item, dict):
normalized.append({str(k): _stringify_value(v) for k, v in item.items()})
else:
normalized.append(_stringify_value(item))
return normalized
def _dict_to_kv_items(data: Dict[str, Any]) -> List[Dict[str, str]]:
return [{"key": str(k), "value": _stringify_value(v)} for k, v in data.items()]
def _first_list_field(data: Dict[str, Any]) -> tuple[Optional[str], Optional[List[Any]]]:
for key, value in data.items():
if isinstance(value, list):
return str(key), value
return None, None
def _build_view(data: Any) -> Dict[str, Any]:
if data is None:
return {"type": "kv", "items": []}
if isinstance(data, str):
return {"type": "text", "text": data}
if isinstance(data, list):
return {"type": "list", "items": _list_to_view_items(data)}
if isinstance(data, dict):
content = data.get("content")
if isinstance(content, str):
meta = {k: _stringify_value(v) for k, v in data.items() if k != "content"}
view: Dict[str, Any] = {"type": "text", "text": content}
if meta:
view["meta"] = meta
return view
list_key, list_val = _first_list_field(data)
if list_key and isinstance(list_val, list):
meta = {k: _stringify_value(v) for k, v in data.items() if k != list_key}
view = {"type": "list", "title": list_key, "items": _list_to_view_items(list_val)}
if meta:
view["meta"] = meta
return view
return {"type": "kv", "items": _dict_to_kv_items(data)}
return {"type": "text", "text": _stringify_value(data)}
def _build_summary(view: Dict[str, Any]) -> str:
view_type = str(view.get("type") or "")
if view_type == "text":
text = view.get("text")
size = len(text) if isinstance(text, str) else 0
return f"chars: {size}" if size else "text"
if view_type == "list":
items = view.get("items")
count = len(items) if isinstance(items, list) else 0
title = str(view.get("title") or "items")
return f"{title}: {count}"
if view_type == "kv":
items = view.get("items")
count = len(items) if isinstance(items, list) else 0
return f"fields: {count}"
if view_type == "error":
return str(view.get("message") or "error")
return ""
def _build_error_payload(code: str, message: str, detail: Any = None) -> Dict[str, Any]:
summary = "Canceled" if code == "canceled" else message or "error"
view = {"type": "error", "message": summary}
payload: Dict[str, Any] = {
"ok": False,
"summary": summary,
"view": view,
"error": {
"code": code,
"message": message,
},
}
if detail is not None:
payload["error"]["detail"] = detail
return payload
def _normalize_tool_result(result: Any) -> Dict[str, Any]:
if isinstance(result, dict) and "ok" in result:
payload = dict(result)
if payload.get("ok") is False:
error = payload.get("error")
message = _stringify_value(error.get("message") if isinstance(error, dict) else error)
payload.setdefault("summary", message or "error")
payload.setdefault("view", {"type": "error", "message": payload["summary"]})
return payload
data = payload.get("data")
if payload.get("view") is None:
payload["view"] = _build_view(data)
if not payload.get("summary"):
payload["summary"] = _build_summary(payload["view"])
return payload
if isinstance(result, dict) and result.get("canceled"):
reason = _stringify_value(result.get("reason") or "canceled")
return _build_error_payload("canceled", reason, detail=result)
if isinstance(result, dict) and "error" in result:
error = result.get("error")
message = _stringify_value(error.get("message") if isinstance(error, dict) else error)
return _build_error_payload("error", message, detail=error)
view = _build_view(result)
summary = _build_summary(view)
return {"ok": True, "summary": summary, "view": view, "data": result}
def tool_result_to_content(result: Any) -> str:
payload = _normalize_tool_result(result)
try:
return json.dumps(payload, ensure_ascii=False, default=str)
except TypeError:
return json.dumps({"ok": False, "summary": "error", "view": {"type": "error", "message": "error"}}, ensure_ascii=False)

View File

@@ -0,0 +1,96 @@
from typing import Any, Dict, Optional
from domain.processors import ProcessDirectoryRequest, ProcessRequest, ProcessorService
from domain.virtual_fs import VirtualFSService
from .base import ToolSpec
async def _processors_list(_: Dict[str, Any]) -> Dict[str, Any]:
return {"processors": ProcessorService.list_processors()}
async def _processors_run(args: Dict[str, Any]) -> Dict[str, Any]:
path = str(args.get("path") or "")
processor_type = str(args.get("processor_type") or "")
config = args.get("config")
if not isinstance(config, dict):
config = {}
save_to = args.get("save_to")
save_to = str(save_to) if isinstance(save_to, str) and save_to.strip() else None
max_depth = args.get("max_depth")
max_depth_value: Optional[int] = None
if max_depth is not None:
try:
max_depth_value = int(max_depth)
except (TypeError, ValueError):
max_depth_value = None
suffix = args.get("suffix")
suffix_value = str(suffix) if isinstance(suffix, str) and suffix.strip() else None
overwrite_value = args.get("overwrite")
overwrite = bool(overwrite_value) if overwrite_value is not None else None
is_dir = await VirtualFSService.path_is_directory(path)
if is_dir and (max_depth_value is not None or suffix_value is not None):
req = ProcessDirectoryRequest(
path=path,
processor_type=processor_type,
config=config,
overwrite=True if overwrite is None else overwrite,
max_depth=max_depth_value,
suffix=suffix_value,
)
result = await ProcessorService.process_directory(req)
return {"mode": "directory", **result}
req = ProcessRequest(
path=path,
processor_type=processor_type,
config=config,
save_to=save_to,
overwrite=False if overwrite is None else overwrite,
)
result = await ProcessorService.process_file(req)
return {"mode": "file", **result}
TOOLS: Dict[str, ToolSpec] = {
"processors_list": ToolSpec(
name="processors_list",
description="获取可用处理器列表type/name/config_schema 等)。",
parameters={
"type": "object",
"properties": {},
"additionalProperties": False,
},
requires_confirmation=False,
handler=_processors_list,
),
"processors_run": ToolSpec(
name="processors_run",
description=(
"运行处理器处理文件或目录。"
" 对目录可选 max_depth/suffix对文件可选 overwrite/save_to。"
" 返回任务 id去任务队列查看进度"
),
parameters={
"type": "object",
"properties": {
"path": {"type": "string", "description": "文件或目录路径(绝对路径,如 /foo/bar"},
"processor_type": {"type": "string", "description": "处理器类型(例如 image_watermark"},
"config": {"type": "object", "description": "处理器配置,按 processors_list 返回的 config_schema 填写"},
"overwrite": {"type": "boolean", "description": "是否覆盖原文件/目录内文件"},
"save_to": {"type": "string", "description": "保存到指定路径(仅文件模式,且 overwrite=false 时使用)"},
"max_depth": {"type": "integer", "description": "目录遍历深度(仅目录模式)"},
"suffix": {"type": "string", "description": "目录批处理时的输出后缀(仅 produces_file 且 overwrite=false"},
},
"required": ["path", "processor_type"],
},
requires_confirmation=True,
handler=_processors_run,
),
}

View File

@@ -0,0 +1,92 @@
import calendar
from datetime import datetime, timedelta
from typing import Any, Dict
from .base import ToolSpec
def _parse_offset(args: Dict[str, Any], key: str) -> int:
value = args.get(key)
if value is None:
return 0
try:
return int(value)
except (TypeError, ValueError):
return 0
def _add_months(dt: datetime, months: int) -> datetime:
if months == 0:
return dt
total = dt.year * 12 + (dt.month - 1) + months
year = total // 12
month = total % 12 + 1
last_day = calendar.monthrange(year, month)[1]
day = min(dt.day, last_day)
return dt.replace(year=year, month=month, day=day)
async def _time(args: Dict[str, Any]) -> Dict[str, Any]:
now = datetime.now()
year_offset = _parse_offset(args, "year")
month_offset = _parse_offset(args, "month")
day_offset = _parse_offset(args, "day")
hour_offset = _parse_offset(args, "hour")
minute_offset = _parse_offset(args, "minute")
second_offset = _parse_offset(args, "second")
dt = _add_months(now, year_offset * 12 + month_offset)
dt = dt + timedelta(days=day_offset, hours=hour_offset, minutes=minute_offset, seconds=second_offset)
weekday_names = [
"Monday",
"Tuesday",
"Wednesday",
"Thursday",
"Friday",
"Saturday",
"Sunday",
]
weekday = weekday_names[dt.weekday()]
dt_str = dt.strftime("%Y-%m-%d %H:%M:%S")
return {
"ok": True,
"summary": f"{dt_str} · {weekday}",
"data": {
"datetime": dt_str,
"weekday": weekday,
"offset": {
"year": year_offset,
"month": month_offset,
"day": day_offset,
"hour": hour_offset,
"minute": minute_offset,
"second": second_offset,
},
},
}
TOOLS: Dict[str, ToolSpec] = {
"time": ToolSpec(
name="time",
description=(
"获取服务器当前时间(精确到秒,含英文星期)。"
" 支持 year/month/day/hour/minute/second 偏移(可为负数)。"
),
parameters={
"type": "object",
"properties": {
"year": {"type": "integer", "description": "年偏移(可为负数)"},
"month": {"type": "integer", "description": "月偏移(可为负数)"},
"day": {"type": "integer", "description": "日偏移(可为负数)"},
"hour": {"type": "integer", "description": "时偏移(可为负数)"},
"minute": {"type": "integer", "description": "分偏移(可为负数)"},
"second": {"type": "integer", "description": "秒偏移(可为负数)"},
},
"additionalProperties": False,
},
requires_confirmation=False,
handler=_time,
),
}

287
domain/agent/tools/vfs.py Normal file
View File

@@ -0,0 +1,287 @@
from typing import Any, Dict, Optional
from domain.virtual_fs import VirtualFSService
from domain.virtual_fs.search import VirtualFSSearchService
from .base import ToolSpec
def _normalize_vfs_path(value: Any) -> str:
s = str(value or "").strip().replace("\\", "/")
if not s:
return ""
if not s.startswith("/"):
s = "/" + s
s = s.rstrip("/") or "/"
return s
def _require_vfs_path(value: Any, field: str) -> str:
path = _normalize_vfs_path(value)
if not path:
raise ValueError(f"missing_{field}")
return path
async def _vfs_list_dir(args: Dict[str, Any]) -> Dict[str, Any]:
path = _normalize_vfs_path(args.get("path") or "/") or "/"
page = int(args.get("page") or 1)
page_size = int(args.get("page_size") or 50)
sort_by = str(args.get("sort_by") or "name")
sort_order = str(args.get("sort_order") or "asc")
return await VirtualFSService.list_directory(path, page, page_size, sort_by, sort_order)
async def _vfs_stat(args: Dict[str, Any]) -> Any:
path = _require_vfs_path(args.get("path"), "path")
return await VirtualFSService.stat(path)
async def _vfs_read_text(args: Dict[str, Any]) -> Dict[str, Any]:
path = _require_vfs_path(args.get("path"), "path")
encoding = str(args.get("encoding") or "utf-8")
max_chars = int(args.get("max_chars") or 8000)
data = await VirtualFSService.read_file(path)
if isinstance(data, (bytes, bytearray)):
try:
text = bytes(data).decode(encoding)
except UnicodeDecodeError:
return {"error": "binary_or_invalid_text", "path": path}
elif isinstance(data, str):
text = data
else:
text = str(data)
original_len = len(text)
truncated = original_len > max_chars
if truncated:
text = text[:max_chars]
return {
"path": path,
"encoding": encoding,
"content": text,
"truncated": truncated,
"length": original_len,
}
async def _vfs_write_text(args: Dict[str, Any]) -> Dict[str, Any]:
path = _require_vfs_path(args.get("path"), "path")
if path == "/":
raise ValueError("invalid_path")
encoding = str(args.get("encoding") or "utf-8")
content = str(args.get("content") or "")
data = content.encode(encoding)
await VirtualFSService.write_file(path, data)
return {"written": True, "path": path, "encoding": encoding, "bytes": len(data)}
async def _vfs_mkdir(args: Dict[str, Any]) -> Dict[str, Any]:
path = _require_vfs_path(args.get("path"), "path")
return await VirtualFSService.mkdir(path)
async def _vfs_delete(args: Dict[str, Any]) -> Dict[str, Any]:
path = _require_vfs_path(args.get("path"), "path")
return await VirtualFSService.delete(path)
async def _vfs_move(args: Dict[str, Any]) -> Dict[str, Any]:
src = _require_vfs_path(args.get("src"), "src")
dst = _require_vfs_path(args.get("dst"), "dst")
if src == "/" or dst == "/":
raise ValueError("invalid_path")
overwrite = bool(args.get("overwrite") or False)
return await VirtualFSService.move(src, dst, overwrite)
async def _vfs_copy(args: Dict[str, Any]) -> Dict[str, Any]:
src = _require_vfs_path(args.get("src"), "src")
dst = _require_vfs_path(args.get("dst"), "dst")
if src == "/" or dst == "/":
raise ValueError("invalid_path")
overwrite = bool(args.get("overwrite") or False)
return await VirtualFSService.copy(src, dst, overwrite)
async def _vfs_rename(args: Dict[str, Any]) -> Dict[str, Any]:
src = _require_vfs_path(args.get("src"), "src")
dst = _require_vfs_path(args.get("dst"), "dst")
if src == "/" or dst == "/":
raise ValueError("invalid_path")
overwrite = bool(args.get("overwrite") or False)
return await VirtualFSService.rename(src, dst, overwrite)
async def _vfs_search(args: Dict[str, Any]) -> Dict[str, Any]:
q = str(args.get("q") or "").strip()
if not q:
raise ValueError("missing_q")
mode = str(args.get("mode") or "vector")
top_k = int(args.get("top_k") or 10)
page = int(args.get("page") or 1)
page_size = int(args.get("page_size") or 10)
return await VirtualFSSearchService.search(q, top_k, mode, page, page_size)
TOOLS: Dict[str, ToolSpec] = {
"vfs_list_dir": ToolSpec(
name="vfs_list_dir",
description="浏览目录(列出 entries + pagination",
parameters={
"type": "object",
"properties": {
"path": {"type": "string", "description": "目录路径(绝对路径,如 /foo/bar"},
"page": {"type": "integer", "description": "页码(从 1 开始)"},
"page_size": {"type": "integer", "description": "每页条数"},
"sort_by": {"type": "string", "description": "排序字段name/size/mtime"},
"sort_order": {"type": "string", "description": "排序顺序asc/desc"},
},
"required": ["path"],
"additionalProperties": False,
},
requires_confirmation=False,
handler=_vfs_list_dir,
),
"vfs_stat": ToolSpec(
name="vfs_stat",
description="查看文件/目录信息size/mtime/is_dir/has_thumbnail/vector_index 等)。",
parameters={
"type": "object",
"properties": {
"path": {"type": "string", "description": "路径(绝对路径,如 /foo/bar.txt"},
},
"required": ["path"],
"additionalProperties": False,
},
requires_confirmation=False,
handler=_vfs_stat,
),
"vfs_read_text": ToolSpec(
name="vfs_read_text",
description="读取文本文件内容(解码失败视为二进制,返回 error",
parameters={
"type": "object",
"properties": {
"path": {"type": "string", "description": "文件路径(绝对路径,如 /foo/bar.md"},
"encoding": {"type": "string", "description": "文本编码(默认 utf-8"},
"max_chars": {"type": "integer", "description": "最多返回的字符数(默认 8000"},
},
"required": ["path"],
"additionalProperties": False,
},
requires_confirmation=False,
handler=_vfs_read_text,
),
"vfs_write_text": ToolSpec(
name="vfs_write_text",
description="写入文本文件内容(会覆盖目标文件)。",
parameters={
"type": "object",
"properties": {
"path": {"type": "string", "description": "文件路径(绝对路径,如 /foo/bar.md"},
"content": {"type": "string", "description": "要写入的文本内容"},
"encoding": {"type": "string", "description": "文本编码(默认 utf-8"},
},
"required": ["path", "content"],
"additionalProperties": False,
},
requires_confirmation=True,
handler=_vfs_write_text,
),
"vfs_mkdir": ToolSpec(
name="vfs_mkdir",
description="创建目录。",
parameters={
"type": "object",
"properties": {
"path": {"type": "string", "description": "目录路径(绝对路径,如 /foo/bar"},
},
"required": ["path"],
"additionalProperties": False,
},
requires_confirmation=True,
handler=_vfs_mkdir,
),
"vfs_delete": ToolSpec(
name="vfs_delete",
description="删除文件或目录(由底层适配器决定是否递归)。",
parameters={
"type": "object",
"properties": {
"path": {"type": "string", "description": "路径(绝对路径,如 /foo/bar 或 /foo/bar.txt"},
},
"required": ["path"],
"additionalProperties": False,
},
requires_confirmation=True,
handler=_vfs_delete,
),
"vfs_move": ToolSpec(
name="vfs_move",
description="移动路径(可能进入任务队列)。",
parameters={
"type": "object",
"properties": {
"src": {"type": "string", "description": "源路径(绝对路径)"},
"dst": {"type": "string", "description": "目标路径(绝对路径)"},
"overwrite": {"type": "boolean", "description": "是否允许覆盖已存在目标(默认 false"},
},
"required": ["src", "dst"],
"additionalProperties": False,
},
requires_confirmation=True,
handler=_vfs_move,
),
"vfs_copy": ToolSpec(
name="vfs_copy",
description="复制路径(可能进入任务队列)。",
parameters={
"type": "object",
"properties": {
"src": {"type": "string", "description": "源路径(绝对路径)"},
"dst": {"type": "string", "description": "目标路径(绝对路径)"},
"overwrite": {"type": "boolean", "description": "是否覆盖已存在目标(默认 false"},
},
"required": ["src", "dst"],
"additionalProperties": False,
},
requires_confirmation=True,
handler=_vfs_copy,
),
"vfs_rename": ToolSpec(
name="vfs_rename",
description="重命名路径(本质是同目录 move",
parameters={
"type": "object",
"properties": {
"src": {"type": "string", "description": "源路径(绝对路径)"},
"dst": {"type": "string", "description": "目标路径(绝对路径)"},
"overwrite": {"type": "boolean", "description": "是否允许覆盖已存在目标(默认 false"},
},
"required": ["src", "dst"],
"additionalProperties": False,
},
requires_confirmation=True,
handler=_vfs_rename,
),
"vfs_search": ToolSpec(
name="vfs_search",
description="搜索文件mode=vector 或 filename",
parameters={
"type": "object",
"properties": {
"q": {"type": "string", "description": "搜索关键词"},
"mode": {"type": "string", "description": "搜索模式vector/filename默认 vector"},
"top_k": {"type": "integer", "description": "返回数量vector 模式使用,默认 10"},
"page": {"type": "integer", "description": "页码filename 模式使用,默认 1"},
"page_size": {"type": "integer", "description": "分页大小filename 模式使用,默认 10"},
},
"required": ["q"],
"additionalProperties": False,
},
requires_confirmation=False,
handler=_vfs_search,
),
}

View File

@@ -0,0 +1,182 @@
from html.parser import HTMLParser
from typing import Any, Dict, List
from urllib.parse import urljoin
import httpx
from .base import ToolSpec
class _HtmlTextExtractor(HTMLParser):
def __init__(self, base_url: str):
super().__init__()
self.base_url = base_url
self.links: List[str] = []
self._link_set: set[str] = set()
self._title_parts: List[str] = []
self._text_parts: List[str] = []
self._in_title = False
self._skip_text = False
def handle_starttag(self, tag: str, attrs: List[tuple[str, str | None]]):
tag = tag.lower()
if tag == "title":
self._in_title = True
if tag in ("script", "style", "noscript"):
self._skip_text = True
if tag != "a":
return
href = ""
for key, value in attrs:
if key.lower() == "href":
href = str(value or "").strip()
break
if not href or href.startswith("#"):
return
lower = href.lower()
if lower.startswith(("javascript:", "mailto:", "tel:", "data:")):
return
resolved = urljoin(self.base_url, href)
if resolved in self._link_set:
return
self._link_set.add(resolved)
self.links.append(resolved)
def handle_endtag(self, tag: str):
tag = tag.lower()
if tag == "title":
self._in_title = False
if tag in ("script", "style", "noscript"):
self._skip_text = False
def handle_data(self, data: str):
if not data:
return
if self._in_title:
self._title_parts.append(data)
if self._skip_text:
return
if data.strip():
self._text_parts.append(data)
@property
def title(self) -> str:
return " ".join(part.strip() for part in self._title_parts if part and part.strip()).strip()
@property
def text(self) -> str:
if not self._text_parts:
return ""
text = " ".join(part.strip() for part in self._text_parts if part and part.strip())
return " ".join(text.split())
async def _web_fetch(args: Dict[str, Any]) -> Dict[str, Any]:
url = str(args.get("url") or "").strip()
if not url:
raise ValueError("missing_url")
method = str(args.get("method") or "GET").upper()
allowed_methods = {"GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"}
if method not in allowed_methods:
raise ValueError("invalid_method")
headers_raw = args.get("headers")
headers = {str(k): str(v) for k, v in headers_raw.items() if v is not None} if isinstance(headers_raw, dict) else None
params_raw = args.get("params")
params = {str(k): str(v) for k, v in params_raw.items() if v is not None} if isinstance(params_raw, dict) else None
json_body = args.get("json") if "json" in args else None
body = args.get("body")
request_kwargs: Dict[str, Any] = {}
if headers:
request_kwargs["headers"] = headers
if params:
request_kwargs["params"] = params
if json_body is not None:
request_kwargs["json"] = json_body
elif body is not None:
request_kwargs["content"] = str(body)
async with httpx.AsyncClient(timeout=20.0, follow_redirects=True) as client:
resp = await client.request(method, url, **request_kwargs)
content_type = resp.headers.get("content-type") or ""
text = resp.text or ""
is_html = "html" in content_type.lower()
if not is_html:
probe = text.lstrip()[:200].lower()
if "<html" in probe or "<!doctype html" in probe:
is_html = True
html = ""
title = ""
links: List[str] = []
extracted_text = text
if is_html and text:
html = text
parser = _HtmlTextExtractor(str(resp.url))
parser.feed(text)
title = parser.title
links = parser.links
extracted_text = parser.text
data = {
"url": url,
"method": method,
"final_url": str(resp.url),
"status_code": resp.status_code,
"content_type": content_type,
"title": title,
"html": html,
"text": extracted_text,
"links": links,
}
summary_parts = [method, str(resp.status_code)]
if title:
summary_parts.append(title)
summary_parts.append(f"{len(links)} links")
summary = " · ".join(summary_parts)
view = {
"type": "text",
"text": extracted_text,
"meta": {
"url": url,
"final_url": str(resp.url),
"status_code": resp.status_code,
"content_type": content_type,
"title": title,
"method": method,
"links": len(links),
},
}
return {"ok": True, "summary": summary, "view": view, "data": data}
TOOLS: Dict[str, ToolSpec] = {
"web_fetch": ToolSpec(
name="web_fetch",
description=(
"抓取网页内容返回状态、标题、正文、HTML、链接等信息。"
" 支持 GET/POST/PUT/PATCH/DELETE/HEAD/OPTIONS。"
),
parameters={
"type": "object",
"properties": {
"url": {"type": "string", "description": "目标 URL"},
"method": {"type": "string", "description": "请求方法(默认 GET"},
"headers": {"type": "object", "description": "请求头", "additionalProperties": {"type": "string"}},
"params": {"type": "object", "description": "查询参数", "additionalProperties": {"type": "string"}},
"json": {"type": "object", "description": "JSON 请求体"},
"body": {"type": "string", "description": "原始请求体"},
},
"required": ["url"],
"additionalProperties": False,
},
requires_confirmation=False,
handler=_web_fetch,
),
}

23
domain/agent/types.py Normal file
View File

@@ -0,0 +1,23 @@
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
class AgentChatContext(BaseModel):
current_path: Optional[str] = None
class AgentChatRequest(BaseModel):
messages: List[Dict[str, Any]] = Field(default_factory=list)
auto_execute: bool = False
approved_tool_call_ids: List[str] = Field(default_factory=list)
rejected_tool_call_ids: List[str] = Field(default_factory=list)
context: Optional[AgentChatContext] = None
class PendingToolCall(BaseModel):
id: str
name: str
arguments: Dict[str, Any] = Field(default_factory=dict)
requires_confirmation: bool = True

View File

@@ -1,28 +1,61 @@
from .api import router_ai, router_vector_db
from .inference import (
MissingModelError,
chat_completion,
chat_completion_stream,
describe_image_base64,
get_text_embedding,
provider_service,
rerank_texts,
)
from .service import (
AIProviderService,
FILE_COLLECTION_NAME,
VECTOR_COLLECTION_NAME,
DEFAULT_VECTOR_DIMENSION,
VectorDBConfigManager,
VectorDBService,
DEFAULT_VECTOR_DIMENSION,
ABILITIES,
normalize_capabilities,
)
from .types import (
ABILITIES,
AIDefaultsUpdate,
AIModelCreate,
AIModelUpdate,
AIProviderCreate,
AIProviderUpdate,
VectorDBConfigPayload,
normalize_capabilities,
)
from .vector_providers import (
BaseVectorProvider,
MilvusLiteProvider,
MilvusServerProvider,
QdrantProvider,
get_provider_class,
get_provider_entry,
list_providers,
)
__all__ = [
"router_ai",
"router_vector_db",
"MissingModelError",
"chat_completion",
"chat_completion_stream",
"describe_image_base64",
"get_text_embedding",
"provider_service",
"rerank_texts",
"AIProviderService",
"VectorDBService",
"VectorDBConfigManager",
"DEFAULT_VECTOR_DIMENSION",
"VECTOR_COLLECTION_NAME",
"FILE_COLLECTION_NAME",
"BaseVectorProvider",
"MilvusLiteProvider",
"MilvusServerProvider",
"QdrantProvider",
"list_providers",
"get_provider_entry",
"get_provider_class",
"ABILITIES",
"normalize_capabilities",
"AIDefaultsUpdate",

View File

@@ -5,8 +5,9 @@ from fastapi import APIRouter, Depends, HTTPException, Path, Request
from api.response import success
from domain.audit import AuditAction, audit
from domain.ai.service import AIProviderService, VectorDBConfigManager, VectorDBService
from domain.ai.types import (
from domain.auth import User, get_current_active_user
from .service import AIProviderService, VectorDBConfigManager, VectorDBService
from .types import (
AIDefaultsUpdate,
AIModelCreate,
AIModelUpdate,
@@ -14,9 +15,7 @@ from domain.ai.types import (
AIProviderUpdate,
VectorDBConfigPayload,
)
from domain.ai.vector_providers import get_provider_class, get_provider_entry, list_providers
from domain.auth.service import get_current_active_user
from domain.auth.types import User
from .vector_providers import get_provider_class, get_provider_entry, list_providers
router_ai = APIRouter(prefix="/api/ai", tags=["ai"])
router_vector_db = APIRouter(prefix="/api/vector-db", tags=["vector-db"])
@@ -251,7 +250,7 @@ async def get_vector_db_stats(request: Request, user: User = Depends(get_current
@audit(action=AuditAction.READ, description="获取向量数据库提供者列表")
@router_vector_db.get("/providers", summary="列出可用向量数据库提供者")
async def list_vector_providers(request: Request, user: User = Depends(get_current_active_user)):
async def list_vector_providers(request: Request):
return success(list_providers())

File diff suppressed because it is too large Load Diff

View File

@@ -7,7 +7,7 @@ import httpx
from tortoise.exceptions import DoesNotExist
from tortoise.transactions import in_transaction
from domain.config.service import ConfigService
from domain.config import ConfigService
from models.database import AIDefaultModel, AIModel, AIProvider
from .types import ABILITIES, normalize_capabilities
@@ -140,7 +140,7 @@ def serialize_provider(provider: AIProvider) -> Dict[str, Any]:
"provider_type": provider.provider_type,
"api_format": provider.api_format,
"base_url": provider.base_url,
"api_key": provider.api_key,
"has_api_key": bool(provider.api_key),
"logo_url": provider.logo_url,
"extra_config": provider.extra_config or {},
"created_at": provider.created_at,

View File

@@ -30,8 +30,8 @@ class AIProviderBase(BaseModel):
@classmethod
def normalize_format(cls, value: str) -> str:
fmt = value.lower()
if fmt not in {"openai", "gemini"}:
raise ValueError("api_format must be 'openai' or 'gemini'")
if fmt not in {"openai", "gemini", "anthropic", "ollama"}:
raise ValueError("api_format must be 'openai', 'gemini', 'anthropic', or 'ollama'")
return fmt
@@ -54,8 +54,8 @@ class AIProviderUpdate(BaseModel):
if value is None:
return value
fmt = value.lower()
if fmt not in {"openai", "gemini"}:
raise ValueError("api_format must be 'openai' or 'gemini'")
if fmt not in {"openai", "gemini", "anthropic", "ollama"}:
raise ValueError("api_format must be 'openai', 'gemini', 'anthropic', or 'ollama'")
return fmt

View File

@@ -1,5 +1,4 @@
from domain.audit.decorator import audit
from domain.audit.types import AuditAction
from domain.audit.api import router
from .decorator import audit
from .types import AuditAction
__all__ = ["audit", "AuditAction", "router"]
__all__ = ["audit", "AuditAction"]

View File

@@ -4,10 +4,11 @@ from typing import Annotated, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from api import response
from domain.audit.service import AuditService
from domain.audit.types import AuditAction
from domain.auth.service import get_current_active_user
from domain.auth.types import User
from domain.auth import User, get_current_active_user
from domain.permission import require_system_permission
from domain.permission.types import SystemPermission
from .service import AuditService
from .types import AuditAction
CurrentUser = Annotated[User, Depends(get_current_active_user)]
@@ -28,6 +29,7 @@ def _parse_iso(value: Optional[str], field: str):
@router.get("/logs")
@require_system_permission(SystemPermission.AUDIT_VIEW)
async def list_audit_logs(
current_user: CurrentUser,
page_num: int = Query(1, ge=1, alias="page", description="页码"),
@@ -55,6 +57,7 @@ async def list_audit_logs(
@router.delete("/logs")
@require_system_permission(SystemPermission.AUDIT_VIEW)
async def clear_audit_logs(
current_user: CurrentUser,
start_time: str | None = Query(None, description="开始时间 (ISO 8601)"),

View File

@@ -7,11 +7,11 @@ import jwt
from fastapi import Request
from jwt.exceptions import InvalidTokenError
from domain.audit.service import AuditService
from domain.audit.types import AuditAction
from domain.auth.service import ALGORITHM
from domain.config.service import ConfigService
from domain.auth import ALGORITHM
from domain.config import ConfigService
from models.database import UserAccount
from .service import AuditService
from .types import AuditAction
def _extract_request(bound_args: Mapping[str, Any]) -> Request | None:

View File

@@ -2,7 +2,7 @@ from typing import Any, Dict, Optional
from models.database import AuditLog
from domain.audit.types import AuditAction
from .types import AuditAction
class AuditService:

49
domain/auth/__init__.py Normal file
View File

@@ -0,0 +1,49 @@
from .service import (
ALGORITHM,
AuthService,
authenticate_user_db,
create_access_token,
get_current_active_user,
get_current_user,
get_password_hash,
has_users,
register_user,
request_password_reset,
reset_password_with_token,
verify_password,
verify_password_reset_token,
)
from .types import (
PasswordResetConfirm,
PasswordResetRequest,
RegisterRequest,
Token,
TokenData,
UpdateMeRequest,
User,
UserInDB,
)
__all__ = [
"ALGORITHM",
"AuthService",
"authenticate_user_db",
"create_access_token",
"get_current_active_user",
"get_current_user",
"get_password_hash",
"has_users",
"register_user",
"request_password_reset",
"reset_password_with_token",
"verify_password",
"verify_password_reset_token",
"PasswordResetConfirm",
"PasswordResetRequest",
"RegisterRequest",
"Token",
"TokenData",
"UpdateMeRequest",
"User",
"UserInDB",
]

View File

@@ -5,8 +5,8 @@ from fastapi.security import OAuth2PasswordRequestForm
from api.response import success
from domain.audit import AuditAction, audit
from domain.auth.service import AuthService, get_current_active_user
from domain.auth.types import (
from .service import AuthService, get_current_active_user
from .types import (
PasswordResetConfirm,
PasswordResetRequest,
RegisterRequest,
@@ -18,16 +18,16 @@ from domain.auth.types import (
router = APIRouter(prefix="/api/auth", tags=["auth"])
@router.post("/register", summary="注册第一个管理员用户")
@router.post("/register", summary="注册用户(首个用户为管理员")
@audit(
action=AuditAction.REGISTER,
description="注册管理员",
description="注册用户",
body_fields=["username", "email", "full_name"],
redact_fields=["password"],
)
async def register(request: Request, data: RegisterRequest):
user = await AuthService.register_user(data)
return success({"username": user.username}, msg="初始用户注册成功")
return success({"username": user.username}, msg="注册成功")
@router.post("/login")

View File

@@ -11,7 +11,9 @@ from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from jwt.exceptions import InvalidTokenError
from domain.auth.types import (
from domain.config import ConfigService
from models.database import Role, UserAccount, UserRole
from .types import (
PasswordResetConfirm,
PasswordResetRequest,
RegisterRequest,
@@ -21,8 +23,6 @@ from domain.auth.types import (
User,
UserInDB,
)
from models.database import UserAccount
from domain.config.service import ConfigService
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 * 365
@@ -140,6 +140,7 @@ class AuthService:
email=user.email,
full_name=user.full_name,
disabled=user.disabled,
is_admin=user.is_admin,
hashed_password=user.hashed_password,
)
return None
@@ -160,19 +161,60 @@ class AuthService:
@classmethod
async def register_user(cls, payload: RegisterRequest):
if await cls.has_users():
raise HTTPException(status_code=403, detail="系统已初始化,不允许注册新用户")
has_users = await cls.has_users()
normalized_email = cls._normalize_email(payload.email)
if not normalized_email:
raise HTTPException(status_code=400, detail="邮箱不能为空")
if has_users:
allow_register = str(await ConfigService.get("AUTH_ALLOW_REGISTER", "false") or "").strip().lower()
if allow_register not in ("1", "true", "yes", "on"):
raise HTTPException(status_code=403, detail="系统未开放注册")
default_role_id_raw = str(await ConfigService.get("AUTH_DEFAULT_REGISTER_ROLE_ID", "") or "").strip()
if not default_role_id_raw:
raise HTTPException(status_code=400, detail="未配置默认注册角色")
try:
default_role_id = int(default_role_id_raw)
except ValueError as exc:
raise HTTPException(status_code=400, detail="默认注册角色配置错误") from exc
role = await Role.get_or_none(id=default_role_id)
if not role:
raise HTTPException(status_code=400, detail="默认注册角色不存在")
exists = await UserAccount.get_or_none(username=payload.username)
if exists:
raise HTTPException(status_code=400, detail="用户名已存在")
existing_email = await UserAccount.get_or_none(email=normalized_email)
if existing_email:
raise HTTPException(status_code=400, detail="邮箱已被使用")
hashed = cls.get_password_hash(payload.password)
# 第一个用户自动成为超级管理员(不受开放注册开关影响)
if not has_users:
user = await UserAccount.create(
username=payload.username,
email=normalized_email,
full_name=payload.full_name,
hashed_password=hashed,
disabled=False,
is_admin=True,
)
return user
# 系统已初始化:按默认角色创建普通用户
user = await UserAccount.create(
username=payload.username,
email=payload.email,
email=normalized_email,
full_name=payload.full_name,
hashed_password=hashed,
disabled=False,
is_admin=False,
)
await UserRole.create(user_id=user.id, role_id=default_role_id)
return user
@classmethod
@@ -195,6 +237,13 @@ class AuthService:
detail="用户名或密码错误",
headers={"WWW-Authenticate": "Bearer"},
)
# 更新最后登录时间
db_user = await UserAccount.get_or_none(id=user.id)
if db_user:
db_user.last_login = _now()
await db_user.save(update_fields=["last_login"])
access_token_expires = timedelta(minutes=cls.access_token_expire_minutes)
access_token = await cls.create_access_token(
data={"sub": user.username}, expires_delta=access_token_expires
@@ -212,6 +261,7 @@ class AuthService:
"email": getattr(user, "email", None),
"full_name": getattr(user, "full_name", None),
"gravatar_url": gravatar_url,
"is_admin": getattr(user, "is_admin", False),
}
@classmethod
@@ -324,7 +374,7 @@ class AuthService:
@classmethod
async def _send_password_reset_email(cls, user: UserAccount, token: str) -> None:
from domain.email.service import EmailService
from domain.email import EmailService
app_domain = await ConfigService.get("APP_DOMAIN", None)
base_url = (app_domain or "http://localhost:5173").rstrip("/")

View File

@@ -16,6 +16,7 @@ class User(BaseModel):
email: str | None = None
full_name: str | None = None
disabled: bool | None = None
is_admin: bool = False
class UserInDB(User):
@@ -25,7 +26,7 @@ class UserInDB(User):
class RegisterRequest(BaseModel):
username: str
password: str
email: str | None = None
email: str
full_name: str | None = None

View File

@@ -1 +1,7 @@
from .service import BackupService
from .types import BackupData
__all__ = [
"BackupService",
"BackupData",
]

View File

@@ -1,11 +1,14 @@
import datetime
from typing import Annotated
from fastapi import APIRouter, Depends, File, Request, UploadFile
from fastapi import APIRouter, Depends, File, Form, Query, Request, UploadFile
from fastapi.responses import JSONResponse
from domain.audit import AuditAction, audit
from domain.auth.service import get_current_active_user
from domain.backup.service import BackupService
from domain.auth import User, get_current_active_user
from domain.permission import require_system_permission
from domain.permission.types import SystemPermission
from .service import BackupService
router = APIRouter(
prefix="/api/backup",
@@ -16,8 +19,13 @@ router = APIRouter(
@router.get("/export", summary="导出全站数据")
@audit(action=AuditAction.DOWNLOAD, description="导出备份")
async def export_backup(request: Request):
data = await BackupService.export_data()
@require_system_permission(SystemPermission.CONFIG_EDIT)
async def export_backup(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
sections: list[str] | None = Query(default=None),
):
data = await BackupService.export_data(sections=sections)
timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
headers = {"Content-Disposition": f"attachment; filename=foxel_backup_{timestamp}.json"}
return JSONResponse(content=data.model_dump(), headers=headers)
@@ -25,6 +33,12 @@ async def export_backup(request: Request):
@router.post("/import", summary="导入数据")
@audit(action=AuditAction.UPLOAD, description="导入备份")
async def import_backup(request: Request, file: UploadFile = File(...)):
await BackupService.import_from_bytes(file.filename, await file.read())
@require_system_permission(SystemPermission.CONFIG_EDIT)
async def import_backup(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
file: UploadFile = File(...),
mode: str = Form("replace"),
):
await BackupService.import_from_bytes(file.filename, await file.read(), mode=mode)
return {"message": "数据导入成功。"}

View File

@@ -4,8 +4,8 @@ from datetime import datetime
from fastapi import HTTPException
from tortoise.transactions import in_transaction
from domain.backup.types import BackupData
from domain.config.service import VERSION
from domain.config import VERSION
from .types import BackupData
from models.database import (
AIDefaultModel,
AIModel,
@@ -20,18 +20,64 @@ from models.database import (
class BackupService:
ALL_SECTIONS = (
"storage_adapters",
"user_accounts",
"automation_tasks",
"share_links",
"configurations",
"ai_providers",
"ai_models",
"ai_default_models",
"plugins",
)
@classmethod
async def export_data(cls) -> BackupData:
async def export_data(cls, sections: list[str] | None = None) -> BackupData:
sections = cls._normalize_sections(sections)
section_set = set(sections)
async with in_transaction():
adapters = await StorageAdapter.all().values()
users = await UserAccount.all().values()
tasks = await AutomationTask.all().values()
shares = await ShareLink.all().values()
configs = await Configuration.all().values()
providers = await AIProvider.all().values()
models = await AIModel.all().values()
default_models = await AIDefaultModel.all().values()
plugins = await Plugin.all().values()
adapters = (
await StorageAdapter.all().values()
if "storage_adapters" in section_set
else []
)
users = (
await UserAccount.all().values()
if "user_accounts" in section_set
else []
)
tasks = (
await AutomationTask.all().values()
if "automation_tasks" in section_set
else []
)
shares = (
await ShareLink.all().values()
if "share_links" in section_set
else []
)
configs = (
await Configuration.all().values()
if "configurations" in section_set
else []
)
providers = (
await AIProvider.all().values()
if "ai_providers" in section_set
else []
)
models = (
await AIModel.all().values() if "ai_models" in section_set else []
)
default_models = (
await AIDefaultModel.all().values()
if "ai_default_models" in section_set
else []
)
plugins = (
await Plugin.all().values() if "plugins" in section_set else []
)
share_links = cls._serialize_datetime_fields(
shares, ["created_at", "expires_at"]
@@ -51,6 +97,7 @@ class BackupService:
return BackupData(
version=VERSION,
sections=sections,
storage_adapters=list(adapters),
user_accounts=list(users),
automation_tasks=list(tasks),
@@ -63,106 +110,195 @@ class BackupService:
)
@classmethod
async def import_from_bytes(cls, filename: str, content: bytes) -> None:
async def import_from_bytes(
cls, filename: str, content: bytes, mode: str = "replace"
) -> None:
if not filename.endswith(".json"):
raise HTTPException(status_code=400, detail="无效的文件类型, 请上传 .json 文件")
try:
raw_data = json.loads(content)
except Exception:
raise HTTPException(status_code=400, detail="无法解析JSON文件")
await cls.import_data(BackupData(**raw_data))
await cls.import_data(BackupData(**raw_data), mode=mode)
@classmethod
async def import_data(cls, payload: BackupData) -> None:
async def import_data(cls, payload: BackupData, mode: str = "replace") -> None:
sections = cls._normalize_sections(payload.sections)
if mode not in {"replace", "merge"}:
raise HTTPException(status_code=400, detail="无效的导入模式")
share_links = (
cls._parse_datetime_fields(payload.share_links, ["created_at", "expires_at"])
if payload.share_links
else []
)
ai_providers = (
cls._parse_datetime_fields(payload.ai_providers, ["created_at", "updated_at"])
if payload.ai_providers
else []
)
ai_models = (
cls._parse_datetime_fields(payload.ai_models, ["created_at", "updated_at"])
if payload.ai_models
else []
)
ai_default_models = (
cls._parse_datetime_fields(
payload.ai_default_models, ["created_at", "updated_at"]
)
if payload.ai_default_models
else []
)
plugins = (
cls._parse_datetime_fields(payload.plugins, ["created_at", "updated_at"])
if payload.plugins
else []
)
async with in_transaction() as conn:
await ShareLink.all().using_db(conn).delete()
await AutomationTask.all().using_db(conn).delete()
await StorageAdapter.all().using_db(conn).delete()
await UserAccount.all().using_db(conn).delete()
await Configuration.all().using_db(conn).delete()
await AIDefaultModel.all().using_db(conn).delete()
await AIModel.all().using_db(conn).delete()
await AIProvider.all().using_db(conn).delete()
await Plugin.all().using_db(conn).delete()
if mode == "replace":
if "share_links" in sections:
await ShareLink.all().using_db(conn).delete()
if "automation_tasks" in sections:
await AutomationTask.all().using_db(conn).delete()
if "storage_adapters" in sections:
await StorageAdapter.all().using_db(conn).delete()
if "user_accounts" in sections:
await UserAccount.all().using_db(conn).delete()
if "configurations" in sections:
await Configuration.all().using_db(conn).delete()
if "ai_default_models" in sections:
await AIDefaultModel.all().using_db(conn).delete()
if "ai_models" in sections:
await AIModel.all().using_db(conn).delete()
if "ai_providers" in sections:
await AIProvider.all().using_db(conn).delete()
if "plugins" in sections:
await Plugin.all().using_db(conn).delete()
if payload.configurations:
await Configuration.bulk_create(
[Configuration(**config) for config in payload.configurations],
using_db=conn,
)
if "configurations" in sections and payload.configurations:
if mode == "merge":
await cls._merge_records(
Configuration, payload.configurations, conn
)
else:
await Configuration.bulk_create(
[Configuration(**config) for config in payload.configurations],
using_db=conn,
)
if payload.user_accounts:
await UserAccount.bulk_create(
[UserAccount(**user) for user in payload.user_accounts],
using_db=conn,
)
if "user_accounts" in sections and payload.user_accounts:
if mode == "merge":
await cls._merge_records(UserAccount, payload.user_accounts, conn)
else:
await UserAccount.bulk_create(
[UserAccount(**user) for user in payload.user_accounts],
using_db=conn,
)
if payload.storage_adapters:
await StorageAdapter.bulk_create(
[StorageAdapter(**adapter) for adapter in payload.storage_adapters],
using_db=conn,
)
if "storage_adapters" in sections and payload.storage_adapters:
if mode == "merge":
await cls._merge_records(
StorageAdapter, payload.storage_adapters, conn
)
else:
await StorageAdapter.bulk_create(
[StorageAdapter(**adapter) for adapter in payload.storage_adapters],
using_db=conn,
)
if payload.automation_tasks:
await AutomationTask.bulk_create(
[AutomationTask(**task) for task in payload.automation_tasks],
using_db=conn,
)
if "automation_tasks" in sections and payload.automation_tasks:
if mode == "merge":
await cls._merge_records(
AutomationTask, payload.automation_tasks, conn
)
else:
await AutomationTask.bulk_create(
[AutomationTask(**task) for task in payload.automation_tasks],
using_db=conn,
)
if payload.share_links:
await ShareLink.bulk_create(
[
ShareLink(**share)
for share in cls._parse_datetime_fields(
payload.share_links, ["created_at", "expires_at"]
)
],
using_db=conn,
)
if "share_links" in sections and share_links:
if mode == "merge":
await cls._merge_records(ShareLink, share_links, conn)
else:
await ShareLink.bulk_create(
[ShareLink(**share) for share in share_links],
using_db=conn,
)
if payload.ai_providers:
await AIProvider.bulk_create(
[
AIProvider(**item)
for item in cls._parse_datetime_fields(
payload.ai_providers, ["created_at", "updated_at"]
)
],
using_db=conn,
)
if "ai_providers" in sections and ai_providers:
if mode == "merge":
await cls._merge_records(AIProvider, ai_providers, conn)
else:
await AIProvider.bulk_create(
[AIProvider(**item) for item in ai_providers],
using_db=conn,
)
if payload.ai_models:
await AIModel.bulk_create(
[
AIModel(**item)
for item in cls._parse_datetime_fields(
payload.ai_models, ["created_at", "updated_at"]
)
],
using_db=conn,
)
if "ai_models" in sections and ai_models:
if mode == "merge":
await cls._merge_records(AIModel, ai_models, conn)
else:
await AIModel.bulk_create(
[AIModel(**item) for item in ai_models],
using_db=conn,
)
if payload.ai_default_models:
await AIDefaultModel.bulk_create(
[
AIDefaultModel(**item)
for item in cls._parse_datetime_fields(
payload.ai_default_models, ["created_at", "updated_at"]
)
],
using_db=conn,
)
if "ai_default_models" in sections and ai_default_models:
if mode == "merge":
await cls._merge_records(
AIDefaultModel, ai_default_models, conn
)
else:
await AIDefaultModel.bulk_create(
[AIDefaultModel(**item) for item in ai_default_models],
using_db=conn,
)
if payload.plugins:
await Plugin.bulk_create(
[
Plugin(**item)
for item in cls._parse_datetime_fields(
payload.plugins, ["created_at", "updated_at"]
)
],
using_db=conn,
)
if "plugins" in sections and plugins:
if mode == "merge":
await cls._merge_records(Plugin, plugins, conn)
else:
await Plugin.bulk_create(
[Plugin(**item) for item in plugins],
using_db=conn,
)
@classmethod
def _normalize_sections(cls, sections: list[str] | None) -> list[str]:
if not sections:
return list(cls.ALL_SECTIONS)
normalized = [item for item in sections if item]
invalid = [item for item in normalized if item not in cls.ALL_SECTIONS]
if invalid:
raise HTTPException(
status_code=400, detail=f"无效的备份分区: {', '.join(invalid)}"
)
result: list[str] = []
seen = set()
for item in normalized:
if item in seen:
continue
seen.add(item)
result.append(item)
return result
@staticmethod
async def _merge_records(model, records: list[dict], using_db) -> None:
for record in records:
data = dict(record)
record_id = data.pop("id", None)
if record_id is None:
await model.create(using_db=using_db, **data)
continue
updated = (
await model.filter(id=record_id)
.using_db(using_db)
.update(**data)
)
if updated == 0:
await model.create(using_db=using_db, id=record_id, **data)
@staticmethod
def _serialize_datetime_fields(

View File

@@ -5,6 +5,7 @@ from pydantic import BaseModel, Field
class BackupData(BaseModel):
version: str | None = None
sections: list[str] = Field(default_factory=list)
storage_adapters: list[dict[str, Any]] = Field(default_factory=list)
user_accounts: list[dict[str, Any]] = Field(default_factory=list)
automation_tasks: list[dict[str, Any]] = Field(default_factory=list)

10
domain/config/__init__.py Normal file
View File

@@ -0,0 +1,10 @@
from .service import ConfigService, VERSION
from .types import ConfigItem, LatestVersionInfo, SystemStatus
__all__ = [
"ConfigService",
"VERSION",
"ConfigItem",
"LatestVersionInfo",
"SystemStatus",
]

View File

@@ -4,16 +4,26 @@ from fastapi import APIRouter, Depends, Form, Request
from api.response import success
from domain.audit import AuditAction, audit
from domain.auth.service import get_current_active_user
from domain.auth.types import User
from domain.config.service import ConfigService
from domain.config.types import ConfigItem
from domain.auth import User, get_current_active_user
from domain.permission import require_system_permission
from domain.permission.types import SystemPermission
from .service import ConfigService
from .types import ConfigItem
router = APIRouter(prefix="/api/config", tags=["config"])
PUBLIC_CONFIG_KEYS = [
"THEME_MODE",
"THEME_PRIMARY_COLOR",
"THEME_BORDER_RADIUS",
"THEME_CUSTOM_TOKENS",
"THEME_CUSTOM_CSS",
]
@router.get("/")
@audit(action=AuditAction.READ, description="获取配置")
@require_system_permission(SystemPermission.CONFIG_EDIT)
async def get_config(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
@@ -25,6 +35,7 @@ async def get_config(
@router.post("/")
@audit(action=AuditAction.UPDATE, description="设置配置", body_fields=["key", "value"])
@require_system_permission(SystemPermission.CONFIG_EDIT)
async def set_config(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
@@ -37,6 +48,7 @@ async def set_config(
@router.get("/all")
@audit(action=AuditAction.READ, description="获取全部配置")
@require_system_permission(SystemPermission.CONFIG_EDIT)
async def get_all_config(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
@@ -44,6 +56,18 @@ async def get_all_config(
configs = await ConfigService.get_all()
return success(configs)
@router.get("/public")
@audit(action=AuditAction.READ, description="获取公开配置")
async def get_public_config(
request: Request,
):
data = {}
for key in PUBLIC_CONFIG_KEYS:
value = await ConfigService.get(key)
if value is not None:
data[key] = value
return success(data)
@router.get("/status")
@audit(action=AuditAction.READ, description="获取系统状态")

View File

@@ -5,12 +5,12 @@ from typing import Any, Dict, Optional
import httpx
from dotenv import load_dotenv
from domain.config.types import LatestVersionInfo, SystemStatus
from .types import LatestVersionInfo, SystemStatus
from models.database import Configuration, UserAccount
load_dotenv(dotenv_path=".env")
VERSION = "v1.5.5"
VERSION = "v1.7.4"
class ConfigService:

20
domain/email/__init__.py Normal file
View File

@@ -0,0 +1,20 @@
from .service import EmailService, EmailTemplateRenderer
from .types import (
EmailConfig,
EmailSecurity,
EmailSendPayload,
EmailTemplatePreviewPayload,
EmailTemplateUpdate,
EmailTestRequest,
)
__all__ = [
"EmailService",
"EmailTemplateRenderer",
"EmailConfig",
"EmailSecurity",
"EmailSendPayload",
"EmailTemplatePreviewPayload",
"EmailTemplateUpdate",
"EmailTestRequest",
]

View File

@@ -2,10 +2,9 @@ from fastapi import APIRouter, Depends, HTTPException, Request
from api.response import success
from domain.audit import AuditAction, audit
from domain.auth.service import get_current_active_user
from domain.auth.types import User
from domain.email.service import EmailService, EmailTemplateRenderer
from domain.email.types import (
from domain.auth import User, get_current_active_user
from .service import EmailService, EmailTemplateRenderer
from .types import (
EmailTemplatePreviewPayload,
EmailTemplateUpdate,
EmailTestRequest,

View File

@@ -7,8 +7,8 @@ from pathlib import Path
from string import Template
from typing import Any, Dict, List, Optional
from domain.config.service import ConfigService
from domain.email.types import EmailConfig, EmailSecurity, EmailSendPayload
from domain.config import ConfigService
from .types import EmailConfig, EmailSecurity, EmailSendPayload
class EmailTemplateRenderer:
@@ -104,7 +104,7 @@ class EmailService:
template: str,
context: Optional[Dict[str, Any]] = None,
):
from domain.tasks.task_queue import TaskProgress, task_queue_service
from domain.tasks import TaskProgress, task_queue_service
payload = EmailSendPayload(
recipients=recipients,
@@ -126,7 +126,7 @@ class EmailService:
@classmethod
async def send_from_task(cls, task_id: str, data: Dict[str, Any]):
from domain.tasks.task_queue import TaskProgress, task_queue_service
from domain.tasks import TaskProgress, task_queue_service
payload = EmailSendPayload(**data)

View File

@@ -0,0 +1,7 @@
from .service import OfflineDownloadService
from .types import OfflineDownloadCreate
__all__ = [
"OfflineDownloadService",
"OfflineDownloadCreate",
]

View File

@@ -4,10 +4,11 @@ from fastapi import APIRouter, Depends, Request
from api.response import success
from domain.audit import AuditAction, audit
from domain.auth.service import get_current_active_user
from domain.auth.types import User
from domain.offline_downloads.service import OfflineDownloadService
from domain.offline_downloads.types import OfflineDownloadCreate
from domain.auth import User, get_current_active_user
from domain.permission import require_path_permission
from domain.permission.types import PathAction
from .service import OfflineDownloadService
from .types import OfflineDownloadCreate
CurrentUser = Annotated[User, Depends(get_current_active_user)]
@@ -23,6 +24,7 @@ router = APIRouter(
description="创建离线下载任务",
body_fields=["url", "dest_dir", "filename"],
)
@require_path_permission(PathAction.WRITE, "payload.dest_dir")
async def create_offline_download(request: Request, payload: OfflineDownloadCreate, current_user: CurrentUser):
data = await OfflineDownloadService.create_download(payload, current_user)
return success(data)

View File

@@ -7,11 +7,10 @@ import aiofiles
import aiohttp
from fastapi import Depends, HTTPException
from domain.auth.service import get_current_active_user
from domain.auth.types import User
from domain.offline_downloads.types import OfflineDownloadCreate
from domain.virtual_fs.service import VirtualFSService
from domain.tasks.task_queue import Task, TaskProgress, task_queue_service
from domain.auth import User, get_current_active_user
from domain.tasks import Task, TaskProgress, task_queue_service
from domain.virtual_fs import VirtualFSService
from .types import OfflineDownloadCreate
class OfflineDownloadService:

View File

@@ -0,0 +1,10 @@
from .service import PermissionService
from .matcher import PathMatcher
from .decorator import require_path_permission, require_system_permission
__all__ = [
"PermissionService",
"PathMatcher",
"require_system_permission",
"require_path_permission",
]

41
domain/permission/api.py Normal file
View File

@@ -0,0 +1,41 @@
from typing import Annotated
from fastapi import APIRouter, Depends
from domain.auth.service import get_current_active_user
from domain.auth.types import User
from .service import PermissionService
from .types import (
PathPermissionCheck,
PathPermissionResult,
UserPermissions,
PermissionInfo,
)
router = APIRouter(prefix="/api", tags=["permissions"])
@router.get("/permissions", response_model=list[PermissionInfo])
async def get_all_permissions(
current_user: Annotated[User, Depends(get_current_active_user)]
) -> list[PermissionInfo]:
"""获取所有权限定义"""
return await PermissionService.get_all_permissions()
@router.get("/me/permissions", response_model=UserPermissions)
async def get_my_permissions(
current_user: Annotated[User, Depends(get_current_active_user)]
) -> UserPermissions:
"""获取当前用户的有效权限"""
return await PermissionService.get_user_permissions(current_user.id)
@router.post("/me/check-path", response_model=PathPermissionResult)
async def check_path_permission(
data: PathPermissionCheck,
current_user: Annotated[User, Depends(get_current_active_user)],
) -> PathPermissionResult:
"""检查当前用户对某路径的权限"""
return await PermissionService.check_path_permission_detailed(
current_user.id, data.path, data.action
)

View File

@@ -0,0 +1,103 @@
import inspect
from functools import wraps
from typing import Any, Iterable, Mapping
from fastapi import HTTPException
from .service import PermissionService
def _get_user_id(user: Any) -> int | None:
if user is None:
return None
if isinstance(user, Mapping):
raw = user.get("id") or user.get("user_id")
return int(raw) if isinstance(raw, int) else None
value = getattr(user, "id", None) or getattr(user, "user_id", None)
return int(value) if isinstance(value, int) else None
def _resolve_expr(bound_args: Mapping[str, Any], expr: str) -> Any:
parts = [p for p in (expr or "").split(".") if p]
if not parts:
return None
cur: Any = bound_args.get(parts[0])
for part in parts[1:]:
if cur is None:
return None
if isinstance(cur, Mapping):
cur = cur.get(part)
else:
cur = getattr(cur, part, None)
return cur
def require_system_permission(permission_code: str, *, user_kw: str = "current_user"):
"""
在 endpoint 内部执行系统/适配器权限校验。
设计目标:
- 保持和当前“在函数体内手写 require_*”一致的行为:失败会被外层 @audit 捕获记录
- 不依赖 FastAPI dependencies避免权限失败发生在 endpoint 之外)
"""
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
bound = inspect.signature(func).bind_partial(*args, **kwargs)
bound.apply_defaults()
user_id = _get_user_id(bound.arguments.get(user_kw))
if user_id is None:
raise HTTPException(status_code=401, detail="Unauthorized")
await PermissionService.require_system_permission(user_id, permission_code)
result = func(*args, **kwargs)
if inspect.isawaitable(result):
result = await result
return result
return wrapper
return decorator
def require_path_permission(action: str, path_expr: str, *, user_kw: str = "current_user"):
"""
在 endpoint 内部执行路径权限校验。
path_expr 支持:
- "full_path"
- "body.src" / "body.dst"
- "payload.paths"list[str] 会逐个检查)
"""
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
bound = inspect.signature(func).bind_partial(*args, **kwargs)
bound.apply_defaults()
user_id = _get_user_id(bound.arguments.get(user_kw))
if user_id is None:
raise HTTPException(status_code=401, detail="Unauthorized")
value = _resolve_expr(bound.arguments, path_expr)
paths: Iterable[Any]
if isinstance(value, (list, tuple, set)):
paths = value
else:
paths = [value]
for path in paths:
if path is None:
raise HTTPException(status_code=400, detail="Missing path")
await PermissionService.require_path_permission(user_id, str(path), action)
result = func(*args, **kwargs)
if inspect.isawaitable(result):
result = await result
return result
return wrapper
return decorator

View File

@@ -0,0 +1,158 @@
import re
import fnmatch
from functools import lru_cache
class PathMatcher:
"""路径匹配器,支持精确匹配、通配符匹配和正则匹配"""
@classmethod
def normalize_path(cls, path: str) -> str:
"""规范化路径"""
if not path:
return "/"
# 确保以 / 开头
if not path.startswith("/"):
path = "/" + path
# 移除末尾的 /(除了根路径)
if path != "/" and path.endswith("/"):
path = path.rstrip("/")
return path
@classmethod
def get_parent_path(cls, path: str) -> str | None:
"""获取父目录路径"""
path = cls.normalize_path(path)
if path == "/":
return None
parent = "/".join(path.rsplit("/", 1)[:-1])
return parent if parent else "/"
@classmethod
def match_pattern(cls, path: str, pattern: str, is_regex: bool = False) -> bool:
"""
匹配路径和模式
Args:
path: 要匹配的路径
pattern: 匹配模式
is_regex: 是否为正则表达式
Returns:
是否匹配
"""
path = cls.normalize_path(path)
pattern = cls.normalize_path(pattern)
if is_regex:
return cls._match_regex(path, pattern)
else:
return cls._match_glob(path, pattern)
@classmethod
def _match_regex(cls, path: str, pattern: str) -> bool:
"""正则表达式匹配"""
try:
# 限制正则表达式的复杂度,防止 ReDoS 攻击
if len(pattern) > 500:
return False
regex = re.compile(pattern)
return bool(regex.match(path))
except re.error:
return False
@classmethod
def _match_glob(cls, path: str, pattern: str) -> bool:
"""
通配符匹配
支持的语法:
- * : 匹配单层目录中的任意字符
- ** : 匹配任意层级目录
- ? : 匹配单个字符
"""
# 精确匹配
if pattern == path:
return True
# 处理 ** 通配符
if "**" in pattern:
return cls._match_double_star(path, pattern)
# 使用 fnmatch 进行标准通配符匹配
return fnmatch.fnmatch(path, pattern)
@classmethod
def _match_double_star(cls, path: str, pattern: str) -> bool:
"""处理 ** 通配符匹配"""
# 将 ** 替换为特殊标记
parts = pattern.split("**")
if len(parts) == 2:
prefix, suffix = parts
# 移除 prefix 末尾的 / 和 suffix 开头的 /
prefix = prefix.rstrip("/") if prefix else ""
suffix = suffix.lstrip("/") if suffix else ""
# 检查前缀匹配
if prefix and not path.startswith(prefix):
return False
# 如果没有后缀,只需要前缀匹配
if not suffix:
return True
# 检查后缀匹配
remaining = path[len(prefix):].lstrip("/") if prefix else path.lstrip("/")
# 后缀可以出现在任意位置
if "*" in suffix or "?" in suffix:
# 后缀包含通配符,逐层检查
path_parts = remaining.split("/")
suffix_parts = suffix.split("/")
# 简化处理:检查路径的最后几层是否与后缀匹配
if len(path_parts) >= len(suffix_parts):
tail = "/".join(path_parts[-len(suffix_parts):])
return fnmatch.fnmatch(tail, suffix)
return False
else:
# 后缀是精确字符串
return remaining.endswith(suffix) or ("/" + suffix) in remaining or remaining == suffix
# 多个 ** 的情况,使用简化匹配
regex_pattern = pattern.replace("**", ".*").replace("*", "[^/]*").replace("?", ".")
try:
return bool(re.match(f"^{regex_pattern}$", path))
except re.error:
return False
@classmethod
def get_pattern_specificity(cls, pattern: str, is_regex: bool = False) -> int:
"""
计算模式的具体程度(用于优先级排序)
返回值越大表示模式越具体
"""
pattern = cls.normalize_path(pattern)
if is_regex:
# 正则表达式具体程度较低
return len(pattern) // 2
# 精确路径最具体
if "*" not in pattern and "?" not in pattern:
return len(pattern) * 10
# 计算非通配符部分的长度
specificity = 0
parts = pattern.split("/")
for part in parts:
if part == "**":
specificity += 1
elif "*" in part or "?" in part:
specificity += 5
else:
specificity += 10
return specificity

View File

@@ -0,0 +1,340 @@
from typing import List, Optional
from fastapi import HTTPException
from models.database import (
UserAccount,
UserRole,
RolePermission,
PathRule,
)
from .matcher import PathMatcher
from .types import (
PathAction,
PathRuleInfo,
PathPermissionResult,
UserPermissions,
PermissionInfo,
PERMISSION_DEFINITIONS,
)
class PermissionService:
"""权限检查服务"""
# 权限检查结果缓存(简单的内存缓存)
_cache: dict[str, tuple[bool, float]] = {}
_cache_ttl = 300 # 5分钟缓存
@classmethod
async def check_path_permission(
cls, user_id: int, path: str, action: str
) -> bool:
"""
检查用户对路径的操作权限
Args:
user_id: 用户ID
path: 要检查的路径
action: 操作类型 (read/write/delete/share)
Returns:
是否有权限
"""
import time
# 检查缓存
cache_key = f"{user_id}:{path}:{action}"
if cache_key in cls._cache:
result, timestamp = cls._cache[cache_key]
if time.time() - timestamp < cls._cache_ttl:
return result
# 获取用户
user = await UserAccount.get_or_none(id=user_id)
if not user:
return False
# 超级管理员直接放行
if user.is_admin:
cls._cache[cache_key] = (True, time.time())
return True
# 获取用户所有角色
user_roles = await UserRole.filter(user_id=user_id).prefetch_related("role")
role_ids = [ur.role_id for ur in user_roles]
if not role_ids:
cls._cache[cache_key] = (False, time.time())
return False
# 获取所有角色的路径规则
path_rules = await PathRule.filter(role_id__in=role_ids).order_by("-priority")
# 规范化路径
normalized_path = PathMatcher.normalize_path(path)
# 按优先级和具体程度匹配
result = cls._match_path_rules(normalized_path, action, list(path_rules))
# 如果没有匹配到规则,检查父目录(继承)
if result is None:
parent_path = PathMatcher.get_parent_path(normalized_path)
if parent_path:
result = await cls.check_path_permission(user_id, parent_path, action)
else:
result = False # 默认拒绝
cls._cache[cache_key] = (result, time.time())
return result
@classmethod
def _match_path_rules(
cls, path: str, action: str, rules: List[PathRule]
) -> Optional[bool]:
"""
匹配路径规则
Returns:
True/False 表示明确的权限结果None 表示没有匹配到规则
"""
# 按优先级和具体程度排序
sorted_rules = sorted(
rules,
key=lambda r: (
r.priority,
PathMatcher.get_pattern_specificity(r.path_pattern, r.is_regex),
),
reverse=True,
)
for rule in sorted_rules:
if PathMatcher.match_pattern(path, rule.path_pattern, rule.is_regex):
# 匹配到规则,检查具体操作权限
if action == PathAction.READ:
return rule.can_read
elif action == PathAction.WRITE:
return rule.can_write
elif action == PathAction.DELETE:
return rule.can_delete
elif action == PathAction.SHARE:
return rule.can_share
else:
return False
return None
@classmethod
async def check_system_permission(cls, user_id: int, permission_code: str) -> bool:
"""检查用户的系统/适配器权限"""
# 获取用户
user = await UserAccount.get_or_none(id=user_id)
if not user:
return False
# 超级管理员直接放行
if user.is_admin:
return True
# 获取用户所有角色
user_roles = await UserRole.filter(user_id=user_id)
role_ids = [ur.role_id for ur in user_roles]
if not role_ids:
return False
role_permission = await RolePermission.filter(
role_id__in=role_ids, permission_code=permission_code
).first()
return role_permission is not None
@classmethod
async def require_path_permission(
cls, user_id: int, path: str, action: str
) -> None:
"""要求用户具有路径权限,否则抛出 403"""
if not await cls.check_path_permission(user_id, path, action):
raise HTTPException(403, detail=f"没有权限执行此操作: {action}")
@classmethod
async def require_system_permission(
cls, user_id: int, permission_code: str
) -> None:
"""要求用户具有系统权限,否则抛出 403"""
if not await cls.check_system_permission(user_id, permission_code):
raise HTTPException(403, detail=f"没有权限: {permission_code}")
@classmethod
async def get_user_permissions(cls, user_id: int) -> UserPermissions:
"""获取用户的所有权限"""
user = await UserAccount.get_or_none(id=user_id)
if not user:
raise HTTPException(404, detail="用户不存在")
# 超级管理员拥有所有权限
if user.is_admin:
all_permission_codes = [item["code"] for item in PERMISSION_DEFINITIONS]
all_path_rules = await PathRule.all()
return UserPermissions(
user_id=user_id,
is_admin=True,
permissions=all_permission_codes,
path_rules=[
PathRuleInfo(
id=r.id,
role_id=r.role_id,
path_pattern=r.path_pattern,
is_regex=r.is_regex,
can_read=r.can_read,
can_write=r.can_write,
can_delete=r.can_delete,
can_share=r.can_share,
priority=r.priority,
created_at=r.created_at,
)
for r in all_path_rules
],
)
# 获取用户角色
user_roles = await UserRole.filter(user_id=user_id)
role_ids = [ur.role_id for ur in user_roles]
# 获取权限
permissions = []
if role_ids:
role_permissions = await RolePermission.filter(role_id__in=role_ids)
permissions = sorted(set(rp.permission_code for rp in role_permissions))
# 获取路径规则
path_rules = []
if role_ids:
rules = await PathRule.filter(role_id__in=role_ids)
path_rules = [
PathRuleInfo(
id=r.id,
role_id=r.role_id,
path_pattern=r.path_pattern,
is_regex=r.is_regex,
can_read=r.can_read,
can_write=r.can_write,
can_delete=r.can_delete,
can_share=r.can_share,
priority=r.priority,
created_at=r.created_at,
)
for r in rules
]
return UserPermissions(
user_id=user_id,
is_admin=False,
permissions=permissions,
path_rules=path_rules,
)
@classmethod
async def get_all_permissions(cls) -> List[PermissionInfo]:
"""获取所有权限定义"""
return [
PermissionInfo(
code=item["code"],
name=item["name"],
category=item["category"],
description=item.get("description"),
)
for item in PERMISSION_DEFINITIONS
]
@classmethod
async def check_path_permission_detailed(
cls, user_id: int, path: str, action: str
) -> PathPermissionResult:
"""检查路径权限并返回详细结果"""
user = await UserAccount.get_or_none(id=user_id)
if not user:
return PathPermissionResult(path=path, action=action, allowed=False)
# 超级管理员
if user.is_admin:
return PathPermissionResult(path=path, action=action, allowed=True)
# 获取用户角色
user_roles = await UserRole.filter(user_id=user_id)
role_ids = [ur.role_id for ur in user_roles]
if not role_ids:
return PathPermissionResult(path=path, action=action, allowed=False)
# 获取路径规则
path_rules = await PathRule.filter(role_id__in=role_ids).order_by("-priority")
normalized_path = PathMatcher.normalize_path(path)
# 查找匹配的规则
matched_rule = None
for rule in sorted(
path_rules,
key=lambda r: (
r.priority,
PathMatcher.get_pattern_specificity(r.path_pattern, r.is_regex),
),
reverse=True,
):
if PathMatcher.match_pattern(
normalized_path, rule.path_pattern, rule.is_regex
):
matched_rule = rule
break
# 检查权限
allowed = False
if matched_rule:
if action == PathAction.READ:
allowed = matched_rule.can_read
elif action == PathAction.WRITE:
allowed = matched_rule.can_write
elif action == PathAction.DELETE:
allowed = matched_rule.can_delete
elif action == PathAction.SHARE:
allowed = matched_rule.can_share
rule_info = None
if matched_rule:
rule_info = PathRuleInfo(
id=matched_rule.id,
role_id=matched_rule.role_id,
path_pattern=matched_rule.path_pattern,
is_regex=matched_rule.is_regex,
can_read=matched_rule.can_read,
can_write=matched_rule.can_write,
can_delete=matched_rule.can_delete,
can_share=matched_rule.can_share,
priority=matched_rule.priority,
created_at=matched_rule.created_at,
)
return PathPermissionResult(
path=path, action=action, allowed=allowed, matched_rule=rule_info
)
@classmethod
def clear_cache(cls, user_id: int | None = None) -> None:
"""清除权限缓存"""
if user_id is None:
cls._cache.clear()
else:
# 清除特定用户的缓存
keys_to_delete = [k for k in cls._cache if k.startswith(f"{user_id}:")]
for k in keys_to_delete:
del cls._cache[k]
@classmethod
async def filter_paths_by_permission(
cls, user_id: int, paths: List[str], action: str
) -> List[str]:
"""过滤出用户有权限的路径列表"""
result = []
for path in paths:
if await cls.check_path_permission(user_id, path, action):
result.append(path)
return result

107
domain/permission/types.py Normal file
View File

@@ -0,0 +1,107 @@
from pydantic import BaseModel
from datetime import datetime
# 权限操作类型
class PathAction:
READ = "read"
WRITE = "write"
DELETE = "delete"
SHARE = "share"
# 系统权限代码
class SystemPermission:
USER_CREATE = "system.user.create"
USER_EDIT = "system.user.edit"
USER_DELETE = "system.user.delete"
USER_LIST = "system.user.list"
ROLE_MANAGE = "system.role.manage"
CONFIG_EDIT = "system.config.edit"
AUDIT_VIEW = "system.audit.view"
# 适配器权限代码
class AdapterPermission:
CREATE = "adapter.create"
EDIT = "adapter.edit"
DELETE = "adapter.delete"
LIST = "adapter.list"
# 所有权限定义
PERMISSION_DEFINITIONS = [
# 系统权限
{"code": SystemPermission.USER_CREATE, "name": "创建用户", "category": "system", "description": "允许创建新用户"},
{"code": SystemPermission.USER_EDIT, "name": "编辑用户", "category": "system", "description": "允许编辑用户信息"},
{"code": SystemPermission.USER_DELETE, "name": "删除用户", "category": "system", "description": "允许删除用户"},
{"code": SystemPermission.USER_LIST, "name": "查看用户列表", "category": "system", "description": "允许查看用户列表"},
{"code": SystemPermission.ROLE_MANAGE, "name": "管理角色和权限", "category": "system", "description": "允许管理角色和权限配置"},
{"code": SystemPermission.CONFIG_EDIT, "name": "修改系统配置", "category": "system", "description": "允许修改系统配置"},
{"code": SystemPermission.AUDIT_VIEW, "name": "查看审计日志", "category": "system", "description": "允许查看审计日志"},
# 适配器权限
{"code": AdapterPermission.CREATE, "name": "创建存储适配器", "category": "adapter", "description": "允许创建存储适配器"},
{"code": AdapterPermission.EDIT, "name": "编辑存储适配器", "category": "adapter", "description": "允许编辑存储适配器"},
{"code": AdapterPermission.DELETE, "name": "删除存储适配器", "category": "adapter", "description": "允许删除存储适配器"},
{"code": AdapterPermission.LIST, "name": "查看存储适配器列表", "category": "adapter", "description": "允许查看存储适配器列表"},
]
# Pydantic 模型
class PermissionInfo(BaseModel):
code: str
name: str
category: str
description: str | None = None
class PathRuleInfo(BaseModel):
id: int
role_id: int
path_pattern: str
is_regex: bool
can_read: bool
can_write: bool
can_delete: bool
can_share: bool
priority: int
created_at: datetime
class PathRuleCreate(BaseModel):
path_pattern: str
is_regex: bool = False
can_read: bool = True
can_write: bool = False
can_delete: bool = False
can_share: bool = False
priority: int = 0
class PathRuleUpdate(BaseModel):
path_pattern: str | None = None
is_regex: bool | None = None
can_read: bool | None = None
can_write: bool | None = None
can_delete: bool | None = None
can_share: bool | None = None
priority: int | None = None
class PathPermissionCheck(BaseModel):
path: str
action: str
class PathPermissionResult(BaseModel):
path: str
action: str
allowed: bool
matched_rule: PathRuleInfo | None = None
class UserPermissions(BaseModel):
user_id: int
is_admin: bool
permissions: list[str] # 系统/适配器权限代码列表
path_rules: list[PathRuleInfo] # 路径权限规则

View File

@@ -1 +1,17 @@
"""
Foxel 插件系统
提供 .foxpkg 插件包的安装、管理和运行时加载功能。
"""
from .loader import PluginLoadError, PluginLoader
from .service import PluginService
from .startup import init_plugins, load_installed_plugins
__all__ = [
"PluginLoader",
"PluginLoadError",
"PluginService",
"init_plugins",
"load_installed_plugins",
]

View File

@@ -1,76 +1,131 @@
from typing import List
"""
插件管理 API 路由
"""
from fastapi import APIRouter, Body, Request
from typing import Annotated, List
from fastapi import APIRouter, Depends, File, Request, UploadFile
from fastapi.responses import FileResponse
from domain.audit import AuditAction, audit
from domain.plugins.service import PluginService
from domain.plugins.routes import video_player as video_player_routes
from domain.plugins.types import PluginCreate, PluginManifestUpdate, PluginOut
from domain.auth import User, get_current_active_user
from domain.permission import require_system_permission
from domain.permission.types import SystemPermission
from .service import PluginService
from .types import (
PluginInstallResult,
PluginOut,
)
router = APIRouter(prefix="/api/plugins", tags=["plugins"])
router.include_router(video_player_routes.router)
@router.post("", response_model=PluginOut)
@audit(
action=AuditAction.CREATE,
description="创建插件",
body_fields=["url", "enabled"],
)
async def create_plugin(request: Request, payload: PluginCreate):
return await PluginService.create(payload)
# ========== 安装 ==========
@router.post("/install", response_model=PluginInstallResult)
@audit(action=AuditAction.CREATE, description="安装插件包")
@require_system_permission(SystemPermission.ROLE_MANAGE)
async def install_plugin(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
file: UploadFile = File(...),
):
"""
安装 .foxpkg 插件包
上传 .foxpkg 文件进行安装。
"""
content = await file.read()
return await PluginService.install_package(content, file.filename or "plugin.foxpkg")
# ========== 插件列表和详情 ==========
@router.get("", response_model=List[PluginOut])
@audit(action=AuditAction.READ, description="获取插件列表")
async def list_plugins(request: Request):
async def list_plugins(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
):
"""获取已安装的插件列表"""
return await PluginService.list_plugins()
@router.delete("/{plugin_id}")
@audit(action=AuditAction.DELETE, description="删除插件")
async def delete_plugin(request: Request, plugin_id: int):
await PluginService.delete(plugin_id)
@router.get("/{key_or_id}", response_model=PluginOut)
@audit(action=AuditAction.READ, description="获取插件详情")
async def get_plugin(
request: Request,
key_or_id: str,
current_user: Annotated[User, Depends(get_current_active_user)],
):
"""获取单个插件详情"""
return await PluginService.get_plugin(key_or_id)
# ========== 插件管理 ==========
@router.delete("/{key_or_id}")
@audit(action=AuditAction.DELETE, description="卸载插件")
@require_system_permission(SystemPermission.ROLE_MANAGE)
async def delete_plugin(
request: Request,
key_or_id: str,
current_user: Annotated[User, Depends(get_current_active_user)],
):
"""卸载插件"""
await PluginService.delete(key_or_id)
return {"code": 0, "msg": "ok"}
@router.put("/{plugin_id}", response_model=PluginOut)
@audit(
action=AuditAction.UPDATE,
description="更新插件",
body_fields=["url", "enabled"],
)
async def update_plugin(request: Request, plugin_id: int, payload: PluginCreate):
return await PluginService.update(plugin_id, payload)
# ========== 插件资源 ==========
@router.post("/{plugin_id}/metadata", response_model=PluginOut)
@audit(
action=AuditAction.UPDATE,
description="更新插件 manifest",
body_fields=[
"key",
"name",
"version",
"open_app",
"supported_exts",
"default_bounds",
"default_maximized",
"icon",
"description",
"author",
"website",
"github",
],
)
async def update_manifest(
request: Request, plugin_id: int, manifest: PluginManifestUpdate = Body(...)
):
return await PluginService.update_manifest(plugin_id, manifest)
@router.get("/{key_or_id}/bundle.js")
async def get_bundle(request: Request, key_or_id: str):
"""获取插件前端 bundle"""
path = await PluginService.get_bundle_path(key_or_id)
v = (request.query_params.get("v") or "").strip()
cache_control = "public, max-age=31536000, immutable" if v else "no-cache"
return FileResponse(
path,
media_type="application/javascript",
headers={"Cache-Control": cache_control},
)
@router.get("/{plugin_id}/bundle.js")
async def get_bundle(request: Request, plugin_id: int):
path = await PluginService.get_bundle_path(plugin_id)
return FileResponse(path, media_type="application/javascript", headers={"Cache-Control": "no-store"})
@router.get("/{key}/assets/{asset_path:path}")
async def get_asset(request: Request, key: str, asset_path: str):
"""获取插件静态资源"""
path = await PluginService.get_asset_path(key, asset_path)
# 根据扩展名确定 MIME 类型
ext = path.suffix.lower()
media_types = {
".js": "application/javascript",
".css": "text/css",
".json": "application/json",
".svg": "image/svg+xml",
".png": "image/png",
".jpg": "image/jpeg",
".jpeg": "image/jpeg",
".gif": "image/gif",
".webp": "image/webp",
".ico": "image/x-icon",
".woff": "font/woff",
".woff2": "font/woff2",
".ttf": "font/ttf",
".eot": "application/vnd.ms-fontobject",
".html": "text/html",
".txt": "text/plain",
".md": "text/markdown",
}
media_type = media_types.get(ext, "application/octet-stream")
return FileResponse(
path,
media_type=media_type,
headers={"Cache-Control": "public, max-age=3600"},
)

449
domain/plugins/loader.py Normal file
View File

@@ -0,0 +1,449 @@
"""
插件加载器模块
负责:
1. .foxpkg 解包和验证
2. 插件文件部署
3. 后端路由动态加载
4. 处理器动态注册
"""
import io
import json
import shutil
import sys
import zipfile
from importlib.util import module_from_spec, spec_from_file_location
from pathlib import Path
from types import ModuleType
from typing import Any, Dict, List, Optional, Tuple
from fastapi import APIRouter
from .types import (
ManifestProcessorConfig,
ManifestRouteConfig,
PluginManifest,
)
class PluginLoadError(Exception):
"""插件加载错误"""
pass
class PluginLoader:
"""插件加载器"""
PLUGINS_ROOT = Path("data/plugins")
# 已加载的插件模块缓存
_loaded_modules: Dict[str, ModuleType] = {}
# 已挂载的路由追踪
_mounted_routers: Dict[str, List[APIRouter]] = {}
@classmethod
def get_plugin_dir(cls, plugin_key: str) -> Path:
"""获取插件目录"""
return cls.PLUGINS_ROOT / plugin_key
@classmethod
def get_manifest_path(cls, plugin_key: str) -> Path:
"""获取插件 manifest.json 路径"""
return cls.get_plugin_dir(plugin_key) / "manifest.json"
@classmethod
def get_frontend_bundle_path(cls, plugin_key: str, entry: Optional[str] = None) -> Path:
"""获取前端 bundle 路径"""
plugin_dir = cls.get_plugin_dir(plugin_key)
if entry:
return plugin_dir / entry
# 默认位置
return plugin_dir / "frontend" / "index.js"
@classmethod
def get_asset_path(cls, plugin_key: str, asset_path: str) -> Path:
"""获取静态资源路径"""
return cls.get_plugin_dir(plugin_key) / asset_path
# ========== 解包和验证 ==========
@classmethod
def validate_manifest(cls, manifest_data: Dict[str, Any]) -> Tuple[bool, List[str]]:
"""验证 manifest 数据"""
errors: List[str] = []
# 必需字段检查
if not manifest_data.get("key"):
errors.append("manifest 缺少必需字段: key")
if not manifest_data.get("name"):
errors.append("manifest 缺少必需字段: name")
# key 格式检查Java 命名空间格式)
key = manifest_data.get("key", "")
if key:
import re
# 格式: com.example.plugin (至少两级,每级以小写字母开头,可包含小写字母和数字)
if not re.match(r"^[a-z][a-z0-9]*(\.[a-z][a-z0-9]*)+$", key):
errors.append(
"key 格式无效:必须使用命名空间格式(如 com.example.plugin"
"每个部分以小写字母开头,只能包含小写字母和数字,至少两级"
)
# 版本格式检查(简单检查)
version = manifest_data.get("version", "")
if version and not isinstance(version, str):
errors.append("version 必须是字符串")
# 验证 frontend 配置
frontend = manifest_data.get("frontend")
if frontend and isinstance(frontend, dict):
if frontend.get("entry") and not isinstance(frontend["entry"], str):
errors.append("frontend.entry 必须是字符串")
if frontend.get("styles") is not None:
if not isinstance(frontend["styles"], list) or not all(
isinstance(x, str) for x in frontend["styles"]
):
errors.append("frontend.styles 必须是字符串数组")
supported_exts = frontend.get("supportedExts") or frontend.get("supported_exts")
if supported_exts and not isinstance(supported_exts, list):
errors.append("frontend.supportedExts 必须是数组")
use_system_window = frontend.get("useSystemWindow") or frontend.get("use_system_window")
if use_system_window is not None and not isinstance(use_system_window, bool):
errors.append("frontend.useSystemWindow 必须是布尔值")
# 验证 backend 配置
backend = manifest_data.get("backend")
if backend and isinstance(backend, dict):
routes = backend.get("routes", [])
if routes:
for i, route in enumerate(routes):
if not route.get("module"):
errors.append(f"backend.routes[{i}] 缺少 module")
if not route.get("prefix"):
errors.append(f"backend.routes[{i}] 缺少 prefix")
processors = backend.get("processors", [])
if processors:
for i, proc in enumerate(processors):
if not proc.get("module"):
errors.append(f"backend.processors[{i}] 缺少 module")
if not proc.get("type"):
errors.append(f"backend.processors[{i}] 缺少 type")
return len(errors) == 0, errors
@classmethod
def unpack_foxpkg(
cls, file_content: bytes, target_key: Optional[str] = None
) -> Tuple[PluginManifest, Path]:
"""
解包 .foxpkg 文件
Args:
file_content: .foxpkg 文件内容
target_key: 可选,指定安装的插件 key覆盖 manifest 中的 key
Returns:
(manifest, plugin_dir) 元组
Raises:
PluginLoadError: 解包或验证失败
"""
try:
with zipfile.ZipFile(io.BytesIO(file_content)) as zf:
# 读取 manifest.json
try:
manifest_bytes = zf.read("manifest.json")
except KeyError:
raise PluginLoadError("插件包缺少 manifest.json")
try:
manifest_data = json.loads(manifest_bytes.decode("utf-8"))
except json.JSONDecodeError as e:
raise PluginLoadError(f"manifest.json 解析失败: {e}")
# 验证 manifest
valid, errors = cls.validate_manifest(manifest_data)
if not valid:
raise PluginLoadError(f"manifest 验证失败: {'; '.join(errors)}")
# 解析 manifest
try:
manifest = PluginManifest.model_validate(manifest_data)
except Exception as e:
raise PluginLoadError(f"manifest 解析失败: {e}")
# 确定插件 key
plugin_key = target_key or manifest.key
# 验证包内文件
cls._validate_package_files(zf, manifest)
# 部署文件
target_dir = cls.PLUGINS_ROOT / plugin_key
if target_dir.exists():
# 备份旧版本
backup_dir = cls.PLUGINS_ROOT / f"{plugin_key}.backup"
if backup_dir.exists():
shutil.rmtree(backup_dir)
shutil.move(str(target_dir), str(backup_dir))
target_dir.mkdir(parents=True, exist_ok=True)
try:
zf.extractall(target_dir)
except Exception as e:
# 恢复备份
if (cls.PLUGINS_ROOT / f"{plugin_key}.backup").exists():
shutil.rmtree(target_dir, ignore_errors=True)
shutil.move(str(cls.PLUGINS_ROOT / f"{plugin_key}.backup"), str(target_dir))
raise PluginLoadError(f"文件解压失败: {e}")
# 清理备份
backup_dir = cls.PLUGINS_ROOT / f"{plugin_key}.backup"
if backup_dir.exists():
shutil.rmtree(backup_dir, ignore_errors=True)
return manifest, target_dir
except zipfile.BadZipFile:
raise PluginLoadError("无效的插件包格式(非 ZIP 文件)")
@classmethod
def _validate_package_files(cls, zf: zipfile.ZipFile, manifest: PluginManifest) -> None:
"""验证包内文件是否完整"""
file_list = zf.namelist()
# 检查前端入口
if manifest.frontend and manifest.frontend.entry:
if manifest.frontend.entry not in file_list:
raise PluginLoadError(f"前端入口文件不存在: {manifest.frontend.entry}")
# 检查后端模块
if manifest.backend:
if manifest.backend.routes:
for route in manifest.backend.routes:
if route.module not in file_list:
raise PluginLoadError(f"路由模块不存在: {route.module}")
if manifest.backend.processors:
for proc in manifest.backend.processors:
if proc.module not in file_list:
raise PluginLoadError(f"处理器模块不存在: {proc.module}")
# ========== 路由动态加载 ==========
@classmethod
def load_route_module(cls, plugin_key: str, route_config: ManifestRouteConfig) -> APIRouter:
"""
动态加载插件路由模块
Args:
plugin_key: 插件标识
route_config: 路由配置
Returns:
加载的 APIRouter
"""
module_path = cls.get_plugin_dir(plugin_key) / route_config.module
if not module_path.exists():
raise PluginLoadError(f"路由模块不存在: {module_path}")
module_name = f"foxel_plugin_{plugin_key}_route_{module_path.stem}"
try:
spec = spec_from_file_location(module_name, module_path)
if spec is None or spec.loader is None:
raise PluginLoadError(f"无法加载路由模块: {module_path}")
module = module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
# 缓存模块
cls._loaded_modules[f"{plugin_key}:route:{route_config.module}"] = module
# 获取 router
router = getattr(module, "router", None)
if router is None:
raise PluginLoadError(f"路由模块缺少 'router' 对象: {module_path}")
if not isinstance(router, APIRouter):
raise PluginLoadError(f"'router' 不是有效的 APIRouter 实例: {module_path}")
# 创建包装路由器添加前缀
wrapper = APIRouter(prefix=route_config.prefix, tags=route_config.tags or [])
wrapper.include_router(router)
return wrapper
except PluginLoadError:
raise
except Exception as e:
raise PluginLoadError(f"加载路由模块失败 [{module_path}]: {e}")
@classmethod
def load_all_routes(cls, plugin_key: str, manifest: PluginManifest) -> List[APIRouter]:
"""加载插件的所有路由"""
routers: List[APIRouter] = []
if not manifest.backend or not manifest.backend.routes:
return routers
for route_config in manifest.backend.routes:
router = cls.load_route_module(plugin_key, route_config)
routers.append(router)
cls._mounted_routers[plugin_key] = routers
return routers
# ========== 处理器动态注册 ==========
@classmethod
def load_processor_module(
cls, plugin_key: str, processor_config: ManifestProcessorConfig
) -> None:
"""
动态加载并注册处理器模块
Args:
plugin_key: 插件标识
processor_config: 处理器配置
"""
module_path = cls.get_plugin_dir(plugin_key) / processor_config.module
if not module_path.exists():
raise PluginLoadError(f"处理器模块不存在: {module_path}")
module_name = f"foxel_plugin_{plugin_key}_processor_{module_path.stem}"
try:
spec = spec_from_file_location(module_name, module_path)
if spec is None or spec.loader is None:
raise PluginLoadError(f"无法加载处理器模块: {module_path}")
module = module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
# 缓存模块
cls._loaded_modules[f"{plugin_key}:processor:{processor_config.module}"] = module
# 获取处理器工厂
factory = getattr(module, "PROCESSOR_FACTORY", None)
if factory is None:
raise PluginLoadError(f"处理器模块缺少 'PROCESSOR_FACTORY': {module_path}")
# 获取配置 schema
config_schema = getattr(module, "CONFIG_SCHEMA", [])
processor_name = getattr(module, "PROCESSOR_NAME", processor_config.name or processor_config.type)
supported_exts = getattr(module, "SUPPORTED_EXTS", [])
# 注册到处理器注册表
from domain.processors import CONFIG_SCHEMAS, TYPE_MAP
processor_type = processor_config.type
TYPE_MAP[processor_type] = factory
# 获取实例以读取属性
try:
sample = factory()
produces_file = getattr(sample, "produces_file", False)
supports_directory = getattr(sample, "supports_directory", False)
except Exception:
produces_file = False
supports_directory = False
CONFIG_SCHEMAS[processor_type] = {
"type": processor_type,
"name": processor_name,
"supported_exts": supported_exts,
"config_schema": config_schema,
"produces_file": produces_file,
"supports_directory": supports_directory,
"plugin": plugin_key, # 标记来源插件
"module_path": str(module_path),
}
except PluginLoadError:
raise
except Exception as e:
raise PluginLoadError(f"加载处理器模块失败 [{module_path}]: {e}")
@classmethod
def load_all_processors(cls, plugin_key: str, manifest: PluginManifest) -> List[str]:
"""加载插件的所有处理器,返回处理器类型列表"""
processor_types: List[str] = []
if not manifest.backend or not manifest.backend.processors:
return processor_types
for proc_config in manifest.backend.processors:
cls.load_processor_module(plugin_key, proc_config)
processor_types.append(proc_config.type)
return processor_types
# ========== 卸载 ==========
@classmethod
def unload_plugin(cls, plugin_key: str, manifest: Optional[PluginManifest] = None) -> None:
"""
卸载插件的后端组件
Args:
plugin_key: 插件标识
manifest: 可选的 manifest用于确定要卸载的组件
"""
# 卸载处理器
if manifest and manifest.backend and manifest.backend.processors:
from domain.processors import CONFIG_SCHEMAS, TYPE_MAP
for proc_config in manifest.backend.processors:
proc_type = proc_config.type
if proc_type in TYPE_MAP:
del TYPE_MAP[proc_type]
if proc_type in CONFIG_SCHEMAS:
del CONFIG_SCHEMAS[proc_type]
# 清理缓存的模块
keys_to_remove = [k for k in cls._loaded_modules if k.startswith(f"{plugin_key}:")]
for key in keys_to_remove:
module = cls._loaded_modules.pop(key, None)
if module and module.__name__ in sys.modules:
del sys.modules[module.__name__]
# 清理路由追踪注意FastAPI 不支持动态移除路由,需要重启应用)
cls._mounted_routers.pop(plugin_key, None)
@classmethod
def delete_plugin_files(cls, plugin_key: str) -> None:
"""删除插件文件"""
plugin_dir = cls.get_plugin_dir(plugin_key)
if plugin_dir.exists():
shutil.rmtree(plugin_dir)
# 同时删除备份
backup_dir = cls.PLUGINS_ROOT / f"{plugin_key}.backup"
if backup_dir.exists():
shutil.rmtree(backup_dir)
# ========== 读取 manifest ==========
@classmethod
def read_manifest(cls, plugin_key: str) -> Optional[PluginManifest]:
"""从文件系统读取插件 manifest"""
manifest_path = cls.get_manifest_path(plugin_key)
if not manifest_path.exists():
return None
try:
with open(manifest_path, "r", encoding="utf-8") as f:
data = json.load(f)
return PluginManifest.model_validate(data)
except Exception:
return None

View File

@@ -1,2 +0,0 @@
"""插件专属服务端路由集合。"""

View File

@@ -1,142 +0,0 @@
import json
from datetime import UTC, datetime
from pathlib import Path
from typing import Any, Dict, List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from api.response import success
from domain.auth.service import get_current_active_user
router = APIRouter(
prefix="/video-player",
tags=["plugins"],
dependencies=[Depends(get_current_active_user)],
)
DATA_ROOT = Path("data/.video")
def _read_json(path: Path) -> Dict[str, Any]:
return json.loads(path.read_text(encoding="utf-8"))
def _file_mtime_iso(path: Path) -> str:
try:
ts = path.stat().st_mtime
except FileNotFoundError:
return ""
return datetime.fromtimestamp(ts, tz=UTC).isoformat()
def _extract_title(payload: Dict[str, Any]) -> str:
detail = (payload.get("tmdb") or {}).get("detail") or {}
if payload.get("type") == "tv":
return str(detail.get("name") or detail.get("original_name") or "")
return str(detail.get("title") or detail.get("original_title") or "")
def _extract_year(payload: Dict[str, Any]) -> Optional[str]:
detail = (payload.get("tmdb") or {}).get("detail") or {}
value = detail.get("first_air_date") if payload.get("type") == "tv" else detail.get("release_date")
if not value or not isinstance(value, str):
return None
return value[:4] if len(value) >= 4 else value
def _extract_genres(payload: Dict[str, Any]) -> List[str]:
detail = (payload.get("tmdb") or {}).get("detail") or {}
genres = detail.get("genres") or []
out: List[str] = []
if isinstance(genres, list):
for g in genres:
if isinstance(g, dict) and g.get("name"):
out.append(str(g["name"]))
return out
def _summarize(item_id: str, payload: Dict[str, Any], mtime_iso: str) -> Dict[str, Any]:
detail = (payload.get("tmdb") or {}).get("detail") or {}
media_type = payload.get("type") or "unknown"
episodes = payload.get("episodes") or []
seasons = {e.get("season") for e in episodes if isinstance(e, dict) and e.get("season") is not None}
return {
"id": item_id,
"type": media_type,
"title": _extract_title(payload),
"year": _extract_year(payload),
"overview": detail.get("overview"),
"poster_path": detail.get("poster_path"),
"backdrop_path": detail.get("backdrop_path"),
"genres": _extract_genres(payload),
"tmdb_id": (payload.get("tmdb") or {}).get("id"),
"source_path": payload.get("source_path"),
"scraped_at": payload.get("scraped_at"),
"updated_at": mtime_iso,
"episodes_count": len(episodes) if isinstance(episodes, list) else 0,
"seasons_count": len(seasons),
"vote_average": detail.get("vote_average"),
"vote_count": detail.get("vote_count"),
}
def _iter_library_files() -> List[tuple[str, Path]]:
files: List[tuple[str, Path]] = []
for sub in ("tv", "movie"):
folder = DATA_ROOT / sub
if not folder.exists():
continue
for p in folder.glob("*.json"):
if not p.is_file():
continue
files.append((sub, p))
return files
@router.get("/library")
async def list_library(
q: str | None = Query(None, description="搜索关键字(标题/简介)"),
media_type: str | None = Query(None, alias="type", description="tv 或 movie"),
):
items: List[Dict[str, Any]] = []
keyword = (q or "").strip().lower()
type_filter = (media_type or "").strip().lower()
if type_filter and type_filter not in {"tv", "movie"}:
raise HTTPException(status_code=400, detail="type must be tv or movie")
for _sub, path in _iter_library_files():
item_id = path.stem
try:
payload = _read_json(path)
except Exception:
continue
if type_filter and str(payload.get("type") or "").lower() != type_filter:
continue
summary = _summarize(item_id, payload, _file_mtime_iso(path))
if keyword:
haystack = f"{summary.get('title') or ''} {summary.get('overview') or ''}".lower()
if keyword not in haystack:
continue
items.append(summary)
items.sort(key=lambda x: x.get("updated_at") or "", reverse=True)
return success(items)
@router.get("/library/{item_id}")
async def get_library_item(item_id: str):
candidates = [
DATA_ROOT / "tv" / f"{item_id}.json",
DATA_ROOT / "movie" / f"{item_id}.json",
]
path = next((p for p in candidates if p.exists()), None)
if not path:
raise HTTPException(status_code=404, detail="Item not found")
payload = _read_json(path)
payload["id"] = item_id
payload["updated_at"] = _file_mtime_iso(path)
return success(payload)

View File

@@ -1,138 +1,273 @@
"""
插件服务模块
负责插件的安装、卸载等管理操作
"""
import contextlib
import re
import logging
import shutil
from pathlib import Path
from typing import List, Optional, Union
import aiofiles
import httpx
from fastapi import HTTPException
from domain.plugins.types import PluginCreate, PluginManifestUpdate, PluginOut
from .loader import PluginLoadError, PluginLoader
from .types import (
PluginInstallResult,
PluginManifest,
PluginOut,
)
from models.database import Plugin
logger = logging.getLogger(__name__)
class PluginService:
"""插件服务"""
_plugins_root = Path("data/plugins")
@classmethod
def _folder_name(cls, rec: Plugin) -> str:
if rec.key:
safe = re.sub(r"[^A-Za-z0-9_.-]", "_", rec.key)
return safe or str(rec.id)
return str(rec.id)
# ========== 工具方法 ==========
@classmethod
def _bundle_dir_from_rec(cls, rec: Plugin) -> Path:
return cls._plugins_root / cls._folder_name(rec) / "current"
def _get_plugin_dir(cls, plugin_key: str) -> Path:
"""获取插件目录"""
return cls._plugins_root / plugin_key
@classmethod
def _bundle_path_from_rec(cls, rec: Plugin) -> Path:
return cls._bundle_dir_from_rec(rec) / "index.js"
def _get_bundle_path(cls, rec: Plugin) -> Path:
"""获取前端 bundle 路径"""
plugin_dir = cls._get_plugin_dir(rec.key)
# 从 manifest 读取
if rec.manifest:
frontend = rec.manifest.get("frontend", {})
entry = frontend.get("entry")
if entry:
return plugin_dir / entry
# 默认位置
return plugin_dir / "frontend" / "index.js"
@classmethod
async def _download_bundle(cls, rec: Plugin, url: str) -> None:
dest_dir = cls._bundle_dir_from_rec(rec)
dest_dir.mkdir(parents=True, exist_ok=True)
dest_path = cls._bundle_path_from_rec(rec)
tmp_path = dest_path.with_suffix(".tmp")
try:
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
async with client.stream("GET", url) as resp:
resp.raise_for_status()
async with aiofiles.open(tmp_path, "wb") as f:
async for chunk in resp.aiter_bytes(chunk_size=65536):
if not chunk:
continue
await f.write(chunk)
tmp_path.replace(dest_path)
except Exception:
with contextlib.suppress(Exception):
if tmp_path.exists():
tmp_path.unlink()
raise
@classmethod
async def _ensure_bundle(cls, plugin_id: int) -> Path:
rec = await cls._get_or_404(plugin_id)
bundle_path = cls._bundle_path_from_rec(rec)
if bundle_path.exists():
return bundle_path
legacy = cls._plugins_root / str(rec.id) / "current" / "index.js"
if legacy.exists():
return legacy
raise HTTPException(status_code=404, detail="Plugin bundle not found")
@classmethod
async def get_bundle_path(cls, plugin_id: int) -> Path:
return await cls._ensure_bundle(plugin_id)
@classmethod
async def create(cls, payload: PluginCreate) -> PluginOut:
rec = await Plugin.create(**payload.model_dump())
try:
await cls._download_bundle(rec, rec.url)
except Exception as exc:
with contextlib.suppress(Exception):
await rec.delete()
raise HTTPException(status_code=400, detail=f"Failed to fetch plugin: {exc}")
return PluginOut.model_validate(rec)
@classmethod
async def list_plugins(cls) -> list[PluginOut]:
rows = await Plugin.all().order_by("-id")
return [PluginOut.model_validate(r) for r in rows]
@classmethod
async def _get_or_404(cls, plugin_id: int) -> Plugin:
rec = await Plugin.get_or_none(id=plugin_id)
async def _get_by_key_or_404(cls, key: str) -> Plugin:
"""通过 key 获取插件,不存在则返回 404"""
rec = await Plugin.get_or_none(key=key)
if not rec:
raise HTTPException(status_code=404, detail="Plugin not found")
return rec
@classmethod
async def delete(cls, plugin_id: int) -> None:
rec = await cls._get_or_404(plugin_id)
await rec.delete()
with contextlib.suppress(Exception):
dirs = {cls._bundle_dir_from_rec(rec).parent, cls._plugins_root / str(rec.id)}
for plugin_dir in dirs:
if plugin_dir.exists():
shutil.rmtree(plugin_dir)
async def _get_by_key_or_id(cls, key_or_id: Union[str, int]) -> Plugin:
"""通过 key 或 ID 获取插件"""
# 尝试作为 ID
if isinstance(key_or_id, int) or (isinstance(key_or_id, str) and key_or_id.isdigit()):
plugin_id = int(key_or_id)
rec = await Plugin.get_or_none(id=plugin_id)
if rec:
return rec
# 尝试作为 key
if isinstance(key_or_id, str):
rec = await Plugin.get_or_none(key=key_or_id)
if rec:
return rec
raise HTTPException(status_code=404, detail="Plugin not found")
# ========== 安装 ==========
@classmethod
async def update(cls, plugin_id: int, payload: PluginCreate) -> PluginOut:
rec = await cls._get_or_404(plugin_id)
url_changed = rec.url != payload.url
if url_changed:
try:
await cls._download_bundle(rec, payload.url)
except Exception as exc:
raise HTTPException(status_code=400, detail=f"Failed to fetch plugin: {exc}")
rec.url = payload.url
rec.enabled = payload.enabled
await rec.save()
return PluginOut.model_validate(rec)
async def install_package(cls, file_content: bytes, filename: str) -> PluginInstallResult:
"""
安装 .foxpkg 插件包
Args:
file_content: 插件包内容
filename: 文件名
Returns:
安装结果
"""
errors: List[str] = []
try:
# 解包
manifest, plugin_dir = PluginLoader.unpack_foxpkg(file_content)
plugin_key = manifest.key
# 检查是否已存在
existing = await Plugin.get_or_none(key=plugin_key)
if existing:
# 更新现有插件
logger.info(f"更新插件: {plugin_key}")
rec = existing
else:
# 创建新插件
logger.info(f"安装新插件: {plugin_key}")
rec = Plugin(key=plugin_key)
# 更新字段
rec.name = manifest.name
rec.version = manifest.version
rec.description = manifest.description
rec.author = manifest.author
rec.website = manifest.website
rec.github = manifest.github
rec.license = manifest.license
rec.manifest = manifest.model_dump(mode="json")
# 从 manifest.frontend 提取前端配置
if manifest.frontend:
rec.open_app = manifest.frontend.open_app or False
rec.supported_exts = manifest.frontend.supported_exts
rec.default_bounds = manifest.frontend.default_bounds
rec.default_maximized = manifest.frontend.default_maximized
rec.icon = manifest.frontend.icon
@classmethod
async def update_manifest(
cls, plugin_id: int, manifest: PluginManifestUpdate
) -> PluginOut:
rec = await cls._get_or_404(plugin_id)
old_dir = cls._bundle_dir_from_rec(rec).parent
updates = manifest.model_dump(exclude_none=True)
if updates:
for key, value in updates.items():
setattr(rec, key, value)
await rec.save()
new_dir = cls._bundle_dir_from_rec(rec).parent
if rec.key and new_dir != old_dir:
candidate_dir = old_dir if old_dir.exists() else (cls._plugins_root / str(rec.id))
if candidate_dir.exists():
new_dir.parent.mkdir(parents=True, exist_ok=True)
with contextlib.suppress(Exception):
if new_dir.exists():
shutil.rmtree(new_dir)
shutil.move(str(candidate_dir), str(new_dir))
# 加载后端组件(如果有)
loaded_routes: List[str] = []
loaded_processors: List[str] = []
if manifest.backend:
# 加载路由
if manifest.backend.routes:
try:
from main import app
routers = PluginLoader.load_all_routes(plugin_key, manifest)
for router in routers:
app.include_router(router)
loaded_routes.append(router.prefix)
except PluginLoadError as e:
errors.append(f"路由加载失败: {e}")
logger.error(f"插件 {plugin_key} 路由加载失败: {e}")
except Exception as e:
errors.append(f"路由加载失败: {e}")
logger.exception(f"插件 {plugin_key} 路由加载异常")
# 加载处理器
if manifest.backend.processors:
try:
processor_types = PluginLoader.load_all_processors(plugin_key, manifest)
loaded_processors = processor_types
except PluginLoadError as e:
errors.append(f"处理器加载失败: {e}")
logger.error(f"插件 {plugin_key} 处理器加载失败: {e}")
except Exception as e:
errors.append(f"处理器加载失败: {e}")
logger.exception(f"插件 {plugin_key} 处理器加载异常")
# 更新加载状态
rec.loaded_routes = loaded_routes if loaded_routes else None
rec.loaded_processors = loaded_processors if loaded_processors else None
await rec.save()
return PluginInstallResult(
success=True,
plugin=PluginOut.model_validate(rec),
message="安装成功" if not errors else "安装完成,但有部分组件加载失败",
errors=errors if errors else None,
)
except PluginLoadError as e:
logger.error(f"插件安装失败: {e}")
return PluginInstallResult(
success=False,
message=str(e),
errors=[str(e)],
)
except Exception as e:
logger.exception("插件安装异常")
return PluginInstallResult(
success=False,
message=f"安装失败: {e}",
errors=[str(e)],
)
# ========== 查询 ==========
@classmethod
async def list_plugins(cls) -> List[PluginOut]:
"""获取所有插件列表"""
rows = await Plugin.all().order_by("-id")
for rec in rows:
try:
manifest = PluginLoader.read_manifest(rec.key)
if manifest:
rec.manifest = manifest.model_dump(mode="json")
except Exception:
continue
return [PluginOut.model_validate(r) for r in rows]
@classmethod
async def get_plugin(cls, key_or_id: Union[str, int]) -> PluginOut:
"""获取单个插件详情"""
rec = await cls._get_by_key_or_id(key_or_id)
try:
manifest = PluginLoader.read_manifest(rec.key)
if manifest:
rec.manifest = manifest.model_dump(mode="json")
except Exception:
pass
return PluginOut.model_validate(rec)
@classmethod
async def get_bundle_path(cls, key_or_id: Union[str, int]) -> Path:
"""获取插件前端 bundle 路径"""
rec = await cls._get_by_key_or_id(key_or_id)
bundle_path = cls._get_bundle_path(rec)
if not bundle_path.exists():
raise HTTPException(status_code=404, detail="Plugin bundle not found")
return bundle_path
@classmethod
async def get_asset_path(cls, key: str, asset_path: str) -> Path:
"""获取插件静态资源路径"""
rec = await cls._get_by_key_or_404(key)
plugin_dir = cls._get_plugin_dir(rec.key)
# 安全检查:防止路径遍历
asset_path = asset_path.lstrip("/")
if ".." in asset_path:
raise HTTPException(status_code=400, detail="Invalid asset path")
full_path = plugin_dir / asset_path
if not full_path.exists():
raise HTTPException(status_code=404, detail="Asset not found")
# 确保路径在插件目录内
try:
full_path.resolve().relative_to(plugin_dir.resolve())
except ValueError:
raise HTTPException(status_code=400, detail="Invalid asset path")
return full_path
# ========== 管理操作 ==========
@classmethod
async def delete(cls, key_or_id: Union[str, int]) -> None:
"""删除/卸载插件"""
rec = await cls._get_by_key_or_id(key_or_id)
# 获取 manifest 用于卸载组件
manifest: Optional[PluginManifest] = None
if rec.manifest:
try:
manifest = PluginManifest.model_validate(rec.manifest)
except Exception:
pass
# 卸载后端组件
if manifest:
PluginLoader.unload_plugin(rec.key, manifest)
# 删除数据库记录
await rec.delete()
# 删除文件
with contextlib.suppress(Exception):
plugin_dir = cls._get_plugin_dir(rec.key)
if plugin_dir.exists():
shutil.rmtree(plugin_dir)
logger.info(f"插件 {rec.key} 已卸载")

115
domain/plugins/startup.py Normal file
View File

@@ -0,0 +1,115 @@
"""
插件启动加载模块
负责在应用启动时加载所有已安装的插件
"""
import logging
from typing import TYPE_CHECKING, List, Tuple
from .loader import PluginLoadError, PluginLoader
from .types import PluginManifest
if TYPE_CHECKING:
from fastapi import FastAPI
logger = logging.getLogger(__name__)
async def load_installed_plugins(app: "FastAPI") -> Tuple[int, List[str]]:
"""
加载所有已安装的插件
Args:
app: FastAPI 应用实例
Returns:
(成功加载数量, 错误列表)
"""
from models.database import Plugin
errors: List[str] = []
loaded_count = 0
try:
plugins = await Plugin.all()
except Exception as e:
logger.error(f"查询插件列表失败: {e}")
return 0, [f"查询插件列表失败: {e}"]
for plugin in plugins:
if not plugin.key:
continue
try:
# 获取 manifest
manifest = None
if plugin.manifest:
try:
manifest = PluginManifest.model_validate(plugin.manifest)
except Exception:
# 尝试从文件系统读取
manifest = PluginLoader.read_manifest(plugin.key)
else:
manifest = PluginLoader.read_manifest(plugin.key)
if not manifest:
logger.warning(f"插件 {plugin.key} 缺少 manifest跳过加载")
continue
# 加载后端路由
loaded_routes: List[str] = []
if manifest.backend and manifest.backend.routes:
try:
routers = PluginLoader.load_all_routes(plugin.key, manifest)
for router in routers:
app.include_router(router)
loaded_routes.append(router.prefix)
logger.info(f"插件 {plugin.key} 加载了 {len(routers)} 个路由")
except PluginLoadError as e:
errors.append(f"插件 {plugin.key} 路由加载失败: {e}")
logger.error(f"插件 {plugin.key} 路由加载失败: {e}")
# 加载处理器
loaded_processors: List[str] = []
if manifest.backend and manifest.backend.processors:
try:
processor_types = PluginLoader.load_all_processors(plugin.key, manifest)
loaded_processors = processor_types
logger.info(f"插件 {plugin.key} 注册了 {len(processor_types)} 个处理器")
except PluginLoadError as e:
errors.append(f"插件 {plugin.key} 处理器加载失败: {e}")
logger.error(f"插件 {plugin.key} 处理器加载失败: {e}")
# 更新数据库记录
plugin.loaded_routes = loaded_routes if loaded_routes else None
plugin.loaded_processors = loaded_processors if loaded_processors else None
await plugin.save()
loaded_count += 1
logger.info(f"插件 {plugin.key} 加载完成")
except Exception as e:
error_msg = f"插件 {plugin.key} 加载异常: {e}"
errors.append(error_msg)
logger.exception(error_msg)
return loaded_count, errors
async def init_plugins(app: "FastAPI") -> None:
"""
初始化插件系统
在应用启动时调用
"""
logger.info("开始加载已安装插件...")
loaded_count, errors = await load_installed_plugins(app)
if errors:
logger.warning(f"插件加载完成,共 {loaded_count} 个成功,{len(errors)} 个错误")
for error in errors:
logger.warning(f" - {error}")
else:
logger.info(f"插件加载完成,共 {loaded_count} 个插件")

View File

@@ -1,48 +1,119 @@
from typing import Any, Dict, List, Optional
from pydantic import AliasChoices, BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field
class PluginCreate(BaseModel):
url: str = Field(min_length=1)
enabled: bool = True
# ========== Manifest 相关类型 ==========
class PluginManifestUpdate(BaseModel):
class ManifestFrontend(BaseModel):
"""manifest.json 中的 frontend 配置"""
model_config = ConfigDict(populate_by_name=True, extra="ignore")
key: Optional[str] = None
name: Optional[str] = None
version: Optional[str] = None
entry: Optional[str] = Field(default=None, description="前端入口文件路径")
styles: Optional[List[str]] = Field(default=None, description="前端样式文件路径列表(相对插件根目录)")
open_app: Optional[bool] = Field(
default=None,
validation_alias=AliasChoices("open_app", "openApp"),
alias="openApp",
description="是否支持独立打开",
)
supported_exts: Optional[List[str]] = Field(
default=None,
validation_alias=AliasChoices("supported_exts", "supportedExts"),
alias="supportedExts",
description="支持的文件扩展名列表",
)
default_bounds: Optional[Dict[str, Any]] = Field(
default=None,
validation_alias=AliasChoices("default_bounds", "defaultBounds"),
alias="defaultBounds",
description="默认窗口尺寸",
)
default_maximized: Optional[bool] = Field(
default=None,
validation_alias=AliasChoices("default_maximized", "defaultMaximized"),
alias="defaultMaximized",
description="是否默认最大化",
)
icon: Optional[str] = None
description: Optional[str] = None
author: Optional[str] = None
website: Optional[str] = None
github: Optional[str] = None
icon: Optional[str] = Field(default=None, description="图标路径")
use_system_window: Optional[bool] = Field(
default=None,
alias="useSystemWindow",
description="是否使用系统窗口",
)
class ManifestRouteConfig(BaseModel):
"""manifest.json 中的路由配置"""
model_config = ConfigDict(extra="ignore")
module: str = Field(..., description="路由模块路径")
prefix: str = Field(..., description="路由前缀")
tags: Optional[List[str]] = Field(default=None, description="API 标签")
class ManifestProcessorConfig(BaseModel):
"""manifest.json 中的处理器配置"""
model_config = ConfigDict(extra="ignore")
module: str = Field(..., description="处理器模块路径")
type: str = Field(..., description="处理器类型标识")
name: Optional[str] = Field(default=None, description="处理器显示名称")
class ManifestBackend(BaseModel):
"""manifest.json 中的 backend 配置"""
model_config = ConfigDict(extra="ignore")
routes: Optional[List[ManifestRouteConfig]] = Field(default=None, description="路由列表")
processors: Optional[List[ManifestProcessorConfig]] = Field(
default=None, description="处理器列表"
)
class ManifestDependencies(BaseModel):
"""manifest.json 中的依赖配置"""
model_config = ConfigDict(extra="ignore")
python: Optional[str] = Field(default=None, description="Python 版本要求")
packages: Optional[List[str]] = Field(default=None, description="Python 包依赖列表")
class PluginManifest(BaseModel):
"""完整的 manifest.json 结构"""
model_config = ConfigDict(populate_by_name=True, extra="ignore")
foxpkg: str = Field(default="1.0", description="foxpkg 格式版本")
key: str = Field(..., min_length=1, description="插件唯一标识")
name: str = Field(..., min_length=1, description="插件名称")
version: str = Field(default="1.0.0", description="插件版本")
description: Optional[str] = Field(default=None, description="插件描述")
i18n: Optional[Dict[str, Dict[str, str]]] = Field(
default=None,
description="多语言信息name/description例如{'en': {'name': '...', 'description': '...'}}",
)
author: Optional[str] = Field(default=None, description="作者")
website: Optional[str] = Field(default=None, description="网站")
github: Optional[str] = Field(default=None, description="GitHub 地址")
license: Optional[str] = Field(default=None, description="许可证")
frontend: Optional[ManifestFrontend] = Field(default=None, description="前端配置")
backend: Optional[ManifestBackend] = Field(default=None, description="后端配置")
dependencies: Optional[ManifestDependencies] = Field(default=None, description="依赖配置")
# ========== API 请求/响应类型 ==========
class PluginOut(BaseModel):
"""插件输出模型"""
id: int
url: str
enabled: bool
key: str
open_app: bool = False
key: Optional[str] = None
name: Optional[str] = None
version: Optional[str] = None
supported_exts: Optional[List[str]] = None
@@ -53,5 +124,20 @@ class PluginOut(BaseModel):
author: Optional[str] = None
website: Optional[str] = None
github: Optional[str] = None
license: Optional[str] = None
# 新增字段
manifest: Optional[Dict[str, Any]] = None
loaded_routes: Optional[List[str]] = None
loaded_processors: Optional[List[str]] = None
model_config = ConfigDict(from_attributes=True)
class PluginInstallResult(BaseModel):
"""安装结果"""
success: bool
plugin: Optional[PluginOut] = None
message: Optional[str] = None
errors: Optional[List[str]] = None

View File

@@ -0,0 +1,35 @@
from .base import BaseProcessor
from .registry import (
CONFIG_SCHEMAS,
TYPE_MAP,
get_config_schema,
get_config_schemas,
get_last_discovery_errors,
get_module_path,
reload_processors,
)
from .service import (
ProcessorService,
get_processor,
list_processors,
reload_processor_modules,
)
from .types import ProcessDirectoryRequest, ProcessRequest, UpdateSourceRequest
__all__ = [
"BaseProcessor",
"CONFIG_SCHEMAS",
"TYPE_MAP",
"get_config_schema",
"get_config_schemas",
"get_last_discovery_errors",
"get_module_path",
"reload_processors",
"ProcessorService",
"get_processor",
"list_processors",
"reload_processor_modules",
"ProcessDirectoryRequest",
"ProcessRequest",
"UpdateSourceRequest",
]

View File

@@ -4,10 +4,15 @@ from fastapi import APIRouter, Body, Depends, Request
from api.response import success
from domain.audit import AuditAction, audit
from domain.auth.service import get_current_active_user
from domain.auth.types import User
from domain.processors.service import ProcessorService
from domain.processors.types import (
from domain.auth import User, get_current_active_user
from domain.permission import require_path_permission
from domain.permission import require_system_permission
from domain.permission.service import PermissionService
from domain.permission.types import PathAction
from domain.permission.types import SystemPermission
from domain.processors.registry import get_config_schema
from .service import ProcessorService
from .types import (
ProcessDirectoryRequest,
ProcessRequest,
UpdateSourceRequest,
@@ -32,11 +37,18 @@ async def list_processors(
description="处理单个文件",
body_fields=["path", "processor_type", "save_to", "overwrite"],
)
@require_path_permission(PathAction.READ, "req.path")
async def process_file_with_processor(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
req: ProcessRequest = Body(...),
):
meta = get_config_schema(req.processor_type) or {}
if meta.get("produces_file"):
if req.overwrite:
await PermissionService.require_path_permission(current_user.id, req.path, PathAction.WRITE)
elif req.save_to:
await PermissionService.require_path_permission(current_user.id, req.save_to, PathAction.WRITE)
data = await ProcessorService.process_file(req)
return success(data)
@@ -47,17 +59,22 @@ async def process_file_with_processor(
description="批量处理目录",
body_fields=["path", "processor_type", "overwrite", "max_depth", "suffix"],
)
@require_path_permission(PathAction.READ, "req.path")
async def process_directory_with_processor(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
req: ProcessDirectoryRequest = Body(...),
):
meta = get_config_schema(req.processor_type) or {}
if meta.get("produces_file"):
await PermissionService.require_path_permission(current_user.id, req.path, PathAction.WRITE)
data = await ProcessorService.process_directory(req)
return success(data)
@router.get("/source/{processor_type}")
@audit(action=AuditAction.READ, description="获取处理器源码")
@require_system_permission(SystemPermission.ROLE_MANAGE)
async def get_processor_source(
request: Request,
processor_type: str,
@@ -69,6 +86,7 @@ async def get_processor_source(
@router.put("/source/{processor_type}")
@audit(action=AuditAction.UPDATE, description="更新处理器源码")
@require_system_permission(SystemPermission.ROLE_MANAGE)
async def update_processor_source(
request: Request,
processor_type: str,
@@ -81,6 +99,7 @@ async def update_processor_source(
@router.post("/reload")
@audit(action=AuditAction.UPDATE, description="重载处理器模块")
@require_system_permission(SystemPermission.ROLE_MANAGE)
async def reload_processor_modules(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],

View File

@@ -8,12 +8,14 @@ from fastapi.responses import Response
from PIL import Image
from ..base import BaseProcessor
from domain.ai.inference import describe_image_base64, get_text_embedding, provider_service
from domain.ai.service import (
VectorDBService,
from domain.ai import (
DEFAULT_VECTOR_DIMENSION,
VECTOR_COLLECTION_NAME,
FILE_COLLECTION_NAME,
VECTOR_COLLECTION_NAME,
VectorDBService,
describe_image_base64,
get_text_embedding,
provider_service,
)
@@ -112,8 +114,15 @@ class VectorIndexProcessor:
}
]
produces_file = False
requires_input_bytes = False
async def process(self, input_bytes: bytes, path: str, config: Dict[str, Any]) -> Response:
async def ensure_input_bytes() -> bytes:
if input_bytes:
return input_bytes
from domain.virtual_fs import VirtualFSService
return await VirtualFSService.read_file(path)
action = config.get("action", "create")
index_type = config.get("index_type", "vector")
vector_db = VectorDBService()
@@ -157,7 +166,8 @@ class VectorIndexProcessor:
await vector_db.delete_vector(vector_collection, path)
if file_ext in ["jpg", "jpeg", "png", "bmp"]:
processed_bytes, compression = _compress_image_for_embedding(input_bytes)
file_bytes = await ensure_input_bytes()
processed_bytes, compression = _compress_image_for_embedding(file_bytes)
base64_image = base64.b64encode(processed_bytes).decode("utf-8")
description = await describe_image_base64(base64_image)
embedding = await get_text_embedding(description)
@@ -178,7 +188,8 @@ class VectorIndexProcessor:
if file_ext in ["txt", "md"]:
try:
text = input_bytes.decode("utf-8")
file_bytes = await ensure_input_bytes()
text = file_bytes.decode("utf-8")
except UnicodeDecodeError:
return Response(content="文本文件解码失败", status_code=400)

View File

@@ -1,396 +0,0 @@
import hashlib
import json
import os
import re
from datetime import UTC, datetime
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import httpx
from domain.virtual_fs.service import VirtualFSService
from domain.virtual_fs.thumbnail import VIDEO_EXT, is_video_filename
DATA_ROOT = Path("data/.video")
TMDB_BASE_URL = "https://api.themoviedb.org/3"
def _sha1(text: str) -> str:
return hashlib.sha1(text.encode("utf-8")).hexdigest()
def _store_path(media_type: str, source_path: str) -> Path:
subdir = "tv" if media_type == "tv" else "movie"
return DATA_ROOT / subdir / f"{_sha1(source_path)}.json"
def _write_json(path: Path, payload: dict) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
_CLEAN_TAGS_RE = re.compile(
r"\b("
r"2160p|1080p|720p|480p|4k|hdr|dv|dolby|atmos|"
r"x264|x265|h264|h265|hevc|av1|aac|dts|flac|"
r"bluray|bdrip|web[- ]?dl|webrip|dvdrip|remux|proper|repack"
r")\b",
re.IGNORECASE,
)
def _clean_query_name(raw: str) -> str:
name = raw
name = name.replace(".", " ").replace("_", " ")
name = re.sub(r"\[[^\]]*\]", " ", name)
name = re.sub(r"\([^\)]*\)", " ", name)
name = _CLEAN_TAGS_RE.sub(" ", name)
name = re.sub(r"\s+", " ", name).strip()
return name
def _guess_name_from_path(path: str, is_dir: bool) -> str:
norm = path.rstrip("/") if is_dir else path
p = Path(norm)
raw = p.name if is_dir else p.stem
return _clean_query_name(raw)
def _as_bool(value: Any, default: bool) -> bool:
if value is None:
return default
if isinstance(value, bool):
return value
if isinstance(value, int):
return value != 0
if isinstance(value, str):
v = value.strip().lower()
if v in {"1", "true", "yes", "y", "on"}:
return True
if v in {"0", "false", "no", "n", "off"}:
return False
return default
_SXXEYY_RE = re.compile(r"[Ss](\d{1,2})\s*[.\-_ ]*\s*[Ee](\d{1,3})")
_X_RE = re.compile(r"(\d{1,2})x(\d{1,3})", re.IGNORECASE)
_CN_EP_RE = re.compile(r"\s*(\d{1,3})\s*[集话]")
_CN_SEASON_RE = re.compile(r"\s*(\d{1,2})\s*季")
_SEASON_WORD_RE = re.compile(r"Season\s*(\d{1,2})", re.IGNORECASE)
_S_RE = re.compile(r"[Ss](\d{1,2})")
def _parse_season_episode(rel_path: str) -> Tuple[Optional[int], Optional[int]]:
stem = Path(rel_path).stem
m = _SXXEYY_RE.search(stem) or _SXXEYY_RE.search(rel_path)
if m:
return int(m.group(1)), int(m.group(2))
m = _X_RE.search(stem)
if m:
return int(m.group(1)), int(m.group(2))
m = _CN_EP_RE.search(stem)
if m:
episode = int(m.group(1))
season = None
for part in reversed(Path(rel_path).parts[:-1]):
sm = _CN_SEASON_RE.search(part) or _SEASON_WORD_RE.search(part) or _S_RE.search(part)
if sm:
season = int(sm.group(1))
break
return season or 1, episode
m = re.match(r"^(\d{1,3})(?!\d)", stem)
if m:
episode = int(m.group(1))
season = None
for part in reversed(Path(rel_path).parts[:-1]):
sm = _CN_SEASON_RE.search(part) or _SEASON_WORD_RE.search(part) or _S_RE.search(part)
if sm:
season = int(sm.group(1))
break
return season or 1, episode
return None, None
class TMDBClient:
def __init__(self, access_token: str | None, api_key: str | None):
self._access_token = access_token
self._api_key = api_key
@classmethod
def from_env(cls) -> "TMDBClient":
access_token = os.getenv("TMDB_ACCESS_TOKEN")
api_key = os.getenv("TMDB_API_KEY")
if not access_token and not api_key:
raise RuntimeError("缺少 TMDB_ACCESS_TOKEN 或 TMDB_API_KEY")
return cls(access_token=access_token, api_key=api_key)
def _headers(self) -> dict:
headers = {"Accept": "application/json"}
if self._access_token:
headers["Authorization"] = f"Bearer {self._access_token}"
return headers
def _merge_params(self, params: dict) -> dict:
merged = dict(params or {})
if self._api_key:
merged.setdefault("api_key", self._api_key)
return merged
async def get(self, path: str, params: dict) -> dict:
url = f"{TMDB_BASE_URL}{path}"
async with httpx.AsyncClient(timeout=30.0) as client:
resp = await client.get(url, headers=self._headers(), params=self._merge_params(params))
resp.raise_for_status()
return resp.json()
class VideoLibraryProcessor:
name = "影视入库"
supported_exts = sorted(VIDEO_EXT)
config_schema = [
{
"key": "name",
"label": "手动名称(可选)",
"type": "string",
"required": False,
"placeholder": "留空则从路径提取",
},
{
"key": "language",
"label": "语言",
"type": "string",
"required": False,
"default": "zh-CN",
},
{
"key": "include_episodes",
"label": "电视剧:保存每集",
"type": "select",
"required": False,
"default": 1,
"options": [
{"label": "", "value": 1},
{"label": "", "value": 0},
],
},
]
produces_file = False
supports_directory = True
requires_input_bytes = False
async def process(self, input_bytes: bytes, path: str, config: Dict[str, Any]) -> Dict[str, Any]:
tmdb = TMDBClient.from_env()
is_dir = await VirtualFSService.path_is_directory(path)
language = str(config.get("language") or "zh-CN")
manual_name = str(config.get("name") or "").strip()
query_name = manual_name or _guess_name_from_path(path, is_dir=is_dir)
scraped_at = datetime.now(UTC).isoformat()
if is_dir:
payload, saved_to = await self._process_tv_dir(tmdb, path, query_name, language, scraped_at, config)
return {
"ok": True,
"type": "tv",
"path": path,
"tmdb_id": payload.get("tmdb", {}).get("id"),
"saved_to": str(saved_to),
}
payload, saved_to = await self._process_movie_file(tmdb, path, query_name, language, scraped_at)
return {
"ok": True,
"type": "movie",
"path": path,
"tmdb_id": payload.get("tmdb", {}).get("id"),
"saved_to": str(saved_to),
}
async def _process_movie_file(
self,
tmdb: TMDBClient,
path: str,
query_name: str,
language: str,
scraped_at: str,
) -> Tuple[dict, Path]:
search = await tmdb.get("/search/movie", {"query": query_name, "language": language})
results = search.get("results") or []
if not results:
raise RuntimeError(f"未找到电影条目:{query_name}")
chosen = results[0] or {}
movie_id = chosen.get("id")
if not movie_id:
raise RuntimeError("TMDB 搜索结果缺少 id")
detail = await tmdb.get(
f"/movie/{movie_id}",
{
"language": language,
"append_to_response": "credits,images,external_ids,videos",
},
)
payload = {
"type": "movie",
"source_path": path,
"query": {"name": query_name, "language": language},
"scraped_at": scraped_at,
"tmdb": {
"id": movie_id,
"search": {"page": search.get("page"), "total_results": search.get("total_results"), "results": results[:5]},
"detail": detail,
},
}
saved_to = _store_path("movie", path)
_write_json(saved_to, payload)
return payload, saved_to
async def _process_tv_dir(
self,
tmdb: TMDBClient,
path: str,
query_name: str,
language: str,
scraped_at: str,
config: Dict[str, Any],
) -> Tuple[dict, Path]:
search = await tmdb.get("/search/tv", {"query": query_name, "language": language})
results = search.get("results") or []
if not results:
raise RuntimeError(f"未找到电视剧条目:{query_name}")
chosen = results[0] or {}
tv_id = chosen.get("id")
if not tv_id:
raise RuntimeError("TMDB 搜索结果缺少 id")
detail = await tmdb.get(
f"/tv/{tv_id}",
{
"language": language,
"append_to_response": "credits,images,external_ids,videos",
},
)
include_episodes = _as_bool(config.get("include_episodes"), True)
episodes: List[dict] = []
seasons_detail: Dict[str, Any] = {}
if include_episodes:
episodes = await self._collect_episode_files(path)
seasons = sorted({ep["season"] for ep in episodes if ep.get("season") is not None})
for season in seasons:
seasons_detail[str(season)] = await tmdb.get(
f"/tv/{tv_id}/season/{int(season)}",
{"language": language},
)
self._attach_tmdb_episode_detail(episodes, seasons_detail)
payload = {
"type": "tv",
"source_path": path,
"query": {"name": query_name, "language": language},
"scraped_at": scraped_at,
"tmdb": {
"id": tv_id,
"search": {"page": search.get("page"), "total_results": search.get("total_results"), "results": results[:5]},
"detail": detail,
"seasons": seasons_detail,
},
"episodes": episodes,
}
saved_to = _store_path("tv", path)
_write_json(saved_to, payload)
return payload, saved_to
async def _collect_episode_files(self, dir_path: str) -> List[dict]:
adapter_instance, adapter_model, root, rel = await VirtualFSService.resolve_adapter_and_rel(dir_path)
rel = rel.rstrip("/")
list_dir = await VirtualFSService._ensure_method(adapter_instance, "list_dir")
stack: List[str] = [rel]
page_size = 200
out: List[dict] = []
while stack:
current_rel = stack.pop()
page = 1
while True:
entries, total = await list_dir(root, current_rel, page, page_size, "name", "asc")
entries = entries or []
if not entries and (total or 0) == 0:
break
for entry in entries:
name = entry.get("name")
if not name:
continue
child_rel = VirtualFSService._join_rel(current_rel, name)
if entry.get("is_dir"):
stack.append(child_rel.rstrip("/"))
continue
if not is_video_filename(name):
continue
absolute_path = VirtualFSService._build_absolute_path(adapter_model.path, child_rel)
rel_in_show = child_rel
if rel and child_rel.startswith(rel.rstrip("/") + "/"):
rel_in_show = child_rel[len(rel.rstrip("/")) + 1 :]
season, episode = _parse_season_episode(rel_in_show)
out.append(
{
"path": absolute_path,
"rel": rel_in_show,
"name": name,
"size": entry.get("size"),
"mtime": entry.get("mtime"),
"season": season,
"episode": episode,
}
)
if total is None or page * page_size >= total:
break
page += 1
return out
def _attach_tmdb_episode_detail(self, episodes: List[dict], seasons_detail: Dict[str, Any]) -> None:
episode_maps: Dict[str, Dict[int, Any]] = {}
for season_str, season_payload in (seasons_detail or {}).items():
items = (season_payload or {}).get("episodes") or []
m: Dict[int, Any] = {}
for item in items:
try:
number = int(item.get("episode_number"))
except Exception:
continue
m[number] = item
episode_maps[season_str] = m
for ep in episodes:
season = ep.get("season")
episode = ep.get("episode")
if season is None or episode is None:
continue
m = episode_maps.get(str(season))
if not m:
continue
detail = m.get(int(episode))
if detail:
ep["tmdb_episode"] = detail
PROCESSOR_TYPE = "video_library"
PROCESSOR_NAME = VideoLibraryProcessor.name
SUPPORTED_EXTS = VideoLibraryProcessor.supported_exts
CONFIG_SCHEMA = VideoLibraryProcessor.config_schema
PROCESSOR_FACTORY = lambda: VideoLibraryProcessor()

View File

@@ -5,7 +5,7 @@ from pathlib import Path
from types import ModuleType
from typing import Callable, Dict, Optional
from domain.processors.base import BaseProcessor
from .base import BaseProcessor
ProcessorFactory = Callable[[], BaseProcessor]
TYPE_MAP: Dict[str, ProcessorFactory] = {}
@@ -16,7 +16,7 @@ LAST_DISCOVERY_ERRORS: list[str] = []
def discover_processors(force_reload: bool = False) -> list[str]:
"""扫描并缓存可用的处理器模块。"""
from domain.processors import builtin as processors_pkg
from . import builtin as processors_pkg
TYPE_MAP.clear()
CONFIG_SCHEMAS.clear()

View File

@@ -3,20 +3,20 @@ from typing import List, Tuple
from fastapi import HTTPException
from fastapi.concurrency import run_in_threadpool
from domain.processors.registry import (
from domain.tasks import task_queue_service
from domain.virtual_fs import VirtualFSService
from .registry import (
get,
get_config_schema,
get_config_schemas,
get_module_path,
reload_processors,
)
from domain.processors.types import (
from .types import (
ProcessDirectoryRequest,
ProcessRequest,
UpdateSourceRequest,
)
from domain.virtual_fs.service import VirtualFSService
from domain.tasks.task_queue import task_queue_service
class ProcessorService:
@@ -85,6 +85,44 @@ class ProcessorService:
suffix = raw_suffix
overwrite = req.overwrite
if produces_file:
if not overwrite and not suffix:
raise HTTPException(400, detail="Suffix is required when not overwriting files")
else:
overwrite = False
suffix = None
payload = {
"path": req.path,
"processor_type": req.processor_type,
"config": req.config,
"overwrite": overwrite,
"max_depth": req.max_depth,
"suffix": suffix,
}
task = await task_queue_service.add_task("process_directory_scan", payload)
return {"task_id": task.id}
@classmethod
async def scan_directory(cls, req: ProcessDirectoryRequest):
if req.max_depth is not None and req.max_depth < 0:
raise HTTPException(400, detail="max_depth must be >= 0")
is_dir = await VirtualFSService.path_is_directory(req.path)
if not is_dir:
raise HTTPException(400, detail="Path must be a directory")
schema = get_config_schema(req.processor_type)
_processor = get(req.processor_type)
if not schema or not _processor:
raise HTTPException(404, detail="Processor not found")
produces_file = bool(schema.get("produces_file"))
raw_suffix = req.suffix if req.suffix is not None else None
if raw_suffix is not None and raw_suffix.strip() == "":
raw_suffix = None
suffix = raw_suffix
overwrite = req.overwrite
if produces_file:
if not overwrite and not suffix:
raise HTTPException(400, detail="Suffix is required when not overwriting files")
@@ -133,7 +171,7 @@ class ProcessorService:
new_name = f"{name}{suffix_str}"
return str(path_obj.with_name(new_name))
scheduled_tasks: List[str] = []
scheduled_count = 0
stack: List[Tuple[str, int]] = [(rel, 0)]
page_size = 200
@@ -161,7 +199,7 @@ class ProcessorService:
save_to = None
if produces_file and not overwrite and suffix:
save_to = apply_suffix(absolute_path, suffix)
task = await task_queue_service.add_task(
await task_queue_service.add_task(
"process_file",
{
"path": absolute_path,
@@ -171,16 +209,13 @@ class ProcessorService:
"overwrite": overwrite,
},
)
scheduled_tasks.append(task.id)
scheduled_count += 1
if total is None or page * page_size >= total:
break
page += 1
return {
"task_ids": scheduled_tasks,
"scheduled": len(scheduled_tasks),
}
return {"scheduled": scheduled_count}
@classmethod
async def get_source(cls, processor_type: str):

View File

@@ -0,0 +1 @@
__all__: list[str] = []

3
domain/role/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
from .service import RoleService
__all__ = ["RoleService"]

119
domain/role/api.py Normal file
View File

@@ -0,0 +1,119 @@
from typing import Annotated
from fastapi import APIRouter, Depends
from domain.auth.service import get_current_active_user
from domain.auth.types import User
from domain.permission import require_system_permission
from domain.permission.types import PathRuleCreate, PathRuleInfo, SystemPermission
from domain.user.service import UserService
from domain.user.types import UserInfo
from .service import RoleService
from .types import RoleCreate, RoleDetail, RoleInfo, RolePermissionsUpdate, RoleUpdate
router = APIRouter(prefix="/api", tags=["role"])
@router.get("/roles", response_model=list[RoleInfo])
@require_system_permission(SystemPermission.ROLE_MANAGE)
async def list_roles(
current_user: Annotated[User, Depends(get_current_active_user)]
) -> list[RoleInfo]:
return await RoleService.get_all_roles()
@router.get("/roles/{role_id}", response_model=RoleDetail)
@require_system_permission(SystemPermission.ROLE_MANAGE)
async def get_role(
role_id: int,
current_user: Annotated[User, Depends(get_current_active_user)],
) -> RoleDetail:
return await RoleService.get_role(role_id)
@router.get("/roles/{role_id}/users", response_model=list[UserInfo])
@require_system_permission(SystemPermission.ROLE_MANAGE)
async def list_role_users(
role_id: int,
current_user: Annotated[User, Depends(get_current_active_user)],
) -> list[UserInfo]:
return await UserService.get_users_by_role(role_id)
@router.post("/roles", response_model=RoleInfo)
@require_system_permission(SystemPermission.ROLE_MANAGE)
async def create_role(
data: RoleCreate,
current_user: Annotated[User, Depends(get_current_active_user)],
) -> RoleInfo:
return await RoleService.create_role(data)
@router.put("/roles/{role_id}", response_model=RoleInfo)
@require_system_permission(SystemPermission.ROLE_MANAGE)
async def update_role(
role_id: int,
data: RoleUpdate,
current_user: Annotated[User, Depends(get_current_active_user)],
) -> RoleInfo:
return await RoleService.update_role(role_id, data)
@router.delete("/roles/{role_id}")
@require_system_permission(SystemPermission.ROLE_MANAGE)
async def delete_role(
role_id: int,
current_user: Annotated[User, Depends(get_current_active_user)],
) -> dict:
await RoleService.delete_role(role_id)
return {"success": True}
@router.post("/roles/{role_id}/permissions", response_model=list[str])
@require_system_permission(SystemPermission.ROLE_MANAGE)
async def set_role_permissions(
role_id: int,
data: RolePermissionsUpdate,
current_user: Annotated[User, Depends(get_current_active_user)],
) -> list[str]:
return await RoleService.set_role_permissions(role_id, data.permission_codes)
@router.get("/roles/{role_id}/path-rules", response_model=list[PathRuleInfo])
@require_system_permission(SystemPermission.ROLE_MANAGE)
async def get_role_path_rules(
role_id: int,
current_user: Annotated[User, Depends(get_current_active_user)],
) -> list[PathRuleInfo]:
return await RoleService.get_role_path_rules(role_id)
@router.post("/roles/{role_id}/path-rules", response_model=PathRuleInfo)
@require_system_permission(SystemPermission.ROLE_MANAGE)
async def add_path_rule(
role_id: int,
data: PathRuleCreate,
current_user: Annotated[User, Depends(get_current_active_user)],
) -> PathRuleInfo:
return await RoleService.add_path_rule(role_id, data)
@router.put("/path-rules/{rule_id}", response_model=PathRuleInfo)
@require_system_permission(SystemPermission.ROLE_MANAGE)
async def update_path_rule(
rule_id: int,
data: PathRuleCreate,
current_user: Annotated[User, Depends(get_current_active_user)],
) -> PathRuleInfo:
return await RoleService.update_path_rule(rule_id, data)
@router.delete("/path-rules/{rule_id}")
@require_system_permission(SystemPermission.ROLE_MANAGE)
async def delete_path_rule(
rule_id: int,
current_user: Annotated[User, Depends(get_current_active_user)],
) -> dict:
await RoleService.delete_path_rule(rule_id)
return {"success": True}

288
domain/role/service.py Normal file
View File

@@ -0,0 +1,288 @@
from typing import List
from fastapi import HTTPException
from models.database import Role, RolePermission, PathRule, UserRole
from domain.permission.service import PermissionService
from domain.permission.types import PathRuleCreate, PathRuleInfo, PERMISSION_DEFINITIONS
from .types import RoleInfo, RoleDetail, RoleCreate, RoleUpdate, SystemRoles
class RoleService:
"""角色管理服务"""
@classmethod
async def get_all_roles(cls) -> List[RoleInfo]:
"""获取所有角色"""
roles = await Role.all().order_by("id")
return [
RoleInfo(
id=r.id,
name=r.name,
description=r.description,
is_system=r.is_system,
created_at=r.created_at,
)
for r in roles
]
@classmethod
async def get_role(cls, role_id: int) -> RoleDetail:
"""获取角色详情"""
role = await Role.get_or_none(id=role_id)
if not role:
raise HTTPException(404, detail="角色不存在")
# 获取权限
role_permissions = await RolePermission.filter(role_id=role_id)
permissions = sorted(set(rp.permission_code for rp in role_permissions))
# 获取路径规则数量
path_rules_count = await PathRule.filter(role_id=role_id).count()
return RoleDetail(
id=role.id,
name=role.name,
description=role.description,
is_system=role.is_system,
created_at=role.created_at,
permissions=permissions,
path_rules_count=path_rules_count,
)
@classmethod
async def create_role(cls, data: RoleCreate) -> RoleInfo:
"""创建角色"""
# 检查名称是否已存在
existing = await Role.get_or_none(name=data.name)
if existing:
raise HTTPException(400, detail="角色名称已存在")
role = await Role.create(
name=data.name,
description=data.description,
is_system=False,
)
return RoleInfo(
id=role.id,
name=role.name,
description=role.description,
is_system=role.is_system,
created_at=role.created_at,
)
@classmethod
async def update_role(cls, role_id: int, data: RoleUpdate) -> RoleInfo:
"""更新角色"""
role = await Role.get_or_none(id=role_id)
if not role:
raise HTTPException(404, detail="角色不存在")
if data.name is not None:
# 检查名称是否与其他角色冲突
existing = await Role.filter(name=data.name).exclude(id=role_id).first()
if existing:
raise HTTPException(400, detail="角色名称已存在")
role.name = data.name
if data.description is not None:
role.description = data.description
await role.save()
return RoleInfo(
id=role.id,
name=role.name,
description=role.description,
is_system=role.is_system,
created_at=role.created_at,
)
@classmethod
async def delete_role(cls, role_id: int) -> None:
"""删除角色"""
role = await Role.get_or_none(id=role_id)
if not role:
raise HTTPException(404, detail="角色不存在")
if role.is_system:
raise HTTPException(400, detail="系统内置角色不可删除")
# 检查是否有用户使用此角色
user_count = await UserRole.filter(role_id=role_id).count()
if user_count > 0:
raise HTTPException(400, detail=f"{user_count} 个用户正在使用此角色,无法删除")
await role.delete()
# 清除权限缓存
PermissionService.clear_cache()
@classmethod
async def set_role_permissions(cls, role_id: int, permission_codes: List[str]) -> List[str]:
"""设置角色的权限"""
role = await Role.get_or_none(id=role_id)
if not role:
raise HTTPException(404, detail="角色不存在")
all_permission_codes = {item["code"] for item in PERMISSION_DEFINITIONS}
invalid_codes = set(permission_codes) - all_permission_codes
if invalid_codes:
raise HTTPException(400, detail=f"无效的权限代码: {', '.join(invalid_codes)}")
# 删除现有权限
await RolePermission.filter(role_id=role_id).delete()
# 添加新权限
for code in permission_codes:
await RolePermission.create(
role_id=role_id,
permission_code=code,
)
# 清除权限缓存
PermissionService.clear_cache()
return list(permission_codes)
@classmethod
async def get_role_path_rules(cls, role_id: int) -> List[PathRuleInfo]:
"""获取角色的路径规则"""
role = await Role.get_or_none(id=role_id)
if not role:
raise HTTPException(404, detail="角色不存在")
rules = await PathRule.filter(role_id=role_id).order_by("-priority", "id")
return [
PathRuleInfo(
id=r.id,
role_id=r.role_id,
path_pattern=r.path_pattern,
is_regex=r.is_regex,
can_read=r.can_read,
can_write=r.can_write,
can_delete=r.can_delete,
can_share=r.can_share,
priority=r.priority,
created_at=r.created_at,
)
for r in rules
]
@classmethod
async def add_path_rule(cls, role_id: int, data: PathRuleCreate) -> PathRuleInfo:
"""添加路径规则"""
role = await Role.get_or_none(id=role_id)
if not role:
raise HTTPException(404, detail="角色不存在")
# 验证路径模式
if data.is_regex:
import re
try:
re.compile(data.path_pattern)
except re.error as e:
raise HTTPException(400, detail=f"无效的正则表达式: {e}")
rule = await PathRule.create(
role_id=role_id,
path_pattern=data.path_pattern,
is_regex=data.is_regex,
can_read=data.can_read,
can_write=data.can_write,
can_delete=data.can_delete,
can_share=data.can_share,
priority=data.priority,
)
# 清除权限缓存
PermissionService.clear_cache()
return PathRuleInfo(
id=rule.id,
role_id=rule.role_id,
path_pattern=rule.path_pattern,
is_regex=rule.is_regex,
can_read=rule.can_read,
can_write=rule.can_write,
can_delete=rule.can_delete,
can_share=rule.can_share,
priority=rule.priority,
created_at=rule.created_at,
)
@classmethod
async def update_path_rule(cls, rule_id: int, data: PathRuleCreate) -> PathRuleInfo:
"""更新路径规则"""
rule = await PathRule.get_or_none(id=rule_id)
if not rule:
raise HTTPException(404, detail="路径规则不存在")
# 验证路径模式
if data.is_regex:
import re
try:
re.compile(data.path_pattern)
except re.error as e:
raise HTTPException(400, detail=f"无效的正则表达式: {e}")
rule.path_pattern = data.path_pattern
rule.is_regex = data.is_regex
rule.can_read = data.can_read
rule.can_write = data.can_write
rule.can_delete = data.can_delete
rule.can_share = data.can_share
rule.priority = data.priority
await rule.save()
# 清除权限缓存
PermissionService.clear_cache()
return PathRuleInfo(
id=rule.id,
role_id=rule.role_id,
path_pattern=rule.path_pattern,
is_regex=rule.is_regex,
can_read=rule.can_read,
can_write=rule.can_write,
can_delete=rule.can_delete,
can_share=rule.can_share,
priority=rule.priority,
created_at=rule.created_at,
)
@classmethod
async def delete_path_rule(cls, rule_id: int) -> None:
"""删除路径规则"""
rule = await PathRule.get_or_none(id=rule_id)
if not rule:
raise HTTPException(404, detail="路径规则不存在")
await rule.delete()
# 清除权限缓存
PermissionService.clear_cache()
@classmethod
async def ensure_system_roles(cls) -> None:
"""确保系统内置角色存在"""
system_roles = [
{
"name": SystemRoles.ADMIN,
"description": "管理员角色,拥有所有系统和适配器权限",
"is_system": True,
},
{
"name": SystemRoles.USER,
"description": "普通用户角色,需要管理员配置路径权限",
"is_system": True,
},
{
"name": SystemRoles.VIEWER,
"description": "只读用户角色,仅可查看文件",
"is_system": True,
},
]
for role_data in system_roles:
existing = await Role.get_or_none(name=role_data["name"])
if not existing:
await Role.create(**role_data)

36
domain/role/types.py Normal file
View File

@@ -0,0 +1,36 @@
from pydantic import BaseModel
from datetime import datetime
class RoleInfo(BaseModel):
id: int
name: str
description: str | None = None
is_system: bool
created_at: datetime
class RoleDetail(RoleInfo):
permissions: list[str] # 权限代码列表
path_rules_count: int
class RoleCreate(BaseModel):
name: str
description: str | None = None
class RoleUpdate(BaseModel):
name: str | None = None
description: str | None = None
class RolePermissionsUpdate(BaseModel):
permission_codes: list[str]
# 预置角色名称
class SystemRoles:
ADMIN = "Admin"
USER = "User"
VIEWER = "Viewer"

10
domain/share/__init__.py Normal file
View File

@@ -0,0 +1,10 @@
from .service import ShareService
from .types import ShareCreate, ShareInfo, ShareInfoWithPassword, SharePassword
__all__ = [
"ShareService",
"ShareCreate",
"ShareInfo",
"ShareInfoWithPassword",
"SharePassword",
]

View File

@@ -4,10 +4,11 @@ from fastapi import APIRouter, Depends, Request
from api.response import success
from domain.audit import AuditAction, audit
from domain.auth.service import get_current_active_user
from domain.auth.types import User
from domain.share.service import ShareService
from domain.share.types import (
from domain.auth import User, get_current_active_user
from domain.permission import require_path_permission
from domain.permission.types import PathAction
from .service import ShareService
from .types import (
ShareCreate,
ShareInfo,
ShareInfoWithPassword,
@@ -25,6 +26,7 @@ router = APIRouter(prefix="/api/shares", tags=["Share - Management"])
description="创建分享链接",
body_fields=["name", "paths", "expires_in_days", "access_type"],
)
@require_path_permission(PathAction.SHARE, "payload.paths")
async def create_share(
request: Request,
payload: ShareCreate,

View File

@@ -7,7 +7,7 @@ import bcrypt
from fastapi import HTTPException, status
from fastapi.responses import Response
from domain.virtual_fs.service import VirtualFSService
from domain.virtual_fs import VirtualFSService
from models.database import ShareLink, UserAccount

26
domain/tasks/__init__.py Normal file
View File

@@ -0,0 +1,26 @@
from .service import TaskService
from .scheduler import task_scheduler
from .task_queue import Task, TaskProgress, TaskStatus, task_queue_service
from .types import (
AutomationTaskBase,
AutomationTaskCreate,
AutomationTaskRead,
AutomationTaskUpdate,
TaskQueueSettings,
TaskQueueSettingsResponse,
)
__all__ = [
"TaskService",
"Task",
"TaskProgress",
"TaskStatus",
"task_queue_service",
"task_scheduler",
"AutomationTaskBase",
"AutomationTaskCreate",
"AutomationTaskRead",
"AutomationTaskUpdate",
"TaskQueueSettings",
"TaskQueueSettingsResponse",
]

View File

@@ -2,9 +2,9 @@ from fastapi import APIRouter, Depends, Request
from api.response import success
from domain.audit import AuditAction, audit
from domain.auth.service import get_current_active_user
from domain.tasks.service import TaskService
from domain.tasks.types import (
from domain.auth import get_current_active_user
from .service import TaskService
from .types import (
AutomationTaskCreate,
AutomationTaskUpdate,
TaskQueueSettings,
@@ -59,8 +59,7 @@ async def get_task_status(task_id: str, request: Request, current_user: CurrentU
body_fields=[
"name",
"event",
"path_pattern",
"filename_regex",
"trigger_config",
"processor_type",
"processor_config",
"enabled",
@@ -93,8 +92,7 @@ async def list_tasks(request: Request, current_user: CurrentUser):
body_fields=[
"name",
"event",
"path_pattern",
"filename_regex",
"trigger_config",
"processor_type",
"processor_config",
"enabled",

102
domain/tasks/scheduler.py Normal file
View File

@@ -0,0 +1,102 @@
import asyncio
from dataclasses import dataclass
from datetime import datetime
from croniter import croniter
from models.database import AutomationTask
from .task_queue import task_queue_service
@dataclass
class CronTaskItem:
task_id: int
processor_type: str
path: str
cron: croniter
next_run: datetime
class AutomationTaskScheduler:
def __init__(self):
self._items: list[CronTaskItem] = []
self._worker: asyncio.Task | None = None
self._reload_event = asyncio.Event()
self._stop_event = asyncio.Event()
async def start(self) -> None:
if self._worker and not self._worker.done():
return
self._stop_event.clear()
await self._load_tasks()
self._worker = asyncio.create_task(self._run_loop())
async def stop(self) -> None:
if not self._worker:
return
self._stop_event.set()
self._reload_event.set()
await self._worker
self._worker = None
def refresh(self) -> None:
if self._worker and not self._worker.done():
self._reload_event.set()
async def _load_tasks(self) -> None:
tasks = await AutomationTask.filter(event="cron", enabled=True)
items: list[CronTaskItem] = []
now = datetime.now()
for task in tasks:
trigger = task.trigger_config or {}
if not isinstance(trigger, dict):
continue
cron_expr = trigger.get("cron_expr")
path = trigger.get("path")
if not cron_expr or not path:
continue
cron = self._build_cron(cron_expr, now)
if not cron:
continue
next_run = cron.get_next(datetime)
items.append(
CronTaskItem(
task_id=task.id,
processor_type=task.processor_type,
path=path,
cron=cron,
next_run=next_run,
)
)
self._items = items
def _build_cron(self, expr: str, base_time: datetime) -> croniter | None:
expr = str(expr or "").strip()
if not expr:
return None
parts = [p for p in expr.split() if p]
if len(parts) not in (5, 6):
return None
second_at_beginning = len(parts) == 6
try:
return croniter(expr, base_time, second_at_beginning=second_at_beginning)
except Exception:
return None
async def _run_loop(self) -> None:
while not self._stop_event.is_set():
if self._reload_event.is_set():
self._reload_event.clear()
await self._load_tasks()
now = datetime.now()
for item in list(self._items):
if item.next_run <= now:
await task_queue_service.add_task(
item.processor_type,
{"task_id": item.task_id, "path": item.path},
)
item.next_run = item.cron.get_next(datetime)
await asyncio.sleep(1)
task_scheduler = AutomationTaskScheduler()

View File

@@ -3,17 +3,17 @@ from typing import Annotated, Any, Dict, Optional
from fastapi import Depends, HTTPException
from domain.auth.service import get_current_active_user
from domain.auth.types import User
from domain.config.service import ConfigService
from domain.tasks.types import (
from domain.auth import User, get_current_active_user
from domain.config import ConfigService
from .scheduler import task_scheduler
from .task_queue import task_queue_service
from .types import (
AutomationTaskCreate,
AutomationTaskUpdate,
TaskQueueSettings,
TaskQueueSettingsResponse,
)
from models.database import AutomationTask
from domain.tasks.task_queue import task_queue_service
class TaskService:
@@ -47,6 +47,7 @@ class TaskService:
@classmethod
async def create_task(cls, payload: AutomationTaskCreate, user: Optional[User]) -> AutomationTask:
task = await AutomationTask.create(**payload.model_dump())
task_scheduler.refresh()
return task
@classmethod
@@ -70,6 +71,7 @@ class TaskService:
for key, value in update_data.items():
setattr(task, key, value)
await task.save()
task_scheduler.refresh()
return task
@classmethod
@@ -77,6 +79,7 @@ class TaskService:
deleted_count = await AutomationTask.filter(id=task_id).delete()
if not deleted_count:
raise HTTPException(status_code=404, detail=f"Task {task_id} not found")
task_scheduler.refresh()
@classmethod
async def trigger_tasks(cls, event: str, path: str):
@@ -87,11 +90,16 @@ class TaskService:
@classmethod
def match(cls, task: AutomationTask, path: str) -> bool:
if task.path_pattern and not path.startswith(task.path_pattern):
trigger_config = task.trigger_config or {}
if not isinstance(trigger_config, dict):
trigger_config = {}
path_prefix = trigger_config.get("path_prefix")
filename_regex = trigger_config.get("filename_regex")
if path_prefix and not path.startswith(path_prefix):
return False
if task.filename_regex:
if filename_regex:
filename = path.split("/")[-1]
if not re.match(task.filename_regex, filename):
if not re.match(filename_regex, filename):
return False
return True

View File

@@ -74,7 +74,7 @@ class TaskQueueService:
try:
# Local import to avoid circular dependency during module load.
from domain.virtual_fs.service import VirtualFSService
from domain.virtual_fs import VirtualFSService
if task.name == "process_file":
params = task.task_info
@@ -86,37 +86,38 @@ class TaskQueueService:
overwrite=params.get("overwrite", False),
)
task.result = result
elif task.name == "process_directory_scan":
from domain.processors import ProcessDirectoryRequest, ProcessorService
params = task.task_info or {}
req = ProcessDirectoryRequest(**params)
task.result = await ProcessorService.scan_directory(req)
elif task.name == "automation_task" or self._is_processor_task(task.name):
from models.database import AutomationTask
from domain.processors.service import get_processor
params = task.task_info
auto_task = await AutomationTask.get(id=params["task_id"])
path = params["path"]
processor_type = auto_task.processor_type if task.name == "automation_task" else task.name
processor = get_processor(processor_type)
if not processor:
raise ValueError(f"Processor {processor_type} not found for task {auto_task.id}")
if processor_type != auto_task.processor_type:
processor_type = auto_task.processor_type
processor = get_processor(processor_type)
if not processor:
raise ValueError(f"Processor {processor_type} not found for task {auto_task.id}")
requires_input_bytes = bool(getattr(processor, "requires_input_bytes", True))
file_content = b""
if requires_input_bytes:
file_content = await VirtualFSService.read_file(path)
result = await processor.process(file_content, path, auto_task.processor_config)
save_to = auto_task.processor_config.get("save_to")
if save_to and getattr(processor, "produces_file", False):
await VirtualFSService.write_file(save_to, result)
processor_type = auto_task.processor_type
config = auto_task.processor_config or {}
save_to = config.get("save_to") if isinstance(config, dict) else None
overwrite = bool(config.get("overwrite")) if isinstance(config, dict) else False
try:
if await VirtualFSService.path_is_directory(path):
overwrite = True
except Exception:
pass
await VirtualFSService.process_file(
path=path,
processor_type=processor_type,
config=config if isinstance(config, dict) else {},
save_to=save_to,
overwrite=overwrite,
)
task.result = "Automation task completed"
elif task.name == "offline_http_download":
from domain.offline_downloads.service import OfflineDownloadService
from domain.offline_downloads import OfflineDownloadService
result_path = await OfflineDownloadService.run_http_download(task)
task.result = {"path": result_path}
@@ -124,12 +125,11 @@ class TaskQueueService:
result = await VirtualFSService.run_cross_mount_transfer_task(task)
task.result = result
elif task.name == "send_email":
from domain.email.service import EmailService
from domain.email import EmailService
await EmailService.send_from_task(task.id, task.task_info)
task.result = "Email sent"
else:
raise ValueError(f"Unknown task name: {task.name}")
task.status = TaskStatus.SUCCESS
except Exception as e:
@@ -141,7 +141,7 @@ class TaskQueueService:
def _is_processor_task(self, task_name: str) -> bool:
try:
from domain.processors.service import get_processor
from domain.processors import get_processor
return get_processor(task_name) is not None
except Exception:
@@ -180,7 +180,7 @@ class TaskQueueService:
async def start_worker(self, concurrency: int | None = None):
if concurrency is None:
from domain.config.service import ConfigService
from domain.config import ConfigService
stored_value = await ConfigService.get("TASK_QUEUE_CONCURRENCY", self._concurrency)
try:

View File

@@ -6,8 +6,7 @@ from pydantic import BaseModel, Field
class AutomationTaskBase(BaseModel):
name: str
event: str
path_pattern: Optional[str] = None
filename_regex: Optional[str] = None
trigger_config: Dict[str, Any] = {}
processor_type: str
processor_config: Dict[str, Any] = {}
enabled: bool = True
@@ -22,6 +21,7 @@ class AutomationTaskUpdate(AutomationTaskBase):
event: Optional[str] = None
processor_type: Optional[str] = None
processor_config: Optional[Dict[str, Any]] = None
trigger_config: Optional[Dict[str, Any]] = None
enabled: Optional[bool] = None

4
domain/user/__init__.py Normal file
View File

@@ -0,0 +1,4 @@
from .service import UserService
__all__ = ["UserService"]

79
domain/user/api.py Normal file
View File

@@ -0,0 +1,79 @@
from typing import Annotated
from fastapi import APIRouter, Depends
from domain.auth.service import get_current_active_user
from domain.auth.types import User
from domain.permission import require_system_permission
from domain.permission.types import SystemPermission
from .service import UserService
from .types import UserCreate, UserDetail, UserInfo, UserRoleAssign, UserUpdate
router = APIRouter(prefix="/api", tags=["user"])
@router.get("/users", response_model=list[UserInfo])
@require_system_permission(SystemPermission.USER_LIST)
async def list_users(
current_user: Annotated[User, Depends(get_current_active_user)]
) -> list[UserInfo]:
return await UserService.get_all_users()
@router.get("/users/{user_id}", response_model=UserDetail)
@require_system_permission(SystemPermission.USER_LIST)
async def get_user(
user_id: int,
current_user: Annotated[User, Depends(get_current_active_user)],
) -> UserDetail:
return await UserService.get_user(user_id)
@router.post("/users", response_model=UserDetail)
@require_system_permission(SystemPermission.USER_CREATE)
async def create_user(
data: UserCreate,
current_user: Annotated[User, Depends(get_current_active_user)],
) -> UserDetail:
return await UserService.create_user(data, current_user.id)
@router.put("/users/{user_id}", response_model=UserDetail)
@require_system_permission(SystemPermission.USER_EDIT)
async def update_user(
user_id: int,
data: UserUpdate,
current_user: Annotated[User, Depends(get_current_active_user)],
) -> UserDetail:
return await UserService.update_user(user_id, data, current_user.id)
@router.delete("/users/{user_id}")
@require_system_permission(SystemPermission.USER_DELETE)
async def delete_user(
user_id: int,
current_user: Annotated[User, Depends(get_current_active_user)],
) -> dict:
await UserService.delete_user(user_id, current_user.id)
return {"success": True}
@router.post("/users/{user_id}/roles", response_model=list[str])
@require_system_permission(SystemPermission.USER_EDIT)
async def set_user_roles(
user_id: int,
data: UserRoleAssign,
current_user: Annotated[User, Depends(get_current_active_user)],
) -> list[str]:
return await UserService.set_user_roles(user_id, data.role_ids)
@router.delete("/users/{user_id}/roles/{role_id}", response_model=list[str])
@require_system_permission(SystemPermission.USER_EDIT)
async def remove_user_role(
user_id: int,
role_id: int,
current_user: Annotated[User, Depends(get_current_active_user)],
) -> list[str]:
return await UserService.remove_user_role(user_id, role_id)

190
domain/user/service.py Normal file
View File

@@ -0,0 +1,190 @@
from typing import List
from fastapi import HTTPException
from domain.auth.service import AuthService
from domain.permission.service import PermissionService
from models.database import Role, UserAccount, UserRole
from .types import UserCreate, UserDetail, UserInfo, UserUpdate
class UserService:
"""用户管理服务"""
@classmethod
async def get_all_users(cls) -> List[UserInfo]:
users = await UserAccount.all().order_by("id")
return [
UserInfo(
id=u.id,
username=u.username,
email=u.email,
full_name=u.full_name,
disabled=u.disabled,
is_admin=u.is_admin,
created_at=u.created_at,
last_login=u.last_login,
)
for u in users
]
@classmethod
async def get_user(cls, user_id: int) -> UserDetail:
user = await UserAccount.get_or_none(id=user_id).prefetch_related("created_by")
if not user:
raise HTTPException(404, detail="用户不存在")
user_roles = await UserRole.filter(user_id=user_id).prefetch_related("role")
roles = [ur.role.name for ur in user_roles]
created_by_username = None
if user.created_by_id:
creator = await UserAccount.get_or_none(id=user.created_by_id)
if creator:
created_by_username = creator.username
return UserDetail(
id=user.id,
username=user.username,
email=user.email,
full_name=user.full_name,
disabled=user.disabled,
is_admin=user.is_admin,
created_at=user.created_at,
last_login=user.last_login,
roles=roles,
created_by_username=created_by_username,
)
@classmethod
async def get_users_by_role(cls, role_id: int) -> List[UserInfo]:
role = await Role.get_or_none(id=role_id)
if not role:
raise HTTPException(404, detail="角色不存在")
user_roles = await UserRole.filter(role_id=role_id).prefetch_related("user")
users = [ur.user for ur in user_roles if ur.user]
users.sort(key=lambda u: u.id)
return [
UserInfo(
id=u.id,
username=u.username,
email=u.email,
full_name=u.full_name,
disabled=u.disabled,
is_admin=u.is_admin,
created_at=u.created_at,
last_login=u.last_login,
)
for u in users
]
@classmethod
async def create_user(cls, data: UserCreate, creator_id: int) -> UserDetail:
existing = await UserAccount.get_or_none(username=data.username)
if existing:
raise HTTPException(400, detail="用户名已存在")
if data.email:
existing_email = await UserAccount.get_or_none(email=data.email)
if existing_email:
raise HTTPException(400, detail="邮箱已被使用")
hashed_password = AuthService.get_password_hash(data.password)
user = await UserAccount.create(
username=data.username,
email=data.email,
full_name=data.full_name,
hashed_password=hashed_password,
disabled=data.disabled,
is_admin=data.is_admin,
created_by_id=creator_id,
)
if data.role_ids:
for role_id in data.role_ids:
role = await Role.get_or_none(id=role_id)
if role:
await UserRole.create(user_id=user.id, role_id=role_id)
return await cls.get_user(user.id)
@classmethod
async def update_user(cls, user_id: int, data: UserUpdate, operator_id: int) -> UserDetail:
user = await UserAccount.get_or_none(id=user_id)
if not user:
raise HTTPException(404, detail="用户不存在")
if data.is_admin is not None and user_id == operator_id:
raise HTTPException(400, detail="不能修改自己的管理员状态")
if data.email is not None:
existing = await UserAccount.filter(email=data.email).exclude(id=user_id).first()
if existing:
raise HTTPException(400, detail="邮箱已被使用")
user.email = data.email
if data.full_name is not None:
user.full_name = data.full_name
if data.password is not None:
user.hashed_password = AuthService.get_password_hash(data.password)
if data.is_admin is not None:
user.is_admin = data.is_admin
if data.disabled is not None:
if user_id == operator_id and data.disabled:
raise HTTPException(400, detail="不能禁用自己")
user.disabled = data.disabled
await user.save()
PermissionService.clear_cache(user_id)
return await cls.get_user(user_id)
@classmethod
async def delete_user(cls, user_id: int, operator_id: int) -> None:
if user_id == operator_id:
raise HTTPException(400, detail="不能删除自己")
user = await UserAccount.get_or_none(id=user_id)
if not user:
raise HTTPException(404, detail="用户不存在")
await UserRole.filter(user_id=user_id).delete()
await user.delete()
PermissionService.clear_cache(user_id)
@classmethod
async def set_user_roles(cls, user_id: int, role_ids: List[int]) -> List[str]:
user = await UserAccount.get_or_none(id=user_id)
if not user:
raise HTTPException(404, detail="用户不存在")
roles = await Role.filter(id__in=role_ids)
valid_role_ids = {r.id for r in roles}
invalid_ids = set(role_ids) - valid_role_ids
if invalid_ids:
raise HTTPException(400, detail=f"无效的角色ID: {invalid_ids}")
await UserRole.filter(user_id=user_id).delete()
for role_id in role_ids:
await UserRole.create(user_id=user_id, role_id=role_id)
PermissionService.clear_cache(user_id)
return [r.name for r in roles if r.id in role_ids]
@classmethod
async def remove_user_role(cls, user_id: int, role_id: int) -> List[str]:
user = await UserAccount.get_or_none(id=user_id)
if not user:
raise HTTPException(404, detail="用户不存在")
await UserRole.filter(user_id=user_id, role_id=role_id).delete()
PermissionService.clear_cache(user_id)
user_roles = await UserRole.filter(user_id=user_id).prefetch_related("role")
return [ur.role.name for ur in user_roles]

42
domain/user/types.py Normal file
View File

@@ -0,0 +1,42 @@
from datetime import datetime
from pydantic import BaseModel
class UserInfo(BaseModel):
id: int
username: str
email: str | None = None
full_name: str | None = None
disabled: bool
is_admin: bool
created_at: datetime
last_login: datetime | None = None
class UserDetail(UserInfo):
roles: list[str]
created_by_username: str | None = None
class UserCreate(BaseModel):
username: str
password: str
email: str | None = None
full_name: str | None = None
is_admin: bool = False
disabled: bool = False
role_ids: list[int] = []
class UserUpdate(BaseModel):
email: str | None = None
full_name: str | None = None
password: str | None = None
is_admin: bool | None = None
disabled: bool | None = None
class UserRoleAssign(BaseModel):
role_ids: list[int]

View File

@@ -0,0 +1,11 @@
from .service import VirtualFSService
from .types import DirListing, MkdirRequest, MoveRequest, SearchResultItem, VfsEntry
__all__ = [
"VirtualFSService",
"DirListing",
"MkdirRequest",
"MoveRequest",
"SearchResultItem",
"VfsEntry",
]

View File

@@ -4,16 +4,18 @@ from fastapi import APIRouter, Depends, File, Query, Request, UploadFile
from api.response import success
from domain.audit import AuditAction, audit
from domain.auth.service import get_current_active_user
from domain.auth.types import User
from domain.virtual_fs.service import VirtualFSService
from domain.virtual_fs.types import MkdirRequest, MoveRequest
from domain.auth import User, get_current_active_user
from domain.permission import require_path_permission
from domain.permission.types import PathAction
from .service import VirtualFSService
from .types import MkdirRequest, MoveRequest
router = APIRouter(prefix="/api/fs", tags=["virtual-fs"])
@router.get("/file/{full_path:path}")
@audit(action=AuditAction.DOWNLOAD, description="获取文件")
@require_path_permission(PathAction.READ, "full_path")
async def get_file(
full_path: str,
request: Request,
@@ -45,6 +47,7 @@ async def stream_endpoint(
@router.get("/temp-link/{full_path:path}")
@audit(action=AuditAction.SHARE, description="创建临时链接")
@require_path_permission(PathAction.READ, "full_path")
async def get_temp_link(
full_path: str,
request: Request,
@@ -64,8 +67,19 @@ async def access_public_file(
return await VirtualFSService.access_public_file(token, request.headers.get("Range"))
@router.get("/public/{token}/{filename}")
@audit(action=AuditAction.DOWNLOAD, description="访问临时链接文件")
async def access_public_file_with_name(
token: str,
filename: str,
request: Request,
):
return await VirtualFSService.access_public_file(token, request.headers.get("Range"))
@router.get("/stat/{full_path:path}")
@audit(action=AuditAction.READ, description="查看文件信息")
@require_path_permission(PathAction.READ, "full_path")
async def get_file_stat(
full_path: str,
request: Request,
@@ -77,6 +91,7 @@ async def get_file_stat(
@router.post("/file/{full_path:path}")
@audit(action=AuditAction.UPLOAD, description="上传文件")
@require_path_permission(PathAction.WRITE, "full_path")
async def put_file(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
@@ -90,6 +105,7 @@ async def put_file(
@router.post("/mkdir")
@audit(action=AuditAction.CREATE, description="创建目录", body_fields=["path"])
@require_path_permission(PathAction.WRITE, "body.path")
async def api_mkdir(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
@@ -101,6 +117,8 @@ async def api_mkdir(
@router.post("/move")
@audit(action=AuditAction.UPDATE, description="移动路径", body_fields=["src", "dst"])
@require_path_permission(PathAction.WRITE, "body.dst")
@require_path_permission(PathAction.DELETE, "body.src")
async def api_move(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
@@ -113,6 +131,7 @@ async def api_move(
@router.post("/rename")
@audit(action=AuditAction.UPDATE, description="重命名路径", body_fields=["src", "dst"])
@require_path_permission(PathAction.WRITE, "body.src")
async def api_rename(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
@@ -125,6 +144,8 @@ async def api_rename(
@router.post("/copy")
@audit(action=AuditAction.CREATE, description="复制路径", body_fields=["src", "dst"])
@require_path_permission(PathAction.WRITE, "body.dst")
@require_path_permission(PathAction.READ, "body.src")
async def api_copy(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
@@ -137,6 +158,7 @@ async def api_copy(
@router.post("/upload/{full_path:path}")
@audit(action=AuditAction.UPLOAD, description="流式上传文件")
@require_path_permission(PathAction.WRITE, "full_path")
async def upload_stream(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
@@ -151,6 +173,7 @@ async def upload_stream(
@router.get("/{full_path:path}")
@audit(action=AuditAction.READ, description="浏览目录")
@require_path_permission(PathAction.READ, "full_path")
async def browse_fs(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
@@ -160,12 +183,15 @@ async def browse_fs(
sort_by: str = Query("name", description="按字段排序: name, size, mtime"),
sort_order: str = Query("asc", description="排序顺序: asc, desc"),
):
data = await VirtualFSService.list_directory(full_path, page_num, page_size, sort_by, sort_order)
data = await VirtualFSService.list_directory_with_permission(
full_path, current_user.id, page_num, page_size, sort_by, sort_order
)
return success(data)
@router.delete("/{full_path:path}")
@audit(action=AuditAction.DELETE, description="删除路径")
@require_path_permission(PathAction.DELETE, "full_path")
async def api_delete(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
@@ -185,5 +211,8 @@ async def root_listing(
sort_by: str = Query("name", description="按字段排序: name, size, mtime"),
sort_order: str = Query("asc", description="排序顺序: asc, desc"),
):
data = await VirtualFSService.list_directory("/", page_num, page_size, sort_by, sort_order)
# 根目录不需要权限检查,但需要过滤无权限的子目录
data = await VirtualFSService.list_directory_with_permission(
"/", current_user.id, page_num, page_size, sort_by, sort_order
)
return success(data)

View File

@@ -4,13 +4,36 @@ from typing import Any, AsyncIterator, Union
from fastapi import HTTPException
from fastapi.responses import Response
from domain.tasks.service import TaskService
from domain.virtual_fs.thumbnail import is_raw_filename, raw_bytes_to_jpeg
from domain.tasks import TaskService
from .thumbnail import is_raw_filename, raw_bytes_to_jpeg
from .listing import VirtualFSListingMixin
class VirtualFSFileOpsMixin(VirtualFSListingMixin):
@classmethod
def _normalize_written_result(
cls,
original_path: str,
adapter_model: Any,
result: Any,
size_hint: int,
) -> tuple[str, int]:
final_path = original_path
size = size_hint
if isinstance(result, dict):
rel_override = result.get("rel")
if isinstance(rel_override, str) and rel_override:
final_path = cls._build_absolute_path(adapter_model.path, rel_override)
else:
path_override = result.get("path")
if isinstance(path_override, str) and path_override:
final_path = cls._normalize_path(path_override)
size_val = result.get("size")
if isinstance(size_val, int):
size = size_val
return final_path, size
@classmethod
async def read_file(cls, path: str) -> Union[bytes, Any]:
adapter_instance, _, root, rel = await cls.resolve_adapter_and_rel(path)
@@ -21,16 +44,18 @@ class VirtualFSFileOpsMixin(VirtualFSListingMixin):
@classmethod
async def write_file(cls, path: str, data: bytes):
adapter_instance, _, root, rel = await cls.resolve_adapter_and_rel(path)
adapter_instance, adapter_model, root, rel = await cls.resolve_adapter_and_rel(path)
if rel.endswith("/"):
raise HTTPException(400, detail="Invalid file path")
write_func = await cls._ensure_method(adapter_instance, "write_file")
await write_func(root, rel, data)
await TaskService.trigger_tasks("file_written", path)
result = await write_func(root, rel, data)
final_path, size = cls._normalize_written_result(path, adapter_model, result, len(data))
await TaskService.trigger_tasks("file_written", final_path)
return {"path": final_path, "size": size}
@classmethod
async def write_file_stream(cls, path: str, data_iter: AsyncIterator[bytes], overwrite: bool = True):
adapter_instance, _, root, rel = await cls.resolve_adapter_and_rel(path)
adapter_instance, adapter_model, root, rel = await cls.resolve_adapter_and_rel(path)
if rel.endswith("/"):
raise HTTPException(400, detail="Invalid file path")
exists_func = getattr(adapter_instance, "exists", None)
@@ -46,18 +71,23 @@ class VirtualFSFileOpsMixin(VirtualFSListingMixin):
size = 0
stream_func = getattr(adapter_instance, "write_file_stream", None)
if callable(stream_func):
size = await stream_func(root, rel, data_iter)
result = await stream_func(root, rel, data_iter)
if isinstance(result, dict):
size = int(result.get("size") or 0)
else:
size = int(result or 0)
else:
buf = bytearray()
async for chunk in data_iter:
if chunk:
buf.extend(chunk)
write_func = await cls._ensure_method(adapter_instance, "write_file")
await write_func(root, rel, bytes(buf))
result = await write_func(root, rel, bytes(buf))
size = len(buf)
await TaskService.trigger_tasks("file_written", path)
return size
final_path, size = cls._normalize_written_result(path, adapter_model, result, size)
await TaskService.trigger_tasks("file_written", final_path)
return {"path": final_path, "size": size}
@classmethod
async def make_dir(cls, path: str):

View File

@@ -3,9 +3,11 @@ from typing import Any, Dict, List, Tuple
from fastapi import HTTPException
from api.response import page
from domain.adapters.registry import runtime_registry
from domain.ai.service import VectorDBService, VECTOR_COLLECTION_NAME, FILE_COLLECTION_NAME
from domain.virtual_fs.thumbnail import is_image_filename, is_video_filename
from domain.adapters import runtime_registry
from domain.ai import FILE_COLLECTION_NAME, VECTOR_COLLECTION_NAME, VectorDBService
from domain.permission.service import PermissionService
from domain.permission.types import PathAction
from .thumbnail import is_image_filename, is_video_filename
from models import StorageAdapter
from .resolver import VirtualFSResolverMixin
@@ -225,7 +227,10 @@ class VirtualFSListingMixin(VirtualFSResolverMixin):
stat_func = getattr(adapter_instance, "stat_file", None)
if not callable(stat_func):
raise HTTPException(501, detail="Adapter does not implement stat_file")
info = await stat_func(root, rel)
try:
info = await stat_func(root, rel)
except FileNotFoundError as exc:
raise HTTPException(404, detail=str(exc))
if isinstance(info, dict):
info.setdefault("path", path)
@@ -242,3 +247,54 @@ class VirtualFSListingMixin(VirtualFSResolverMixin):
info["vector_index"] = vector_index
return info
@classmethod
async def list_virtual_dir_with_permission(
cls,
path: str,
user_id: int,
page_num: int = 1,
page_size: int = 50,
sort_by: str = "name",
sort_order: str = "asc",
) -> Dict:
"""
带权限过滤的目录列表
过滤掉用户没有读取权限的条目
"""
# 首先获取完整的目录列表
result = await cls.list_virtual_dir(path, page_num, page_size, sort_by, sort_order)
# 检查用户是否是管理员(管理员可以看到所有内容)
from models.database import UserAccount
user = await UserAccount.get_or_none(id=user_id)
if user and user.is_admin:
return result
# 过滤无权限的条目
items = result.get("items", [])
if not items:
return result
norm = cls._normalize_path(path).rstrip("/") or "/"
filtered_items = []
for item in items:
item_name = item.get("name", "")
if norm == "/":
item_path = f"/{item_name}"
else:
item_path = f"{norm}/{item_name}"
# 检查用户是否有读取权限
has_permission = await PermissionService.check_path_permission(
user_id, item_path, PathAction.READ
)
if has_permission:
filtered_items.append(item)
# 更新结果
result["items"] = filtered_items
return result

View File

@@ -0,0 +1 @@
__all__: list[str] = []

View File

@@ -15,8 +15,8 @@ from fastapi import APIRouter, Request, Response
from fastapi import HTTPException
from domain.audit import AuditAction, audit
from domain.config.service import ConfigService
from domain.virtual_fs.service import VirtualFSService
from domain.config import ConfigService
from domain.virtual_fs import VirtualFSService
router = APIRouter(prefix="/s3", tags=["s3"])

View File

@@ -9,10 +9,11 @@ from fastapi import APIRouter, Request, Response, HTTPException, Depends
import xml.etree.ElementTree as ET
from domain.audit import AuditAction, audit
from domain.auth.service import AuthService
from domain.auth.types import User, UserInDB
from domain.virtual_fs.service import VirtualFSService
from domain.config.service import ConfigService
from domain.auth import AuthService, User, UserInDB
from domain.config import ConfigService
from domain.permission.service import PermissionService
from domain.permission.types import PathAction
from domain.virtual_fs import VirtualFSService
_WEBDAV_ENABLED_KEY = "WEBDAV_MAPPING_ENABLED"
@@ -66,11 +67,26 @@ async def _get_basic_user(request: Request) -> User:
if not user_or_false:
raise HTTPException(401, detail="Invalid credentials", headers={"WWW-Authenticate": "Basic realm=webdav"})
u: UserInDB = user_or_false
return User(id=u.id, username=u.username, email=u.email, full_name=u.full_name, disabled=u.disabled)
return User(
id=u.id,
username=u.username,
email=u.email,
full_name=u.full_name,
disabled=u.disabled,
is_admin=u.is_admin,
)
elif scheme_lower == "bearer":
if not param:
raise HTTPException(401, detail="Invalid Bearer token")
return User(id=0, username="bearer", email=None, full_name=None, disabled=False)
u = await AuthService.get_current_user(param)
return User(
id=u.id,
username=u.username,
email=u.email,
full_name=u.full_name,
disabled=u.disabled,
is_admin=u.is_admin,
)
else:
raise HTTPException(401, detail="Unsupported auth", headers={"WWW-Authenticate": "Basic realm=webdav"})
@@ -156,6 +172,8 @@ async def propfind(
user: User = Depends(_get_basic_user),
):
full_path = _normalize_fs_path(path)
if full_path != "/":
await PermissionService.require_path_permission(user.id, full_path, PathAction.READ)
depth = request.headers.get("Depth", "1").lower()
if depth not in ("0", "1", "infinity"):
depth = "1"
@@ -172,12 +190,34 @@ async def propfind(
ctype = None if is_dir else (mimetypes.guess_type(name)[0] or "application/octet-stream")
responses.append(_build_prop_response(full_path, name, is_dir, size, mtime, ctype))
except FileNotFoundError:
raise HTTPException(404, detail="Not found")
st = None
except HTTPException as e:
if e.status_code != 404:
raise
st = None
if st is None:
is_mount_root = False
try:
_, rel = await VirtualFSService.resolve_adapter_by_path(full_path)
is_mount_root = rel == ""
except HTTPException:
is_mount_root = False
if not is_mount_root and full_path != "/":
listing_probe = await VirtualFSService.list_virtual_dir(full_path, page_num=1, page_size=1)
if not (listing_probe.get("items") or []):
raise HTTPException(404, detail="Not found")
name = "/" if full_path == "/" else (full_path.rstrip("/").rsplit("/", 1)[-1] or "/")
responses.append(_build_prop_response(full_path, name, True, None, 0, None))
if depth in ("1", "infinity"):
try:
listing = await VirtualFSService.list_virtual_dir(full_path, page_num=1, page_size=1000)
for ent in listing["items"]:
listing = await VirtualFSService.list_virtual_dir_with_permission(
full_path, user.id, page_num=1, page_size=1000
)
for ent in (listing.get("items") or []):
is_dir = bool(ent.get("is_dir"))
name = ent.get("name")
child_path = full_path.rstrip("/") + "/" + name
@@ -204,6 +244,8 @@ async def dav_get(
user: User = Depends(_get_basic_user),
):
full_path = _normalize_fs_path(path)
if full_path != "/":
await PermissionService.require_path_permission(user.id, full_path, PathAction.READ)
range_header = request.headers.get("Range")
return await VirtualFSService.stream_file(full_path, range_header)
@@ -217,6 +259,8 @@ async def dav_head(
user: User = Depends(_get_basic_user),
):
full_path = _normalize_fs_path(path)
if full_path != "/":
await PermissionService.require_path_permission(user.id, full_path, PathAction.READ)
try:
st = await VirtualFSService.stat_file(full_path)
except FileNotFoundError:
@@ -245,6 +289,7 @@ async def dav_put(
user: User = Depends(_get_basic_user),
):
full_path = _normalize_fs_path(path)
await PermissionService.require_path_permission(user.id, full_path, PathAction.WRITE)
async def body_iter():
async for chunk in request.stream():
if chunk:
@@ -262,6 +307,7 @@ async def dav_delete(
user: User = Depends(_get_basic_user),
):
full_path = _normalize_fs_path(path)
await PermissionService.require_path_permission(user.id, full_path, PathAction.DELETE)
await VirtualFSService.delete_path(full_path)
return Response(status_code=204, headers=_dav_headers())
@@ -275,6 +321,7 @@ async def dav_mkcol(
user: User = Depends(_get_basic_user),
):
full_path = _normalize_fs_path(path)
await PermissionService.require_path_permission(user.id, full_path, PathAction.WRITE)
await VirtualFSService.make_dir(full_path)
return Response(status_code=201, headers=_dav_headers())
@@ -303,6 +350,8 @@ async def dav_move(
dest_header = request.headers.get("Destination")
dst = _parse_destination(dest_header or "")
overwrite = request.headers.get("Overwrite", "T").upper() != "F"
await PermissionService.require_path_permission(user.id, full_src, PathAction.DELETE)
await PermissionService.require_path_permission(user.id, dst, PathAction.WRITE)
await VirtualFSService.move_path(full_src, dst, overwrite=overwrite)
return Response(status_code=204, headers=_dav_headers())
@@ -319,5 +368,7 @@ async def dav_copy(
dest_header = request.headers.get("Destination")
dst = _parse_destination(dest_header or "")
overwrite = request.headers.get("Overwrite", "T").upper() != "F"
await PermissionService.require_path_permission(user.id, full_src, PathAction.READ)
await PermissionService.require_path_permission(user.id, dst, PathAction.WRITE)
await VirtualFSService.copy_path(full_src, dst, overwrite=overwrite)
return Response(status_code=201 if not overwrite else 204, headers=_dav_headers())

View File

@@ -16,7 +16,7 @@ class VirtualFSProcessingMixin(VirtualFSTransferMixin):
save_to: str | None = None,
overwrite: bool = False,
) -> Any:
from domain.processors.service import get_processor
from domain.processors import get_processor
processor = get_processor(processor_type)
if not processor:

View File

@@ -3,7 +3,7 @@ from typing import Tuple
from fastapi import HTTPException
from fastapi.responses import Response
from domain.adapters.registry import runtime_registry
from domain.adapters import runtime_registry
from models import StorageAdapter
from .common import VirtualFSCommonMixin

View File

@@ -1,11 +1,13 @@
import mimetypes
import re
from urllib.parse import quote
from fastapi import HTTPException, UploadFile
from fastapi.responses import Response
from domain.config.service import ConfigService
from domain.virtual_fs.thumbnail import (
from domain.config import ConfigService
from domain.tasks import TaskService
from .thumbnail import (
get_or_create_thumb,
is_image_filename,
is_raw_filename,
@@ -112,12 +114,14 @@ class VirtualFSRouteMixin(VirtualFSTempLinkMixin):
async def create_temp_link(cls, full_path: str, expires_in: int):
full_path = cls._normalize_path(full_path)
token = await cls.generate_temp_link_token(full_path, expires_in=expires_in)
filename = full_path.rstrip("/").split("/")[-1]
filename_part = f"/{quote(filename, safe='')}" if filename else ""
file_domain = await ConfigService.get("FILE_DOMAIN")
if file_domain:
file_domain = file_domain.rstrip("/")
url = f"{file_domain}/api/fs/public/{token}"
url = f"{file_domain}/api/fs/public/{token}{filename_part}"
else:
url = f"/api/fs/public/{token}"
url = f"/api/fs/public/{token}{filename_part}"
return {"token": token, "path": full_path, "url": url}
@classmethod
@@ -128,12 +132,17 @@ class VirtualFSRouteMixin(VirtualFSTempLinkMixin):
raise exc
try:
return await cls.stream_file(path, range_header)
response = await cls.stream_file(path, range_header)
except FileNotFoundError:
raise HTTPException(404, detail="File not found via token")
except Exception as exc:
raise HTTPException(500, detail=f"File access error: {exc}")
filename = path.rstrip("/").split("/")[-1]
if filename and not response.headers.get("Content-Disposition"):
response.headers["Content-Disposition"] = f"inline; filename*=UTF-8''{quote(filename, safe='')}"
return response
@classmethod
async def stat(cls, full_path: str):
full_path = cls._normalize_path(full_path)
@@ -142,8 +151,15 @@ class VirtualFSRouteMixin(VirtualFSTempLinkMixin):
@classmethod
async def write_uploaded_file(cls, full_path: str, data: bytes):
full_path = cls._normalize_path(full_path)
await cls.write_file(full_path, data)
return {"written": True, "path": full_path, "size": len(data)}
result = await cls.write_file(full_path, data)
path = full_path
size = len(data)
if isinstance(result, dict):
path = result.get("path") or path
size_val = result.get("size")
if isinstance(size_val, int):
size = size_val
return {"written": True, "path": path, "size": size}
@classmethod
async def mkdir(cls, path: str):
@@ -201,7 +217,7 @@ class VirtualFSRouteMixin(VirtualFSTempLinkMixin):
full_path = cls._normalize_path(full_path)
if full_path.endswith("/"):
raise HTTPException(400, detail="Path must be a file")
adapter, _m, root, rel = await cls.resolve_adapter_and_rel(full_path)
adapter, adapter_model, root, rel = await cls.resolve_adapter_and_rel(full_path)
exists_func = getattr(adapter, "exists", None)
if not overwrite and callable(exists_func):
try:
@@ -212,6 +228,21 @@ class VirtualFSRouteMixin(VirtualFSTempLinkMixin):
except Exception:
pass
upload_func = getattr(adapter, "write_upload_file", None)
if callable(upload_func):
try:
await file.seek(0)
except Exception:
pass
size_hint = getattr(file, "size", None)
if not isinstance(size_hint, int):
size_hint = None
filename = file.filename or (rel.rsplit("/", 1)[-1] if rel else "file")
result = await upload_func(root, rel, file.file, filename, size_hint, file.content_type)
final_path, size = cls._normalize_written_result(full_path, adapter_model, result, size_hint or 0)
await TaskService.trigger_tasks("file_written", final_path)
return {"uploaded": True, "path": final_path, "size": size, "overwrite": overwrite}
async def gen():
while True:
chunk = await file.read(chunk_size)
@@ -219,8 +250,17 @@ class VirtualFSRouteMixin(VirtualFSTempLinkMixin):
break
yield chunk
size = await cls.write_file_stream(full_path, gen(), overwrite=overwrite)
return {"uploaded": True, "path": full_path, "size": size, "overwrite": overwrite}
result = await cls.write_file_stream(full_path, gen(), overwrite=overwrite)
path = full_path
size = 0
if isinstance(result, dict):
path = result.get("path") or path
size_val = result.get("size")
if isinstance(size_val, int):
size = size_val
else:
size = int(result or 0)
return {"uploaded": True, "path": path, "size": size, "overwrite": overwrite}
@classmethod
async def list_directory(cls, full_path: str, page_num: int, page_size: int, sort_by: str, sort_order: str):

View File

@@ -0,0 +1,3 @@
from .search_service import VirtualFSSearchService
__all__ = ["VirtualFSSearchService"]

Some files were not shown because too many files have changed in this diff Show More