mirror of
https://github.com/DrizzleTime/Foxel.git
synced 2026-05-08 19:22:42 +08:00
Compare commits
206 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
984b7a74ae | ||
|
|
97a3c58f0f | ||
|
|
451e8555d5 | ||
|
|
f444ec46cc | ||
|
|
103beb7dad | ||
|
|
c5e4b3ef43 | ||
|
|
4014a4dd74 | ||
|
|
d0c6e1882f | ||
|
|
434715fc8b | ||
|
|
a127987a3f | ||
|
|
edf95e897d | ||
|
|
b72f8152b6 | ||
|
|
aacddb1208 | ||
|
|
1d6d793f7a | ||
|
|
d9d2ddf2d1 | ||
|
|
e6ab01ef9d | ||
|
|
4a2e01196d | ||
|
|
f22ca62902 | ||
|
|
a394ffa46b | ||
|
|
d003e53a3a | ||
|
|
060a427fe4 | ||
|
|
f4c18f991f | ||
|
|
58c2cdd440 | ||
|
|
7d861ca5f7 | ||
|
|
52bac11760 | ||
|
|
c441d8776f | ||
|
|
45e0194465 | ||
|
|
540065f195 | ||
|
|
4f86e2da4d | ||
|
|
31d347d24f | ||
|
|
7a9a20509c | ||
|
|
373b6410c2 | ||
|
|
d6eb6e1605 | ||
|
|
1d66fb56c8 | ||
|
|
bb9589fa62 | ||
|
|
ab89451b2d | ||
|
|
3e1b75d81a | ||
|
|
1679b03d3a | ||
|
|
ab6562fc79 | ||
|
|
87770176b6 | ||
|
|
e7cf8dbdb8 | ||
|
|
e7eafdee97 | ||
|
|
051b49d3f6 | ||
|
|
b059b0eb44 | ||
|
|
59ad2cb622 | ||
|
|
6b2ada0b42 | ||
|
|
a727e77341 | ||
|
|
4638356a45 | ||
|
|
e51344b43e | ||
|
|
b7685db0e8 | ||
|
|
4e16de973c | ||
|
|
4dd0a4b1d6 | ||
|
|
5703825c31 | ||
|
|
24255744df | ||
|
|
31d97b2968 | ||
|
|
35abd080be | ||
|
|
2fa93a1eeb | ||
|
|
ff7eb13187 | ||
|
|
ed9090c3d0 | ||
|
|
d430254868 | ||
|
|
a8870f80da | ||
|
|
14ef2a4ccc | ||
|
|
dd41941b04 | ||
|
|
01a259bae0 | ||
|
|
ef5ef2730c | ||
|
|
8b8772b064 | ||
|
|
5393a973eb | ||
|
|
cc1f130099 | ||
|
|
c8b3817805 | ||
|
|
b1ea181f96 | ||
|
|
078709b871 | ||
|
|
d788bde44f | ||
|
|
28ede26801 | ||
|
|
53130383c1 | ||
|
|
036eeb92c2 | ||
|
|
5701a13f4f | ||
|
|
184997deed | ||
|
|
1d5824d498 | ||
|
|
91ff1860b7 | ||
|
|
56f947d0bf | ||
|
|
ad016baaf9 | ||
|
|
ad2e2858da | ||
|
|
a69d6c21a6 | ||
|
|
2a4a3c44b9 | ||
|
|
cdb8543370 | ||
|
|
2dabe9255f | ||
|
|
239216e574 | ||
|
|
09c65bffb7 | ||
|
|
ff1c06ad18 | ||
|
|
d88e95a9af | ||
|
|
ae80a751a8 | ||
|
|
b40e700a64 | ||
|
|
040d8346b3 | ||
|
|
55d062f0a7 | ||
|
|
cfaaff8a8c | ||
|
|
d6d41333fd | ||
|
|
a4efba94d5 | ||
|
|
00e6419b12 | ||
|
|
bbe8465aa0 | ||
|
|
baadaa70a7 | ||
|
|
e7e34cda54 | ||
|
|
adb80d0a6c | ||
|
|
bcd4ae7aef | ||
|
|
1ef80a087c | ||
|
|
f503d521e6 | ||
|
|
7c38c0045b | ||
|
|
b582a89d08 | ||
|
|
4ea0b9884a | ||
|
|
dfeec58ed9 | ||
|
|
e2f0037053 | ||
|
|
e34ee6f70d | ||
|
|
0f856bb5b7 | ||
|
|
3b4b01a18d | ||
|
|
2e1f76d0bc | ||
|
|
18ed7dcee1 | ||
|
|
5c3ab65cee | ||
|
|
1ddd2e464c | ||
|
|
aeb7cf75a1 | ||
|
|
648fd51d26 | ||
|
|
98c7b3af9b | ||
|
|
fc3b6a9d70 | ||
|
|
1c0fc24cfa | ||
|
|
5127d9f0fc | ||
|
|
ba1feb150b | ||
|
|
6a1ff3afa6 | ||
|
|
724f551b00 | ||
|
|
8cf147bf34 | ||
|
|
c2a473fac9 | ||
|
|
aaae37e7cb | ||
|
|
78de3b46be | ||
|
|
388ddfd869 | ||
|
|
18f59f8d33 | ||
|
|
b319b545fc | ||
|
|
0fcb3b8ce0 | ||
|
|
686202a0dd | ||
|
|
1cda987723 | ||
|
|
49a4300fc3 | ||
|
|
d7260e8863 | ||
|
|
62d0316d48 | ||
|
|
fc85f21aaa | ||
|
|
16283dea09 | ||
|
|
055c240079 | ||
|
|
12a3bb8efc | ||
|
|
050577cf62 | ||
|
|
394c2f7229 | ||
|
|
8f515aaaf4 | ||
|
|
cf8d10f71c | ||
|
|
5c4d3a625b | ||
|
|
f0a51c3369 | ||
|
|
3278896d4b | ||
|
|
219f3e81b8 | ||
|
|
8ef0a34642 | ||
|
|
8aaa2900ef | ||
|
|
e3e68f5397 | ||
|
|
78dfbac458 | ||
|
|
583db651a7 | ||
|
|
3a15362422 | ||
|
|
e55a09d84f | ||
|
|
8957174e6f | ||
|
|
abb6b0ce22 | ||
|
|
74df438053 | ||
|
|
f271a8bee5 | ||
|
|
17236e601f | ||
|
|
71e5f84eb7 | ||
|
|
4e724b9c4a | ||
|
|
ba62bd0d4a | ||
|
|
138296e5a6 | ||
|
|
51326dea08 | ||
|
|
ac6d8ff7ad | ||
|
|
029aa2574d | ||
|
|
eeb0e6aa70 | ||
|
|
d1ceb7ddba | ||
|
|
63b54458e9 | ||
|
|
f7e6815265 | ||
|
|
4d6e0b86ad | ||
|
|
77a4749fec | ||
|
|
8eaa025f7e | ||
|
|
11799cd97c | ||
|
|
c14224827d | ||
|
|
130a304f25 | ||
|
|
bc595310a6 | ||
|
|
bf83187d8c | ||
|
|
02cc31d296 | ||
|
|
c66ca181c6 | ||
|
|
5815e6a545 | ||
|
|
7cf335ab19 | ||
|
|
36365d7410 | ||
|
|
90ddeef027 | ||
|
|
8ac3acebb4 | ||
|
|
5625f2d8bf | ||
|
|
7f33eb85ba | ||
|
|
0da64b8d9c | ||
|
|
7caa602d93 | ||
|
|
a4af9475ef | ||
|
|
ee6e570ccb | ||
|
|
ce45fca8bd | ||
|
|
77058f3535 | ||
|
|
738f3c9718 | ||
|
|
f3d9220569 | ||
|
|
da41393db3 | ||
|
|
0399011406 | ||
|
|
00462f2259 | ||
|
|
f0892ebcd6 | ||
|
|
cf5f19043b | ||
|
|
6444ed264c | ||
|
|
bed8c8b19c |
1
.github/FUNDING.yml
vendored
Normal file
1
.github/FUNDING.yml
vendored
Normal file
@@ -0,0 +1 @@
|
||||
custom: https://foxel.cc/sponsor
|
||||
75
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
Normal file
75
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
Normal file
@@ -0,0 +1,75 @@
|
||||
name: Bug Report / 缺陷报告
|
||||
description: Report reproducible defects with clear context / 请提供可复现的缺陷信息
|
||||
title: "[Bug] "
|
||||
labels:
|
||||
- bug
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
Thanks for helping us improve Foxel! / 感谢你帮助改进 Foxel!
|
||||
Please confirm the checklist below before filing. / 在提交前请确认以下事项。
|
||||
- type: checkboxes
|
||||
id: validations
|
||||
attributes:
|
||||
label: Pre-flight Check / 提交前检查
|
||||
options:
|
||||
- label: I searched existing issues and docs / 我已搜索现有 Issue 与文档
|
||||
required: true
|
||||
- label: This is not a question or feature request / 这不是问题咨询或功能需求
|
||||
required: true
|
||||
- type: textarea
|
||||
id: summary
|
||||
attributes:
|
||||
label: Bug Summary / 缺陷摘要
|
||||
description: Briefly describe what is wrong / 简要说明出现了什么问题
|
||||
placeholder: e.g. Upload fails with 500 error / 例如:上传时报 500 错误
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: steps
|
||||
attributes:
|
||||
label: Steps to Reproduce / 复现步骤
|
||||
description: List numbered steps to trigger the bug / 列出触发问题的步骤
|
||||
placeholder: |
|
||||
1. ...
|
||||
2. ...
|
||||
3. ...
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: expected
|
||||
attributes:
|
||||
label: Expected Behavior / 预期行为
|
||||
description: What should happen instead? / 期望看到什么结果?
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: actual
|
||||
attributes:
|
||||
label: Actual Behavior / 实际行为
|
||||
description: What actually happens? Include messages or screenshots / 实际发生了什么?可附报错或截图
|
||||
validations:
|
||||
required: true
|
||||
- type: input
|
||||
id: version
|
||||
attributes:
|
||||
label: Version / 版本信息
|
||||
description: Git commit, tag, or build number / 提供 Git 提交、标签或构建号
|
||||
validations:
|
||||
required: false
|
||||
- type: textarea
|
||||
id: environment
|
||||
attributes:
|
||||
label: Environment / 运行环境
|
||||
description: OS, browser, API server config, etc. / 操作系统、浏览器、服务端配置等
|
||||
validations:
|
||||
required: false
|
||||
- type: textarea
|
||||
id: logs
|
||||
attributes:
|
||||
label: Logs & Attachments / 日志与附件
|
||||
description: Paste relevant logs, stack traces, screenshots / 粘贴相关日志、堆栈或截图
|
||||
render: shell
|
||||
validations:
|
||||
required: false
|
||||
56
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
Normal file
56
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
Normal file
@@ -0,0 +1,56 @@
|
||||
name: Feature Request / 功能需求
|
||||
description: Suggest enhancements or new capabilities / 提出改进或新增能力
|
||||
title: "[Feature] "
|
||||
labels:
|
||||
- enhancement
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
Tell us about your idea! / 欢迎分享你的想法!
|
||||
Please complete the sections below so we can evaluate it quickly. / 请完整填写以下信息,便于快速评估。
|
||||
- type: checkboxes
|
||||
id: prechecks
|
||||
attributes:
|
||||
label: Pre-flight Check / 提交前检查
|
||||
options:
|
||||
- label: I searched existing issues and roadmap / 我已搜索现有 Issue 与路线图
|
||||
required: true
|
||||
- label: This is not a bug report or question / 这不是缺陷或问题咨询
|
||||
required: true
|
||||
- type: textarea
|
||||
id: summary
|
||||
attributes:
|
||||
label: Feature Summary / 功能概述
|
||||
description: What do you want to build? / 希望新增什么能力?
|
||||
placeholder: e.g. Support sharing download links / 例如:支持分享下载链接
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: motivation
|
||||
attributes:
|
||||
label: Motivation / 背景与价值
|
||||
description: Why is this feature important? Who benefits? / 为什么重要?受益者是谁?
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: scope
|
||||
attributes:
|
||||
label: Proposed Solution / 建议方案
|
||||
description: Outline how the feature might work, including API or UI hints / 描述可能的实现方式,包含 API 或 UI 提示
|
||||
validations:
|
||||
required: false
|
||||
- type: textarea
|
||||
id: alternatives
|
||||
attributes:
|
||||
label: Alternatives / 可选方案
|
||||
description: List any alternatives considered / 如有考虑过其他方案请列出
|
||||
validations:
|
||||
required: false
|
||||
- type: textarea
|
||||
id: extra
|
||||
attributes:
|
||||
label: Additional Context / 补充信息
|
||||
description: Diagrams, sketches, links, constraints, etc. / 可附上草图、链接或约束
|
||||
validations:
|
||||
required: false
|
||||
42
.github/ISSUE_TEMPLATE/question.yml
vendored
Normal file
42
.github/ISSUE_TEMPLATE/question.yml
vendored
Normal file
@@ -0,0 +1,42 @@
|
||||
name: Question / 问题咨询
|
||||
description: Ask about usage, configuration, or clarification / 用于使用、配置或澄清问题
|
||||
title: "[Question] "
|
||||
labels:
|
||||
- question
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
Need help? You're in the right place. / 需要帮助?请按以下提示填写。
|
||||
Check the docs before filing. / 提交前请先查阅文档。
|
||||
- type: checkboxes
|
||||
id: prechecks
|
||||
attributes:
|
||||
label: Pre-flight Check / 提交前检查
|
||||
options:
|
||||
- label: I searched existing issues and discussions / 我已搜索现有 Issue 和讨论
|
||||
required: true
|
||||
- label: I read the relevant documentation / 我已阅读相关文档
|
||||
required: true
|
||||
- type: textarea
|
||||
id: question
|
||||
attributes:
|
||||
label: Question Details / 问题详情
|
||||
description: What do you need help with? Be specific. / 具体说明需要帮助的内容
|
||||
placeholder: Describe the scenario, expectation, and blockers / 说明场景、期望结果与阻碍
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: tried
|
||||
attributes:
|
||||
label: What You Tried / 已尝试方案
|
||||
description: List commands, configs, or steps attempted / 列出尝试过的命令、配置或步骤
|
||||
validations:
|
||||
required: false
|
||||
- type: textarea
|
||||
id: context
|
||||
attributes:
|
||||
label: Additional Context / 补充信息
|
||||
description: Environment details, logs, screenshots / 可补充运行环境、日志或截图
|
||||
validations:
|
||||
required: false
|
||||
16
.github/dependabot.yml
vendored
Normal file
16
.github/dependabot.yml
vendored
Normal file
@@ -0,0 +1,16 @@
|
||||
version: 2
|
||||
updates:
|
||||
- package-ecosystem: "github-actions"
|
||||
directory: "/"
|
||||
schedule:
|
||||
interval: "monthly"
|
||||
|
||||
- package-ecosystem: "bun"
|
||||
directory: "/web"
|
||||
schedule:
|
||||
interval: "monthly"
|
||||
|
||||
- package-ecosystem: "uv"
|
||||
directory: "/"
|
||||
schedule:
|
||||
interval: "monthly"
|
||||
51
.github/workflows/docker-clean.yml
vendored
Normal file
51
.github/workflows/docker-clean.yml
vendored
Normal file
@@ -0,0 +1,51 @@
|
||||
name: Clean dangling Docker images
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
docker-clean:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
|
||||
steps:
|
||||
- name: Delete untagged GHCR versions
|
||||
shell: bash
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
OWNER="${GITHUB_REPOSITORY_OWNER}"
|
||||
PACKAGE="$(echo "${GITHUB_REPOSITORY##*/}" | tr '[:upper:]' '[:lower:]')"
|
||||
|
||||
OWNER_TYPE="$(gh api "/users/${OWNER}" -q '.type')"
|
||||
if [[ "${OWNER_TYPE}" == "Organization" ]]; then
|
||||
SCOPE="orgs/${OWNER}"
|
||||
else
|
||||
SCOPE="users/${OWNER}"
|
||||
fi
|
||||
|
||||
BASE_PATH="/${SCOPE}/packages/container/${PACKAGE}"
|
||||
|
||||
if ! gh api "${BASE_PATH}" >/dev/null 2>&1; then
|
||||
echo "Package ghcr.io/${OWNER}/${PACKAGE} not found or accessible. Nothing to clean."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
mapfile -t VERSION_IDS < <(gh api --paginate "${BASE_PATH}/versions?per_page=100" \
|
||||
-q '.[] | select(.metadata.container.tags | length == 0) | .id')
|
||||
|
||||
if [[ ${#VERSION_IDS[@]} -eq 0 ]]; then
|
||||
echo "No untagged versions to delete."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "Deleting ${#VERSION_IDS[@]} untagged versions from ghcr.io/${OWNER}/${PACKAGE}..."
|
||||
for id in "${VERSION_IDS[@]}"; do
|
||||
gh api -X DELETE "${BASE_PATH}/versions/${id}" >/dev/null
|
||||
echo "Deleted version ${id}"
|
||||
done
|
||||
|
||||
echo "Cleanup complete."
|
||||
8
.github/workflows/docker.yml
vendored
8
.github/workflows/docker.yml
vendored
@@ -2,6 +2,8 @@ name: Build and Push Docker image
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
tags:
|
||||
- 'v*.*.*'
|
||||
workflow_dispatch:
|
||||
@@ -15,7 +17,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
@@ -43,9 +45,9 @@ jobs:
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Build and push Docker image (multi arch)
|
||||
uses: docker/build-push-action@v5
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: ${{ env.DOCKER_TAGS }}
|
||||
tags: ${{ env.DOCKER_TAGS }}
|
||||
|
||||
2
.github/workflows/release-drafter.yml
vendored
2
.github/workflows/release-drafter.yml
vendored
@@ -10,7 +10,7 @@ jobs:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
steps:
|
||||
- uses: release-drafter/release-drafter@v5
|
||||
- uses: release-drafter/release-drafter@v6
|
||||
with:
|
||||
config-name: release-drafter.yml
|
||||
env:
|
||||
|
||||
28
.gitignore
vendored
28
.gitignore
vendored
@@ -5,6 +5,30 @@ __pycache__/
|
||||
.venv/
|
||||
.vscode/
|
||||
data/
|
||||
migrate/
|
||||
.env
|
||||
AGENTS.md
|
||||
AGENTS.md
|
||||
|
||||
# Logs
|
||||
/web/logs
|
||||
*.log
|
||||
npm-debug.log*
|
||||
yarn-debug.log*
|
||||
yarn-error.log*
|
||||
pnpm-debug.log*
|
||||
lerna-debug.log*
|
||||
|
||||
/web/node_modules
|
||||
/web/dist
|
||||
/web/dist-ssr
|
||||
/web/*.local
|
||||
|
||||
# Editor directories and files
|
||||
.vscode/*
|
||||
!.vscode/extensions.json
|
||||
.idea
|
||||
.DS_Store
|
||||
*.suo
|
||||
*.ntvs*
|
||||
*.njsproj
|
||||
*.sln
|
||||
*.sw?
|
||||
@@ -1 +1 @@
|
||||
3.13
|
||||
3.14
|
||||
|
||||
@@ -137,8 +137,8 @@ Install the following tooling first:
|
||||
|
||||
Storage adapters integrate new storage providers (for example S3, FTP, or Alist).
|
||||
|
||||
1. Create a new module under [`services/adapters/`](services/adapters/) (for example `my_new_adapter.py`).
|
||||
2. Implement a class that inherits from [`services.adapters.base.BaseAdapter`](services/adapters/base.py) and provide concrete implementations for the abstract methods such as `list_dir`, `get_meta`, `upload`, and `download`.
|
||||
1. Create a new module under [`domain/adapters/providers/`](domain/adapters/providers/) (for example `my_new_adapter.py`).
|
||||
2. Implement a class that inherits from [`domain.adapters.providers.base.BaseAdapter`](domain/adapters/providers/base.py) and provide concrete implementations for the abstract methods such as `list_dir`, `get_meta`, `upload`, and `download`.
|
||||
|
||||
### Frontend Apps
|
||||
|
||||
|
||||
@@ -143,9 +143,9 @@
|
||||
|
||||
存储适配器是 Foxel 的核心扩展点,用于接入不同的存储后端 (如 S3, FTP, Alist 等)。
|
||||
|
||||
1. **创建适配器文件**: 在 [`services/adapters/`](services/adapters/) 目录下,创建一个新文件,例如 `my_new_adapter.py`。
|
||||
1. **创建适配器文件**: 在 [`domain/adapters/providers/`](domain/adapters/providers/) 目录下,创建一个新文件,例如 `my_new_adapter.py`。
|
||||
2. **实现适配器类**:
|
||||
- 创建一个类,继承自 [`services.adapters.base.BaseAdapter`](services/adapters/base.py)。
|
||||
- 创建一个类,继承自 [`domain.adapters.providers.base.BaseAdapter`](domain/adapters/providers/base.py)。
|
||||
- 实现 `BaseAdapter` 中定义的所有抽象方法,如 `list_dir`, `get_meta`, `upload`, `download` 等。请仔细阅读基类中的文档注释以理解每个方法的作用和参数。
|
||||
|
||||
### 贡献前端应用 (App)
|
||||
|
||||
23
Dockerfile
23
Dockerfile
@@ -9,30 +9,37 @@ COPY web/ ./
|
||||
|
||||
RUN bun run build
|
||||
|
||||
FROM python:3.13-slim
|
||||
FROM python:3.14-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
RUN apt-get update && apt-get install -y nginx git && rm -rf /var/lib/apt/lists/*
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends ffmpeg curl ca-certificates \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN pip install uv
|
||||
COPY pyproject.toml uv.lock ./
|
||||
RUN uv pip install --system . gunicorn
|
||||
RUN uv pip install --system . gunicorn \
|
||||
&& rm -rf /root/.cache
|
||||
|
||||
RUN git clone https://github.com/DrizzleTime/FoxelUpgrade /app/migrate
|
||||
RUN curl -L https://github.com/DrizzleTime/FoxelUpgrade/archive/refs/heads/main.tar.gz -o /tmp/migrate.tgz \
|
||||
&& mkdir -p /app/migrate \
|
||||
&& tar -xzf /tmp/migrate.tgz --strip-components=1 -C /app/migrate \
|
||||
&& rm -rf /tmp/migrate.tgz
|
||||
|
||||
COPY --from=frontend-builder /app/web/dist /app/web/dist
|
||||
|
||||
COPY . .
|
||||
|
||||
COPY nginx.conf /etc/nginx/nginx.conf
|
||||
|
||||
RUN mkdir -p data/db data/mount && \
|
||||
chmod 777 data/db data/mount
|
||||
chmod 777 data/db data/mount && \
|
||||
chmod +x setup/foxel_cli.py && \
|
||||
ln -sf /app/setup/foxel_cli.py /usr/local/bin/foxel && \
|
||||
rm -rf /var/log/apt /var/cache/apt/archives
|
||||
|
||||
EXPOSE 80
|
||||
|
||||
COPY entrypoint.sh /entrypoint.sh
|
||||
RUN chmod +x /entrypoint.sh
|
||||
|
||||
CMD ["/entrypoint.sh"]
|
||||
CMD ["/entrypoint.sh"]
|
||||
|
||||
45
README.md
45
README.md
@@ -8,15 +8,17 @@
|
||||
|
||||
**A highly extensible private cloud storage solution for individuals and teams, featuring AI-powered semantic search.**
|
||||
|
||||

|
||||

|
||||

|
||||

|
||||
|
||||

|
||||
|
||||
---
|
||||
<blockquote>
|
||||
<em><strong>The ocean of data is boundless, let the eye of insight guide the voyage, yet its intricate connections lie deep, not fully discernible from the surface.</strong></em>
|
||||
</blockquote>
|
||||
<img src="https://foxel.cc/image/ad-min-en.png" alt="UI Screenshot">
|
||||
</div>
|
||||
|
||||
## 👀 Online Demo
|
||||
@@ -38,36 +40,37 @@
|
||||
|
||||
Using Docker Compose is the most recommended way to start Foxel.
|
||||
|
||||
1. **Create Data Directories**:
|
||||
Create a `data` folder for persistent data:
|
||||
1. **Create Data Directories**
|
||||
|
||||
```bash
|
||||
mkdir -p data/db
|
||||
mkdir -p data/mount
|
||||
chmod 777 data/db data/mount
|
||||
```
|
||||
Create a `data` folder for persistent data:
|
||||
|
||||
2. **Download Docker Compose File**:
|
||||
```bash
|
||||
mkdir -p data/db
|
||||
mkdir -p data/mount
|
||||
chmod 777 data/db data/mount
|
||||
```
|
||||
|
||||
```bash
|
||||
curl -L -O https://github.com/DrizzleTime/Foxel/raw/main/compose.yaml
|
||||
```
|
||||
2. **Download Docker Compose File**
|
||||
|
||||
After downloading, it is **strongly recommended** to modify the environment variables in the `compose.yaml` file to ensure security:
|
||||
```bash
|
||||
curl -L -O https://github.com/DrizzleTime/Foxel/raw/main/compose.yaml
|
||||
```
|
||||
|
||||
- Modify `SECRET_KEY` and `TEMP_LINK_SECRET_KEY`: Replace the default keys with randomly generated strong keys.
|
||||
After downloading, it is **strongly recommended** to modify the environment variables in the `compose.yaml` file to ensure security:
|
||||
|
||||
3. **Start the Services**:
|
||||
- Modify `SECRET_KEY` and `TEMP_LINK_SECRET_KEY`: Replace the default keys with randomly generated strong keys.
|
||||
|
||||
```bash
|
||||
docker-compose up -d
|
||||
```
|
||||
3. **Start the Services**
|
||||
|
||||
4. **Access the Application**:
|
||||
```bash
|
||||
docker-compose up -d
|
||||
```
|
||||
|
||||
Once the services are running, open the page in your browser.
|
||||
4. **Access the Application**
|
||||
|
||||
> On the first launch, please follow the setup guide to initialize the administrator account.
|
||||
Once the services are running, open the page in your browser.
|
||||
|
||||
> On the first launch, please follow the setup guide to initialize the administrator account.
|
||||
|
||||
## 🤝 How to Contribute
|
||||
|
||||
|
||||
46
README_zh.md
46
README_zh.md
@@ -8,16 +8,17 @@
|
||||
|
||||
**一个面向个人和团队的、高度可扩展的私有云盘解决方案,支持 AI 语义搜索。**
|
||||
|
||||

|
||||

|
||||

|
||||

|
||||
|
||||

|
||||
|
||||
---
|
||||
<blockquote>
|
||||
<em><strong>数据之洋浩瀚无涯,当以洞察之目引航,然其脉络深隐,非表象所能尽窥。</strong></em><br>
|
||||
<em><strong>The ocean of data is boundless, let the eye of insight guide the voyage, yet its intricate connections lie deep, not fully discernible from the surface.</strong></em>
|
||||
</blockquote>
|
||||
<img src="https://foxel.cc/image/ad-min-zh.png" alt="UI Screenshot">
|
||||
</div>
|
||||
|
||||
## 👀 在线体验
|
||||
@@ -39,36 +40,37 @@
|
||||
|
||||
使用 Docker Compose 是启动 Foxel 最推荐的方式。
|
||||
|
||||
1. **创建数据目录**:
|
||||
新建 `data` 文件夹用于持久化数据:
|
||||
1. **创建数据目录**
|
||||
|
||||
```bash
|
||||
mkdir -p data/db
|
||||
mkdir -p data/mount
|
||||
chmod 777 data/db data/mount
|
||||
```
|
||||
新建 `data` 文件夹用于持久化数据:
|
||||
|
||||
2. **下载 Docker Compose 文件**:
|
||||
```bash
|
||||
mkdir -p data/db
|
||||
mkdir -p data/mount
|
||||
chmod 777 data/db data/mount
|
||||
```
|
||||
|
||||
```bash
|
||||
curl -L -O https://github.com/DrizzleTime/Foxel/raw/main/compose.yaml
|
||||
```
|
||||
2. **下载 Docker Compose 文件**
|
||||
|
||||
下载完成后,**强烈建议**修改 `compose.yaml` 文件中的环境变量以确保安全:
|
||||
```bash
|
||||
curl -L -O https://github.com/DrizzleTime/Foxel/raw/main/compose.yaml
|
||||
```
|
||||
|
||||
- 修改 `SECRET_KEY` 和 `TEMP_LINK_SECRET_KEY`:将默认的密钥替换为随机生成的强密钥
|
||||
下载完成后,**强烈建议**修改 `compose.yaml` 文件中的环境变量以确保安全:
|
||||
|
||||
3. **启动服务**:
|
||||
- 修改 `SECRET_KEY` 和 `TEMP_LINK_SECRET_KEY`:将默认的密钥替换为随机生成的强密钥
|
||||
|
||||
```bash
|
||||
docker-compose up -d
|
||||
```
|
||||
3. **启动服务**
|
||||
|
||||
4. **访问应用**:
|
||||
```bash
|
||||
docker-compose up -d
|
||||
```
|
||||
|
||||
服务启动后,在浏览器中打开页面。
|
||||
4. **访问应用**
|
||||
|
||||
> 首次启动,请根据引导页面完成管理员账号的初始化设置。
|
||||
服务启动后,在浏览器中打开页面。
|
||||
|
||||
> 首次启动,请根据引导页面完成管理员账号的初始化设置。
|
||||
|
||||
## 🤝 如何贡献
|
||||
|
||||
|
||||
@@ -1,23 +1,46 @@
|
||||
from fastapi import FastAPI
|
||||
|
||||
from .routes import adapters, virtual_fs, auth, config, processors, tasks, logs, share, backup, search, vector_db, offline_downloads
|
||||
from .routes import webdav
|
||||
from .routes import plugins
|
||||
from domain.adapters import api as adapters
|
||||
from domain.auth import api as auth
|
||||
from domain.backup import api as backup
|
||||
from domain.config import api as config
|
||||
from domain.email import api as email
|
||||
from domain.offline_downloads import api as offline_downloads
|
||||
from domain.plugins import api as plugins
|
||||
from domain.processors import api as processors
|
||||
from domain.share import api as share
|
||||
from domain.tasks import api as tasks
|
||||
from domain.ai import api as ai
|
||||
from domain.agent import api as agent
|
||||
from domain.virtual_fs import api as virtual_fs
|
||||
from domain.virtual_fs.mapping import s3_api, webdav_api
|
||||
from domain.virtual_fs.search import search_api
|
||||
from domain.audit import api as audit
|
||||
from domain.permission import api as permission
|
||||
from domain.user import api as user
|
||||
from domain.role import api as role
|
||||
|
||||
|
||||
def include_routers(app: FastAPI):
|
||||
app.include_router(adapters.router)
|
||||
app.include_router(search_api.router)
|
||||
app.include_router(virtual_fs.router)
|
||||
app.include_router(search.router)
|
||||
app.include_router(auth.router)
|
||||
app.include_router(config.router)
|
||||
app.include_router(processors.router)
|
||||
app.include_router(tasks.router)
|
||||
app.include_router(logs.router)
|
||||
app.include_router(share.router)
|
||||
app.include_router(share.public_router)
|
||||
app.include_router(backup.router)
|
||||
app.include_router(vector_db.router)
|
||||
app.include_router(ai.router_vector_db)
|
||||
app.include_router(ai.router_ai)
|
||||
app.include_router(agent.router)
|
||||
app.include_router(plugins.router)
|
||||
app.include_router(webdav.router)
|
||||
app.include_router(webdav_api.router)
|
||||
app.include_router(s3_api.router)
|
||||
app.include_router(offline_downloads.router)
|
||||
app.include_router(email.router)
|
||||
app.include_router(audit.router)
|
||||
app.include_router(permission.router)
|
||||
app.include_router(user.router)
|
||||
app.include_router(role.router)
|
||||
|
||||
@@ -1,149 +0,0 @@
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from tortoise.transactions import in_transaction
|
||||
from typing import Annotated
|
||||
|
||||
from models import StorageAdapter
|
||||
from schemas import AdapterCreate, AdapterOut
|
||||
from services.auth import get_current_active_user, User
|
||||
from services.adapters.registry import runtime_registry, get_config_schemas
|
||||
from api.response import success
|
||||
from services.logging import LogService
|
||||
|
||||
router = APIRouter(prefix="/api/adapters", tags=["adapters"])
|
||||
|
||||
|
||||
def validate_and_normalize_config(adapter_type: str, cfg):
|
||||
schemas = get_config_schemas()
|
||||
if not isinstance(cfg, dict):
|
||||
raise HTTPException(400, detail="config 必须是对象")
|
||||
schema = schemas.get(adapter_type)
|
||||
if not schema:
|
||||
raise HTTPException(400, detail=f"不支持的适配器类型: {adapter_type}")
|
||||
out = {}
|
||||
missing = []
|
||||
for f in schema:
|
||||
k = f["key"]
|
||||
if k in cfg and cfg[k] not in (None, ""):
|
||||
out[k] = cfg[k]
|
||||
elif "default" in f:
|
||||
out[k] = f["default"]
|
||||
elif f.get("required"):
|
||||
missing.append(k)
|
||||
if missing:
|
||||
raise HTTPException(400, detail="缺少必填配置字段: " + ", ".join(missing))
|
||||
return out
|
||||
|
||||
|
||||
@router.post("")
|
||||
async def create_adapter(
|
||||
data: AdapterCreate,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
norm_path = AdapterCreate.normalize_mount_path(data.path)
|
||||
exists = await StorageAdapter.get_or_none(path=norm_path)
|
||||
if exists:
|
||||
raise HTTPException(400, detail="Mount path already exists")
|
||||
|
||||
adapter_fields = {
|
||||
"name": data.name,
|
||||
"type": data.type,
|
||||
"config": validate_and_normalize_config(data.type, data.config or {}),
|
||||
"enabled": data.enabled,
|
||||
"path": norm_path,
|
||||
"sub_path": data.sub_path,
|
||||
}
|
||||
|
||||
rec = await StorageAdapter.create(**adapter_fields)
|
||||
await runtime_registry.upsert(rec)
|
||||
await LogService.action(
|
||||
"route:adapters",
|
||||
f"Created adapter {rec.name}",
|
||||
details=adapter_fields,
|
||||
user_id=current_user.id if hasattr(current_user, "id") else None,
|
||||
)
|
||||
return success(rec)
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_adapters(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
adapters = await StorageAdapter.all()
|
||||
out = [AdapterOut.model_validate(a) for a in adapters]
|
||||
return success(out)
|
||||
|
||||
|
||||
@router.get("/available")
|
||||
async def available_adapter_types(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
data = []
|
||||
for t, fields in get_config_schemas().items():
|
||||
data.append({
|
||||
"type": t,
|
||||
"name": "本地文件系统" if t == "local" else ("WebDAV" if t == "webdav" else t),
|
||||
"config_schema": fields,
|
||||
})
|
||||
return success(data)
|
||||
|
||||
|
||||
@router.get("/{adapter_id}")
|
||||
async def get_adapter(
|
||||
adapter_id: int,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
rec = await StorageAdapter.get_or_none(id=adapter_id)
|
||||
if not rec:
|
||||
raise HTTPException(404, detail="Not found")
|
||||
return success(AdapterOut.model_validate(rec))
|
||||
|
||||
|
||||
@router.put("/{adapter_id}")
|
||||
async def update_adapter(
|
||||
adapter_id: int,
|
||||
data: AdapterCreate,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
rec = await StorageAdapter.get_or_none(id=adapter_id)
|
||||
if not rec:
|
||||
raise HTTPException(404, detail="Not found")
|
||||
|
||||
norm_path = AdapterCreate.normalize_mount_path(data.path)
|
||||
existing = await StorageAdapter.get_or_none(path=norm_path)
|
||||
if existing and existing.id != adapter_id:
|
||||
raise HTTPException(400, detail="Mount path already exists")
|
||||
|
||||
rec.name = data.name
|
||||
rec.type = data.type
|
||||
rec.config = validate_and_normalize_config(data.type, data.config or {})
|
||||
rec.enabled = data.enabled
|
||||
rec.path = norm_path
|
||||
rec.sub_path = data.sub_path
|
||||
await rec.save()
|
||||
|
||||
await runtime_registry.upsert(rec)
|
||||
await LogService.action(
|
||||
"route:adapters",
|
||||
f"Updated adapter {rec.name}",
|
||||
details=data.model_dump(),
|
||||
user_id=current_user.id if hasattr(current_user, "id") else None,
|
||||
)
|
||||
return success(rec)
|
||||
|
||||
|
||||
@router.delete("/{adapter_id}")
|
||||
async def delete_adapter(
|
||||
adapter_id: int,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
deleted = await StorageAdapter.filter(id=adapter_id).delete()
|
||||
if not deleted:
|
||||
raise HTTPException(404, detail="Not found")
|
||||
runtime_registry.remove(adapter_id)
|
||||
await LogService.action(
|
||||
"route:adapters",
|
||||
f"Deleted adapter {adapter_id}",
|
||||
details={"adapter_id": adapter_id},
|
||||
user_id=current_user.id if hasattr(current_user, "id") else None,
|
||||
)
|
||||
return success({"deleted": True})
|
||||
@@ -1,122 +0,0 @@
|
||||
from typing import Annotated
|
||||
from fastapi import APIRouter, HTTPException, Depends, Form
|
||||
import hashlib
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from services.auth import (
|
||||
authenticate_user_db,
|
||||
create_access_token,
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES,
|
||||
register_user,
|
||||
Token,
|
||||
get_current_active_user,
|
||||
User,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
from datetime import timedelta
|
||||
from api.response import success
|
||||
from models.database import UserAccount
|
||||
from services.auth import verify_password, get_password_hash
|
||||
|
||||
router = APIRouter(prefix="/api/auth", tags=["auth"])
|
||||
|
||||
|
||||
class RegisterRequest(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
email: str | None = None
|
||||
full_name: str | None = None
|
||||
|
||||
|
||||
@router.post("/register", summary="注册第一个管理员用户")
|
||||
async def register(data: RegisterRequest):
|
||||
"""
|
||||
仅当系统中没有用户时,才允许注册。
|
||||
"""
|
||||
user = await register_user(
|
||||
username=data.username,
|
||||
password=data.password,
|
||||
email=data.email,
|
||||
full_name=data.full_name,
|
||||
)
|
||||
return success({"username": user.username}, msg="初始用户注册成功")
|
||||
|
||||
|
||||
@router.post("/login")
|
||||
async def login_for_access_token(
|
||||
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
||||
) -> Token:
|
||||
user = await authenticate_user_db(form_data.username, form_data.password)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="用户名或密码错误",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
access_token = await create_access_token(
|
||||
data={"sub": user.username}, expires_delta=access_token_expires
|
||||
)
|
||||
return Token(access_token=access_token, token_type="bearer")
|
||||
|
||||
|
||||
@router.get("/me", summary="获取当前登录用户信息")
|
||||
async def get_me(current_user: Annotated[User, Depends(get_current_active_user)]):
|
||||
"""
|
||||
返回当前登录用户的基本信息,并附带 gravatar 头像链接。
|
||||
"""
|
||||
email = (current_user.email or "").strip().lower()
|
||||
md5_hash = hashlib.md5(email.encode("utf-8")).hexdigest()
|
||||
gravatar_url = f"https://www.gravatar.com/avatar/{md5_hash}?s=64&d=identicon"
|
||||
return success({
|
||||
"id": current_user.id,
|
||||
"username": current_user.username,
|
||||
"email": current_user.email,
|
||||
"full_name": current_user.full_name,
|
||||
"gravatar_url": gravatar_url,
|
||||
})
|
||||
|
||||
|
||||
class UpdateMeRequest(BaseModel):
|
||||
email: str | None = None
|
||||
full_name: str | None = None
|
||||
old_password: str | None = None
|
||||
new_password: str | None = None
|
||||
|
||||
|
||||
@router.put("/me", summary="更新当前登录用户信息")
|
||||
async def update_me(
|
||||
payload: UpdateMeRequest,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
db_user = await UserAccount.get_or_none(id=current_user.id)
|
||||
if not db_user:
|
||||
raise HTTPException(status_code=404, detail="用户不存在")
|
||||
|
||||
if payload.email is not None:
|
||||
exists = await UserAccount.filter(email=payload.email).exclude(id=db_user.id).exists()
|
||||
if exists:
|
||||
raise HTTPException(status_code=400, detail="邮箱已被占用")
|
||||
db_user.email = payload.email
|
||||
|
||||
if payload.full_name is not None:
|
||||
db_user.full_name = payload.full_name
|
||||
|
||||
if payload.new_password:
|
||||
if not payload.old_password:
|
||||
raise HTTPException(status_code=400, detail="请提供原密码")
|
||||
if not verify_password(payload.old_password, db_user.hashed_password):
|
||||
raise HTTPException(status_code=400, detail="原密码错误")
|
||||
db_user.hashed_password = get_password_hash(payload.new_password)
|
||||
|
||||
await db_user.save()
|
||||
|
||||
email = (db_user.email or "").strip().lower()
|
||||
md5_hash = hashlib.md5(email.encode("utf-8")).hexdigest()
|
||||
gravatar_url = f"https://cn.cravatar.com/avatar/{md5_hash}?s=64&d=identicon"
|
||||
return success({
|
||||
"id": db_user.id,
|
||||
"username": db_user.username,
|
||||
"email": db_user.email,
|
||||
"full_name": db_user.full_name,
|
||||
"gravatar_url": gravatar_url,
|
||||
})
|
||||
@@ -1,50 +0,0 @@
|
||||
from fastapi import APIRouter, Depends, UploadFile, File, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from services.auth import get_current_active_user
|
||||
from services.backup import BackupService
|
||||
from models.database import UserAccount
|
||||
import json
|
||||
import datetime
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api/backup",
|
||||
tags=["Backup & Restore"],
|
||||
dependencies=[Depends(get_current_active_user)],
|
||||
)
|
||||
|
||||
@router.get("/export", summary="导出全站数据")
|
||||
async def export_backup():
|
||||
"""
|
||||
生成并下载一个包含所有关键数据的JSON文件。
|
||||
"""
|
||||
try:
|
||||
data = await BackupService.export_data()
|
||||
timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
headers = {
|
||||
"Content-Disposition": f"attachment; filename=foxel_backup_{timestamp}.json"
|
||||
}
|
||||
return JSONResponse(content=data, headers=headers)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.post("/import", summary="导入数据")
|
||||
async def import_backup(file: UploadFile = File(...)):
|
||||
"""
|
||||
从上传的JSON文件恢复数据。
|
||||
**警告**: 这将会覆盖所有现有数据!
|
||||
"""
|
||||
|
||||
if not file.filename.endswith(".json"):
|
||||
raise HTTPException(status_code=400, detail="无效的文件类型, 请上传 .json 文件")
|
||||
|
||||
try:
|
||||
contents = await file.read()
|
||||
data = json.loads(contents)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=400, detail="无法解析JSON文件")
|
||||
|
||||
try:
|
||||
await BackupService.import_data(data)
|
||||
return {"message": "数据导入成功。"}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"导入失败: {e}")
|
||||
@@ -1,100 +0,0 @@
|
||||
import httpx
|
||||
import time
|
||||
from fastapi import APIRouter, Depends, Form, HTTPException
|
||||
from typing import Annotated
|
||||
from services.config import ConfigCenter, VERSION
|
||||
from services.auth import get_current_active_user, User, has_users
|
||||
from api.response import success
|
||||
from services.vector_db import VectorDBService
|
||||
router = APIRouter(prefix="/api/config", tags=["config"])
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def get_config(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
key: str
|
||||
):
|
||||
value = await ConfigCenter.get(key)
|
||||
return success({"key": key, "value": value})
|
||||
|
||||
|
||||
@router.post("/")
|
||||
async def set_config(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
key: str = Form(...),
|
||||
value: str = Form(...)
|
||||
):
|
||||
original_value = await ConfigCenter.get(key)
|
||||
value_to_save = value
|
||||
if key == "AI_EMBED_DIM":
|
||||
try:
|
||||
parsed_value = int(value)
|
||||
except (TypeError, ValueError):
|
||||
raise HTTPException(status_code=400, detail="AI_EMBED_DIM must be an integer")
|
||||
if parsed_value <= 0:
|
||||
raise HTTPException(status_code=400, detail="AI_EMBED_DIM must be greater than zero")
|
||||
value_to_save = str(parsed_value)
|
||||
|
||||
await ConfigCenter.set(key, value_to_save)
|
||||
|
||||
if key == "AI_EMBED_DIM" and str(original_value) != value_to_save:
|
||||
try:
|
||||
service = VectorDBService()
|
||||
await service.clear_all_data()
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to clear vector database: {exc}")
|
||||
|
||||
return success({"key": key, "value": value_to_save})
|
||||
|
||||
|
||||
@router.get("/all")
|
||||
async def get_all_config(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
configs = await ConfigCenter.get_all()
|
||||
return success(configs)
|
||||
|
||||
|
||||
@router.get("/status")
|
||||
async def get_system_status():
|
||||
system_info = {
|
||||
"version": VERSION,
|
||||
"title": await ConfigCenter.get("APP_NAME", "Foxel"),
|
||||
"logo": await ConfigCenter.get("APP_LOGO", "/logo.svg"),
|
||||
"is_initialized": await has_users(),
|
||||
"app_domain": await ConfigCenter.get("APP_DOMAIN"),
|
||||
"file_domain": await ConfigCenter.get("FILE_DOMAIN"),
|
||||
}
|
||||
return success(system_info)
|
||||
|
||||
|
||||
latest_version_cache = {
|
||||
"timestamp": 0,
|
||||
"data": None
|
||||
}
|
||||
|
||||
|
||||
@router.get("/latest-version")
|
||||
async def get_latest_version():
|
||||
current_time = time.time()
|
||||
if current_time - latest_version_cache["timestamp"] < 3600 and latest_version_cache["data"]:
|
||||
return success(latest_version_cache["data"])
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.get(
|
||||
"https://api.github.com/repos/DrizzleTime/Foxel/releases/latest",
|
||||
follow_redirects=True,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
version_info = {
|
||||
"latest_version": data.get("tag_name"),
|
||||
"body": data.get("body")
|
||||
}
|
||||
latest_version_cache["timestamp"] = current_time
|
||||
latest_version_cache["data"] = version_info
|
||||
return success(version_info)
|
||||
except httpx.RequestError as e:
|
||||
if latest_version_cache["data"]:
|
||||
return success(latest_version_cache["data"])
|
||||
return success({"latest_version": None, "body": None})
|
||||
@@ -1,48 +0,0 @@
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, Query
|
||||
from models.database import Log
|
||||
from api.response import page, success
|
||||
from tortoise.expressions import Q
|
||||
from datetime import datetime
|
||||
|
||||
router = APIRouter(prefix="/api/logs", tags=["Logs"])
|
||||
|
||||
@router.get("")
|
||||
async def get_logs(
|
||||
page_num: int = Query(1, alias="page"),
|
||||
page_size: int = Query(20, alias="page_size"),
|
||||
level: Optional[str] = Query(None),
|
||||
source: Optional[str] = Query(None),
|
||||
start_time: Optional[datetime] = Query(None),
|
||||
end_time: Optional[datetime] = Query(None),
|
||||
):
|
||||
"""获取日志列表,支持分页和筛选"""
|
||||
query = Log.all()
|
||||
if level:
|
||||
query = query.filter(level=level)
|
||||
if source:
|
||||
query = query.filter(source__icontains=source)
|
||||
if start_time:
|
||||
query = query.filter(timestamp__gte=start_time)
|
||||
if end_time:
|
||||
query = query.filter(timestamp__lte=end_time)
|
||||
|
||||
total = await query.count()
|
||||
logs = await query.order_by("-timestamp").offset((page_num - 1) * page_size).limit(page_size)
|
||||
|
||||
return success(page([log for log in logs], total, page_num, page_size))
|
||||
|
||||
@router.delete("")
|
||||
async def clear_logs(
|
||||
start_time: Optional[datetime] = Query(None),
|
||||
end_time: Optional[datetime] = Query(None),
|
||||
):
|
||||
"""清理指定时间范围内的日志"""
|
||||
query = Log.all()
|
||||
if start_time:
|
||||
query = query.filter(timestamp__gte=start_time)
|
||||
if end_time:
|
||||
query = query.filter(timestamp__lte=end_time)
|
||||
|
||||
deleted_count = await query.delete()
|
||||
return success({"deleted_count": deleted_count})
|
||||
@@ -1,79 +0,0 @@
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from api.response import success
|
||||
from schemas.offline_downloads import OfflineDownloadCreate
|
||||
from services.auth import User, get_current_active_user
|
||||
from services.logging import LogService
|
||||
from services.task_queue import task_queue_service, TaskProgress
|
||||
from services.virtual_fs import path_is_directory
|
||||
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api/offline-downloads",
|
||||
tags=["OfflineDownloads"],
|
||||
)
|
||||
|
||||
|
||||
@router.post("/")
|
||||
async def create_offline_download(
|
||||
payload: OfflineDownloadCreate,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
dest_dir = payload.dest_dir
|
||||
try:
|
||||
is_dir = await path_is_directory(dest_dir)
|
||||
except HTTPException:
|
||||
is_dir = False
|
||||
if not is_dir:
|
||||
raise HTTPException(400, detail="Destination directory not found")
|
||||
|
||||
task = await task_queue_service.add_task(
|
||||
"offline_http_download",
|
||||
{
|
||||
"url": str(payload.url),
|
||||
"dest_dir": dest_dir,
|
||||
"filename": payload.filename,
|
||||
},
|
||||
)
|
||||
|
||||
await task_queue_service.update_progress(
|
||||
task.id,
|
||||
TaskProgress(
|
||||
stage="queued",
|
||||
percent=0.0,
|
||||
bytes_total=None,
|
||||
bytes_done=0,
|
||||
detail="Waiting to start",
|
||||
),
|
||||
)
|
||||
|
||||
await LogService.action(
|
||||
"route:offline_downloads",
|
||||
f"Offline download task created {task.id}",
|
||||
details={"url": str(payload.url), "dest_dir": dest_dir, "filename": payload.filename},
|
||||
user_id=current_user.id if hasattr(current_user, "id") else None,
|
||||
)
|
||||
|
||||
return success({"task_id": task.id})
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def list_offline_downloads(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
tasks = [t for t in task_queue_service.get_all_tasks() if t.name == "offline_http_download"]
|
||||
data = [t.dict() for t in tasks]
|
||||
return success(data)
|
||||
|
||||
|
||||
@router.get("/{task_id}")
|
||||
async def get_offline_download(
|
||||
task_id: str,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
task = task_queue_service.get_task(task_id)
|
||||
if not task or task.name != "offline_http_download":
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
return success(task.dict())
|
||||
@@ -1,73 +0,0 @@
|
||||
from typing import List, Any, Dict
|
||||
from fastapi import APIRouter, HTTPException, Body
|
||||
from models import database
|
||||
from schemas import PluginCreate, PluginOut
|
||||
|
||||
router = APIRouter(prefix="/api/plugins", tags=["plugins"])
|
||||
|
||||
|
||||
@router.post("", response_model=PluginOut)
|
||||
async def create_plugin(payload: PluginCreate):
|
||||
rec = await database.Plugin.create(
|
||||
url=payload.url,
|
||||
enabled=payload.enabled,
|
||||
)
|
||||
return PluginOut.model_validate(rec)
|
||||
|
||||
|
||||
@router.get("", response_model=List[PluginOut])
|
||||
async def list_plugins():
|
||||
rows = await database.Plugin.all().order_by("-id")
|
||||
return [PluginOut.model_validate(r) for r in rows]
|
||||
|
||||
|
||||
@router.delete("/{plugin_id}")
|
||||
async def delete_plugin(plugin_id: int):
|
||||
rec = await database.Plugin.get_or_none(id=plugin_id)
|
||||
if not rec:
|
||||
raise HTTPException(status_code=404, detail="Plugin not found")
|
||||
await rec.delete()
|
||||
return {"code": 0, "msg": "ok"}
|
||||
|
||||
|
||||
@router.put("/{plugin_id}", response_model=PluginOut)
|
||||
async def update_plugin(plugin_id: int, payload: PluginCreate):
|
||||
rec = await database.Plugin.get_or_none(id=plugin_id)
|
||||
if not rec:
|
||||
raise HTTPException(status_code=404, detail="Plugin not found")
|
||||
rec.url = payload.url
|
||||
rec.enabled = payload.enabled
|
||||
await rec.save()
|
||||
return PluginOut.model_validate(rec)
|
||||
|
||||
|
||||
@router.post("/{plugin_id}/metadata", response_model=PluginOut)
|
||||
async def update_manifest(plugin_id: int, manifest: Dict[str, Any] = Body(...)):
|
||||
rec = await database.Plugin.get_or_none(id=plugin_id)
|
||||
if not rec:
|
||||
raise HTTPException(status_code=404, detail="Plugin not found")
|
||||
key_map = {
|
||||
'key': 'key',
|
||||
'name': 'name',
|
||||
'version': 'version',
|
||||
'supported_exts': 'supported_exts',
|
||||
'supportedExts': 'supported_exts',
|
||||
'default_bounds': 'default_bounds',
|
||||
'defaultBounds': 'default_bounds',
|
||||
'default_maximized': 'default_maximized',
|
||||
'defaultMaximized': 'default_maximized',
|
||||
'icon': 'icon',
|
||||
'description': 'description',
|
||||
'author': 'author',
|
||||
'website': 'website',
|
||||
'github': 'github',
|
||||
}
|
||||
for k, v in list(manifest.items()):
|
||||
if v is None:
|
||||
continue
|
||||
attr = key_map.get(k)
|
||||
if not attr:
|
||||
continue
|
||||
setattr(rec, attr, v)
|
||||
await rec.save()
|
||||
return PluginOut.model_validate(rec)
|
||||
@@ -1,116 +0,0 @@
|
||||
from pathlib import Path
|
||||
from fastapi import APIRouter, Depends, Body, HTTPException
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from typing import Annotated
|
||||
from services.processors.registry import (
|
||||
get_config_schemas,
|
||||
get_module_path,
|
||||
reload_processors,
|
||||
)
|
||||
from services.task_queue import task_queue_service
|
||||
from services.auth import get_current_active_user, User
|
||||
from api.response import success
|
||||
from pydantic import BaseModel
|
||||
from services.virtual_fs import path_is_directory
|
||||
|
||||
router = APIRouter(prefix="/api/processors", tags=["processors"])
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_processors(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
schemas = get_config_schemas()
|
||||
out = []
|
||||
for t, meta in schemas.items():
|
||||
out.append({
|
||||
"type": meta["type"],
|
||||
"name": meta["name"],
|
||||
"supported_exts": meta.get("supported_exts", []),
|
||||
"config_schema": meta["config_schema"],
|
||||
"produces_file": meta.get("produces_file", False),
|
||||
"module_path": meta.get("module_path"),
|
||||
})
|
||||
return success(out)
|
||||
|
||||
|
||||
class ProcessRequest(BaseModel):
|
||||
path: str
|
||||
processor_type: str
|
||||
config: dict
|
||||
save_to: str | None = None
|
||||
overwrite: bool = False
|
||||
|
||||
|
||||
class UpdateSourceRequest(BaseModel):
|
||||
source: str
|
||||
|
||||
|
||||
@router.post("/process")
|
||||
async def process_file_with_processor(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
req: ProcessRequest = Body(...)
|
||||
):
|
||||
is_dir = await path_is_directory(req.path)
|
||||
if is_dir and not req.overwrite:
|
||||
raise HTTPException(400, detail="Directory processing requires overwrite")
|
||||
|
||||
save_to = None if is_dir else (req.path if req.overwrite else req.save_to)
|
||||
task = await task_queue_service.add_task(
|
||||
"process_file",
|
||||
{
|
||||
"path": req.path,
|
||||
"processor_type": req.processor_type,
|
||||
"config": req.config,
|
||||
"save_to": save_to,
|
||||
"overwrite": req.overwrite,
|
||||
},
|
||||
)
|
||||
return success({"task_id": task.id})
|
||||
|
||||
|
||||
@router.get("/source/{processor_type}")
|
||||
async def get_processor_source(
|
||||
processor_type: str,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
module_path = get_module_path(processor_type)
|
||||
if not module_path:
|
||||
raise HTTPException(404, detail="Processor not found")
|
||||
path_obj = Path(module_path)
|
||||
if not path_obj.exists():
|
||||
raise HTTPException(404, detail="Processor source not found")
|
||||
try:
|
||||
content = await run_in_threadpool(path_obj.read_text, encoding='utf-8')
|
||||
except Exception as exc:
|
||||
raise HTTPException(500, detail=f"Failed to read source: {exc}")
|
||||
return success({"source": content, "module_path": str(path_obj)})
|
||||
|
||||
|
||||
@router.put("/source/{processor_type}")
|
||||
async def update_processor_source(
|
||||
processor_type: str,
|
||||
req: UpdateSourceRequest,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
module_path = get_module_path(processor_type)
|
||||
if not module_path:
|
||||
raise HTTPException(404, detail="Processor not found")
|
||||
path_obj = Path(module_path)
|
||||
if not path_obj.exists():
|
||||
raise HTTPException(404, detail="Processor source not found")
|
||||
try:
|
||||
await run_in_threadpool(path_obj.write_text, req.source, encoding='utf-8')
|
||||
except Exception as exc:
|
||||
raise HTTPException(500, detail=f"Failed to write source: {exc}")
|
||||
return success(True)
|
||||
|
||||
|
||||
@router.post("/reload")
|
||||
async def reload_processor_modules(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
errors = reload_processors()
|
||||
if errors:
|
||||
raise HTTPException(500, detail="; ".join(errors))
|
||||
return success(True)
|
||||
@@ -1,41 +0,0 @@
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from schemas.fs import SearchResultItem
|
||||
from services.auth import get_current_active_user, User
|
||||
from services.ai import get_text_embedding
|
||||
from services.vector_db import VectorDBService
|
||||
|
||||
router = APIRouter(prefix="/api/search", tags=["search"])
|
||||
|
||||
async def search_files_by_vector(q: str, top_k: int):
|
||||
embedding = await get_text_embedding(q)
|
||||
vector_db = VectorDBService()
|
||||
results = await vector_db.search_vectors("vector_collection", embedding, top_k)
|
||||
items = [
|
||||
SearchResultItem(id=res["id"], path=res["entity"]["path"], score=res["distance"])
|
||||
for res in results[0]
|
||||
]
|
||||
return {"items": items, "query": q}
|
||||
|
||||
async def search_files_by_name(q: str, top_k: int):
|
||||
vector_db = VectorDBService()
|
||||
results = await vector_db.search_by_path("vector_collection", q, top_k)
|
||||
items = [
|
||||
SearchResultItem(id=idx, path=res["entity"]["path"], score=res["distance"])
|
||||
for idx, res in enumerate(results[0])
|
||||
]
|
||||
return {"items": items, "query": q}
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def search_files(
|
||||
q: str = Query(..., description="搜索查询"),
|
||||
top_k: int = Query(10, description="返回结果数量"),
|
||||
mode: str = Query("vector", description="搜索模式: 'vector' 或 'filename'"),
|
||||
user: User = Depends(get_current_active_user),
|
||||
):
|
||||
if mode == "vector":
|
||||
return await search_files_by_vector(q, top_k)
|
||||
elif mode == "filename":
|
||||
return await search_files_by_name(q, top_k)
|
||||
else:
|
||||
return {"items": [], "query": q, "error": "Invalid search mode"}
|
||||
@@ -1,217 +0,0 @@
|
||||
from typing import List, Optional
|
||||
from urllib.parse import quote
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from pydantic import BaseModel
|
||||
|
||||
from api.response import success
|
||||
from services.auth import User, get_current_active_user
|
||||
from services.share import share_service
|
||||
from services.virtual_fs import stream_file, stat_file
|
||||
from models.database import ShareLink, UserAccount
|
||||
|
||||
public_router = APIRouter(prefix="/api/s", tags=["Share - Public"])
|
||||
router = APIRouter(prefix="/api/shares", tags=["Share - Management"])
|
||||
|
||||
class ShareCreate(BaseModel):
|
||||
name: str
|
||||
paths: List[str]
|
||||
expires_in_days: Optional[int] = 7
|
||||
access_type: str = "public"
|
||||
password: Optional[str] = None
|
||||
|
||||
|
||||
class ShareInfo(BaseModel):
|
||||
id: int
|
||||
token: str
|
||||
name: str
|
||||
paths: List[str]
|
||||
created_at: str
|
||||
expires_at: Optional[str] = None
|
||||
access_type: str
|
||||
|
||||
@classmethod
|
||||
def from_orm(cls, obj: ShareLink):
|
||||
return cls(
|
||||
id=obj.id,
|
||||
token=obj.token,
|
||||
name=obj.name,
|
||||
paths=obj.paths,
|
||||
created_at=obj.created_at.isoformat(),
|
||||
expires_at=obj.expires_at.isoformat() if obj.expires_at else None,
|
||||
access_type=obj.access_type,
|
||||
)
|
||||
|
||||
|
||||
class ShareInfoWithPassword(ShareInfo):
|
||||
password: Optional[str] = None
|
||||
|
||||
|
||||
# --- Management Routes ---
|
||||
|
||||
@router.post("", response_model=ShareInfoWithPassword)
|
||||
async def create_share(
|
||||
payload: ShareCreate,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
):
|
||||
"""
|
||||
创建一个新的分享链接。
|
||||
"""
|
||||
user_account = await UserAccount.get(id=current_user.id)
|
||||
share = await share_service.create_share_link(
|
||||
user=user_account,
|
||||
name=payload.name,
|
||||
paths=payload.paths,
|
||||
expires_in_days=payload.expires_in_days,
|
||||
access_type=payload.access_type,
|
||||
password=payload.password,
|
||||
)
|
||||
share_info_base = ShareInfo.from_orm(share)
|
||||
response_data = share_info_base.model_dump()
|
||||
if payload.access_type == "password" and payload.password:
|
||||
response_data['password'] = payload.password
|
||||
|
||||
return response_data
|
||||
|
||||
|
||||
@router.get("", response_model=List[ShareInfo])
|
||||
async def get_my_shares(current_user: User = Depends(get_current_active_user)):
|
||||
"""
|
||||
获取当前用户的所有分享链接。
|
||||
"""
|
||||
user_account = await UserAccount.get(id=current_user.id)
|
||||
shares = await share_service.get_user_shares(user=user_account)
|
||||
return [ShareInfo.from_orm(s) for s in shares]
|
||||
|
||||
|
||||
@router.delete("/expired")
|
||||
async def delete_expired_shares(
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
):
|
||||
"""
|
||||
删除当前用户的所有已过期分享。
|
||||
"""
|
||||
user_account = await UserAccount.get(id=current_user.id)
|
||||
deleted_count = await share_service.delete_expired_shares(user=user_account)
|
||||
return success({"deleted_count": deleted_count})
|
||||
|
||||
|
||||
@router.delete("/{share_id}")
|
||||
async def delete_share(
|
||||
share_id: int,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
):
|
||||
"""
|
||||
删除一个分享链接。
|
||||
"""
|
||||
await share_service.delete_share_link(user=current_user, share_id=share_id)
|
||||
return success(msg="分享已取消")
|
||||
|
||||
|
||||
# --- Public Routes ---
|
||||
|
||||
class SharePassword(BaseModel):
|
||||
password: str
|
||||
|
||||
@public_router.post("/{token}/verify")
|
||||
async def verify_password(token: str, payload: SharePassword):
|
||||
"""
|
||||
验证分享链接的密码。
|
||||
"""
|
||||
share = await share_service.get_share_by_token(token)
|
||||
if share.access_type != "password":
|
||||
raise HTTPException(status_code=400, detail="此分享不需要密码")
|
||||
|
||||
if not share_service._verify_password(payload.password, share.hashed_password):
|
||||
raise HTTPException(status_code=403, detail="密码错误")
|
||||
|
||||
# 在这里可以考虑返回一个有时效性的token用于后续访问,但为了简单起见,
|
||||
# 我们让前端在每次请求时都带上密码或一个会话标识。
|
||||
# 简单起见,我们只返回成功状态。
|
||||
return success(msg="验证成功")
|
||||
|
||||
|
||||
@public_router.get("/{token}/ls")
|
||||
async def list_share_content(token: str, path: str = "/", password: Optional[str] = None):
|
||||
"""
|
||||
列出分享链接中的文件和目录。
|
||||
"""
|
||||
share = await share_service.get_share_by_token(token)
|
||||
|
||||
if share.access_type == "password":
|
||||
if not password:
|
||||
raise HTTPException(status_code=401, detail="需要密码")
|
||||
if not share_service._verify_password(password, share.hashed_password):
|
||||
raise HTTPException(status_code=403, detail="密码错误")
|
||||
|
||||
content = await share_service.get_shared_item_details(share, path)
|
||||
return success({
|
||||
"path": path,
|
||||
"entries": content.get("items", []),
|
||||
"pagination": {
|
||||
"total": content.get("total", 0),
|
||||
"page": content.get("page", 1),
|
||||
"page_size": content.get("page_size", 1),
|
||||
"pages": content.get("pages", 1),
|
||||
}
|
||||
})
|
||||
|
||||
@public_router.get("/{token}")
|
||||
async def get_share_info(token: str):
|
||||
"""
|
||||
获取分享链接的元数据信息。
|
||||
"""
|
||||
share = await share_service.get_share_by_token(token)
|
||||
return success(ShareInfo.from_orm(share))
|
||||
|
||||
|
||||
|
||||
@public_router.get("/{token}/download")
|
||||
async def download_shared_file(token: str, path: str, request: Request, password: Optional[str] = None):
|
||||
"""
|
||||
下载分享链接中的单个文件。
|
||||
"""
|
||||
if not path or path == "/" or ".." in path.split('/'):
|
||||
raise HTTPException(status_code=400, detail="无效的文件路径")
|
||||
|
||||
share = await share_service.get_share_by_token(token)
|
||||
if share.access_type == "password":
|
||||
if not password:
|
||||
raise HTTPException(status_code=401, detail="需要密码")
|
||||
if not share_service._verify_password(password, share.hashed_password):
|
||||
raise HTTPException(status_code=403, detail="密码错误")
|
||||
base_shared_path = share.paths[0]
|
||||
|
||||
# 判断分享的是文件还是目录
|
||||
is_dir = False
|
||||
try:
|
||||
stat = await stat_file(base_shared_path)
|
||||
if stat and stat.get("is_dir"):
|
||||
is_dir = True
|
||||
except HTTPException as e:
|
||||
if "Path is a directory" in str(e.detail) or "Not a file" in str(e.detail):
|
||||
is_dir = True
|
||||
else:
|
||||
# The shared path itself doesn't exist, which is an issue.
|
||||
raise HTTPException(status_code=404, detail="分享的源文件不存在")
|
||||
|
||||
if is_dir:
|
||||
# 目录分享:拼接路径
|
||||
full_virtual_path = f"{base_shared_path.rstrip('/')}/{path.lstrip('/')}"
|
||||
if not full_virtual_path.startswith(base_shared_path):
|
||||
raise HTTPException(status_code=403, detail="无权访问此路径")
|
||||
else:
|
||||
# 文件分享:路径应为分享的根路径
|
||||
shared_filename = base_shared_path.split('/')[-1]
|
||||
request_filename = path.lstrip('/')
|
||||
if shared_filename != request_filename:
|
||||
raise HTTPException(status_code=403, detail="无权访问此路径")
|
||||
full_virtual_path = base_shared_path
|
||||
|
||||
range_header = request.headers.get("Range")
|
||||
response = await stream_file(full_virtual_path, range_header)
|
||||
|
||||
# 设置 Content-Disposition 头来强制下载
|
||||
filename = full_virtual_path.split('/')[-1]
|
||||
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{quote(filename)}"
|
||||
|
||||
return response
|
||||
@@ -1,141 +0,0 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from typing import Annotated
|
||||
|
||||
from models.database import AutomationTask
|
||||
from schemas.tasks import (
|
||||
AutomationTaskCreate,
|
||||
AutomationTaskUpdate,
|
||||
TaskQueueSettings,
|
||||
TaskQueueSettingsResponse,
|
||||
)
|
||||
from api.response import success
|
||||
from services.auth import get_current_active_user, User
|
||||
from services.logging import LogService
|
||||
from services.task_queue import task_queue_service
|
||||
from services.config import ConfigCenter
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api/tasks",
|
||||
tags=["Tasks"],
|
||||
dependencies=[Depends(get_current_active_user)],
|
||||
responses={404: {"description": "Not found"}},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/queue")
|
||||
async def get_task_queue_status(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
tasks = task_queue_service.get_all_tasks()
|
||||
return success([task.dict() for task in tasks])
|
||||
|
||||
|
||||
@router.get("/queue/settings")
|
||||
async def get_task_queue_settings(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
payload = TaskQueueSettingsResponse(
|
||||
concurrency=task_queue_service.get_concurrency(),
|
||||
active_workers=task_queue_service.get_active_worker_count(),
|
||||
)
|
||||
return success(payload.model_dump())
|
||||
|
||||
|
||||
@router.post("/queue/settings")
|
||||
async def update_task_queue_settings(
|
||||
settings: TaskQueueSettings,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
await task_queue_service.set_concurrency(settings.concurrency)
|
||||
await ConfigCenter.set("TASK_QUEUE_CONCURRENCY", str(task_queue_service.get_concurrency()))
|
||||
await LogService.action(
|
||||
"route:tasks",
|
||||
"Updated task queue settings",
|
||||
details={"concurrency": settings.concurrency},
|
||||
user_id=getattr(current_user, "id", None),
|
||||
)
|
||||
payload = TaskQueueSettingsResponse(
|
||||
concurrency=task_queue_service.get_concurrency(),
|
||||
active_workers=task_queue_service.get_active_worker_count(),
|
||||
)
|
||||
return success(payload.model_dump())
|
||||
|
||||
|
||||
@router.get("/queue/{task_id}")
|
||||
async def get_task_status(
|
||||
task_id: str,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
task = task_queue_service.get_task(task_id)
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
return success(task.dict())
|
||||
|
||||
|
||||
@router.post("/")
|
||||
async def create_task(
|
||||
task_in: AutomationTaskCreate,
|
||||
user: User = Depends(get_current_active_user)
|
||||
):
|
||||
task = await AutomationTask.create(**task_in.model_dump())
|
||||
await LogService.action(
|
||||
"route:tasks",
|
||||
f"Created task {task.name}",
|
||||
details=task_in.model_dump(),
|
||||
user_id=user.id if hasattr(user, "id") else None,
|
||||
)
|
||||
return success(task)
|
||||
|
||||
|
||||
@router.get("/{task_id}")
|
||||
async def get_task(task_id: int):
|
||||
task = await AutomationTask.get_or_none(id=task_id)
|
||||
if not task:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Task {task_id} not found")
|
||||
return success(task)
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def list_tasks():
|
||||
tasks = await AutomationTask.all()
|
||||
return success(tasks)
|
||||
|
||||
|
||||
@router.put("/{task_id}")
|
||||
async def update_task(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
task_id: int, task_in: AutomationTaskUpdate):
|
||||
task = await AutomationTask.get_or_none(id=task_id)
|
||||
if not task:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Task {task_id} not found")
|
||||
update_data = task_in.model_dump(exclude_unset=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(task, key, value)
|
||||
await task.save()
|
||||
await LogService.action(
|
||||
"route:tasks",
|
||||
f"Updated task {task.name}",
|
||||
details=task_in.model_dump(),
|
||||
user_id=current_user.id,
|
||||
)
|
||||
return success(task)
|
||||
|
||||
|
||||
@router.delete("/{task_id}")
|
||||
async def delete_task(
|
||||
task_id: int,
|
||||
user: User = Depends(get_current_active_user)
|
||||
):
|
||||
deleted_count = await AutomationTask.filter(id=task_id).delete()
|
||||
if not deleted_count:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Task {task_id} not found")
|
||||
await LogService.action(
|
||||
"route:tasks",
|
||||
f"Deleted task {task_id}",
|
||||
details={"task_id": task_id},
|
||||
user_id=user.id if hasattr(user, "id") else None,
|
||||
)
|
||||
return success(msg="Task deleted")
|
||||
@@ -1,100 +0,0 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from services.auth import get_current_active_user
|
||||
from models.database import UserAccount
|
||||
from services.vector_db import (
|
||||
VectorDBService,
|
||||
VectorDBConfigManager,
|
||||
list_providers,
|
||||
get_provider_entry,
|
||||
)
|
||||
from services.vector_db.providers import get_provider_class
|
||||
from api.response import success
|
||||
|
||||
router = APIRouter(prefix="/api/vector-db", tags=["vector-db"])
|
||||
|
||||
|
||||
class VectorDBConfigPayload(BaseModel):
|
||||
type: str = Field(..., description="向量数据库提供者类型")
|
||||
config: Dict[str, Any] = Field(default_factory=dict, description="提供者配置参数")
|
||||
|
||||
|
||||
@router.post("/clear-all", summary="清空向量数据库")
|
||||
async def clear_vector_db(user: UserAccount = Depends(get_current_active_user)):
|
||||
if user.username != 'admin':
|
||||
raise HTTPException(status_code=403, detail="仅管理员可操作")
|
||||
try:
|
||||
service = VectorDBService()
|
||||
await service.clear_all_data()
|
||||
return success(msg="向量数据库已清空")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/stats", summary="获取向量数据库统计")
|
||||
async def get_vector_db_stats(user: UserAccount = Depends(get_current_active_user)):
|
||||
if user.username != 'admin':
|
||||
raise HTTPException(status_code=403, detail="仅管理员可操作")
|
||||
try:
|
||||
service = VectorDBService()
|
||||
data = await service.get_all_stats()
|
||||
return success(data=data)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/providers", summary="列出可用向量数据库提供者")
|
||||
async def list_vector_providers(user: UserAccount = Depends(get_current_active_user)):
|
||||
if user.username != 'admin':
|
||||
raise HTTPException(status_code=403, detail="仅管理员可操作")
|
||||
return success(list_providers())
|
||||
|
||||
|
||||
@router.get("/config", summary="获取当前向量数据库配置")
|
||||
async def get_vector_db_config(user: UserAccount = Depends(get_current_active_user)):
|
||||
if user.username != 'admin':
|
||||
raise HTTPException(status_code=403, detail="仅管理员可操作")
|
||||
service = VectorDBService()
|
||||
data = await service.current_provider()
|
||||
return success(data)
|
||||
|
||||
|
||||
@router.post("/config", summary="更新向量数据库配置")
|
||||
async def update_vector_db_config(payload: VectorDBConfigPayload, user: UserAccount = Depends(get_current_active_user)):
|
||||
if user.username != 'admin':
|
||||
raise HTTPException(status_code=403, detail="仅管理员可操作")
|
||||
|
||||
entry = get_provider_entry(payload.type)
|
||||
if not entry:
|
||||
raise HTTPException(status_code=400, detail=f"未知的向量数据库类型: {payload.type}")
|
||||
if not entry.get("enabled", True):
|
||||
raise HTTPException(status_code=400, detail="该向量数据库类型暂不可用")
|
||||
|
||||
provider_cls = get_provider_class(payload.type)
|
||||
if not provider_cls:
|
||||
raise HTTPException(status_code=400, detail=f"未找到类型 {payload.type} 对应的实现")
|
||||
|
||||
# 先尝试建立连接,确保配置有效
|
||||
test_provider = provider_cls(payload.config)
|
||||
try:
|
||||
await test_provider.initialize()
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc))
|
||||
finally:
|
||||
client = getattr(test_provider, "client", None)
|
||||
close_fn = getattr(client, "close", None)
|
||||
if callable(close_fn):
|
||||
try:
|
||||
close_fn()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await VectorDBConfigManager.save_config(payload.type, payload.config)
|
||||
service = VectorDBService()
|
||||
await service.reload()
|
||||
config_data = await service.current_provider()
|
||||
stats = await service.get_all_stats()
|
||||
return success({"config": config_data, "stats": stats})
|
||||
@@ -1,369 +0,0 @@
|
||||
from fastapi import APIRouter, UploadFile, File, HTTPException, Response, Query, Request, Depends
|
||||
import mimetypes
|
||||
import re
|
||||
from typing import Annotated
|
||||
|
||||
from services.auth import get_current_active_user, User
|
||||
from services.virtual_fs import (
|
||||
list_virtual_dir,
|
||||
read_file,
|
||||
write_file,
|
||||
make_dir,
|
||||
delete_path,
|
||||
move_path,
|
||||
resolve_adapter_and_rel,
|
||||
stream_file,
|
||||
generate_temp_link_token,
|
||||
verify_temp_link_token,
|
||||
)
|
||||
from services.thumbnail import is_image_filename, get_or_create_thumb, is_raw_filename
|
||||
from schemas import MkdirRequest, MoveRequest
|
||||
from api.response import success
|
||||
from services.config import ConfigCenter
|
||||
|
||||
router = APIRouter(prefix='/api/fs', tags=["virtual-fs"])
|
||||
|
||||
|
||||
@router.get("/file/{full_path:path}")
|
||||
async def get_file(
|
||||
full_path: str,
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
full_path = '/' + full_path if not full_path.startswith('/') else full_path
|
||||
|
||||
if is_raw_filename(full_path):
|
||||
import rawpy
|
||||
from PIL import Image
|
||||
import io
|
||||
try:
|
||||
raw_data = await read_file(full_path)
|
||||
with rawpy.imread(io.BytesIO(raw_data)) as raw:
|
||||
rgb = raw.postprocess(use_camera_wb=True, output_bps=8)
|
||||
im = Image.fromarray(rgb)
|
||||
buf = io.BytesIO()
|
||||
im.save(buf, 'JPEG', quality=90)
|
||||
content = buf.getvalue()
|
||||
return Response(content=content, media_type='image/jpeg')
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(404, detail="File not found")
|
||||
except Exception as e:
|
||||
raise HTTPException(500, detail=f"RAW file processing failed: {e}")
|
||||
|
||||
try:
|
||||
content = await read_file(full_path)
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(404, detail="File not found")
|
||||
|
||||
if not isinstance(content, (bytes, bytearray)):
|
||||
return Response(content=content, media_type="application/octet-stream")
|
||||
|
||||
content_length = len(content)
|
||||
content_type = mimetypes.guess_type(
|
||||
full_path)[0] or "application/octet-stream"
|
||||
|
||||
range_header = request.headers.get('Range')
|
||||
if range_header:
|
||||
range_match = re.match(r'bytes=(\d+)-(\d*)', range_header)
|
||||
if range_match:
|
||||
start = int(range_match.group(1))
|
||||
end = int(range_match.group(2)) if range_match.group(
|
||||
2) else content_length - 1
|
||||
|
||||
start = max(0, min(start, content_length - 1))
|
||||
end = max(start, min(end, content_length - 1))
|
||||
|
||||
chunk = content[start:end + 1]
|
||||
chunk_size = len(chunk)
|
||||
|
||||
headers = {
|
||||
'Content-Range': f'bytes {start}-{end}/{content_length}',
|
||||
'Accept-Ranges': 'bytes',
|
||||
'Content-Length': str(chunk_size),
|
||||
'Content-Type': content_type,
|
||||
}
|
||||
|
||||
return Response(
|
||||
content=chunk,
|
||||
status_code=206,
|
||||
headers=headers
|
||||
)
|
||||
|
||||
headers = {
|
||||
'Accept-Ranges': 'bytes',
|
||||
'Content-Length': str(content_length),
|
||||
'Content-Type': content_type,
|
||||
}
|
||||
|
||||
if content_type.startswith('video/'):
|
||||
headers['Cache-Control'] = 'public, max-age=3600'
|
||||
|
||||
return Response(content=content, headers=headers)
|
||||
|
||||
|
||||
@router.get("/thumb/{full_path:path}")
|
||||
async def get_thumb(
|
||||
full_path: str,
|
||||
w: int = Query(256, ge=8, le=1024),
|
||||
h: int = Query(256, ge=8, le=1024),
|
||||
fit: str = Query("cover"),
|
||||
):
|
||||
full_path = '/' + full_path if not full_path.startswith('/') else full_path
|
||||
if fit not in ("cover", "contain"):
|
||||
raise HTTPException(400, detail="fit must be cover|contain")
|
||||
adapter, mount, root, rel = await resolve_adapter_and_rel(full_path)
|
||||
if not rel or rel.endswith('/'):
|
||||
raise HTTPException(400, detail="Not a file")
|
||||
if not is_image_filename(rel):
|
||||
raise HTTPException(404, detail="Not an image")
|
||||
# type: ignore
|
||||
data, mime, key = await get_or_create_thumb(adapter, mount.id, root, rel, w, h, fit)
|
||||
headers = {
|
||||
'Cache-Control': 'public, max-age=3600',
|
||||
'ETag': key,
|
||||
}
|
||||
return Response(content=data, media_type=mime, headers=headers)
|
||||
|
||||
|
||||
@router.get("/stream/{full_path:path}")
|
||||
async def stream_endpoint(
|
||||
full_path: str,
|
||||
request: Request,
|
||||
):
|
||||
"""支持 Range 的视频/大文件流式读取,优先使用底层适配器 Range 能力。"""
|
||||
full_path = '/' + full_path if not full_path.startswith('/') else full_path
|
||||
range_header = request.headers.get('Range')
|
||||
try:
|
||||
return await stream_file(full_path, range_header)
|
||||
except HTTPException:
|
||||
raise
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(404, detail="File not found")
|
||||
except Exception as e:
|
||||
raise HTTPException(500, detail=f"Stream error: {e}")
|
||||
|
||||
|
||||
@router.get("/temp-link/{full_path:path}")
|
||||
async def get_temp_link(
|
||||
full_path: str,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
expires_in: int = Query(3600, description="有效时间(秒), 0或负数表示永久")
|
||||
):
|
||||
"""获取文件的临时公开访问令牌"""
|
||||
full_path = '/' + full_path if not full_path.startswith('/') else full_path
|
||||
token = await generate_temp_link_token(full_path, expires_in=expires_in)
|
||||
file_domain = await ConfigCenter.get("FILE_DOMAIN")
|
||||
if file_domain:
|
||||
file_domain = file_domain.rstrip('/')
|
||||
url = f"{file_domain}/api/fs/public/{token}"
|
||||
else:
|
||||
url = f"/api/fs/public/{token}"
|
||||
return success({"token": token, "path": full_path, "url": url})
|
||||
|
||||
|
||||
@router.get("/public/{token}")
|
||||
async def access_public_file(
|
||||
token: str,
|
||||
request: Request,
|
||||
):
|
||||
"""通过令牌公开访问文件,支持 Range 请求"""
|
||||
try:
|
||||
path = await verify_temp_link_token(token)
|
||||
except HTTPException as e:
|
||||
raise e
|
||||
|
||||
range_header = request.headers.get('Range')
|
||||
try:
|
||||
return await stream_file(path, range_header)
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(404, detail="File not found via token")
|
||||
except Exception as e:
|
||||
raise HTTPException(500, detail=f"File access error: {e}")
|
||||
|
||||
|
||||
@router.get("/stat/{full_path:path}")
|
||||
async def get_file_stat(
|
||||
full_path: str,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
full_path = '/' + full_path if not full_path.startswith('/') else full_path
|
||||
from services.virtual_fs import stat_file
|
||||
stat = await stat_file(full_path)
|
||||
return success(stat)
|
||||
|
||||
|
||||
@router.post("/file/{full_path:path}")
|
||||
async def put_file(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
full_path: str,
|
||||
file: UploadFile = File(...)
|
||||
):
|
||||
full_path = '/' + full_path if not full_path.startswith('/') else full_path
|
||||
data = await file.read()
|
||||
await write_file(full_path, data)
|
||||
return success({"written": True, "path": full_path, "size": len(data)})
|
||||
|
||||
|
||||
@router.post("/mkdir")
|
||||
async def api_mkdir(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
body: MkdirRequest
|
||||
):
|
||||
path = body.path if body.path.startswith('/') else '/' + body.path
|
||||
if not path or path == '/':
|
||||
raise HTTPException(400, detail="Invalid path")
|
||||
await make_dir(path)
|
||||
return success({"created": True, "path": path})
|
||||
|
||||
|
||||
@router.post("/move")
|
||||
async def api_move(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
body: MoveRequest,
|
||||
overwrite: bool = Query(False, description="是否允许覆盖已存在目标"),
|
||||
):
|
||||
src = body.src if body.src.startswith('/') else '/' + body.src
|
||||
dst = body.dst if body.dst.startswith('/') else '/' + body.dst
|
||||
debug_info = await move_path(src, dst, overwrite=overwrite, return_debug=True, allow_cross=True)
|
||||
queued = bool(debug_info.get("queued"))
|
||||
response = {
|
||||
"moved": not queued,
|
||||
"queued": queued,
|
||||
"src": src,
|
||||
"dst": dst,
|
||||
"overwrite": overwrite,
|
||||
}
|
||||
if queued:
|
||||
response["task_id"] = debug_info.get("task_id")
|
||||
response["task_name"] = debug_info.get("task_name")
|
||||
return success(response)
|
||||
|
||||
|
||||
@router.post("/rename")
|
||||
async def api_rename(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
body: MoveRequest,
|
||||
overwrite: bool = Query(False, description="是否允许覆盖已存在目标")
|
||||
):
|
||||
src = body.src if body.src.startswith('/') else '/' + body.src
|
||||
dst = body.dst if body.dst.startswith('/') else '/' + body.dst
|
||||
from services.virtual_fs import rename_path
|
||||
await rename_path(src, dst, overwrite=overwrite, return_debug=False)
|
||||
return success({
|
||||
"renamed": True,
|
||||
"src": src,
|
||||
"dst": dst,
|
||||
"overwrite": overwrite,
|
||||
})
|
||||
|
||||
|
||||
@router.post("/copy")
|
||||
async def api_copy(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
body: MoveRequest,
|
||||
overwrite: bool = Query(False, description="是否覆盖已存在目标"),
|
||||
):
|
||||
from services.virtual_fs import copy_path
|
||||
src = body.src if body.src.startswith('/') else '/' + body.src
|
||||
dst = body.dst if body.dst.startswith('/') else '/' + body.dst
|
||||
debug_info = await copy_path(src, dst, overwrite=overwrite, return_debug=True, allow_cross=True)
|
||||
queued = bool(debug_info.get("queued"))
|
||||
response = {
|
||||
"copied": not queued,
|
||||
"queued": queued,
|
||||
"src": src,
|
||||
"dst": dst,
|
||||
"overwrite": overwrite,
|
||||
}
|
||||
if queued:
|
||||
response["task_id"] = debug_info.get("task_id")
|
||||
response["task_name"] = debug_info.get("task_name")
|
||||
return success(response)
|
||||
|
||||
|
||||
@router.post("/upload/{full_path:path}")
|
||||
async def upload_stream(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
full_path: str,
|
||||
file: UploadFile = File(...),
|
||||
overwrite: bool = Query(True, description="是否覆盖已存在文件"),
|
||||
chunk_size: int = Query(1024 * 1024, ge=8 * 1024,
|
||||
le=8 * 1024 * 1024, description="单次读取块大小")
|
||||
):
|
||||
full_path = '/' + full_path if not full_path.startswith('/') else full_path
|
||||
if full_path.endswith('/'):
|
||||
raise HTTPException(400, detail="Path must be a file")
|
||||
from services.virtual_fs import write_file_stream, resolve_adapter_and_rel
|
||||
adapter, _m, root, rel = await resolve_adapter_and_rel(full_path)
|
||||
exists_func = getattr(adapter, "exists", None)
|
||||
if not overwrite and callable(exists_func):
|
||||
try:
|
||||
if await exists_func(root, rel):
|
||||
raise HTTPException(409, detail="Destination exists")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def gen():
|
||||
while True:
|
||||
chunk = await file.read(chunk_size)
|
||||
if not chunk:
|
||||
break
|
||||
yield chunk
|
||||
size = await write_file_stream(full_path, gen(), overwrite=overwrite)
|
||||
return success({"uploaded": True, "path": full_path, "size": size, "overwrite": overwrite})
|
||||
|
||||
|
||||
@router.get("/{full_path:path}")
|
||||
async def browse_fs(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
full_path: str,
|
||||
page_num: int = Query(1, alias="page", ge=1, description="页码"),
|
||||
page_size: int = Query(50, ge=1, le=500, description="每页条数"),
|
||||
sort_by: str = Query("name", description="按字段排序: name, size, mtime"),
|
||||
sort_order: str = Query("asc", description="排序顺序: asc, desc")
|
||||
):
|
||||
full_path = '/' + full_path if not full_path.startswith('/') else full_path
|
||||
result = await list_virtual_dir(full_path, page_num, page_size, sort_by, sort_order)
|
||||
return success({
|
||||
"path": full_path,
|
||||
"entries": result["items"],
|
||||
"pagination": {
|
||||
"total": result["total"],
|
||||
"page": result["page"],
|
||||
"page_size": result["page_size"],
|
||||
"pages": result["pages"]
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
@router.delete("/{full_path:path}")
|
||||
async def api_delete(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
full_path: str
|
||||
):
|
||||
full_path = '/' + full_path if not full_path.startswith('/') else full_path
|
||||
await delete_path(full_path)
|
||||
return success({"deleted": True, "path": full_path})
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def root_listing(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
page_num: int = Query(1, alias="page", ge=1, description="页码"),
|
||||
page_size: int = Query(50, ge=1, le=500, description="每页条数"),
|
||||
sort_by: str = Query("name", description="按字段排序: name, size, mtime"),
|
||||
sort_order: str = Query("asc", description="排序顺序: asc, desc")
|
||||
):
|
||||
result = await list_virtual_dir("/", page_num, page_size, sort_by, sort_order)
|
||||
return success({
|
||||
"path": "/",
|
||||
"entries": result["items"],
|
||||
"pagination": {
|
||||
"total": result["total"],
|
||||
"page": result["page"],
|
||||
"page_size": result["page_size"],
|
||||
"pages": result["pages"]
|
||||
}
|
||||
})
|
||||
@@ -5,9 +5,10 @@ services:
|
||||
container_name: foxel
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
- "8088:80"
|
||||
- "${FOXEL_HOST_PORT:-8088}:${FOXEL_PORT:-80}"
|
||||
environment:
|
||||
- TZ=Asia/Shanghai
|
||||
- FOXEL_PORT=${FOXEL_PORT:-80}
|
||||
- SECRET_KEY=EnsRhL9NFPxgFVc+7t96/y70DIOR+9SpntcIqQa90TU=
|
||||
- TEMP_LINK_SECRET_KEY=EnsRhL9NFPxgFVc+7t96/y70DIOR+9SpntcIqQa90TU=
|
||||
volumes:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from tortoise import Tortoise
|
||||
|
||||
from services.adapters.registry import runtime_registry
|
||||
from domain.adapters import runtime_registry
|
||||
|
||||
TORTOISE_ORM = {
|
||||
"connections": {"default": "sqlite://data/db/db.sqlite3"},
|
||||
@@ -12,7 +12,6 @@ TORTOISE_ORM = {
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
async def init_db():
|
||||
await Tortoise.init(config=TORTOISE_ORM)
|
||||
await Tortoise.generate_schemas()
|
||||
|
||||
7
domain/__init__.py
Normal file
7
domain/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
domain:业务域层
|
||||
|
||||
约定:跨包只从各子包 `__init__.py` 导入公开 API。
|
||||
"""
|
||||
|
||||
__all__: list[str] = []
|
||||
24
domain/adapters/__init__.py
Normal file
24
domain/adapters/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from .providers import BaseAdapter
|
||||
from .registry import (
|
||||
RuntimeRegistry,
|
||||
discover_adapters,
|
||||
get_config_schema,
|
||||
get_config_schemas,
|
||||
normalize_adapter_type,
|
||||
runtime_registry,
|
||||
)
|
||||
from .service import AdapterService
|
||||
from .types import AdapterCreate, AdapterOut
|
||||
|
||||
__all__ = [
|
||||
"BaseAdapter",
|
||||
"RuntimeRegistry",
|
||||
"discover_adapters",
|
||||
"get_config_schema",
|
||||
"get_config_schemas",
|
||||
"normalize_adapter_type",
|
||||
"runtime_registry",
|
||||
"AdapterService",
|
||||
"AdapterCreate",
|
||||
"AdapterOut",
|
||||
]
|
||||
92
domain/adapters/api.py
Normal file
92
domain/adapters/api.py
Normal file
@@ -0,0 +1,92 @@
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
|
||||
from api.response import success
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.auth import User, get_current_active_user
|
||||
from domain.permission import require_system_permission
|
||||
from domain.permission.types import AdapterPermission
|
||||
from .service import AdapterService
|
||||
from .types import AdapterCreate
|
||||
|
||||
router = APIRouter(prefix="/api/adapters", tags=["adapters"])
|
||||
|
||||
|
||||
@router.post("")
|
||||
@audit(
|
||||
action=AuditAction.CREATE,
|
||||
description="创建存储适配器",
|
||||
body_fields=["name", "type", "path", "sub_path", "enabled"],
|
||||
)
|
||||
@require_system_permission(AdapterPermission.CREATE)
|
||||
async def create_adapter(
|
||||
request: Request,
|
||||
data: AdapterCreate,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
adapter = await AdapterService.create_adapter(data, current_user)
|
||||
return success(adapter)
|
||||
|
||||
|
||||
@router.get("")
|
||||
@audit(action=AuditAction.READ, description="获取适配器列表")
|
||||
@require_system_permission(AdapterPermission.LIST)
|
||||
async def list_adapters(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
adapters = await AdapterService.list_adapters()
|
||||
return success(adapters)
|
||||
|
||||
|
||||
@router.get("/available")
|
||||
@audit(action=AuditAction.READ, description="获取可用适配器类型")
|
||||
@require_system_permission(AdapterPermission.LIST)
|
||||
async def available_adapter_types(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
data = await AdapterService.available_adapter_types()
|
||||
return success(data)
|
||||
|
||||
|
||||
@router.get("/{adapter_id}")
|
||||
@audit(action=AuditAction.READ, description="获取适配器详情")
|
||||
@require_system_permission(AdapterPermission.LIST)
|
||||
async def get_adapter(
|
||||
request: Request,
|
||||
adapter_id: int,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
adapter = await AdapterService.get_adapter(adapter_id)
|
||||
return success(adapter)
|
||||
|
||||
|
||||
@router.put("/{adapter_id}")
|
||||
@audit(
|
||||
action=AuditAction.UPDATE,
|
||||
description="更新存储适配器",
|
||||
body_fields=["name", "type", "path", "sub_path", "enabled"],
|
||||
)
|
||||
@require_system_permission(AdapterPermission.EDIT)
|
||||
async def update_adapter(
|
||||
request: Request,
|
||||
adapter_id: int,
|
||||
data: AdapterCreate,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
adapter = await AdapterService.update_adapter(adapter_id, data, current_user)
|
||||
return success(adapter)
|
||||
|
||||
|
||||
@router.delete("/{adapter_id}")
|
||||
@audit(action=AuditAction.DELETE, description="删除存储适配器")
|
||||
@require_system_permission(AdapterPermission.DELETE)
|
||||
async def delete_adapter(
|
||||
request: Request,
|
||||
adapter_id: int,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
result = await AdapterService.delete_adapter(adapter_id, current_user)
|
||||
return success(result)
|
||||
3
domain/adapters/providers/__init__.py
Normal file
3
domain/adapters/providers/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .base import BaseAdapter
|
||||
|
||||
__all__ = ["BaseAdapter"]
|
||||
515
domain/adapters/providers/alist.py
Normal file
515
domain/adapters/providers/alist.py
Normal file
@@ -0,0 +1,515 @@
|
||||
import asyncio
|
||||
import mimetypes
|
||||
import re
|
||||
import tempfile
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, AsyncIterator, Dict, List, Tuple
|
||||
from urllib.parse import quote, urljoin
|
||||
|
||||
import httpx
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import Response, StreamingResponse
|
||||
|
||||
from models import StorageAdapter
|
||||
|
||||
|
||||
def _normalize_fs_path(path: str) -> str:
|
||||
path = (path or "").replace("\\", "/").strip()
|
||||
if not path or path == "/":
|
||||
return "/"
|
||||
if not path.startswith("/"):
|
||||
path = "/" + path
|
||||
path = re.sub(r"/{2,}", "/", path)
|
||||
if path != "/" and path.endswith("/"):
|
||||
path = path.rstrip("/")
|
||||
return path or "/"
|
||||
|
||||
|
||||
def _join_fs_path(base: str, rel: str) -> str:
|
||||
base = _normalize_fs_path(base)
|
||||
rel = (rel or "").replace("\\", "/").lstrip("/")
|
||||
if not rel:
|
||||
return base
|
||||
if base == "/":
|
||||
return "/" + rel
|
||||
return f"{base}/{rel}"
|
||||
|
||||
|
||||
def _split_parent_and_name(path: str) -> Tuple[str, str]:
|
||||
path = _normalize_fs_path(path)
|
||||
if path == "/":
|
||||
return "/", ""
|
||||
parent, _, name = path.rpartition("/")
|
||||
if not parent:
|
||||
parent = "/"
|
||||
return parent, name
|
||||
|
||||
|
||||
def _parse_iso_to_epoch(value: str | None) -> int:
|
||||
if not value:
|
||||
return 0
|
||||
text = str(value).strip()
|
||||
if not text:
|
||||
return 0
|
||||
try:
|
||||
if text.endswith("Z"):
|
||||
text = text[:-1] + "+00:00"
|
||||
m = re.match(r"^(.*?)(\.\d+)([+-]\d\d:\d\d)?$", text)
|
||||
if m:
|
||||
head, frac, tz = m.group(1), m.group(2), m.group(3) or ""
|
||||
digits = frac[1:]
|
||||
if len(digits) > 6:
|
||||
frac = "." + digits[:6]
|
||||
text = head + frac + tz
|
||||
dt = datetime.fromisoformat(text)
|
||||
if dt.tzinfo is None:
|
||||
dt = dt.replace(tzinfo=timezone.utc)
|
||||
return int(dt.timestamp())
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
|
||||
class AListApiAdapterBase:
|
||||
def __init__(self, record: StorageAdapter, *, product_name: str):
|
||||
self.record = record
|
||||
self.product_name = product_name
|
||||
|
||||
cfg = record.config or {}
|
||||
self.base_url: str = str(cfg.get("base_url", "")).rstrip("/")
|
||||
if not self.base_url.startswith("http"):
|
||||
raise ValueError(f"{product_name} requires base_url http/https")
|
||||
self.username: str = str(cfg.get("username") or "")
|
||||
self.password: str = str(cfg.get("password") or "")
|
||||
if (self.username and not self.password) or (self.password and not self.username):
|
||||
raise ValueError(f"{product_name} requires both username and password")
|
||||
self.use_auth: bool = bool(self.username and self.password)
|
||||
|
||||
self.timeout: float = float(cfg.get("timeout", 30))
|
||||
self.root_path: str = _normalize_fs_path(str(cfg.get("root") or "/"))
|
||||
self.enable_redirect_307: bool = bool(cfg.get("enable_direct_download_307"))
|
||||
|
||||
self._token: str | None = None
|
||||
self._login_lock = asyncio.Lock()
|
||||
|
||||
def get_effective_root(self, sub_path: str | None) -> str:
|
||||
base = _normalize_fs_path(self.root_path)
|
||||
if sub_path:
|
||||
return _join_fs_path(base, sub_path)
|
||||
return base
|
||||
|
||||
async def _ensure_token(self) -> str:
|
||||
if not self.use_auth:
|
||||
return ""
|
||||
if self._token:
|
||||
return self._token
|
||||
async with self._login_lock:
|
||||
if self._token:
|
||||
return self._token
|
||||
self._token = await self._login()
|
||||
return self._token
|
||||
|
||||
async def _login(self) -> str:
|
||||
url = self.base_url + "/api/auth/login"
|
||||
body = {"username": self.username, "password": self.password}
|
||||
async with httpx.AsyncClient(timeout=self.timeout, follow_redirects=True) as client:
|
||||
resp = await client.post(url, json=body)
|
||||
resp.raise_for_status()
|
||||
payload = resp.json()
|
||||
if not isinstance(payload, dict):
|
||||
raise HTTPException(502, detail=f"{self.product_name} login: invalid response")
|
||||
code = payload.get("code")
|
||||
if code not in (0, 200):
|
||||
raise HTTPException(502, detail=f"{self.product_name} login failed: {payload.get('message')}")
|
||||
data = payload.get("data") or {}
|
||||
token = (data.get("token") if isinstance(data, dict) else None) or ""
|
||||
token = str(token).strip()
|
||||
if not token:
|
||||
raise HTTPException(502, detail=f"{self.product_name} login: missing token")
|
||||
return token
|
||||
|
||||
async def _api_json(
|
||||
self,
|
||||
method: str,
|
||||
endpoint: str,
|
||||
*,
|
||||
json: Dict[str, Any] | None = None,
|
||||
headers: Dict[str, str] | None = None,
|
||||
retry: bool = True,
|
||||
files: Any = None,
|
||||
) -> Any:
|
||||
token = await self._ensure_token()
|
||||
url = self.base_url + endpoint
|
||||
req_headers: Dict[str, str] = {}
|
||||
if token:
|
||||
req_headers["Authorization"] = token
|
||||
if headers:
|
||||
req_headers.update(headers)
|
||||
async with httpx.AsyncClient(timeout=self.timeout, follow_redirects=True) as client:
|
||||
resp = await client.request(method, url, json=json, headers=req_headers, files=files)
|
||||
if resp.status_code == 401 and retry and self.use_auth:
|
||||
self._token = None
|
||||
return await self._api_json(method, endpoint, json=json, headers=headers, retry=False, files=files)
|
||||
resp.raise_for_status()
|
||||
payload = resp.json()
|
||||
if not isinstance(payload, dict):
|
||||
raise HTTPException(502, detail=f"{self.product_name} api: invalid response")
|
||||
|
||||
code = payload.get("code")
|
||||
if code in (0, 200):
|
||||
return payload.get("data")
|
||||
if code in (401, 403) and retry and self.use_auth:
|
||||
self._token = None
|
||||
return await self._api_json(method, endpoint, json=json, headers=headers, retry=False, files=files)
|
||||
if code == 404:
|
||||
raise FileNotFoundError(json.get("path") if json else "")
|
||||
msg = payload.get("message") or payload.get("msg") or ""
|
||||
raise HTTPException(502, detail=f"{self.product_name} api error code={code} msg={msg}")
|
||||
|
||||
def _abs_url(self, url: str) -> str:
|
||||
u = (url or "").strip()
|
||||
if not u:
|
||||
return ""
|
||||
if u.startswith("http://") or u.startswith("https://"):
|
||||
return u
|
||||
return urljoin(self.base_url.rstrip("/") + "/", u.lstrip("/"))
|
||||
|
||||
async def _fs_list(self, path: str) -> Dict[str, Any]:
|
||||
body = {"path": path, "password": "", "page": 1, "per_page": 0, "refresh": False}
|
||||
data = await self._api_json("POST", "/api/fs/list", json=body)
|
||||
return data or {}
|
||||
|
||||
async def _fs_get(self, path: str) -> Dict[str, Any]:
|
||||
body = {"path": path, "password": "", "page": 1, "per_page": 0, "refresh": False}
|
||||
data = await self._api_json("POST", "/api/fs/get", json=body)
|
||||
return data or {}
|
||||
|
||||
async def list_dir(
|
||||
self,
|
||||
root: str,
|
||||
rel: str,
|
||||
page_num: int = 1,
|
||||
page_size: int = 50,
|
||||
sort_by: str = "name",
|
||||
sort_order: str = "asc",
|
||||
) -> Tuple[List[Dict], int]:
|
||||
path = _join_fs_path(root, rel)
|
||||
data = await self._fs_list(path)
|
||||
content = data.get("content") or []
|
||||
if not isinstance(content, list):
|
||||
raise HTTPException(502, detail=f"{self.product_name} list_dir: invalid content")
|
||||
|
||||
entries: List[Dict] = []
|
||||
for it in content:
|
||||
if not isinstance(it, dict):
|
||||
continue
|
||||
name = str(it.get("name") or "")
|
||||
if not name:
|
||||
continue
|
||||
is_dir = bool(it.get("is_dir"))
|
||||
size = int(it.get("size") or 0) if not is_dir else 0
|
||||
mtime = _parse_iso_to_epoch(it.get("modified"))
|
||||
entries.append(
|
||||
{
|
||||
"name": name,
|
||||
"is_dir": is_dir,
|
||||
"size": size,
|
||||
"mtime": mtime,
|
||||
"type": "dir" if is_dir else "file",
|
||||
}
|
||||
)
|
||||
|
||||
reverse = sort_order.lower() == "desc"
|
||||
|
||||
def get_sort_key(item: Dict) -> Tuple:
|
||||
key = (not item.get("is_dir"),)
|
||||
f = sort_by.lower()
|
||||
if f == "name":
|
||||
key += (str(item.get("name", "")).lower(),)
|
||||
elif f == "size":
|
||||
key += (int(item.get("size", 0)),)
|
||||
elif f == "mtime":
|
||||
key += (int(item.get("mtime", 0)),)
|
||||
else:
|
||||
key += (str(item.get("name", "")).lower(),)
|
||||
return key
|
||||
|
||||
entries.sort(key=get_sort_key, reverse=reverse)
|
||||
total = len(entries)
|
||||
start = (page_num - 1) * page_size
|
||||
end = start + page_size
|
||||
return entries[start:end], total
|
||||
|
||||
async def stat_file(self, root: str, rel: str):
|
||||
path = _join_fs_path(root, rel)
|
||||
data = await self._fs_get(path)
|
||||
if not data:
|
||||
raise FileNotFoundError(rel)
|
||||
is_dir = bool(data.get("is_dir"))
|
||||
name = str(data.get("name") or (rel.rstrip("/").split("/")[-1] if rel else ""))
|
||||
size = int(data.get("size") or 0) if not is_dir else 0
|
||||
mtime = _parse_iso_to_epoch(data.get("modified"))
|
||||
info = {
|
||||
"name": name,
|
||||
"is_dir": is_dir,
|
||||
"size": size,
|
||||
"mtime": mtime,
|
||||
"type": "dir" if is_dir else "file",
|
||||
"path": path,
|
||||
}
|
||||
return info
|
||||
|
||||
async def stat_path(self, root: str, rel: str):
|
||||
try:
|
||||
info = await self.stat_file(root, rel)
|
||||
return {"exists": True, "is_dir": bool(info.get("is_dir")), "path": info.get("path")}
|
||||
except FileNotFoundError:
|
||||
return {"exists": False, "is_dir": None, "path": _join_fs_path(root, rel)}
|
||||
|
||||
async def exists(self, root: str, rel: str) -> bool:
|
||||
try:
|
||||
await self.stat_file(root, rel)
|
||||
return True
|
||||
except FileNotFoundError:
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def get_direct_download_response(self, root: str, rel: str):
|
||||
if not self.enable_redirect_307:
|
||||
return None
|
||||
data = await self._fs_get(_join_fs_path(root, rel))
|
||||
if not data:
|
||||
raise FileNotFoundError(rel)
|
||||
if bool(data.get("is_dir")):
|
||||
raise IsADirectoryError(rel)
|
||||
raw_url = self._abs_url(str(data.get("raw_url") or ""))
|
||||
if not raw_url:
|
||||
return None
|
||||
return Response(status_code=307, headers={"Location": raw_url})
|
||||
|
||||
async def _get_raw_url_and_meta(self, root: str, rel: str) -> Tuple[str, int, str]:
|
||||
data = await self._fs_get(_join_fs_path(root, rel))
|
||||
if not data:
|
||||
raise FileNotFoundError(rel)
|
||||
if bool(data.get("is_dir")):
|
||||
raise IsADirectoryError(rel)
|
||||
raw_url = self._abs_url(str(data.get("raw_url") or ""))
|
||||
if not raw_url:
|
||||
raise HTTPException(502, detail=f"{self.product_name} missing raw_url")
|
||||
size = int(data.get("size") or 0)
|
||||
name = str(data.get("name") or "")
|
||||
return raw_url, size, name
|
||||
|
||||
async def read_file(self, root: str, rel: str) -> bytes:
|
||||
raw_url, _, _ = await self._get_raw_url_and_meta(root, rel)
|
||||
async with httpx.AsyncClient(timeout=self.timeout, follow_redirects=True) as client:
|
||||
resp = await client.get(raw_url)
|
||||
resp.raise_for_status()
|
||||
return resp.content
|
||||
|
||||
async def stream_file(self, root: str, rel: str, range_header: str | None):
|
||||
raw_url, file_size, name = await self._get_raw_url_and_meta(root, rel)
|
||||
mime, _ = mimetypes.guess_type(name or rel)
|
||||
content_type = mime or "application/octet-stream"
|
||||
|
||||
start = 0
|
||||
end = max(file_size - 1, 0)
|
||||
status = 200
|
||||
headers = {
|
||||
"Accept-Ranges": "bytes",
|
||||
"Content-Type": content_type,
|
||||
}
|
||||
if file_size >= 0:
|
||||
headers["Content-Length"] = str(file_size)
|
||||
|
||||
if range_header and range_header.startswith("bytes="):
|
||||
try:
|
||||
part = range_header.removeprefix("bytes=")
|
||||
s, e = part.split("-", 1)
|
||||
if s.strip():
|
||||
start = int(s)
|
||||
if e.strip():
|
||||
end = int(e)
|
||||
if file_size and start >= file_size:
|
||||
raise HTTPException(416, detail="Requested Range Not Satisfiable")
|
||||
if file_size and end >= file_size:
|
||||
end = file_size - 1
|
||||
status = 206
|
||||
except ValueError:
|
||||
raise HTTPException(400, detail="Invalid Range header")
|
||||
headers["Content-Range"] = f"bytes {start}-{end}/{file_size}"
|
||||
headers["Content-Length"] = str(end - start + 1)
|
||||
|
||||
async def agen():
|
||||
async with httpx.AsyncClient(timeout=self.timeout, follow_redirects=True) as client:
|
||||
req_headers = {"Range": f"bytes={start}-{end}"} if status == 206 else {}
|
||||
async with client.stream("GET", raw_url, headers=req_headers) as resp:
|
||||
resp.raise_for_status()
|
||||
async for chunk in resp.aiter_bytes():
|
||||
if chunk:
|
||||
yield chunk
|
||||
|
||||
return StreamingResponse(agen(), status_code=status, headers=headers, media_type=content_type)
|
||||
|
||||
async def _upload_file(self, full_path: str, file_path: Path) -> Any:
|
||||
token = await self._ensure_token()
|
||||
headers = {"File-Path": quote(full_path, safe="/")}
|
||||
if token:
|
||||
headers["Authorization"] = token
|
||||
with file_path.open("rb") as f:
|
||||
files = {"file": (file_path.name, f, "application/octet-stream")}
|
||||
async with httpx.AsyncClient(timeout=self.timeout, follow_redirects=True) as client:
|
||||
resp = await client.put(self.base_url + "/api/fs/form", headers=headers, files=files)
|
||||
resp.raise_for_status()
|
||||
payload = resp.json()
|
||||
if not isinstance(payload, dict):
|
||||
raise HTTPException(502, detail=f"{self.product_name} upload: invalid response")
|
||||
code = payload.get("code")
|
||||
if code not in (0, 200):
|
||||
msg = payload.get("message") or payload.get("msg") or ""
|
||||
raise HTTPException(502, detail=f"{self.product_name} upload failed: {msg}")
|
||||
return payload.get("data")
|
||||
|
||||
async def write_file(self, root: str, rel: str, data: bytes):
|
||||
full_path = _join_fs_path(root, rel)
|
||||
suffix = Path(rel).suffix
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tf:
|
||||
tf.write(data)
|
||||
tmp_path = Path(tf.name)
|
||||
try:
|
||||
await self._upload_file(full_path, tmp_path)
|
||||
finally:
|
||||
try:
|
||||
tmp_path.unlink(missing_ok=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def write_upload_file(self, root: str, rel: str, file_obj, filename: str | None, file_size: int | None = None, content_type: str | None = None):
|
||||
full_path = _join_fs_path(root, rel)
|
||||
token = await self._ensure_token()
|
||||
headers = {"File-Path": quote(full_path, safe="/")}
|
||||
if token:
|
||||
headers["Authorization"] = token
|
||||
name = filename or Path(rel).name or "file"
|
||||
mime = content_type or "application/octet-stream"
|
||||
files = {"file": (name, file_obj, mime)}
|
||||
async with httpx.AsyncClient(timeout=self.timeout, follow_redirects=True) as client:
|
||||
resp = await client.put(self.base_url + "/api/fs/form", headers=headers, files=files)
|
||||
resp.raise_for_status()
|
||||
payload = resp.json()
|
||||
if not isinstance(payload, dict):
|
||||
raise HTTPException(502, detail=f"{self.product_name} upload: invalid response")
|
||||
code = payload.get("code")
|
||||
if code not in (0, 200):
|
||||
msg = payload.get("message") or payload.get("msg") or ""
|
||||
raise HTTPException(502, detail=f"{self.product_name} upload failed: {msg}")
|
||||
data = payload.get("data")
|
||||
if isinstance(data, dict) and file_size is not None and "size" not in data:
|
||||
data["size"] = file_size
|
||||
return data
|
||||
|
||||
async def write_file_stream(self, root: str, rel: str, data_iter: AsyncIterator[bytes]):
|
||||
full_path = _join_fs_path(root, rel)
|
||||
suffix = Path(rel).suffix
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tf:
|
||||
tmp_path = Path(tf.name)
|
||||
size = 0
|
||||
try:
|
||||
with tmp_path.open("wb") as f:
|
||||
async for chunk in data_iter:
|
||||
if not chunk:
|
||||
continue
|
||||
f.write(chunk)
|
||||
size += len(chunk)
|
||||
await self._upload_file(full_path, tmp_path)
|
||||
return size
|
||||
finally:
|
||||
try:
|
||||
tmp_path.unlink(missing_ok=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def mkdir(self, root: str, rel: str):
|
||||
path = _join_fs_path(root, rel)
|
||||
await self._api_json("POST", "/api/fs/mkdir", json={"path": path})
|
||||
|
||||
async def delete(self, root: str, rel: str):
|
||||
path = _join_fs_path(root, rel)
|
||||
parent, name = _split_parent_and_name(path)
|
||||
if not name:
|
||||
return
|
||||
await self._api_json("POST", "/api/fs/remove", json={"dir": parent, "names": [name]})
|
||||
|
||||
async def move(self, root: str, src_rel: str, dst_rel: str):
|
||||
src_path = _join_fs_path(root, src_rel)
|
||||
dst_path = _join_fs_path(root, dst_rel)
|
||||
src_dir, src_name = _split_parent_and_name(src_path)
|
||||
dst_dir, dst_name = _split_parent_and_name(dst_path)
|
||||
if not src_name or not dst_name:
|
||||
raise HTTPException(400, detail="Invalid move path")
|
||||
|
||||
if src_dir == dst_dir:
|
||||
if src_name == dst_name:
|
||||
return
|
||||
await self._api_json("POST", "/api/fs/rename", json={"path": src_path, "name": dst_name})
|
||||
return
|
||||
|
||||
await self._api_json("POST", "/api/fs/move", json={"src_dir": src_dir, "dst_dir": dst_dir, "names": [src_name]})
|
||||
if src_name != dst_name:
|
||||
await self._api_json("POST", "/api/fs/rename", json={"path": _join_fs_path(dst_dir, src_name), "name": dst_name})
|
||||
|
||||
async def rename(self, root: str, src_rel: str, dst_rel: str):
|
||||
await self.move(root, src_rel, dst_rel)
|
||||
|
||||
async def copy(self, root: str, src_rel: str, dst_rel: str, overwrite: bool = False):
|
||||
src_path = _join_fs_path(root, src_rel)
|
||||
dst_path = _join_fs_path(root, dst_rel)
|
||||
src_dir, src_name = _split_parent_and_name(src_path)
|
||||
dst_dir, dst_name = _split_parent_and_name(dst_path)
|
||||
if not src_name or not dst_name:
|
||||
raise HTTPException(400, detail="Invalid copy path")
|
||||
|
||||
src_info = await self._fs_get(src_path)
|
||||
if not src_info:
|
||||
raise FileNotFoundError(src_rel)
|
||||
|
||||
if src_name != dst_name and not bool(src_info.get("is_dir")):
|
||||
raw_url, _, _ = await self._get_raw_url_and_meta(root, src_rel)
|
||||
async with httpx.AsyncClient(timeout=self.timeout, follow_redirects=True) as client:
|
||||
async with client.stream("GET", raw_url) as resp:
|
||||
resp.raise_for_status()
|
||||
|
||||
async def gen():
|
||||
async for chunk in resp.aiter_bytes():
|
||||
if chunk:
|
||||
yield chunk
|
||||
|
||||
await self.write_file_stream(root, dst_rel, gen())
|
||||
return
|
||||
|
||||
await self._api_json("POST", "/api/fs/copy", json={"src_dir": src_dir, "dst_dir": dst_dir, "names": [src_name]})
|
||||
if src_name != dst_name:
|
||||
await self._api_json("POST", "/api/fs/rename", json={"path": _join_fs_path(dst_dir, src_name), "name": dst_name})
|
||||
|
||||
|
||||
class AListAdapter(AListApiAdapterBase):
|
||||
def __init__(self, record: StorageAdapter):
|
||||
super().__init__(record, product_name="AList")
|
||||
|
||||
|
||||
class OpenListAdapter(AListApiAdapterBase):
|
||||
def __init__(self, record: StorageAdapter):
|
||||
super().__init__(record, product_name="OpenList")
|
||||
|
||||
|
||||
ADAPTER_TYPES = {"alist": AListAdapter, "openlist": OpenListAdapter}
|
||||
|
||||
CONFIG_SCHEMA = [
|
||||
{"key": "base_url", "label": "基础地址", "type": "string", "required": True, "placeholder": "http://127.0.0.1:5244"},
|
||||
{"key": "username", "label": "用户名", "type": "string", "required": False, "placeholder": "留空则匿名访问"},
|
||||
{"key": "password", "label": "密码", "type": "password", "required": False, "placeholder": "留空则匿名访问"},
|
||||
{"key": "root", "label": "根目录", "type": "string", "required": False, "default": "/"},
|
||||
{"key": "timeout", "label": "超时(秒)", "type": "number", "required": False, "default": 30},
|
||||
{"key": "enable_direct_download_307", "label": "启用 307 直链下载", "type": "boolean", "default": False},
|
||||
]
|
||||
@@ -1,4 +1,3 @@
|
||||
from __future__ import annotations
|
||||
from typing import List, Dict, Protocol, runtime_checkable, Tuple, AsyncIterator
|
||||
from models import StorageAdapter
|
||||
|
||||
471
domain/adapters/providers/dropbox.py
Normal file
471
domain/adapters/providers/dropbox.py
Normal file
@@ -0,0 +1,471 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import mimetypes
|
||||
import re
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import AsyncIterator, Dict, List, Tuple
|
||||
|
||||
import httpx
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import Response, StreamingResponse
|
||||
|
||||
from models import StorageAdapter
|
||||
|
||||
DROPBOX_OAUTH_URL = "https://api.dropboxapi.com/oauth2/token"
|
||||
DROPBOX_API_URL = "https://api.dropboxapi.com/2"
|
||||
DROPBOX_CONTENT_URL = "https://content.dropboxapi.com/2"
|
||||
|
||||
|
||||
def _normalize_dbx_path(path: str | None) -> str:
|
||||
path = (path or "").replace("\\", "/").strip()
|
||||
if not path or path == "/":
|
||||
return ""
|
||||
if not path.startswith("/"):
|
||||
path = "/" + path
|
||||
path = re.sub(r"/{2,}", "/", path)
|
||||
if path.endswith("/"):
|
||||
path = path.rstrip("/")
|
||||
return path
|
||||
|
||||
|
||||
def _join_dbx_path(base: str, rel: str) -> str:
|
||||
base = _normalize_dbx_path(base)
|
||||
rel = (rel or "").replace("\\", "/").strip("/")
|
||||
if not rel:
|
||||
return base
|
||||
if not base:
|
||||
return "/" + rel
|
||||
return f"{base}/{rel}"
|
||||
|
||||
|
||||
def _parse_iso_to_epoch(value: str | None) -> int:
|
||||
if not value:
|
||||
return 0
|
||||
text = str(value).strip()
|
||||
if not text:
|
||||
return 0
|
||||
try:
|
||||
if text.endswith("Z"):
|
||||
text = text[:-1] + "+00:00"
|
||||
dt = datetime.fromisoformat(text)
|
||||
if dt.tzinfo is None:
|
||||
dt = dt.replace(tzinfo=timezone.utc)
|
||||
return int(dt.timestamp())
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
|
||||
class DropboxAdapter:
|
||||
def __init__(self, record: StorageAdapter):
|
||||
self.record = record
|
||||
cfg = record.config or {}
|
||||
|
||||
self.app_key: str = str(cfg.get("app_key") or "").strip()
|
||||
self.app_secret: str = str(cfg.get("app_secret") or "").strip()
|
||||
self.refresh_token: str = str(cfg.get("refresh_token") or "").strip()
|
||||
self.root_path: str = _normalize_dbx_path(str(cfg.get("root") or "/"))
|
||||
self.enable_redirect_307: bool = bool(cfg.get("enable_direct_download_307"))
|
||||
self.timeout: float = float(cfg.get("timeout", 60))
|
||||
|
||||
if not (self.app_key and self.app_secret and self.refresh_token):
|
||||
raise ValueError("Dropbox 适配器需要 app_key, app_secret, refresh_token")
|
||||
|
||||
self._access_token: str | None = None
|
||||
self._token_expiry: datetime | None = None
|
||||
self._token_lock = asyncio.Lock()
|
||||
|
||||
def get_effective_root(self, sub_path: str | None) -> str:
|
||||
base = _normalize_dbx_path(self.root_path)
|
||||
if sub_path:
|
||||
return _join_dbx_path(base, sub_path)
|
||||
return base
|
||||
|
||||
async def _get_access_token(self) -> str:
|
||||
if self._access_token and self._token_expiry and datetime.now(timezone.utc) < self._token_expiry:
|
||||
return self._access_token
|
||||
|
||||
async with self._token_lock:
|
||||
if self._access_token and self._token_expiry and datetime.now(timezone.utc) < self._token_expiry:
|
||||
return self._access_token
|
||||
|
||||
basic = base64.b64encode(f"{self.app_key}:{self.app_secret}".encode("utf-8")).decode("ascii")
|
||||
headers = {"Authorization": f"Basic {basic}"}
|
||||
data = {"grant_type": "refresh_token", "refresh_token": self.refresh_token}
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
resp = await client.post(DROPBOX_OAUTH_URL, data=data, headers=headers)
|
||||
resp.raise_for_status()
|
||||
|
||||
payload = resp.json()
|
||||
token = str(payload.get("access_token") or "").strip()
|
||||
if not token:
|
||||
raise HTTPException(502, detail="Dropbox oauth: missing access_token")
|
||||
expires_in = int(payload.get("expires_in") or 3600)
|
||||
self._access_token = token
|
||||
self._token_expiry = datetime.now(timezone.utc) + timedelta(seconds=max(60, expires_in - 300))
|
||||
return token
|
||||
|
||||
async def _api_json(self, endpoint: str, body: Dict) -> httpx.Response:
|
||||
token = await self._get_access_token()
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
return await client.post(f"{DROPBOX_API_URL}{endpoint}", json=body, headers=headers)
|
||||
|
||||
async def _content_request(
|
||||
self,
|
||||
endpoint: str,
|
||||
api_arg: Dict,
|
||||
*,
|
||||
content: bytes | None = None,
|
||||
data_iter: AsyncIterator[bytes] | None = None,
|
||||
extra_headers: Dict[str, str] | None = None,
|
||||
) -> httpx.Response:
|
||||
token = await self._get_access_token()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Dropbox-API-Arg": json.dumps(api_arg, separators=(",", ":"), ensure_ascii=False),
|
||||
}
|
||||
if extra_headers:
|
||||
headers.update(extra_headers)
|
||||
|
||||
if data_iter is None:
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
return await client.post(f"{DROPBOX_CONTENT_URL}{endpoint}", headers=headers, content=content or b"")
|
||||
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
return await client.post(f"{DROPBOX_CONTENT_URL}{endpoint}", headers=headers, content=data_iter)
|
||||
|
||||
@staticmethod
|
||||
def _raise_dbx_error(resp: httpx.Response, *, rel: str):
|
||||
try:
|
||||
payload = resp.json()
|
||||
except Exception:
|
||||
payload = None
|
||||
summary = ""
|
||||
if isinstance(payload, dict):
|
||||
summary = str(payload.get("error_summary") or "")
|
||||
if "not_found" in summary:
|
||||
raise FileNotFoundError(rel)
|
||||
if "conflict" in summary or "already_exists" in summary:
|
||||
raise FileExistsError(rel)
|
||||
if "is_folder" in summary:
|
||||
raise IsADirectoryError(rel)
|
||||
if "not_folder" in summary:
|
||||
raise NotADirectoryError(rel)
|
||||
raise HTTPException(502, detail=f"Dropbox API error: {summary or resp.text}")
|
||||
|
||||
def _format_entry(self, entry: Dict) -> Dict:
|
||||
tag = entry.get(".tag")
|
||||
is_dir = tag == "folder"
|
||||
mtime = _parse_iso_to_epoch(entry.get("server_modified") if not is_dir else None)
|
||||
return {
|
||||
"name": entry.get("name") or "",
|
||||
"is_dir": is_dir,
|
||||
"size": 0 if is_dir else int(entry.get("size") or 0),
|
||||
"mtime": mtime,
|
||||
"type": "dir" if is_dir else "file",
|
||||
}
|
||||
|
||||
async def list_dir(
|
||||
self,
|
||||
root: str,
|
||||
rel: str,
|
||||
page_num: int = 1,
|
||||
page_size: int = 50,
|
||||
sort_by: str = "name",
|
||||
sort_order: str = "asc",
|
||||
) -> Tuple[List[Dict], int]:
|
||||
path = _join_dbx_path(root, rel)
|
||||
body = {"path": path, "recursive": False, "include_deleted": False, "limit": 2000}
|
||||
resp = await self._api_json("/files/list_folder", body)
|
||||
if resp.status_code == 409:
|
||||
try:
|
||||
payload = resp.json()
|
||||
except Exception:
|
||||
payload = None
|
||||
summary = str((payload or {}).get("error_summary") or "")
|
||||
if "not_found" in summary:
|
||||
return [], 0
|
||||
self._raise_dbx_error(resp, rel=rel)
|
||||
resp.raise_for_status()
|
||||
payload = resp.json()
|
||||
|
||||
all_entries: List[Dict] = []
|
||||
all_entries.extend(payload.get("entries") or [])
|
||||
cursor = payload.get("cursor")
|
||||
has_more = bool(payload.get("has_more"))
|
||||
while has_more and cursor:
|
||||
resp2 = await self._api_json("/files/list_folder/continue", {"cursor": cursor})
|
||||
resp2.raise_for_status()
|
||||
p2 = resp2.json()
|
||||
all_entries.extend(p2.get("entries") or [])
|
||||
cursor = p2.get("cursor")
|
||||
has_more = bool(p2.get("has_more"))
|
||||
|
||||
items = [self._format_entry(e) for e in all_entries if isinstance(e, dict)]
|
||||
|
||||
reverse = sort_order.lower() == "desc"
|
||||
|
||||
def get_sort_key(item):
|
||||
key = (not item["is_dir"],)
|
||||
f = sort_by.lower()
|
||||
if f == "name":
|
||||
key += (item["name"].lower(),)
|
||||
elif f == "size":
|
||||
key += (item["size"],)
|
||||
elif f == "mtime":
|
||||
key += (item["mtime"],)
|
||||
else:
|
||||
key += (item["name"].lower(),)
|
||||
return key
|
||||
|
||||
items.sort(key=get_sort_key, reverse=reverse)
|
||||
|
||||
total = len(items)
|
||||
start = (page_num - 1) * page_size
|
||||
end = start + page_size
|
||||
return items[start:end], total
|
||||
|
||||
async def stat_file(self, root: str, rel: str):
|
||||
path = _join_dbx_path(root, rel)
|
||||
resp = await self._api_json("/files/get_metadata", {"path": path, "include_deleted": False})
|
||||
if resp.status_code == 409:
|
||||
self._raise_dbx_error(resp, rel=rel)
|
||||
resp.raise_for_status()
|
||||
meta = resp.json()
|
||||
if not isinstance(meta, dict):
|
||||
raise HTTPException(502, detail="Dropbox metadata: invalid response")
|
||||
return self._format_entry(meta)
|
||||
|
||||
async def exists(self, root: str, rel: str) -> bool:
|
||||
try:
|
||||
await self.stat_file(root, rel)
|
||||
return True
|
||||
except FileNotFoundError:
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def read_file(self, root: str, rel: str) -> bytes:
|
||||
path = _join_dbx_path(root, rel)
|
||||
resp = await self._content_request("/files/download", {"path": path})
|
||||
if resp.status_code == 409:
|
||||
self._raise_dbx_error(resp, rel=rel)
|
||||
resp.raise_for_status()
|
||||
return resp.content
|
||||
|
||||
async def write_file(self, root: str, rel: str, data: bytes):
|
||||
path = _join_dbx_path(root, rel)
|
||||
arg = {
|
||||
"path": path,
|
||||
"mode": "overwrite",
|
||||
"autorename": False,
|
||||
"mute": False,
|
||||
"strict_conflict": False,
|
||||
}
|
||||
resp = await self._content_request(
|
||||
"/files/upload",
|
||||
arg,
|
||||
content=data,
|
||||
extra_headers={"Content-Type": "application/octet-stream"},
|
||||
)
|
||||
if resp.status_code == 409:
|
||||
self._raise_dbx_error(resp, rel=rel)
|
||||
resp.raise_for_status()
|
||||
return True
|
||||
|
||||
async def write_file_stream(self, root: str, rel: str, data_iter: AsyncIterator[bytes]):
|
||||
path = _join_dbx_path(root, rel)
|
||||
|
||||
size = 0
|
||||
session_id: str | None = None
|
||||
offset = 0
|
||||
|
||||
async for chunk in data_iter:
|
||||
if not chunk:
|
||||
continue
|
||||
if session_id is None:
|
||||
resp = await self._content_request(
|
||||
"/files/upload_session_start",
|
||||
{"close": False},
|
||||
content=chunk,
|
||||
extra_headers={"Content-Type": "application/octet-stream"},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
payload = resp.json()
|
||||
session_id = str(payload.get("session_id") or "")
|
||||
if not session_id:
|
||||
raise HTTPException(502, detail="Dropbox upload_session_start: missing session_id")
|
||||
offset += len(chunk)
|
||||
size += len(chunk)
|
||||
continue
|
||||
|
||||
arg = {"cursor": {"session_id": session_id, "offset": offset}, "close": False}
|
||||
resp = await self._content_request(
|
||||
"/files/upload_session_append_v2",
|
||||
arg,
|
||||
content=chunk,
|
||||
extra_headers={"Content-Type": "application/octet-stream"},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
offset += len(chunk)
|
||||
size += len(chunk)
|
||||
|
||||
if session_id is None:
|
||||
await self.write_file(root, rel, b"")
|
||||
return 0
|
||||
|
||||
finish_arg = {
|
||||
"cursor": {"session_id": session_id, "offset": offset},
|
||||
"commit": {
|
||||
"path": path,
|
||||
"mode": "overwrite",
|
||||
"autorename": False,
|
||||
"mute": False,
|
||||
"strict_conflict": False,
|
||||
},
|
||||
}
|
||||
resp = await self._content_request(
|
||||
"/files/upload_session_finish",
|
||||
finish_arg,
|
||||
content=b"",
|
||||
extra_headers={"Content-Type": "application/octet-stream"},
|
||||
)
|
||||
if resp.status_code == 409:
|
||||
self._raise_dbx_error(resp, rel=rel)
|
||||
resp.raise_for_status()
|
||||
return size
|
||||
|
||||
async def mkdir(self, root: str, rel: str):
|
||||
path = _join_dbx_path(root, rel)
|
||||
resp = await self._api_json("/files/create_folder_v2", {"path": path, "autorename": False})
|
||||
if resp.status_code == 409:
|
||||
self._raise_dbx_error(resp, rel=rel)
|
||||
resp.raise_for_status()
|
||||
return True
|
||||
|
||||
async def delete(self, root: str, rel: str):
|
||||
path = _join_dbx_path(root, rel)
|
||||
resp = await self._api_json("/files/delete_v2", {"path": path})
|
||||
if resp.status_code == 409:
|
||||
try:
|
||||
payload = resp.json()
|
||||
except Exception:
|
||||
payload = None
|
||||
summary = str((payload or {}).get("error_summary") or "")
|
||||
if "not_found" in summary:
|
||||
return
|
||||
self._raise_dbx_error(resp, rel=rel)
|
||||
resp.raise_for_status()
|
||||
return True
|
||||
|
||||
async def move(self, root: str, src_rel: str, dst_rel: str):
|
||||
src = _join_dbx_path(root, src_rel)
|
||||
dst = _join_dbx_path(root, dst_rel)
|
||||
resp = await self._api_json(
|
||||
"/files/move_v2",
|
||||
{"from_path": src, "to_path": dst, "autorename": False, "allow_shared_folder": True},
|
||||
)
|
||||
if resp.status_code == 409:
|
||||
self._raise_dbx_error(resp, rel=src_rel)
|
||||
resp.raise_for_status()
|
||||
return True
|
||||
|
||||
async def rename(self, root: str, src_rel: str, dst_rel: str):
|
||||
return await self.move(root, src_rel, dst_rel)
|
||||
|
||||
async def copy(self, root: str, src_rel: str, dst_rel: str, overwrite: bool = False):
|
||||
src = _join_dbx_path(root, src_rel)
|
||||
dst = _join_dbx_path(root, dst_rel)
|
||||
resp = await self._api_json(
|
||||
"/files/copy_v2",
|
||||
{"from_path": src, "to_path": dst, "autorename": False, "allow_shared_folder": True},
|
||||
)
|
||||
if resp.status_code == 409:
|
||||
self._raise_dbx_error(resp, rel=dst_rel if overwrite else dst_rel)
|
||||
resp.raise_for_status()
|
||||
return True
|
||||
|
||||
async def get_direct_download_response(self, root: str, rel: str):
|
||||
if not self.enable_redirect_307:
|
||||
return None
|
||||
|
||||
path = _join_dbx_path(root, rel)
|
||||
resp = await self._api_json("/files/get_temporary_link", {"path": path})
|
||||
if resp.status_code == 409:
|
||||
self._raise_dbx_error(resp, rel=rel)
|
||||
resp.raise_for_status()
|
||||
payload = resp.json()
|
||||
link = (payload.get("link") if isinstance(payload, dict) else None) or ""
|
||||
link = str(link).strip()
|
||||
if not link:
|
||||
return None
|
||||
return Response(status_code=307, headers={"Location": link})
|
||||
|
||||
async def stream_file(self, root: str, rel: str, range_header: str | None):
|
||||
path = _join_dbx_path(root, rel)
|
||||
token = await self._get_access_token()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Dropbox-API-Arg": json.dumps({"path": path}, separators=(",", ":"), ensure_ascii=False),
|
||||
}
|
||||
if range_header:
|
||||
headers["Range"] = range_header
|
||||
|
||||
client = httpx.AsyncClient(timeout=None)
|
||||
stream_cm = client.stream("POST", f"{DROPBOX_CONTENT_URL}/files/download", headers=headers)
|
||||
try:
|
||||
resp = await stream_cm.__aenter__()
|
||||
except Exception:
|
||||
await client.aclose()
|
||||
raise
|
||||
|
||||
if resp.status_code == 409:
|
||||
try:
|
||||
content = await resp.aread()
|
||||
_ = content
|
||||
finally:
|
||||
await stream_cm.__aexit__(None, None, None)
|
||||
await client.aclose()
|
||||
self._raise_dbx_error(resp, rel=rel)
|
||||
|
||||
if resp.status_code >= 400:
|
||||
try:
|
||||
await resp.aread()
|
||||
finally:
|
||||
await stream_cm.__aexit__(None, None, None)
|
||||
await client.aclose()
|
||||
resp.raise_for_status()
|
||||
|
||||
content_type = resp.headers.get("Content-Type") or (mimetypes.guess_type(rel)[0] or "application/octet-stream")
|
||||
out_headers = {}
|
||||
for key in ("Accept-Ranges", "Content-Range", "Content-Length"):
|
||||
value = resp.headers.get(key)
|
||||
if value:
|
||||
out_headers[key] = value
|
||||
|
||||
async def iterator():
|
||||
try:
|
||||
async for chunk in resp.aiter_bytes():
|
||||
if chunk:
|
||||
yield chunk
|
||||
finally:
|
||||
await stream_cm.__aexit__(None, None, None)
|
||||
await client.aclose()
|
||||
|
||||
return StreamingResponse(iterator(), status_code=resp.status_code, headers=out_headers, media_type=content_type)
|
||||
|
||||
|
||||
ADAPTER_TYPE = "dropbox"
|
||||
CONFIG_SCHEMA = [
|
||||
{"key": "app_key", "label": "App Key", "type": "string", "required": True},
|
||||
{"key": "app_secret", "label": "App Secret", "type": "password", "required": True},
|
||||
{"key": "refresh_token", "label": "Refresh Token", "type": "password", "required": True},
|
||||
{"key": "root", "label": "Root Path", "type": "string", "required": False, "default": "/", "placeholder": "/ or /Apps/Foxel"},
|
||||
{"key": "timeout", "label": "超时(秒)", "type": "number", "required": False, "default": 60},
|
||||
{"key": "enable_direct_download_307", "label": "Enable 307 redirect download", "type": "boolean", "default": False},
|
||||
]
|
||||
|
||||
|
||||
def ADAPTER_FACTORY(rec): return DropboxAdapter(rec)
|
||||
|
||||
435
domain/adapters/providers/foxel.py
Normal file
435
domain/adapters/providers/foxel.py
Normal file
@@ -0,0 +1,435 @@
|
||||
import asyncio
|
||||
import mimetypes
|
||||
import re
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any, AsyncIterator, Dict, List, Tuple
|
||||
from urllib.parse import quote
|
||||
|
||||
import httpx
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from models import StorageAdapter
|
||||
|
||||
|
||||
def _normalize_fs_path(path: str) -> str:
|
||||
path = (path or "").replace("\\", "/").strip()
|
||||
if not path or path == "/":
|
||||
return "/"
|
||||
if not path.startswith("/"):
|
||||
path = "/" + path
|
||||
path = re.sub(r"/{2,}", "/", path)
|
||||
if path != "/" and path.endswith("/"):
|
||||
path = path.rstrip("/")
|
||||
return path or "/"
|
||||
|
||||
|
||||
def _join_fs_path(base: str, rel: str | None) -> str:
|
||||
base = _normalize_fs_path(base)
|
||||
rel_norm = (rel or "").replace("\\", "/").strip().lstrip("/")
|
||||
if not rel_norm:
|
||||
return base
|
||||
if base == "/":
|
||||
return "/" + rel_norm
|
||||
return f"{base}/{rel_norm}"
|
||||
|
||||
|
||||
def _unwrap_success(payload: Any, *, context: str) -> Any:
|
||||
if not isinstance(payload, dict):
|
||||
return payload
|
||||
if "data" not in payload:
|
||||
return payload
|
||||
code = payload.get("code")
|
||||
if code not in (None, 0, 200):
|
||||
msg = payload.get("msg") or payload.get("message") or ""
|
||||
raise HTTPException(502, detail=f"Foxel 上游错误({context}): {msg}")
|
||||
return payload.get("data")
|
||||
|
||||
|
||||
class FoxelAdapter:
|
||||
def __init__(self, record: StorageAdapter):
|
||||
self.record = record
|
||||
cfg = record.config or {}
|
||||
|
||||
self.base_url: str = str(cfg.get("base_url", "")).rstrip("/")
|
||||
if not self.base_url.startswith("http"):
|
||||
raise ValueError("foxel requires base_url http/https")
|
||||
|
||||
self.username: str = str(cfg.get("username") or "")
|
||||
self.password: str = str(cfg.get("password") or "")
|
||||
if not self.username or not self.password:
|
||||
raise ValueError("foxel requires username and password")
|
||||
|
||||
self.timeout: float = float(cfg.get("timeout", 15))
|
||||
self.root_path: str = _normalize_fs_path(str(cfg.get("root") or "/"))
|
||||
|
||||
self._token: str | None = None
|
||||
self._login_lock = asyncio.Lock()
|
||||
|
||||
def get_effective_root(self, sub_path: str | None) -> str:
|
||||
base = _normalize_fs_path(self.root_path)
|
||||
if sub_path:
|
||||
return _join_fs_path(base, sub_path)
|
||||
return base
|
||||
|
||||
async def _login(self) -> str:
|
||||
url = self.base_url + "/api/auth/login"
|
||||
body = {"username": self.username, "password": self.password}
|
||||
async with httpx.AsyncClient(timeout=self.timeout, follow_redirects=True) as client:
|
||||
resp = await client.post(url, data=body)
|
||||
resp.raise_for_status()
|
||||
payload = resp.json()
|
||||
if not isinstance(payload, dict):
|
||||
raise HTTPException(502, detail="Foxel 登录响应异常")
|
||||
token = payload.get("access_token")
|
||||
if not token:
|
||||
raise HTTPException(502, detail="Foxel 登录失败: 缺少 access_token")
|
||||
return str(token)
|
||||
|
||||
async def _ensure_token(self) -> str:
|
||||
if self._token:
|
||||
return self._token
|
||||
async with self._login_lock:
|
||||
if self._token:
|
||||
return self._token
|
||||
self._token = await self._login()
|
||||
return self._token
|
||||
|
||||
async def _request_json(self, method: str, path: str, *, params: dict | None = None, json: Any = None) -> Any:
|
||||
url = self.base_url + path
|
||||
for attempt in range(2):
|
||||
token = await self._ensure_token()
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
async with httpx.AsyncClient(timeout=self.timeout, follow_redirects=True) as client:
|
||||
resp = await client.request(method, url, headers=headers, params=params, json=json)
|
||||
if resp.status_code == 401 and attempt == 0:
|
||||
self._token = None
|
||||
continue
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
raise HTTPException(502, detail="Foxel 上游请求失败")
|
||||
|
||||
@staticmethod
|
||||
def _encode_path(full_path: str) -> str:
|
||||
return quote(full_path.lstrip("/"), safe="/")
|
||||
|
||||
def _browse_path(self, full_path: str) -> str:
|
||||
full_path = _normalize_fs_path(full_path)
|
||||
if full_path == "/":
|
||||
return "/api/fs/"
|
||||
return "/api/fs/" + self._encode_path(full_path)
|
||||
|
||||
def _stat_path(self, full_path: str) -> str:
|
||||
full_path = _normalize_fs_path(full_path)
|
||||
if full_path == "/":
|
||||
return "/api/fs/stat/"
|
||||
return "/api/fs/stat/" + self._encode_path(full_path)
|
||||
|
||||
def _file_path(self, full_path: str) -> str:
|
||||
full_path = _normalize_fs_path(full_path)
|
||||
if full_path == "/":
|
||||
return "/api/fs/file/"
|
||||
return "/api/fs/file/" + self._encode_path(full_path)
|
||||
|
||||
def _stream_path(self, full_path: str) -> str:
|
||||
full_path = _normalize_fs_path(full_path)
|
||||
if full_path == "/":
|
||||
return "/api/fs/stream/"
|
||||
return "/api/fs/stream/" + self._encode_path(full_path)
|
||||
|
||||
async def list_dir(
|
||||
self,
|
||||
root: str,
|
||||
rel: str,
|
||||
page_num: int = 1,
|
||||
page_size: int = 50,
|
||||
sort_by: str = "name",
|
||||
sort_order: str = "asc",
|
||||
) -> Tuple[List[Dict], int]:
|
||||
rel = (rel or "").strip("/")
|
||||
full_path = _join_fs_path(root, rel)
|
||||
payload = await self._request_json(
|
||||
"GET",
|
||||
self._browse_path(full_path),
|
||||
params={
|
||||
"page": page_num,
|
||||
"page_size": page_size,
|
||||
"sort_by": sort_by,
|
||||
"sort_order": sort_order,
|
||||
},
|
||||
)
|
||||
data = _unwrap_success(payload, context="list_dir")
|
||||
if not isinstance(data, dict):
|
||||
raise HTTPException(502, detail="Foxel 浏览响应异常")
|
||||
entries = data.get("entries") or []
|
||||
pagination = data.get("pagination") or {}
|
||||
total = pagination.get("total")
|
||||
try:
|
||||
total_int = int(total) if total is not None else len(entries)
|
||||
except Exception:
|
||||
total_int = len(entries)
|
||||
if not isinstance(entries, list):
|
||||
entries = []
|
||||
return entries, total_int
|
||||
|
||||
async def stat_file(self, root: str, rel: str):
|
||||
rel = (rel or "").strip("/")
|
||||
full_path = _join_fs_path(root, rel)
|
||||
payload = await self._request_json("GET", self._stat_path(full_path))
|
||||
data = _unwrap_success(payload, context="stat_file")
|
||||
if not isinstance(data, dict):
|
||||
raise HTTPException(502, detail="Foxel stat 响应异常")
|
||||
return data
|
||||
|
||||
async def exists(self, root: str, rel: str) -> bool:
|
||||
rel = (rel or "").strip("/")
|
||||
full_path = _join_fs_path(root, rel)
|
||||
url = self.base_url + self._stat_path(full_path)
|
||||
for attempt in range(2):
|
||||
token = await self._ensure_token()
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
async with httpx.AsyncClient(timeout=self.timeout, follow_redirects=True) as client:
|
||||
resp = await client.get(url, headers=headers)
|
||||
if resp.status_code == 401 and attempt == 0:
|
||||
self._token = None
|
||||
continue
|
||||
return resp.status_code == 200
|
||||
return False
|
||||
|
||||
async def read_file(self, root: str, rel: str) -> bytes:
|
||||
rel = (rel or "").lstrip("/")
|
||||
full_path = _join_fs_path(root, rel)
|
||||
url = self.base_url + self._file_path(full_path)
|
||||
for attempt in range(2):
|
||||
token = await self._ensure_token()
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
async with httpx.AsyncClient(timeout=self.timeout, follow_redirects=True) as client:
|
||||
resp = await client.get(url, headers=headers)
|
||||
if resp.status_code == 401 and attempt == 0:
|
||||
self._token = None
|
||||
continue
|
||||
if resp.status_code == 404:
|
||||
raise FileNotFoundError(rel)
|
||||
resp.raise_for_status()
|
||||
return resp.content
|
||||
raise HTTPException(502, detail="Foxel 读取失败")
|
||||
|
||||
async def _upload_file_path(self, full_path: str, file_path: Path) -> None:
|
||||
url = self.base_url + self._file_path(full_path)
|
||||
filename = Path(full_path).name or file_path.name
|
||||
for attempt in range(2):
|
||||
token = await self._ensure_token()
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
with file_path.open("rb") as f:
|
||||
files = {"file": (filename, f, "application/octet-stream")}
|
||||
async with httpx.AsyncClient(timeout=self.timeout, follow_redirects=True) as client:
|
||||
resp = await client.post(url, headers=headers, files=files)
|
||||
if resp.status_code == 401 and attempt == 0:
|
||||
self._token = None
|
||||
continue
|
||||
resp.raise_for_status()
|
||||
return
|
||||
raise HTTPException(502, detail="Foxel 上传失败")
|
||||
|
||||
async def write_file(self, root: str, rel: str, data: bytes):
|
||||
rel = (rel or "").lstrip("/")
|
||||
full_path = _join_fs_path(root, rel)
|
||||
url = self.base_url + self._file_path(full_path)
|
||||
filename = Path(rel).name or "file"
|
||||
for attempt in range(2):
|
||||
token = await self._ensure_token()
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
files = {"file": (filename, data, "application/octet-stream")}
|
||||
async with httpx.AsyncClient(timeout=self.timeout, follow_redirects=True) as client:
|
||||
resp = await client.post(url, headers=headers, files=files)
|
||||
if resp.status_code == 401 and attempt == 0:
|
||||
self._token = None
|
||||
continue
|
||||
resp.raise_for_status()
|
||||
return True
|
||||
raise HTTPException(502, detail="Foxel 写入失败")
|
||||
|
||||
async def write_upload_file(self, root: str, rel: str, file_obj, filename: str | None, file_size: int | None = None, content_type: str | None = None):
|
||||
rel = (rel or "").lstrip("/")
|
||||
full_path = _join_fs_path(root, rel)
|
||||
url = self.base_url + self._file_path(full_path)
|
||||
name = filename or Path(rel).name or "file"
|
||||
mime = content_type or "application/octet-stream"
|
||||
for attempt in range(2):
|
||||
try:
|
||||
if callable(getattr(file_obj, "seek", None)):
|
||||
file_obj.seek(0)
|
||||
except Exception:
|
||||
pass
|
||||
token = await self._ensure_token()
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
files = {"file": (name, file_obj, mime)}
|
||||
async with httpx.AsyncClient(timeout=self.timeout, follow_redirects=True) as client:
|
||||
resp = await client.post(url, headers=headers, files=files)
|
||||
if resp.status_code == 401 and attempt == 0:
|
||||
self._token = None
|
||||
continue
|
||||
resp.raise_for_status()
|
||||
return {"size": file_size or 0}
|
||||
raise HTTPException(502, detail="Foxel 上传失败")
|
||||
|
||||
async def write_file_stream(self, root: str, rel: str, data_iter: AsyncIterator[bytes]):
|
||||
rel = (rel or "").lstrip("/")
|
||||
full_path = _join_fs_path(root, rel)
|
||||
suffix = Path(rel).suffix
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tf:
|
||||
tmp_path = Path(tf.name)
|
||||
|
||||
size = 0
|
||||
try:
|
||||
with tmp_path.open("wb") as f:
|
||||
async for chunk in data_iter:
|
||||
if not chunk:
|
||||
continue
|
||||
f.write(chunk)
|
||||
size += len(chunk)
|
||||
await self._upload_file_path(full_path, tmp_path)
|
||||
return size
|
||||
finally:
|
||||
try:
|
||||
tmp_path.unlink(missing_ok=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def mkdir(self, root: str, rel: str):
|
||||
rel = (rel or "").strip("/")
|
||||
full_path = _join_fs_path(root, rel)
|
||||
payload = await self._request_json("POST", "/api/fs/mkdir", json={"path": full_path})
|
||||
_unwrap_success(payload, context="mkdir")
|
||||
return True
|
||||
|
||||
async def delete(self, root: str, rel: str):
|
||||
rel = (rel or "").strip("/")
|
||||
full_path = _join_fs_path(root, rel)
|
||||
url = self.base_url + self._browse_path(full_path)
|
||||
for attempt in range(2):
|
||||
token = await self._ensure_token()
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
async with httpx.AsyncClient(timeout=self.timeout, follow_redirects=True) as client:
|
||||
resp = await client.delete(url, headers=headers)
|
||||
if resp.status_code == 401 and attempt == 0:
|
||||
self._token = None
|
||||
continue
|
||||
if resp.status_code == 404:
|
||||
return
|
||||
resp.raise_for_status()
|
||||
return
|
||||
raise HTTPException(502, detail="Foxel 删除失败")
|
||||
|
||||
async def move(self, root: str, src_rel: str, dst_rel: str):
|
||||
src_path = _join_fs_path(root, (src_rel or "").lstrip("/"))
|
||||
dst_path = _join_fs_path(root, (dst_rel or "").lstrip("/"))
|
||||
payload = await self._request_json("POST", "/api/fs/move", json={"src": src_path, "dst": dst_path})
|
||||
_unwrap_success(payload, context="move")
|
||||
return True
|
||||
|
||||
async def rename(self, root: str, src_rel: str, dst_rel: str):
|
||||
src_path = _join_fs_path(root, (src_rel or "").lstrip("/"))
|
||||
dst_path = _join_fs_path(root, (dst_rel or "").lstrip("/"))
|
||||
payload = await self._request_json("POST", "/api/fs/rename", json={"src": src_path, "dst": dst_path})
|
||||
_unwrap_success(payload, context="rename")
|
||||
return True
|
||||
|
||||
async def copy(self, root: str, src_rel: str, dst_rel: str, overwrite: bool = False):
|
||||
src_path = _join_fs_path(root, (src_rel or "").lstrip("/"))
|
||||
dst_path = _join_fs_path(root, (dst_rel or "").lstrip("/"))
|
||||
payload = await self._request_json(
|
||||
"POST",
|
||||
"/api/fs/copy",
|
||||
json={"src": src_path, "dst": dst_path},
|
||||
params={"overwrite": overwrite},
|
||||
)
|
||||
_unwrap_success(payload, context="copy")
|
||||
return True
|
||||
|
||||
async def stream_file(self, root: str, rel: str, range_header: str | None):
|
||||
rel = (rel or "").lstrip("/")
|
||||
full_path = _join_fs_path(root, rel)
|
||||
url = self.base_url + self._stream_path(full_path)
|
||||
|
||||
headers = {}
|
||||
if range_header:
|
||||
headers["Range"] = range_header
|
||||
|
||||
for attempt in range(2):
|
||||
token = await self._ensure_token()
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
client = httpx.AsyncClient(timeout=None, follow_redirects=True)
|
||||
stream_cm = client.stream("GET", url, headers=headers)
|
||||
try:
|
||||
resp = await stream_cm.__aenter__()
|
||||
except Exception:
|
||||
await client.aclose()
|
||||
raise
|
||||
|
||||
if resp.status_code == 401 and attempt == 0:
|
||||
try:
|
||||
await resp.aread()
|
||||
finally:
|
||||
await stream_cm.__aexit__(None, None, None)
|
||||
await client.aclose()
|
||||
self._token = None
|
||||
continue
|
||||
|
||||
if resp.status_code == 404:
|
||||
try:
|
||||
await resp.aread()
|
||||
finally:
|
||||
await stream_cm.__aexit__(None, None, None)
|
||||
await client.aclose()
|
||||
raise FileNotFoundError(rel)
|
||||
|
||||
if resp.status_code >= 400:
|
||||
try:
|
||||
await resp.aread()
|
||||
finally:
|
||||
await stream_cm.__aexit__(None, None, None)
|
||||
await client.aclose()
|
||||
resp.raise_for_status()
|
||||
|
||||
content_type = resp.headers.get("Content-Type") or (
|
||||
mimetypes.guess_type(rel)[0] or "application/octet-stream"
|
||||
)
|
||||
out_headers = {}
|
||||
for key in ("Accept-Ranges", "Content-Range", "Content-Length"):
|
||||
value = resp.headers.get(key)
|
||||
if value:
|
||||
out_headers[key] = value
|
||||
|
||||
async def iterator():
|
||||
try:
|
||||
async for chunk in resp.aiter_bytes():
|
||||
if chunk:
|
||||
yield chunk
|
||||
finally:
|
||||
await stream_cm.__aexit__(None, None, None)
|
||||
await client.aclose()
|
||||
|
||||
return StreamingResponse(
|
||||
iterator(),
|
||||
status_code=resp.status_code,
|
||||
headers=out_headers,
|
||||
media_type=content_type,
|
||||
)
|
||||
|
||||
raise HTTPException(502, detail="Foxel 流式读取失败")
|
||||
|
||||
|
||||
ADAPTER_TYPE = "foxel"
|
||||
CONFIG_SCHEMA = [
|
||||
{"key": "base_url", "label": "节点地址", "type": "string", "required": True, "placeholder": "http://127.0.0.1:8000"},
|
||||
{"key": "username", "label": "用户名", "type": "string", "required": True},
|
||||
{"key": "password", "label": "密码", "type": "password", "required": True},
|
||||
{"key": "root", "label": "远端根目录", "type": "string", "required": False, "default": "/", "placeholder": "/ 或 /drive"},
|
||||
{"key": "timeout", "label": "超时(秒)", "type": "number", "required": False, "default": 60},
|
||||
]
|
||||
|
||||
|
||||
def ADAPTER_FACTORY(rec: StorageAdapter):
|
||||
return FoxelAdapter(rec)
|
||||
645
domain/adapters/providers/ftp.py
Normal file
645
domain/adapters/providers/ftp.py
Normal file
@@ -0,0 +1,645 @@
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Dict, Tuple, AsyncIterator, Optional
|
||||
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from ftplib import FTP, error_perm
|
||||
import mimetypes
|
||||
|
||||
from models import StorageAdapter
|
||||
|
||||
|
||||
def _join_remote(root: str, rel: str) -> str:
|
||||
root = (root or "/").rstrip("/") or "/"
|
||||
rel = (rel or "").lstrip("/")
|
||||
if not rel:
|
||||
return root
|
||||
return f"{root}/{rel}"
|
||||
|
||||
|
||||
def _parse_mlst_line(line: str) -> Dict[str, str]:
|
||||
out: Dict[str, str] = {}
|
||||
try:
|
||||
facts, _, name = line.partition(" ")
|
||||
for part in facts.split(";"):
|
||||
if not part or "=" not in part:
|
||||
continue
|
||||
k, v = part.split("=", 1)
|
||||
out[k.strip().lower()] = v.strip()
|
||||
if name:
|
||||
out["name"] = name.strip()
|
||||
except Exception:
|
||||
pass
|
||||
return out
|
||||
|
||||
|
||||
def _parse_modify_to_epoch(mod: str) -> int:
|
||||
# Formats we may see: YYYYMMDDHHMMSS or YYYYMMDDHHMMSS(.sss)
|
||||
try:
|
||||
mod = mod.strip()
|
||||
mod = mod.split(".")[0]
|
||||
if len(mod) >= 14:
|
||||
y = int(mod[0:4])
|
||||
m = int(mod[4:6])
|
||||
d = int(mod[6:8])
|
||||
hh = int(mod[8:10])
|
||||
mm = int(mod[10:12])
|
||||
ss = int(mod[12:14])
|
||||
import datetime as _dt
|
||||
return int(_dt.datetime(y, m, d, hh, mm, ss, tzinfo=_dt.timezone.utc).timestamp())
|
||||
except Exception:
|
||||
return 0
|
||||
return 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class _Range:
|
||||
start: int
|
||||
end: Optional[int] # inclusive
|
||||
|
||||
|
||||
class FTPAdapter:
|
||||
def __init__(self, record: StorageAdapter):
|
||||
self.record = record
|
||||
cfg = record.config
|
||||
self.host: str = cfg.get("host")
|
||||
self.port: int = int(cfg.get("port", 21))
|
||||
self.username: Optional[str] = cfg.get("username")
|
||||
self.password: Optional[str] = cfg.get("password")
|
||||
self.passive: bool = bool(cfg.get("passive", True))
|
||||
self.timeout: int = int(cfg.get("timeout", 15))
|
||||
self.root_path: str = cfg.get("root", "/") or "/"
|
||||
|
||||
if not self.host:
|
||||
raise ValueError("FTP adapter requires 'host'")
|
||||
|
||||
def get_effective_root(self, sub_path: str | None) -> str:
|
||||
base = self.root_path.rstrip("/") or "/"
|
||||
if sub_path:
|
||||
return _join_remote(base, sub_path)
|
||||
return base
|
||||
|
||||
def _connect(self) -> FTP:
|
||||
ftp = FTP()
|
||||
ftp.connect(self.host, self.port, timeout=self.timeout)
|
||||
if self.username:
|
||||
ftp.login(self.username, self.password or "")
|
||||
else:
|
||||
ftp.login()
|
||||
ftp.set_pasv(self.passive)
|
||||
return ftp
|
||||
|
||||
async def list_dir(self, root: str, rel: str, page_num: int = 1, page_size: int = 50, sort_by: str = "name", sort_order: str = "asc") -> Tuple[List[Dict], int]:
|
||||
path = _join_remote(root, rel.strip('/'))
|
||||
|
||||
def _do_list() -> List[Dict]:
|
||||
ftp = self._connect()
|
||||
try:
|
||||
ftp.cwd(path)
|
||||
except error_perm as e:
|
||||
# path may be file
|
||||
ftp.quit()
|
||||
raise NotADirectoryError(rel) from e
|
||||
|
||||
entries: List[Dict] = []
|
||||
# Try MLSD first
|
||||
try:
|
||||
for name, facts in ftp.mlsd():
|
||||
if name in (".", ".."):
|
||||
continue
|
||||
is_dir = (facts.get("type") == "dir")
|
||||
size = int(facts.get("size") or 0)
|
||||
mtime = _parse_modify_to_epoch(facts.get("modify") or "")
|
||||
entries.append({
|
||||
"name": name,
|
||||
"is_dir": is_dir,
|
||||
"size": 0 if is_dir else size,
|
||||
"mtime": mtime,
|
||||
"type": "dir" if is_dir else "file",
|
||||
})
|
||||
ftp.quit()
|
||||
return entries
|
||||
except Exception:
|
||||
# Fallback to NLST + probing
|
||||
pass
|
||||
|
||||
names = []
|
||||
try:
|
||||
names = ftp.nlst()
|
||||
except Exception:
|
||||
ftp.quit()
|
||||
return []
|
||||
|
||||
for name in names:
|
||||
if name in (".", ".."):
|
||||
continue
|
||||
is_dir = False
|
||||
size = 0
|
||||
mtime = 0
|
||||
try:
|
||||
# If we can CWD, it's a directory
|
||||
ftp.cwd(_join_remote(path, name))
|
||||
ftp.cwd(path)
|
||||
is_dir = True
|
||||
except Exception:
|
||||
is_dir = False
|
||||
try:
|
||||
size = ftp.size(_join_remote(path, name)) or 0
|
||||
except Exception:
|
||||
size = 0
|
||||
try:
|
||||
mdtm = ftp.sendcmd("MDTM " + _join_remote(path, name))
|
||||
# Example: '213 20241012XXXXXX'
|
||||
if mdtm.startswith("213 "):
|
||||
mtime = _parse_modify_to_epoch(mdtm.split(" ", 1)[1])
|
||||
except Exception:
|
||||
pass
|
||||
entries.append({
|
||||
"name": name,
|
||||
"is_dir": is_dir,
|
||||
"size": 0 if is_dir else int(size or 0),
|
||||
"mtime": int(mtime or 0),
|
||||
"type": "dir" if is_dir else "file",
|
||||
})
|
||||
ftp.quit()
|
||||
return entries
|
||||
|
||||
entries = await asyncio.to_thread(_do_list)
|
||||
|
||||
reverse = sort_order.lower() == "desc"
|
||||
|
||||
def get_sort_key(item):
|
||||
key = (not item["is_dir"],)
|
||||
f = sort_by.lower()
|
||||
if f == "name":
|
||||
key += (item["name"].lower(),)
|
||||
elif f == "size":
|
||||
key += (item.get("size", 0),)
|
||||
elif f == "mtime":
|
||||
key += (item.get("mtime", 0),)
|
||||
else:
|
||||
key += (item["name"].lower(),)
|
||||
return key
|
||||
|
||||
entries.sort(key=get_sort_key, reverse=reverse)
|
||||
total = len(entries)
|
||||
start = (page_num - 1) * page_size
|
||||
end = start + page_size
|
||||
return entries[start:end], total
|
||||
|
||||
async def read_file(self, root: str, rel: str) -> bytes:
|
||||
path = _join_remote(root, rel)
|
||||
|
||||
def _do_read() -> bytes:
|
||||
ftp = self._connect()
|
||||
try:
|
||||
chunks: List[bytes] = []
|
||||
ftp.retrbinary("RETR " + path, lambda b: chunks.append(b))
|
||||
return b"".join(chunks)
|
||||
except error_perm as e:
|
||||
if str(e).startswith("550"):
|
||||
raise FileNotFoundError(rel)
|
||||
raise
|
||||
finally:
|
||||
try:
|
||||
ftp.quit()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return await asyncio.to_thread(_do_read)
|
||||
|
||||
async def write_file(self, root: str, rel: str, data: bytes):
|
||||
path = _join_remote(root, rel)
|
||||
|
||||
def _ensure_dirs(ftp: FTP, dir_path: str):
|
||||
parts = [p for p in dir_path.strip("/").split("/") if p]
|
||||
cur = "/"
|
||||
for p in parts:
|
||||
cur = _join_remote(cur, p)
|
||||
try:
|
||||
ftp.mkd(cur)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _do_write():
|
||||
ftp = self._connect()
|
||||
try:
|
||||
parent = "/" if "/" not in path.strip("/") else path.rsplit("/", 1)[0]
|
||||
_ensure_dirs(ftp, parent)
|
||||
from io import BytesIO
|
||||
bio = BytesIO(data)
|
||||
ftp.storbinary("STOR " + path, bio)
|
||||
finally:
|
||||
try:
|
||||
ftp.quit()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await asyncio.to_thread(_do_write)
|
||||
|
||||
async def write_upload_file(self, root: str, rel: str, file_obj, filename: str | None, file_size: int | None = None, content_type: str | None = None):
|
||||
path = _join_remote(root, rel)
|
||||
|
||||
def _ensure_dirs(ftp: FTP, dir_path: str):
|
||||
parts = [p for p in dir_path.strip("/").split("/") if p]
|
||||
cur = "/"
|
||||
for p in parts:
|
||||
cur = _join_remote(cur, p)
|
||||
try:
|
||||
ftp.mkd(cur)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _do_upload():
|
||||
ftp = self._connect()
|
||||
try:
|
||||
parent = "/" if "/" not in path.strip("/") else path.rsplit("/", 1)[0]
|
||||
_ensure_dirs(ftp, parent)
|
||||
try:
|
||||
if callable(getattr(file_obj, "seek", None)):
|
||||
file_obj.seek(0)
|
||||
except Exception:
|
||||
pass
|
||||
ftp.storbinary("STOR " + path, file_obj)
|
||||
finally:
|
||||
try:
|
||||
ftp.quit()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await asyncio.to_thread(_do_upload)
|
||||
return {"size": file_size or 0}
|
||||
|
||||
async def write_file_stream(self, root: str, rel: str, data_iter: AsyncIterator[bytes]):
|
||||
# KISS: 聚合后一次性写入
|
||||
buf = bytearray()
|
||||
async for chunk in data_iter:
|
||||
if chunk:
|
||||
buf.extend(chunk)
|
||||
await self.write_file(root, rel, bytes(buf))
|
||||
return len(buf)
|
||||
|
||||
async def mkdir(self, root: str, rel: str):
|
||||
path = _join_remote(root, rel)
|
||||
|
||||
def _do_mkdir():
|
||||
ftp = self._connect()
|
||||
try:
|
||||
parts = [p for p in path.strip("/").split("/") if p]
|
||||
cur = "/"
|
||||
for p in parts:
|
||||
cur = _join_remote(cur, p)
|
||||
try:
|
||||
ftp.mkd(cur)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
try:
|
||||
ftp.quit()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await asyncio.to_thread(_do_mkdir)
|
||||
|
||||
async def delete(self, root: str, rel: str):
|
||||
path = _join_remote(root, rel)
|
||||
|
||||
def _do_delete():
|
||||
ftp = self._connect()
|
||||
try:
|
||||
# Try file delete
|
||||
try:
|
||||
ftp.delete(path)
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Recursively delete dir
|
||||
def _rm_tree(dir_path: str):
|
||||
try:
|
||||
ftp.cwd(dir_path)
|
||||
except Exception:
|
||||
return
|
||||
items = []
|
||||
try:
|
||||
for name, facts in ftp.mlsd():
|
||||
if name in (".", ".."):
|
||||
continue
|
||||
items.append((name, facts.get("type") == "dir"))
|
||||
except Exception:
|
||||
try:
|
||||
names = ftp.nlst()
|
||||
except Exception:
|
||||
names = []
|
||||
for n in names:
|
||||
if n in (".", ".."):
|
||||
continue
|
||||
# Best-effort dir check
|
||||
try:
|
||||
ftp.cwd(_join_remote(dir_path, n))
|
||||
ftp.cwd(dir_path)
|
||||
items.append((n, True))
|
||||
except Exception:
|
||||
items.append((n, False))
|
||||
for n, is_dir in items:
|
||||
child = _join_remote(dir_path, n)
|
||||
if is_dir:
|
||||
_rm_tree(child)
|
||||
else:
|
||||
try:
|
||||
ftp.delete(child)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
ftp.rmd(dir_path)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
_rm_tree(path)
|
||||
finally:
|
||||
try:
|
||||
ftp.quit()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await asyncio.to_thread(_do_delete)
|
||||
|
||||
async def move(self, root: str, src_rel: str, dst_rel: str):
|
||||
src = _join_remote(root, src_rel)
|
||||
dst = _join_remote(root, dst_rel)
|
||||
|
||||
def _do_move():
|
||||
ftp = self._connect()
|
||||
try:
|
||||
# Ensure dst parent exists
|
||||
parent = "/" if "/" not in dst.strip("/") else dst.rsplit("/", 1)[0]
|
||||
parts = [p for p in parent.strip("/").split("/") if p]
|
||||
cur = "/"
|
||||
for p in parts:
|
||||
cur = _join_remote(cur, p)
|
||||
try:
|
||||
ftp.mkd(cur)
|
||||
except Exception:
|
||||
pass
|
||||
ftp.rename(src, dst)
|
||||
finally:
|
||||
try:
|
||||
ftp.quit()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await asyncio.to_thread(_do_move)
|
||||
|
||||
async def rename(self, root: str, src_rel: str, dst_rel: str):
|
||||
await self.move(root, src_rel, dst_rel)
|
||||
|
||||
async def copy(self, root: str, src_rel: str, dst_rel: str, overwrite: bool = False):
|
||||
src = _join_remote(root, src_rel)
|
||||
dst = _join_remote(root, dst_rel)
|
||||
|
||||
# naive implementation: download then upload; recursively for dirs
|
||||
async def _is_dir(path: str) -> bool:
|
||||
def _probe() -> bool:
|
||||
ftp = self._connect()
|
||||
try:
|
||||
try:
|
||||
ftp.cwd(path)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
finally:
|
||||
try:
|
||||
ftp.quit()
|
||||
except Exception:
|
||||
pass
|
||||
return await asyncio.to_thread(_probe)
|
||||
|
||||
if await _is_dir(src):
|
||||
# list children, create dst dir, copy recursively
|
||||
await self.mkdir(root, dst_rel)
|
||||
|
||||
children, _ = await self.list_dir(root, src_rel, page_num=1, page_size=10_000)
|
||||
for ent in children:
|
||||
child_src = f"{src_rel.rstrip('/')}/{ent['name']}"
|
||||
child_dst = f"{dst_rel.rstrip('/')}/{ent['name']}"
|
||||
await self.copy(root, child_src, child_dst, overwrite)
|
||||
return
|
||||
|
||||
# file
|
||||
data = await self.read_file(root, src_rel)
|
||||
if not overwrite:
|
||||
# best-effort existence check
|
||||
try:
|
||||
await self.stat_file(root, dst_rel)
|
||||
raise FileExistsError(dst_rel)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
await self.write_file(root, dst_rel, data)
|
||||
|
||||
async def stat_file(self, root: str, rel: str):
|
||||
path = _join_remote(root, rel)
|
||||
|
||||
def _do_stat():
|
||||
ftp = self._connect()
|
||||
try:
|
||||
# Try MLST
|
||||
try:
|
||||
resp: List[str] = []
|
||||
ftp.retrlines("MLST " + path, resp.append)
|
||||
# The last line usually contains facts
|
||||
facts = {}
|
||||
if resp:
|
||||
facts = _parse_mlst_line(resp[-1])
|
||||
name = rel.split("/")[-1]
|
||||
t = facts.get("type") or "file"
|
||||
is_dir = t == "dir"
|
||||
size = int(facts.get("size") or 0)
|
||||
mtime = _parse_modify_to_epoch(facts.get("modify") or "")
|
||||
return {
|
||||
"name": name,
|
||||
"is_dir": is_dir,
|
||||
"size": 0 if is_dir else size,
|
||||
"mtime": mtime,
|
||||
"type": "dir" if is_dir else "file",
|
||||
"path": path,
|
||||
}
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Probe directory
|
||||
try:
|
||||
ftp.cwd(path)
|
||||
return {
|
||||
"name": rel.split("/")[-1],
|
||||
"is_dir": True,
|
||||
"size": 0,
|
||||
"mtime": 0,
|
||||
"type": "dir",
|
||||
"path": path,
|
||||
}
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Treat as file
|
||||
try:
|
||||
size = ftp.size(path) or 0
|
||||
except Exception:
|
||||
size = 0
|
||||
try:
|
||||
mdtm = ftp.sendcmd("MDTM " + path)
|
||||
mtime = _parse_modify_to_epoch(mdtm.split(" ", 1)[1]) if mdtm.startswith("213 ") else 0
|
||||
except Exception:
|
||||
mtime = 0
|
||||
return {
|
||||
"name": rel.split("/")[-1],
|
||||
"is_dir": False,
|
||||
"size": int(size or 0),
|
||||
"mtime": int(mtime or 0),
|
||||
"type": "file",
|
||||
"path": path,
|
||||
}
|
||||
except error_perm as e:
|
||||
if str(e).startswith("550"):
|
||||
raise FileNotFoundError(rel)
|
||||
raise
|
||||
finally:
|
||||
try:
|
||||
ftp.quit()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return await asyncio.to_thread(_do_stat)
|
||||
|
||||
async def stream_file(self, root: str, rel: str, range_header: str | None):
|
||||
path = _join_remote(root, rel)
|
||||
# Get size (best-effort)
|
||||
def _get_size() -> Optional[int]:
|
||||
ftp = self._connect()
|
||||
try:
|
||||
try:
|
||||
return int(ftp.size(path) or 0)
|
||||
except Exception:
|
||||
return None
|
||||
finally:
|
||||
try:
|
||||
ftp.quit()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
total_size = await asyncio.to_thread(_get_size)
|
||||
mime, _ = mimetypes.guess_type(rel)
|
||||
content_type = mime or "application/octet-stream"
|
||||
|
||||
rng: Optional[_Range] = None
|
||||
status = 200
|
||||
headers = {"Accept-Ranges": "bytes", "Content-Type": content_type}
|
||||
if range_header and range_header.startswith("bytes=") and total_size is not None:
|
||||
try:
|
||||
s, e = (range_header.removeprefix("bytes=").split("-", 1))
|
||||
start = int(s) if s.strip() else 0
|
||||
end = int(e) if e.strip() else (total_size - 1)
|
||||
if start >= total_size:
|
||||
raise HTTPException(416, detail="Requested Range Not Satisfiable")
|
||||
if end >= total_size:
|
||||
end = total_size - 1
|
||||
rng = _Range(start, end)
|
||||
status = 206
|
||||
headers["Content-Range"] = f"bytes {start}-{end}/{total_size}"
|
||||
headers["Content-Length"] = str(end - start + 1)
|
||||
except ValueError:
|
||||
raise HTTPException(400, detail="Invalid Range header")
|
||||
elif total_size is not None:
|
||||
headers["Content-Length"] = str(total_size)
|
||||
|
||||
queue: asyncio.Queue[Optional[bytes]] = asyncio.Queue(maxsize=8)
|
||||
|
||||
class _Stop(Exception):
|
||||
pass
|
||||
|
||||
def _worker():
|
||||
ftp = self._connect()
|
||||
remaining = None
|
||||
if rng is not None:
|
||||
remaining = (rng.end - rng.start + 1) if rng.end is not None else None
|
||||
|
||||
def _cb(chunk: bytes):
|
||||
nonlocal remaining
|
||||
if not chunk:
|
||||
return
|
||||
try:
|
||||
if remaining is not None:
|
||||
if len(chunk) > remaining:
|
||||
part = chunk[:remaining]
|
||||
queue.put_nowait(part)
|
||||
remaining = 0
|
||||
raise _Stop()
|
||||
else:
|
||||
queue.put_nowait(chunk)
|
||||
remaining -= len(chunk)
|
||||
if remaining <= 0:
|
||||
raise _Stop()
|
||||
else:
|
||||
queue.put_nowait(chunk)
|
||||
except _Stop:
|
||||
raise
|
||||
except Exception:
|
||||
# queue full or event loop closed
|
||||
raise _Stop()
|
||||
|
||||
try:
|
||||
if rng is not None:
|
||||
ftp.retrbinary("RETR " + path, _cb, rest=rng.start)
|
||||
else:
|
||||
ftp.retrbinary("RETR " + path, _cb)
|
||||
queue.put_nowait(None)
|
||||
except _Stop:
|
||||
try:
|
||||
queue.put_nowait(None)
|
||||
except Exception:
|
||||
pass
|
||||
except error_perm as e:
|
||||
try:
|
||||
queue.put_nowait(None)
|
||||
except Exception:
|
||||
pass
|
||||
if str(e).startswith("550"):
|
||||
pass
|
||||
finally:
|
||||
try:
|
||||
ftp.quit()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def agen():
|
||||
worker_fut = asyncio.to_thread(_worker)
|
||||
try:
|
||||
while True:
|
||||
chunk = await queue.get()
|
||||
if chunk is None:
|
||||
break
|
||||
yield chunk
|
||||
finally:
|
||||
try:
|
||||
await worker_fut
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return StreamingResponse(agen(), status_code=status, headers=headers, media_type=content_type)
|
||||
|
||||
|
||||
ADAPTER_TYPE = "ftp"
|
||||
|
||||
CONFIG_SCHEMA = [
|
||||
{"key": "host", "label": "主机", "type": "string", "required": True, "placeholder": "ftp.example.com"},
|
||||
{"key": "port", "label": "端口", "type": "number", "required": False, "default": 21},
|
||||
{"key": "username", "label": "用户名", "type": "string", "required": False},
|
||||
{"key": "password", "label": "密码", "type": "password", "required": False},
|
||||
{"key": "passive", "label": "被动模式", "type": "boolean", "required": False, "default": True},
|
||||
{"key": "timeout", "label": "超时(秒)", "type": "number", "required": False, "default": 15},
|
||||
{"key": "root", "label": "根路径", "type": "string", "required": False, "default": "/"},
|
||||
]
|
||||
|
||||
|
||||
def ADAPTER_FACTORY(rec: StorageAdapter):
|
||||
return FTPAdapter(rec)
|
||||
559
domain/adapters/providers/googledrive.py
Normal file
559
domain/adapters/providers/googledrive.py
Normal file
@@ -0,0 +1,559 @@
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import List, Dict, Tuple, AsyncIterator
|
||||
import httpx
|
||||
from fastapi.responses import StreamingResponse, Response
|
||||
from fastapi import HTTPException
|
||||
from models import StorageAdapter
|
||||
|
||||
GOOGLE_OAUTH_URL = "https://oauth2.googleapis.com/token"
|
||||
GOOGLE_DRIVE_API_URL = "https://www.googleapis.com/drive/v3"
|
||||
|
||||
|
||||
class GoogleDriveAdapter:
|
||||
"""Google Drive 存储适配器"""
|
||||
|
||||
def __init__(self, record: StorageAdapter):
|
||||
self.record = record
|
||||
cfg = record.config
|
||||
self.client_id = cfg.get("client_id")
|
||||
self.client_secret = cfg.get("client_secret")
|
||||
self.refresh_token = cfg.get("refresh_token")
|
||||
self.root_folder_id = cfg.get("root_folder_id", "root")
|
||||
self.enable_redirect_307 = bool(cfg.get("enable_direct_download_307"))
|
||||
|
||||
if not all([self.client_id, self.client_secret, self.refresh_token]):
|
||||
raise ValueError(
|
||||
"Google Drive 适配器需要 client_id, client_secret, 和 refresh_token")
|
||||
|
||||
self._access_token: str | None = None
|
||||
self._token_expiry: datetime | None = None
|
||||
|
||||
def get_effective_root(self, sub_path: str | None) -> str:
|
||||
"""
|
||||
获取有效根路径。
|
||||
:param sub_path: 子路径。
|
||||
:return: 完整的有效路径。
|
||||
"""
|
||||
if sub_path:
|
||||
return f"{sub_path.strip('/')}".strip()
|
||||
return ""
|
||||
|
||||
async def _get_access_token(self) -> str:
|
||||
"""
|
||||
获取或刷新 access token。
|
||||
:return: access token。
|
||||
"""
|
||||
if self._access_token and self._token_expiry and datetime.now(timezone.utc) < self._token_expiry:
|
||||
return self._access_token
|
||||
|
||||
data = {
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
"refresh_token": self.refresh_token,
|
||||
"grant_type": "refresh_token",
|
||||
}
|
||||
async with httpx.AsyncClient(timeout=20.0) as client:
|
||||
resp = await client.post(GOOGLE_OAUTH_URL, data=data)
|
||||
resp.raise_for_status()
|
||||
token_data = resp.json()
|
||||
self._access_token = token_data["access_token"]
|
||||
self._token_expiry = datetime.now(
|
||||
timezone.utc) + timedelta(seconds=token_data["expires_in"] - 300)
|
||||
return self._access_token
|
||||
|
||||
async def _request(self, method: str, endpoint: str, **kwargs):
|
||||
"""
|
||||
向 Google Drive API 发送请求。
|
||||
:param method: HTTP 方法。
|
||||
:param endpoint: API 端点。
|
||||
:param kwargs: 其他请求参数。
|
||||
:return: 响应对象。
|
||||
"""
|
||||
token = await self._get_access_token()
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
if "headers" in kwargs:
|
||||
headers.update(kwargs.pop("headers"))
|
||||
|
||||
url = f"{GOOGLE_DRIVE_API_URL}{endpoint}"
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
resp = await client.request(method, url, headers=headers, **kwargs)
|
||||
if resp.status_code == 401:
|
||||
self._access_token = None
|
||||
token = await self._get_access_token()
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
resp = await client.request(method, url, headers=headers, **kwargs)
|
||||
return resp
|
||||
|
||||
async def _get_folder_id_by_path(self, path: str) -> str:
|
||||
"""
|
||||
通过路径获取文件夹 ID。
|
||||
:param path: 路径。
|
||||
:return: 文件夹 ID。
|
||||
"""
|
||||
if not path or path == "/":
|
||||
return self.root_folder_id
|
||||
|
||||
parts = [p for p in path.strip("/").split("/") if p]
|
||||
current_id = self.root_folder_id
|
||||
|
||||
for part in parts:
|
||||
query = f"name='{part}' and '{current_id}' in parents and mimeType='application/vnd.google-apps.folder' and trashed=false"
|
||||
params = {"q": query, "fields": "files(id, name)"}
|
||||
resp = await self._request("GET", "/files", params=params)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
files = data.get("files", [])
|
||||
if not files:
|
||||
raise FileNotFoundError(f"文件夹不存在: {part}")
|
||||
current_id = files[0]["id"]
|
||||
|
||||
return current_id
|
||||
|
||||
async def _get_file_id_by_path(self, path: str) -> str | None:
|
||||
"""
|
||||
通过路径获取文件 ID。
|
||||
:param path: 文件路径。
|
||||
:return: 文件 ID 或 None。
|
||||
"""
|
||||
if not path or path == "/":
|
||||
return self.root_folder_id
|
||||
|
||||
parts = [p for p in path.strip("/").split("/") if p]
|
||||
parent_id = self.root_folder_id
|
||||
|
||||
for i, part in enumerate(parts):
|
||||
is_last = i == len(parts) - 1
|
||||
mime_filter = "" if is_last else "and mimeType='application/vnd.google-apps.folder'"
|
||||
query = f"name='{part}' and '{parent_id}' in parents {mime_filter} and trashed=false"
|
||||
params = {"q": query, "fields": "files(id, name)"}
|
||||
resp = await self._request("GET", "/files", params=params)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
files = data.get("files", [])
|
||||
if not files:
|
||||
return None
|
||||
parent_id = files[0]["id"]
|
||||
|
||||
return parent_id
|
||||
|
||||
def _format_item(self, item: Dict) -> Dict:
|
||||
"""
|
||||
将 Google Drive API 返回的 item 格式化为统一的格式。
|
||||
:param item: Google Drive API 返回的 item 字典。
|
||||
:return: 格式化后的字典。
|
||||
"""
|
||||
is_dir = item["mimeType"] == "application/vnd.google-apps.folder"
|
||||
mtime_str = item.get("modifiedTime", item.get("createdTime", ""))
|
||||
try:
|
||||
mtime = int(datetime.fromisoformat(mtime_str.replace("Z", "+00:00")).timestamp())
|
||||
except:
|
||||
mtime = 0
|
||||
|
||||
return {
|
||||
"name": item["name"],
|
||||
"is_dir": is_dir,
|
||||
"size": 0 if is_dir else int(item.get("size", 0)),
|
||||
"mtime": mtime,
|
||||
"type": "dir" if is_dir else "file",
|
||||
}
|
||||
|
||||
async def list_dir(self, root: str, rel: str, page_num: int = 1, page_size: int = 50, sort_by: str = "name", sort_order: str = "asc") -> Tuple[List[Dict], int]:
|
||||
"""
|
||||
列出目录内容。
|
||||
:param root: 根路径。
|
||||
:param rel: 相对路径。
|
||||
:param page_num: 页码。
|
||||
:param page_size: 每页大小。
|
||||
:param sort_by: 排序字段
|
||||
:param sort_order: 排序顺序
|
||||
:return: 文件/目录列表和总数。
|
||||
"""
|
||||
try:
|
||||
folder_id = await self._get_folder_id_by_path(rel)
|
||||
except FileNotFoundError:
|
||||
return [], 0
|
||||
|
||||
query = f"'{folder_id}' in parents and trashed=false"
|
||||
params = {
|
||||
"q": query,
|
||||
"fields": "files(id, name, mimeType, size, modifiedTime, createdTime)",
|
||||
"pageSize": 1000,
|
||||
}
|
||||
|
||||
all_items = []
|
||||
page_token = None
|
||||
|
||||
while True:
|
||||
if page_token:
|
||||
params["pageToken"] = page_token
|
||||
|
||||
resp = await self._request("GET", "/files", params=params)
|
||||
if resp.status_code == 404:
|
||||
return [], 0
|
||||
resp.raise_for_status()
|
||||
|
||||
data = resp.json()
|
||||
all_items.extend(data.get("files", []))
|
||||
page_token = data.get("nextPageToken")
|
||||
|
||||
if not page_token:
|
||||
break
|
||||
|
||||
formatted_items = [self._format_item(item) for item in all_items]
|
||||
|
||||
# 排序
|
||||
reverse = sort_order.lower() == "desc"
|
||||
def get_sort_key(item):
|
||||
key = (not item["is_dir"],)
|
||||
sort_field = sort_by.lower()
|
||||
if sort_field == "name":
|
||||
key += (item["name"].lower(),)
|
||||
elif sort_field == "size":
|
||||
key += (item["size"],)
|
||||
elif sort_field == "mtime":
|
||||
key += (item["mtime"],)
|
||||
else:
|
||||
key += (item["name"].lower(),)
|
||||
return key
|
||||
formatted_items.sort(key=get_sort_key, reverse=reverse)
|
||||
|
||||
total_count = len(formatted_items)
|
||||
start_idx = (page_num - 1) * page_size
|
||||
end_idx = start_idx + page_size
|
||||
|
||||
return formatted_items[start_idx:end_idx], total_count
|
||||
|
||||
async def read_file(self, root: str, rel: str) -> bytes:
|
||||
"""
|
||||
读取文件内容。
|
||||
:param root: 根路径。
|
||||
:param rel: 相对路径。
|
||||
:return: 文件内容的字节流。
|
||||
"""
|
||||
file_id = await self._get_file_id_by_path(rel)
|
||||
if not file_id:
|
||||
raise FileNotFoundError(rel)
|
||||
|
||||
resp = await self._request("GET", f"/files/{file_id}", params={"alt": "media"})
|
||||
if resp.status_code == 404:
|
||||
raise FileNotFoundError(rel)
|
||||
resp.raise_for_status()
|
||||
return resp.content
|
||||
|
||||
async def write_file(self, root: str, rel: str, data: bytes):
|
||||
"""
|
||||
写入文件。
|
||||
:param root: 根路径。
|
||||
:param rel: 相对路径。
|
||||
:param data: 文件内容的字节流。
|
||||
"""
|
||||
parent_path = "/".join(rel.strip("/").split("/")[:-1])
|
||||
file_name = rel.strip("/").split("/")[-1]
|
||||
parent_id = await self._get_folder_id_by_path(parent_path)
|
||||
|
||||
# 检查文件是否已存在
|
||||
existing_id = await self._get_file_id_by_path(rel)
|
||||
|
||||
if existing_id:
|
||||
# 更新现有文件
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
token = await self._get_access_token()
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
url = f"https://www.googleapis.com/upload/drive/v3/files/{existing_id}?uploadType=media"
|
||||
resp = await client.patch(url, headers=headers, content=data)
|
||||
resp.raise_for_status()
|
||||
else:
|
||||
# 创建新文件
|
||||
metadata = {
|
||||
"name": file_name,
|
||||
"parents": [parent_id]
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
token = await self._get_access_token()
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
|
||||
# 使用 multipart 上传
|
||||
import json
|
||||
boundary = "===============boundary==============="
|
||||
headers["Content-Type"] = f"multipart/related; boundary={boundary}"
|
||||
|
||||
body = (
|
||||
f"--{boundary}\r\n"
|
||||
f"Content-Type: application/json; charset=UTF-8\r\n\r\n"
|
||||
f"{json.dumps(metadata)}\r\n"
|
||||
f"--{boundary}\r\n"
|
||||
f"Content-Type: application/octet-stream\r\n\r\n"
|
||||
).encode() + data + f"\r\n--{boundary}--".encode()
|
||||
|
||||
url = "https://www.googleapis.com/upload/drive/v3/files?uploadType=multipart"
|
||||
resp = await client.post(url, headers=headers, content=body)
|
||||
resp.raise_for_status()
|
||||
|
||||
async def write_file_stream(self, root: str, rel: str, data_iter: AsyncIterator[bytes]):
|
||||
"""
|
||||
以流式方式写入文件。
|
||||
:param root: 根路径。
|
||||
:param rel: 相对路径。
|
||||
:param data_iter: 文件内容的异步迭代器。
|
||||
:return: 文件大小。
|
||||
"""
|
||||
# 先收集所有数据
|
||||
chunks = []
|
||||
total_size = 0
|
||||
async for chunk in data_iter:
|
||||
chunks.append(chunk)
|
||||
total_size += len(chunk)
|
||||
|
||||
data = b"".join(chunks)
|
||||
await self.write_file(root, rel, data)
|
||||
return total_size
|
||||
|
||||
async def mkdir(self, root: str, rel: str):
|
||||
"""
|
||||
创建目录。
|
||||
:param root: 根路径。
|
||||
:param rel: 相对路径。
|
||||
"""
|
||||
parent_path = "/".join(rel.strip("/").split("/")[:-1])
|
||||
folder_name = rel.strip("/").split("/")[-1]
|
||||
parent_id = await self._get_folder_id_by_path(parent_path)
|
||||
|
||||
metadata = {
|
||||
"name": folder_name,
|
||||
"mimeType": "application/vnd.google-apps.folder",
|
||||
"parents": [parent_id]
|
||||
}
|
||||
|
||||
resp = await self._request("POST", "/files", json=metadata)
|
||||
resp.raise_for_status()
|
||||
|
||||
async def delete(self, root: str, rel: str):
|
||||
"""
|
||||
删除文件或目录。
|
||||
:param root: 根路径。
|
||||
:param rel: 相对路径。
|
||||
"""
|
||||
file_id = await self._get_file_id_by_path(rel)
|
||||
if not file_id:
|
||||
return
|
||||
|
||||
resp = await self._request("DELETE", f"/files/{file_id}")
|
||||
if resp.status_code not in (204, 404):
|
||||
resp.raise_for_status()
|
||||
|
||||
async def move(self, root: str, src_rel: str, dst_rel: str):
|
||||
"""
|
||||
移动或重命名文件/目录。
|
||||
:param root: 根路径。
|
||||
:param src_rel: 源相对路径。
|
||||
:param dst_rel: 目标相对路径。
|
||||
"""
|
||||
file_id = await self._get_file_id_by_path(src_rel)
|
||||
if not file_id:
|
||||
raise FileNotFoundError(src_rel)
|
||||
|
||||
# 获取当前父文件夹
|
||||
resp = await self._request("GET", f"/files/{file_id}", params={"fields": "parents"})
|
||||
resp.raise_for_status()
|
||||
current_parents = resp.json().get("parents", [])
|
||||
|
||||
# 获取目标父文件夹和新名称
|
||||
dst_parent_path = "/".join(dst_rel.strip("/").split("/")[:-1])
|
||||
dst_name = dst_rel.strip("/").split("/")[-1]
|
||||
dst_parent_id = await self._get_folder_id_by_path(dst_parent_path)
|
||||
|
||||
# 更新文件
|
||||
params = {
|
||||
"addParents": dst_parent_id,
|
||||
"removeParents": ",".join(current_parents) if current_parents else None,
|
||||
}
|
||||
metadata = {"name": dst_name}
|
||||
|
||||
resp = await self._request("PATCH", f"/files/{file_id}", params=params, json=metadata)
|
||||
resp.raise_for_status()
|
||||
|
||||
async def rename(self, root: str, src_rel: str, dst_rel: str):
|
||||
"""
|
||||
重命名文件或目录。
|
||||
"""
|
||||
await self.move(root, src_rel, dst_rel)
|
||||
|
||||
async def copy(self, root: str, src_rel: str, dst_rel: str, overwrite: bool = False):
|
||||
"""
|
||||
复制文件或目录。
|
||||
:param root: 根路径。
|
||||
:param src_rel: 源相对路径。
|
||||
:param dst_rel: 目标相对路径。
|
||||
:param overwrite: 是否覆盖。
|
||||
"""
|
||||
file_id = await self._get_file_id_by_path(src_rel)
|
||||
if not file_id:
|
||||
raise FileNotFoundError(src_rel)
|
||||
|
||||
dst_parent_path = "/".join(dst_rel.strip("/").split("/")[:-1])
|
||||
dst_name = dst_rel.strip("/").split("/")[-1]
|
||||
dst_parent_id = await self._get_folder_id_by_path(dst_parent_path)
|
||||
|
||||
metadata = {
|
||||
"name": dst_name,
|
||||
"parents": [dst_parent_id]
|
||||
}
|
||||
|
||||
resp = await self._request("POST", f"/files/{file_id}/copy", json=metadata)
|
||||
resp.raise_for_status()
|
||||
|
||||
async def stream_file(self, root: str, rel: str, range_header: str | None):
|
||||
"""
|
||||
流式传输文件(支持范围请求)。
|
||||
:param root: 根路径。
|
||||
:param rel: 相对路径。
|
||||
:param range_header: HTTP Range 头。
|
||||
:return: FastAPI StreamingResponse 对象。
|
||||
"""
|
||||
file_id = await self._get_file_id_by_path(rel)
|
||||
if not file_id:
|
||||
raise FileNotFoundError(rel)
|
||||
|
||||
# 获取文件元数据
|
||||
resp = await self._request("GET", f"/files/{file_id}", params={"fields": "name, size, mimeType"})
|
||||
if resp.status_code == 404:
|
||||
raise FileNotFoundError(rel)
|
||||
resp.raise_for_status()
|
||||
item_data = resp.json()
|
||||
|
||||
file_size = int(item_data.get("size", 0))
|
||||
content_type = item_data.get("mimeType", "application/octet-stream")
|
||||
|
||||
start = 0
|
||||
end = file_size - 1
|
||||
status = 200
|
||||
headers = {
|
||||
"Accept-Ranges": "bytes",
|
||||
"Content-Type": content_type,
|
||||
"Content-Disposition": f"inline; filename=\"{item_data.get('name')}\""
|
||||
}
|
||||
|
||||
if range_header and range_header.startswith("bytes="):
|
||||
try:
|
||||
part = range_header.removeprefix("bytes=")
|
||||
s, e = part.split("-", 1)
|
||||
if s.strip():
|
||||
start = int(s)
|
||||
if e.strip():
|
||||
end = int(e)
|
||||
if start >= file_size:
|
||||
raise HTTPException(416, "Requested Range Not Satisfiable")
|
||||
if end >= file_size:
|
||||
end = file_size - 1
|
||||
status = 206
|
||||
except ValueError:
|
||||
raise HTTPException(400, "Invalid Range header")
|
||||
|
||||
headers["Content-Range"] = f"bytes {start}-{end}/{file_size}"
|
||||
headers["Content-Length"] = str(end - start + 1)
|
||||
else:
|
||||
headers["Content-Length"] = str(file_size)
|
||||
|
||||
async def file_iterator():
|
||||
nonlocal start, end
|
||||
token = await self._get_access_token()
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
req_headers = {
|
||||
'Authorization': f'Bearer {token}',
|
||||
'Range': f'bytes={start}-{end}'
|
||||
}
|
||||
url = f"{GOOGLE_DRIVE_API_URL}/files/{file_id}?alt=media"
|
||||
async with client.stream("GET", url, headers=req_headers) as stream_resp:
|
||||
stream_resp.raise_for_status()
|
||||
async for chunk in stream_resp.aiter_bytes():
|
||||
yield chunk
|
||||
|
||||
return StreamingResponse(file_iterator(), status_code=status, headers=headers, media_type=content_type)
|
||||
|
||||
async def stat_file(self, root: str, rel: str):
|
||||
"""
|
||||
获取文件或目录的元数据。
|
||||
:param root: 根路径。
|
||||
:param rel: 相对路径。
|
||||
:return: 格式化后的文件/目录信息。
|
||||
"""
|
||||
file_id = await self._get_file_id_by_path(rel)
|
||||
if not file_id:
|
||||
raise FileNotFoundError(rel)
|
||||
|
||||
resp = await self._request("GET", f"/files/{file_id}", params={"fields": "id, name, mimeType, size, modifiedTime, createdTime"})
|
||||
if resp.status_code == 404:
|
||||
raise FileNotFoundError(rel)
|
||||
resp.raise_for_status()
|
||||
return self._format_item(resp.json())
|
||||
|
||||
async def get_direct_download_response(self, root: str, rel: str):
|
||||
"""
|
||||
获取直接下载响应 (307 重定向)。
|
||||
:param root: 根路径。
|
||||
:param rel: 相对路径。
|
||||
:return: 307 重定向响应或 None。
|
||||
"""
|
||||
if not self.enable_redirect_307:
|
||||
return None
|
||||
|
||||
file_id = await self._get_file_id_by_path(rel)
|
||||
if not file_id:
|
||||
raise FileNotFoundError(rel)
|
||||
|
||||
# 获取文件的下载链接
|
||||
resp = await self._request("GET", f"/files/{file_id}", params={"fields": "webContentLink"})
|
||||
if resp.status_code == 404:
|
||||
raise FileNotFoundError(rel)
|
||||
resp.raise_for_status()
|
||||
|
||||
item_data = resp.json()
|
||||
download_url = item_data.get("webContentLink")
|
||||
if not download_url:
|
||||
return None
|
||||
|
||||
return Response(status_code=307, headers={"Location": download_url})
|
||||
|
||||
async def get_thumbnail(self, root: str, rel: str, size: str = "medium"):
|
||||
"""
|
||||
获取文件的缩略图。
|
||||
:param root: 根路径。
|
||||
:param rel: 相对路径。
|
||||
:param size: 缩略图大小 (暂未使用,Google Drive 自动决定)。
|
||||
:return: 缩略图内容的字节流,或在不支持时返回 None。
|
||||
"""
|
||||
file_id = await self._get_file_id_by_path(rel)
|
||||
if not file_id:
|
||||
return None
|
||||
|
||||
try:
|
||||
resp = await self._request("GET", f"/files/{file_id}", params={"fields": "thumbnailLink"})
|
||||
if resp.status_code == 200:
|
||||
item_data = resp.json()
|
||||
thumbnail_link = item_data.get("thumbnailLink")
|
||||
if thumbnail_link:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
thumb_resp = await client.get(thumbnail_link)
|
||||
thumb_resp.raise_for_status()
|
||||
return thumb_resp.content
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
ADAPTER_TYPE = "googledrive"
|
||||
|
||||
CONFIG_SCHEMA = [
|
||||
{"key": "client_id", "label": "Client ID", "type": "string", "required": True},
|
||||
{"key": "client_secret", "label": "Client Secret",
|
||||
"type": "password", "required": True},
|
||||
{"key": "refresh_token", "label": "Refresh Token", "type": "password",
|
||||
"required": True, "help_text": "可以通过 Google OAuth 2.0 Playground 获取"},
|
||||
{"key": "root_folder_id", "label": "根文件夹 ID (Root Folder ID)", "type": "string",
|
||||
"required": False, "placeholder": "默认为根目录 (root)", "default": "root"},
|
||||
{"key": "enable_direct_download_307", "label": "Enable 307 redirect download", "type": "boolean", "default": False},
|
||||
]
|
||||
|
||||
|
||||
def ADAPTER_FACTORY(rec): return GoogleDriveAdapter(rec)
|
||||
@@ -1,4 +1,3 @@
|
||||
from __future__ import annotations
|
||||
import os
|
||||
import shutil
|
||||
import stat
|
||||
@@ -10,7 +9,6 @@ import mimetypes
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import StreamingResponse, Response
|
||||
from models import StorageAdapter
|
||||
from services.logging import LogService
|
||||
|
||||
|
||||
def _safe_join(root: str, rel: str) -> Path:
|
||||
@@ -115,11 +113,32 @@ class LocalAdapter:
|
||||
await asyncio.to_thread(fp.write_bytes, data)
|
||||
if not pre_exists:
|
||||
await asyncio.to_thread(_apply_mode, fp, DEFAULT_FILE_MODE)
|
||||
await LogService.info(
|
||||
"adapter:local",
|
||||
f"Wrote file to {rel}",
|
||||
details={"adapter_id": self.record.id, "path": str(fp), "size": len(data)},
|
||||
)
|
||||
|
||||
async def write_upload_file(self, root: str, rel: str, file_obj, filename: str | None, file_size: int | None = None, content_type: str | None = None):
|
||||
fp = _safe_join(root, rel)
|
||||
pre_exists = fp.exists()
|
||||
await asyncio.to_thread(os.makedirs, fp.parent, mode=DEFAULT_DIR_MODE, exist_ok=True)
|
||||
|
||||
def _copy():
|
||||
try:
|
||||
if callable(getattr(file_obj, "seek", None)):
|
||||
file_obj.seek(0)
|
||||
except Exception:
|
||||
pass
|
||||
with open(fp, "wb") as f:
|
||||
shutil.copyfileobj(file_obj, f)
|
||||
|
||||
await asyncio.to_thread(_copy)
|
||||
if not pre_exists:
|
||||
await asyncio.to_thread(_apply_mode, fp, DEFAULT_FILE_MODE)
|
||||
|
||||
size = file_size
|
||||
if size is None:
|
||||
try:
|
||||
size = fp.stat().st_size
|
||||
except Exception:
|
||||
size = 0
|
||||
return {"size": int(size or 0)}
|
||||
|
||||
async def write_file_stream(self, root: str, rel: str, data_iter: AsyncIterator[bytes]):
|
||||
fp = _safe_join(root, rel)
|
||||
@@ -140,21 +159,11 @@ class LocalAdapter:
|
||||
await asyncio.to_thread(f.close)
|
||||
if not pre_exists:
|
||||
await asyncio.to_thread(_apply_mode, fp, DEFAULT_FILE_MODE)
|
||||
await LogService.info(
|
||||
"adapter:local",
|
||||
f"Wrote file stream to {rel}",
|
||||
details={"adapter_id": self.record.id, "path": str(fp), "size": size},
|
||||
)
|
||||
return size
|
||||
|
||||
async def mkdir(self, root: str, rel: str):
|
||||
fp = _safe_join(root, rel)
|
||||
await asyncio.to_thread(os.makedirs, fp, mode=DEFAULT_DIR_MODE, exist_ok=True)
|
||||
await LogService.info(
|
||||
"adapter:local",
|
||||
f"Created directory {rel}",
|
||||
details={"adapter_id": self.record.id, "path": str(fp)},
|
||||
)
|
||||
|
||||
async def delete(self, root: str, rel: str):
|
||||
fp = _safe_join(root, rel)
|
||||
@@ -164,11 +173,6 @@ class LocalAdapter:
|
||||
await asyncio.to_thread(shutil.rmtree, fp)
|
||||
else:
|
||||
await asyncio.to_thread(fp.unlink)
|
||||
await LogService.info(
|
||||
"adapter:local",
|
||||
f"Deleted {rel}",
|
||||
details={"adapter_id": self.record.id, "path": str(fp)},
|
||||
)
|
||||
|
||||
async def stat_path(self, root: str, rel: str):
|
||||
"""新增: 返回路径状态调试信息"""
|
||||
@@ -203,15 +207,6 @@ class LocalAdapter:
|
||||
except OSError:
|
||||
shutil.move(str(src), str(dst))
|
||||
await asyncio.to_thread(_do_move)
|
||||
await LogService.info(
|
||||
"adapter:local",
|
||||
f"Moved {src_rel} to {dst_rel}",
|
||||
details={
|
||||
"adapter_id": self.record.id,
|
||||
"src": str(src),
|
||||
"dst": str(dst),
|
||||
},
|
||||
)
|
||||
|
||||
async def rename(self, root: str, src_rel: str, dst_rel: str):
|
||||
src = _safe_join(root, src_rel)
|
||||
@@ -227,15 +222,6 @@ class LocalAdapter:
|
||||
except OSError:
|
||||
os.replace(src, dst)
|
||||
await asyncio.to_thread(_do_rename)
|
||||
await LogService.info(
|
||||
"adapter:local",
|
||||
f"Renamed {src_rel} to {dst_rel}",
|
||||
details={
|
||||
"adapter_id": self.record.id,
|
||||
"src": str(src),
|
||||
"dst": str(dst),
|
||||
},
|
||||
)
|
||||
|
||||
async def copy(self, root: str, src_rel: str, dst_rel: str, overwrite: bool = False):
|
||||
src = _safe_join(root, src_rel)
|
||||
@@ -258,15 +244,6 @@ class LocalAdapter:
|
||||
else:
|
||||
shutil.copy2(src, dst)
|
||||
await asyncio.to_thread(_do)
|
||||
await LogService.info(
|
||||
"adapter:local",
|
||||
f"Copied {src_rel} to {dst_rel}",
|
||||
details={
|
||||
"adapter_id": self.record.id,
|
||||
"src": str(src),
|
||||
"dst": str(dst),
|
||||
},
|
||||
)
|
||||
|
||||
async def stream_file(self, root: str, rel: str, range_header: str | None):
|
||||
fp = _safe_join(root, rel)
|
||||
@@ -1,8 +1,7 @@
|
||||
from __future__ import annotations
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import List, Dict, Tuple, AsyncIterator
|
||||
import httpx
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.responses import StreamingResponse, Response
|
||||
from fastapi import HTTPException
|
||||
from models import StorageAdapter
|
||||
|
||||
@@ -20,6 +19,7 @@ class OneDriveAdapter:
|
||||
self.client_secret = cfg.get("client_secret")
|
||||
self.refresh_token = cfg.get("refresh_token")
|
||||
self.root = cfg.get("root", "/").strip("/")
|
||||
self.enable_redirect_307 = bool(cfg.get("enable_direct_download_307"))
|
||||
|
||||
if not all([self.client_id, self.client_secret, self.refresh_token]):
|
||||
raise ValueError(
|
||||
@@ -380,6 +380,26 @@ class OneDriveAdapter:
|
||||
|
||||
return StreamingResponse(file_iterator(), status_code=status, headers=headers, media_type=content_type)
|
||||
|
||||
async def get_direct_download_response(self, root: str, rel: str):
|
||||
if not self.enable_redirect_307:
|
||||
return None
|
||||
|
||||
api_path = self._get_api_path(rel)
|
||||
if not api_path:
|
||||
raise IsADirectoryError("不能对目录进行直链重定向")
|
||||
|
||||
resp = await self._request("GET", api_path_segment=api_path)
|
||||
if resp.status_code == 404:
|
||||
raise FileNotFoundError(rel)
|
||||
resp.raise_for_status()
|
||||
|
||||
item_data = resp.json()
|
||||
download_url = item_data.get("@microsoft.graph.downloadUrl")
|
||||
if not download_url:
|
||||
return None
|
||||
|
||||
return Response(status_code=307, headers={"Location": download_url})
|
||||
|
||||
async def get_thumbnail(self, root: str, rel: str, size: str = "medium"):
|
||||
"""
|
||||
获取文件的缩略图。
|
||||
@@ -424,16 +444,17 @@ class OneDriveAdapter:
|
||||
return self._format_item(resp.json())
|
||||
|
||||
|
||||
ADAPTER_TYPE = "OneDrive"
|
||||
ADAPTER_TYPE = "onedrive"
|
||||
|
||||
CONFIG_SCHEMA = [
|
||||
{"key": "client_id", "label": "Client ID", "type": "string", "required": True},
|
||||
{"key": "client_secret", "label": "Client Secret",
|
||||
"type": "password", "required": True},
|
||||
{"key": "refresh_token", "label": "Refresh Token", "type": "password",
|
||||
"required": True, "help_text": "可以通过运行 'python -m services.adapters.onedrive' 获取"},
|
||||
"required": True, "help_text": "可以通过运行 'python -m domain.adapters.providers.onedrive' 获取"},
|
||||
{"key": "root", "label": "根目录 (Root Path)", "type": "string",
|
||||
"required": False, "placeholder": "默认为根目录 /"},
|
||||
{"key": "enable_direct_download_307", "label": "Enable 307 redirect download", "type": "boolean", "default": False},
|
||||
]
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
@@ -34,8 +33,15 @@ class QuarkAdapter:
|
||||
cfg = record.config or {}
|
||||
self.cookie: str = cfg.get("cookie") or cfg.get("Cookie")
|
||||
self.root_fid: str = cfg.get("root_fid", "0")
|
||||
self.use_transcoding_address: bool = bool(cfg.get("use_transcoding_address", False))
|
||||
self.only_list_video_file: bool = bool(cfg.get("only_list_video_file", False))
|
||||
def _as_bool(value: Any) -> bool:
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
return value.strip().lower() in {"1", "true", "yes", "on"}
|
||||
return bool(value)
|
||||
|
||||
self.use_transcoding_address: bool = _as_bool(cfg.get("use_transcoding_address", False))
|
||||
self.only_list_video_file: bool = _as_bool(cfg.get("only_list_video_file", False))
|
||||
|
||||
if not self.cookie:
|
||||
raise ValueError("Quark 适配器需要 cookie 配置")
|
||||
@@ -284,6 +290,11 @@ class QuarkAdapter:
|
||||
return None
|
||||
return None
|
||||
|
||||
async def get_video_transcoding_url(self, fid: str) -> Optional[str]:
|
||||
if not self.use_transcoding_address:
|
||||
return None
|
||||
return await self._get_transcoding_url(fid)
|
||||
|
||||
def _is_video_name(self, name: str) -> bool:
|
||||
mime, _ = mimetypes.guess_type(name)
|
||||
return bool(mime and mime.startswith("video/"))
|
||||
@@ -310,6 +321,29 @@ class QuarkAdapter:
|
||||
resp.raise_for_status()
|
||||
return resp.content
|
||||
|
||||
async def read_file_range(self, root: str, rel: str, start: int, end: Optional[int] = None) -> bytes:
|
||||
if not rel or rel.endswith("/"):
|
||||
raise IsADirectoryError("Path is a directory")
|
||||
parent = rel.rsplit("/", 1)[0] if "/" in rel else ""
|
||||
name = rel.rsplit("/", 1)[-1]
|
||||
base_fid = root or self.root_fid
|
||||
parent_fid = await self._resolve_dir_fid_from(base_fid, parent)
|
||||
it = await self._find_child(parent_fid, name)
|
||||
if not it or it["is_dir"]:
|
||||
raise FileNotFoundError(rel)
|
||||
|
||||
url = await self._get_download_url(it["fid"])
|
||||
headers = dict(self._download_headers())
|
||||
headers["Range"] = f"bytes={start}-" if end is None else f"bytes={start}-{end}"
|
||||
async with httpx.AsyncClient(timeout=self._timeout, follow_redirects=True) as client:
|
||||
resp = await client.get(url, headers=headers)
|
||||
if resp.status_code == 404:
|
||||
raise FileNotFoundError(rel)
|
||||
if resp.status_code == 416:
|
||||
raise HTTPException(416, detail="Requested Range Not Satisfiable")
|
||||
resp.raise_for_status()
|
||||
return resp.content
|
||||
|
||||
async def stream_file(self, root: str, rel: str, range_header: str | None):
|
||||
if not rel or rel.endswith("/"):
|
||||
raise IsADirectoryError("Path is a directory")
|
||||
@@ -419,6 +453,159 @@ class QuarkAdapter:
|
||||
yield data
|
||||
return await self.write_file_stream(root, rel, gen())
|
||||
|
||||
async def write_upload_file(self, root: str, rel: str, file_obj, filename: str | None, file_size: int | None = None, content_type: str | None = None):
|
||||
if not rel or rel.endswith("/"):
|
||||
raise HTTPException(400, detail="Invalid file path")
|
||||
|
||||
parent = rel.rsplit("/", 1)[0] if "/" in rel else ""
|
||||
name = filename or rel.rsplit("/", 1)[-1]
|
||||
base_fid = root or self.root_fid
|
||||
parent_fid = await self._resolve_dir_fid_from(base_fid, parent)
|
||||
|
||||
md5 = hashlib.md5()
|
||||
sha1 = hashlib.sha1()
|
||||
total = 0
|
||||
try:
|
||||
if callable(getattr(file_obj, "seek", None)):
|
||||
file_obj.seek(0)
|
||||
except Exception:
|
||||
pass
|
||||
while True:
|
||||
chunk = file_obj.read(1024 * 1024)
|
||||
if not chunk:
|
||||
break
|
||||
total += len(chunk)
|
||||
md5.update(chunk)
|
||||
sha1.update(chunk)
|
||||
|
||||
md5_hex = md5.hexdigest()
|
||||
sha1_hex = sha1.hexdigest()
|
||||
|
||||
# 预上传,拿到上传信息
|
||||
pre_resp = await self._upload_pre(name, total, parent_fid)
|
||||
pre_data = pre_resp.get("data", {})
|
||||
|
||||
# hash 秒传
|
||||
hash_body = {"md5": md5_hex, "sha1": sha1_hex, "task_id": pre_data.get("task_id")}
|
||||
hash_resp = await self._request("POST", "/file/update/hash", json=hash_body)
|
||||
if (hash_resp.get("data") or {}).get("finish") is True:
|
||||
self._invalidate_children_cache(parent_fid)
|
||||
return {"size": total}
|
||||
|
||||
# 分片上传
|
||||
part_size = int((pre_resp.get("metadata") or {}).get("part_size") or 0)
|
||||
if part_size <= 0:
|
||||
raise HTTPException(502, detail="Invalid part_size from Quark")
|
||||
|
||||
bucket = pre_data.get("bucket")
|
||||
obj_key = pre_data.get("obj_key")
|
||||
upload_id = pre_data.get("upload_id")
|
||||
upload_url = pre_data.get("upload_url")
|
||||
if not (bucket and obj_key and upload_id and upload_url):
|
||||
raise HTTPException(502, detail="Upload pre missing fields")
|
||||
|
||||
try:
|
||||
upload_host = upload_url.split("://", 1)[1]
|
||||
except Exception:
|
||||
upload_host = upload_url
|
||||
base_url = f"https://{bucket}.{upload_host}/{obj_key}"
|
||||
|
||||
try:
|
||||
if callable(getattr(file_obj, "seek", None)):
|
||||
file_obj.seek(0)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
etags: List[str] = []
|
||||
oss_ua = "aliyun-sdk-js/6.6.1 Chrome 98.0.4758.80 on Windows 10 64-bit"
|
||||
async with httpx.AsyncClient(timeout=None, follow_redirects=True) as client:
|
||||
part_number = 1
|
||||
left = total
|
||||
while left > 0:
|
||||
sz = min(part_size, left)
|
||||
data_bytes = file_obj.read(sz)
|
||||
if len(data_bytes) != sz:
|
||||
raise IOError("Failed to read part bytes")
|
||||
now_str = time.strftime("%a, %d %b %Y %H:%M:%S GMT", time.gmtime())
|
||||
auth_meta = (
|
||||
"PUT\n\n"
|
||||
f"{self._guess_mime(name)}\n"
|
||||
f"{now_str}\n"
|
||||
f"x-oss-date:{now_str}\n"
|
||||
f"x-oss-user-agent:{oss_ua}\n"
|
||||
f"/{bucket}/{obj_key}?partNumber={part_number}&uploadId={upload_id}"
|
||||
)
|
||||
auth_req_body = {"auth_info": pre_data.get("auth_info"), "auth_meta": auth_meta, "task_id": pre_data.get("task_id")}
|
||||
auth_resp = await self._request("POST", "/file/upload/auth", json=auth_req_body)
|
||||
auth_key = (auth_resp.get("data") or {}).get("auth_key")
|
||||
if not auth_key:
|
||||
raise HTTPException(502, detail="upload/auth missing auth_key")
|
||||
|
||||
put_headers = {
|
||||
"Authorization": auth_key,
|
||||
"Content-Type": self._guess_mime(name),
|
||||
"Referer": REFERER + "/",
|
||||
"x-oss-date": now_str,
|
||||
"x-oss-user-agent": oss_ua,
|
||||
}
|
||||
put_url = f"{base_url}?partNumber={part_number}&uploadId={upload_id}"
|
||||
put_resp = await client.put(put_url, headers=put_headers, content=data_bytes)
|
||||
if put_resp.status_code != 200:
|
||||
raise HTTPException(502, detail=f"Upload part failed status={put_resp.status_code} text={put_resp.text}")
|
||||
etag = put_resp.headers.get("Etag", "")
|
||||
etags.append(etag)
|
||||
left -= sz
|
||||
part_number += 1
|
||||
|
||||
parts_xml = [f"<Part>\n<PartNumber>{i+1}</PartNumber>\n<ETag>{etags[i]}</ETag>\n</Part>\n" for i in range(len(etags))]
|
||||
body_xml = "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<CompleteMultipartUpload>\n" + "".join(parts_xml) + "</CompleteMultipartUpload>"
|
||||
content_md5 = base64.b64encode(hashlib.md5(body_xml.encode("utf-8")).digest()).decode("ascii")
|
||||
callback = pre_data.get("callback") or {}
|
||||
try:
|
||||
import json as _json
|
||||
callback_b64 = base64.b64encode(_json.dumps(callback).encode("utf-8")).decode("ascii")
|
||||
except Exception:
|
||||
callback_b64 = ""
|
||||
|
||||
now_str = time.strftime("%a, %d %b %Y %H:%M:%S GMT", time.gmtime())
|
||||
auth_meta_commit = (
|
||||
"POST\n"
|
||||
f"{content_md5}\n"
|
||||
"application/xml\n"
|
||||
f"{now_str}\n"
|
||||
f"x-oss-callback:{callback_b64}\n"
|
||||
f"x-oss-date:{now_str}\n"
|
||||
f"x-oss-user-agent:{oss_ua}\n"
|
||||
f"/{bucket}/{obj_key}?uploadId={upload_id}"
|
||||
)
|
||||
auth_commit_resp = await self._request("POST", "/file/upload/auth", json={"auth_info": pre_data.get("auth_info"), "auth_meta": auth_meta_commit, "task_id": pre_data.get("task_id")})
|
||||
auth_key_commit = (auth_commit_resp.get("data") or {}).get("auth_key")
|
||||
if not auth_key_commit:
|
||||
raise HTTPException(502, detail="upload/auth(commit) missing auth_key")
|
||||
|
||||
async with httpx.AsyncClient(timeout=None, follow_redirects=True) as client:
|
||||
commit_headers = {
|
||||
"Authorization": auth_key_commit,
|
||||
"Content-MD5": content_md5,
|
||||
"Content-Type": "application/xml",
|
||||
"Referer": REFERER + "/",
|
||||
"x-oss-callback": callback_b64,
|
||||
"x-oss-date": now_str,
|
||||
"x-oss-user-agent": oss_ua,
|
||||
}
|
||||
commit_url = f"{base_url}?uploadId={upload_id}"
|
||||
r = await client.post(commit_url, headers=commit_headers, content=body_xml.encode("utf-8"))
|
||||
if r.status_code != 200:
|
||||
raise HTTPException(502, detail=f"Upload commit failed status={r.status_code} text={r.text}")
|
||||
|
||||
await self._request("POST", "/file/upload/finish", json={"obj_key": obj_key, "task_id": pre_data.get("task_id")})
|
||||
try:
|
||||
await asyncio.sleep(1.0)
|
||||
except Exception:
|
||||
pass
|
||||
self._invalidate_children_cache(parent_fid)
|
||||
return {"size": total}
|
||||
|
||||
async def write_file_stream(self, root: str, rel: str, data_iter: AsyncIterator[bytes]):
|
||||
if not rel or rel.endswith("/"):
|
||||
raise HTTPException(400, detail="Invalid file path")
|
||||
@@ -711,13 +898,13 @@ class QuarkAdapter:
|
||||
return it["fid"]
|
||||
|
||||
|
||||
ADAPTER_TYPE = "Quark"
|
||||
ADAPTER_TYPE = "quark"
|
||||
|
||||
CONFIG_SCHEMA = [
|
||||
{"key": "cookie", "label": "Cookie", "type": "password", "required": True, "placeholder": "从 pan.quark.cn 复制"},
|
||||
{"key": "root_fid", "label": "根 FID", "type": "string", "required": False, "default": "0"},
|
||||
{"key": "use_transcoding_address", "label": "视频转码直链", "type": "checkbox", "required": False, "default": False},
|
||||
{"key": "only_list_video_file", "label": "仅列出视频文件", "type": "checkbox", "required": False, "default": False},
|
||||
{"key": "use_transcoding_address", "label": "视频转码直链", "type": "boolean", "required": False, "default": False},
|
||||
{"key": "only_list_video_file", "label": "仅列出视频文件", "type": "boolean", "required": False, "default": False},
|
||||
]
|
||||
|
||||
def ADAPTER_FACTORY(rec: StorageAdapter) -> BaseAdapter:
|
||||
@@ -1,4 +1,3 @@
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import mimetypes
|
||||
from datetime import datetime
|
||||
@@ -10,7 +9,6 @@ from botocore.exceptions import ClientError
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from models import StorageAdapter
|
||||
from services.logging import LogService
|
||||
|
||||
|
||||
class S3Adapter:
|
||||
@@ -127,11 +125,6 @@ class S3Adapter:
|
||||
key = self._get_s3_key(rel)
|
||||
async with self._get_client() as s3:
|
||||
await s3.put_object(Bucket=self.bucket_name, Key=key, Body=data)
|
||||
await LogService.info(
|
||||
"adapter:s3", f"Wrote file to {rel}",
|
||||
details={"adapter_id": self.record.id,
|
||||
"bucket": self.bucket_name, "key": key, "size": len(data)}
|
||||
)
|
||||
|
||||
async def write_file_stream(self, root: str, rel: str, data_iter: AsyncIterator[bytes]):
|
||||
key = self._get_s3_key(rel)
|
||||
@@ -193,10 +186,6 @@ class S3Adapter:
|
||||
)
|
||||
raise IOError(f"S3 stream upload failed: {e}") from e
|
||||
|
||||
await LogService.info(
|
||||
"adapter:s3", f"Wrote file stream to {rel}",
|
||||
details={"adapter_id": self.record.id, "bucket": self.bucket_name, "key": key, "size": total_size}
|
||||
)
|
||||
return total_size
|
||||
|
||||
async def mkdir(self, root: str, rel: str):
|
||||
@@ -205,11 +194,6 @@ class S3Adapter:
|
||||
key += "/"
|
||||
async with self._get_client() as s3:
|
||||
await s3.put_object(Bucket=self.bucket_name, Key=key, Body=b"")
|
||||
await LogService.info(
|
||||
"adapter:s3", f"Created directory {rel}",
|
||||
details={"adapter_id": self.record.id,
|
||||
"bucket": self.bucket_name, "key": key}
|
||||
)
|
||||
|
||||
async def delete(self, root: str, rel: str):
|
||||
key = self._get_s3_key(rel)
|
||||
@@ -237,20 +221,9 @@ class S3Adapter:
|
||||
else:
|
||||
await s3.delete_object(Bucket=self.bucket_name, Key=key)
|
||||
|
||||
await LogService.info(
|
||||
"adapter:s3", f"Deleted {rel}",
|
||||
details={"adapter_id": self.record.id,
|
||||
"bucket": self.bucket_name, "key": key}
|
||||
)
|
||||
|
||||
async def move(self, root: str, src_rel: str, dst_rel: str):
|
||||
await self.copy(root, src_rel, dst_rel, overwrite=True)
|
||||
await self.delete(root, src_rel)
|
||||
await LogService.info(
|
||||
"adapter:s3", f"Moved {src_rel} to {dst_rel}",
|
||||
details={"adapter_id": self.record.id, "bucket": self.bucket_name,
|
||||
"src_key": self._get_s3_key(src_rel), "dst_key": self._get_s3_key(dst_rel)}
|
||||
)
|
||||
|
||||
async def rename(self, root: str, src_rel: str, dst_rel: str):
|
||||
await self.move(root, src_rel, dst_rel)
|
||||
@@ -270,11 +243,6 @@ class S3Adapter:
|
||||
|
||||
copy_source = {"Bucket": self.bucket_name, "Key": src_key}
|
||||
await s3.copy_object(CopySource=copy_source, Bucket=self.bucket_name, Key=dst_key)
|
||||
await LogService.info(
|
||||
"adapter:s3", f"Copied {src_rel} to {dst_rel}",
|
||||
details={"adapter_id": self.record.id, "bucket": self.bucket_name,
|
||||
"src_key": src_key, "dst_key": dst_key}
|
||||
)
|
||||
|
||||
async def stat_file(self, root: str, rel: str):
|
||||
key = self._get_s3_key(rel)
|
||||
@@ -353,13 +321,12 @@ class S3Adapter:
|
||||
while chunk := await body.read(65536):
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
LogService.error(
|
||||
"adapter:s3", f"Error streaming file {key}: {e}")
|
||||
raise
|
||||
|
||||
return StreamingResponse(iterator(), status_code=status, headers=headers, media_type=content_type)
|
||||
|
||||
|
||||
ADAPTER_TYPE = "S3"
|
||||
ADAPTER_TYPE = "s3"
|
||||
|
||||
CONFIG_SCHEMA = [
|
||||
{"key": "bucket_name", "label": "Bucket 名称",
|
||||
473
domain/adapters/providers/sftp.py
Normal file
473
domain/adapters/providers/sftp.py
Normal file
@@ -0,0 +1,473 @@
|
||||
import asyncio
|
||||
import mimetypes
|
||||
import stat as statmod
|
||||
from typing import List, Dict, Tuple, AsyncIterator, Optional
|
||||
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
import paramiko
|
||||
|
||||
from models import StorageAdapter
|
||||
|
||||
|
||||
def _join_remote(root: str, rel: str) -> str:
|
||||
root = (root or "/").rstrip("/") or "/"
|
||||
rel = (rel or "").lstrip("/")
|
||||
if not rel:
|
||||
return root
|
||||
return f"{root}/{rel}"
|
||||
|
||||
|
||||
class SFTPAdapter:
|
||||
def __init__(self, record: StorageAdapter):
|
||||
self.record = record
|
||||
cfg = record.config
|
||||
self.host: str = cfg.get("host")
|
||||
self.port: int = int(cfg.get("port", 22))
|
||||
self.username: str | None = cfg.get("username")
|
||||
self.password: str | None = cfg.get("password")
|
||||
self.timeout: int = int(cfg.get("timeout", 15))
|
||||
self.root_path: str = cfg.get("root") # 必填
|
||||
self.allow_unknown_host: bool = bool(cfg.get("allow_unknown_host", True))
|
||||
|
||||
if not self.host:
|
||||
raise ValueError("SFTP adapter requires 'host'")
|
||||
if not self.username or not self.password:
|
||||
raise ValueError("SFTP adapter requires 'username' and 'password'")
|
||||
if not self.root_path:
|
||||
raise ValueError("SFTP adapter requires 'root'")
|
||||
|
||||
def get_effective_root(self, sub_path: str | None) -> str:
|
||||
base = self.root_path.rstrip("/") or "/"
|
||||
if sub_path:
|
||||
return _join_remote(base, sub_path)
|
||||
return base
|
||||
|
||||
def _connect(self) -> paramiko.SFTPClient:
|
||||
ssh = paramiko.SSHClient()
|
||||
if self.allow_unknown_host:
|
||||
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
||||
ssh.connect(
|
||||
hostname=self.host,
|
||||
port=self.port,
|
||||
username=self.username,
|
||||
password=self.password,
|
||||
timeout=self.timeout,
|
||||
allow_agent=False,
|
||||
look_for_keys=False,
|
||||
)
|
||||
return ssh.open_sftp()
|
||||
|
||||
async def list_dir(self, root: str, rel: str, page_num: int = 1, page_size: int = 50, sort_by: str = "name", sort_order: str = "asc") -> Tuple[List[Dict], int]:
|
||||
path = _join_remote(root, rel)
|
||||
|
||||
def _do_list() -> List[Dict]:
|
||||
sftp = self._connect()
|
||||
try:
|
||||
attrs = sftp.listdir_attr(path)
|
||||
entries: List[Dict] = []
|
||||
for a in attrs:
|
||||
name = a.filename
|
||||
is_dir = statmod.S_ISDIR(a.st_mode)
|
||||
entries.append({
|
||||
"name": name,
|
||||
"is_dir": is_dir,
|
||||
"size": 0 if is_dir else int(a.st_size or 0),
|
||||
"mtime": int(a.st_mtime or 0),
|
||||
"type": "dir" if is_dir else "file",
|
||||
})
|
||||
return entries
|
||||
finally:
|
||||
try:
|
||||
sftp.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
entries = await asyncio.to_thread(_do_list)
|
||||
|
||||
reverse = sort_order.lower() == "desc"
|
||||
|
||||
def get_sort_key(item):
|
||||
key = (not item["is_dir"],)
|
||||
f = sort_by.lower()
|
||||
if f == "name":
|
||||
key += (item["name"].lower(),)
|
||||
elif f == "size":
|
||||
key += (item.get("size", 0),)
|
||||
elif f == "mtime":
|
||||
key += (item.get("mtime", 0),)
|
||||
else:
|
||||
key += (item["name"].lower(),)
|
||||
return key
|
||||
|
||||
entries.sort(key=get_sort_key, reverse=reverse)
|
||||
total = len(entries)
|
||||
start = (page_num - 1) * page_size
|
||||
end = start + page_size
|
||||
return entries[start:end], total
|
||||
|
||||
async def read_file(self, root: str, rel: str) -> bytes:
|
||||
path = _join_remote(root, rel)
|
||||
|
||||
def _do_read() -> bytes:
|
||||
sftp = self._connect()
|
||||
try:
|
||||
with sftp.open(path, "rb") as f:
|
||||
return f.read()
|
||||
except FileNotFoundError:
|
||||
raise
|
||||
except IOError as e:
|
||||
if getattr(e, "errno", None) == 2:
|
||||
raise FileNotFoundError(rel)
|
||||
raise
|
||||
finally:
|
||||
try:
|
||||
sftp.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return await asyncio.to_thread(_do_read)
|
||||
|
||||
async def write_file(self, root: str, rel: str, data: bytes):
|
||||
path = _join_remote(root, rel)
|
||||
|
||||
def _ensure_dirs(sftp: paramiko.SFTPClient, dir_path: str):
|
||||
parts = [p for p in dir_path.strip("/").split("/") if p]
|
||||
cur = "/"
|
||||
for p in parts:
|
||||
cur = _join_remote(cur, p)
|
||||
try:
|
||||
sftp.mkdir(cur)
|
||||
except IOError:
|
||||
# likely exists
|
||||
pass
|
||||
|
||||
def _do_write():
|
||||
sftp = self._connect()
|
||||
try:
|
||||
parent = "/" if "/" not in path.strip("/") else path.rsplit("/", 1)[0]
|
||||
_ensure_dirs(sftp, parent)
|
||||
with sftp.open(path, "wb") as f:
|
||||
f.write(data)
|
||||
finally:
|
||||
try:
|
||||
sftp.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await asyncio.to_thread(_do_write)
|
||||
|
||||
async def write_upload_file(self, root: str, rel: str, file_obj, filename: str | None, file_size: int | None = None, content_type: str | None = None):
|
||||
path = _join_remote(root, rel)
|
||||
|
||||
def _ensure_dirs(sftp: paramiko.SFTPClient, dir_path: str):
|
||||
parts = [p for p in dir_path.strip("/").split("/") if p]
|
||||
cur = "/"
|
||||
for p in parts:
|
||||
cur = _join_remote(cur, p)
|
||||
try:
|
||||
sftp.mkdir(cur)
|
||||
except IOError:
|
||||
pass
|
||||
|
||||
def _do_upload():
|
||||
sftp = self._connect()
|
||||
try:
|
||||
parent = "/" if "/" not in path.strip("/") else path.rsplit("/", 1)[0]
|
||||
_ensure_dirs(sftp, parent)
|
||||
try:
|
||||
if callable(getattr(file_obj, "seek", None)):
|
||||
file_obj.seek(0)
|
||||
except Exception:
|
||||
pass
|
||||
with sftp.open(path, "wb") as f:
|
||||
import shutil
|
||||
shutil.copyfileobj(file_obj, f)
|
||||
finally:
|
||||
try:
|
||||
sftp.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await asyncio.to_thread(_do_upload)
|
||||
return {"size": file_size or 0}
|
||||
|
||||
async def write_file_stream(self, root: str, rel: str, data_iter: AsyncIterator[bytes]):
|
||||
buf = bytearray()
|
||||
async for chunk in data_iter:
|
||||
if chunk:
|
||||
buf.extend(chunk)
|
||||
await self.write_file(root, rel, bytes(buf))
|
||||
return len(buf)
|
||||
|
||||
async def mkdir(self, root: str, rel: str):
|
||||
path = _join_remote(root, rel)
|
||||
|
||||
def _do_mkdir():
|
||||
sftp = self._connect()
|
||||
try:
|
||||
parts = [p for p in path.strip("/").split("/") if p]
|
||||
cur = "/"
|
||||
for p in parts:
|
||||
cur = _join_remote(cur, p)
|
||||
try:
|
||||
sftp.mkdir(cur)
|
||||
except IOError:
|
||||
pass
|
||||
finally:
|
||||
try:
|
||||
sftp.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await asyncio.to_thread(_do_mkdir)
|
||||
|
||||
async def delete(self, root: str, rel: str):
|
||||
path = _join_remote(root, rel)
|
||||
|
||||
def _do_delete():
|
||||
sftp = self._connect()
|
||||
try:
|
||||
# Try file remove first
|
||||
try:
|
||||
sftp.remove(path)
|
||||
return
|
||||
except IOError:
|
||||
pass
|
||||
|
||||
def _rm_tree(dp: str):
|
||||
try:
|
||||
for a in sftp.listdir_attr(dp):
|
||||
child = _join_remote(dp, a.filename)
|
||||
if statmod.S_ISDIR(a.st_mode):
|
||||
_rm_tree(child)
|
||||
else:
|
||||
try:
|
||||
sftp.remove(child)
|
||||
except Exception:
|
||||
pass
|
||||
sftp.rmdir(dp)
|
||||
except IOError:
|
||||
pass
|
||||
|
||||
_rm_tree(path)
|
||||
finally:
|
||||
try:
|
||||
sftp.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await asyncio.to_thread(_do_delete)
|
||||
|
||||
async def move(self, root: str, src_rel: str, dst_rel: str):
|
||||
src = _join_remote(root, src_rel)
|
||||
dst = _join_remote(root, dst_rel)
|
||||
|
||||
def _do_move():
|
||||
sftp = self._connect()
|
||||
try:
|
||||
# ensure dst parent exists
|
||||
parent = "/" if "/" not in dst.strip("/") else dst.rsplit("/", 1)[0]
|
||||
parts = [p for p in parent.strip("/").split("/") if p]
|
||||
cur = "/"
|
||||
for p in parts:
|
||||
cur = _join_remote(cur, p)
|
||||
try:
|
||||
sftp.mkdir(cur)
|
||||
except IOError:
|
||||
pass
|
||||
sftp.rename(src, dst)
|
||||
finally:
|
||||
try:
|
||||
sftp.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await asyncio.to_thread(_do_move)
|
||||
|
||||
async def rename(self, root: str, src_rel: str, dst_rel: str):
|
||||
await self.move(root, src_rel, dst_rel)
|
||||
|
||||
async def copy(self, root: str, src_rel: str, dst_rel: str, overwrite: bool = False):
|
||||
src = _join_remote(root, src_rel)
|
||||
dst = _join_remote(root, dst_rel)
|
||||
|
||||
def _is_dir() -> bool:
|
||||
sftp = self._connect()
|
||||
try:
|
||||
st = sftp.stat(src)
|
||||
return statmod.S_ISDIR(st.st_mode)
|
||||
finally:
|
||||
try:
|
||||
sftp.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if await asyncio.to_thread(_is_dir):
|
||||
await self.mkdir(root, dst_rel)
|
||||
|
||||
children, _ = await self.list_dir(root, src_rel, page_num=1, page_size=10_000)
|
||||
for ent in children:
|
||||
child_src = f"{src_rel.rstrip('/')}/{ent['name']}"
|
||||
child_dst = f"{dst_rel.rstrip('/')}/{ent['name']}"
|
||||
await self.copy(root, child_src, child_dst, overwrite)
|
||||
return
|
||||
|
||||
# file copy
|
||||
data = await self.read_file(root, src_rel)
|
||||
if not overwrite:
|
||||
try:
|
||||
await self.stat_file(root, dst_rel)
|
||||
raise FileExistsError(dst_rel)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
await self.write_file(root, dst_rel, data)
|
||||
|
||||
async def stat_file(self, root: str, rel: str):
|
||||
path = _join_remote(root, rel)
|
||||
|
||||
def _do_stat():
|
||||
sftp = self._connect()
|
||||
try:
|
||||
st = sftp.stat(path)
|
||||
is_dir = statmod.S_ISDIR(st.st_mode)
|
||||
info = {
|
||||
"name": rel.split("/")[-1],
|
||||
"is_dir": is_dir,
|
||||
"size": 0 if is_dir else int(st.st_size or 0),
|
||||
"mtime": int(st.st_mtime or 0),
|
||||
"type": "dir" if is_dir else "file",
|
||||
"path": path,
|
||||
}
|
||||
return info
|
||||
except FileNotFoundError:
|
||||
raise
|
||||
except IOError as e:
|
||||
if getattr(e, "errno", None) == 2:
|
||||
raise FileNotFoundError(rel)
|
||||
raise
|
||||
finally:
|
||||
try:
|
||||
sftp.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return await asyncio.to_thread(_do_stat)
|
||||
|
||||
async def exists(self, root: str, rel: str) -> bool:
|
||||
try:
|
||||
await self.stat_file(root, rel)
|
||||
return True
|
||||
except FileNotFoundError:
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def stream_file(self, root: str, rel: str, range_header: str | None):
|
||||
path = _join_remote(root, rel)
|
||||
|
||||
def _get_stat():
|
||||
sftp = self._connect()
|
||||
try:
|
||||
st = sftp.stat(path)
|
||||
return int(st.st_size or 0)
|
||||
finally:
|
||||
try:
|
||||
sftp.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
file_size = await asyncio.to_thread(_get_stat)
|
||||
if file_size is None:
|
||||
raise HTTPException(404, detail="File not found")
|
||||
|
||||
mime, _ = mimetypes.guess_type(rel)
|
||||
content_type = mime or "application/octet-stream"
|
||||
|
||||
start = 0
|
||||
end = file_size - 1
|
||||
status = 200
|
||||
headers = {
|
||||
"Accept-Ranges": "bytes",
|
||||
"Content-Type": content_type,
|
||||
"Content-Length": str(file_size),
|
||||
}
|
||||
|
||||
if range_header and range_header.startswith("bytes="):
|
||||
try:
|
||||
s, e = (range_header.removeprefix("bytes=").split("-", 1))
|
||||
if s.strip():
|
||||
start = int(s)
|
||||
if e.strip():
|
||||
end = int(e)
|
||||
if start >= file_size:
|
||||
raise HTTPException(416, detail="Requested Range Not Satisfiable")
|
||||
if end >= file_size:
|
||||
end = file_size - 1
|
||||
status = 206
|
||||
headers["Content-Length"] = str(end - start + 1)
|
||||
headers["Content-Range"] = f"bytes {start}-{end}/{file_size}"
|
||||
except ValueError:
|
||||
raise HTTPException(400, detail="Invalid Range header")
|
||||
|
||||
queue: asyncio.Queue[Optional[bytes]] = asyncio.Queue(maxsize=8)
|
||||
|
||||
def _worker():
|
||||
sftp = self._connect()
|
||||
try:
|
||||
with sftp.open(path, "rb") as f:
|
||||
f.seek(start)
|
||||
remaining = end - start + 1
|
||||
chunk_size = 64 * 1024
|
||||
while remaining > 0:
|
||||
to_read = chunk_size if remaining > chunk_size else remaining
|
||||
data = f.read(to_read)
|
||||
if not data:
|
||||
break
|
||||
try:
|
||||
queue.put_nowait(data)
|
||||
except Exception:
|
||||
break
|
||||
remaining -= len(data)
|
||||
try:
|
||||
queue.put_nowait(None)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
try:
|
||||
sftp.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def agen():
|
||||
worker_fut = asyncio.to_thread(_worker)
|
||||
try:
|
||||
while True:
|
||||
chunk = await queue.get()
|
||||
if chunk is None:
|
||||
break
|
||||
yield chunk
|
||||
finally:
|
||||
try:
|
||||
await worker_fut
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return StreamingResponse(agen(), status_code=status, headers=headers, media_type=content_type)
|
||||
|
||||
|
||||
ADAPTER_TYPE = "sftp"
|
||||
|
||||
CONFIG_SCHEMA = [
|
||||
{"key": "host", "label": "主机", "type": "string", "required": True, "placeholder": "sftp.example.com"},
|
||||
{"key": "port", "label": "端口", "type": "number", "required": False, "default": 22},
|
||||
{"key": "username", "label": "用户名", "type": "string", "required": True},
|
||||
{"key": "password", "label": "密码", "type": "password", "required": True},
|
||||
{"key": "root", "label": "根路径", "type": "string", "required": True, "placeholder": "/data"},
|
||||
{"key": "timeout", "label": "超时(秒)", "type": "number", "required": False, "default": 15},
|
||||
{"key": "allow_unknown_host", "label": "允许未知主机指纹", "type": "boolean", "required": False, "default": True},
|
||||
]
|
||||
|
||||
|
||||
def ADAPTER_FACTORY(rec: StorageAdapter):
|
||||
return SFTPAdapter(rec)
|
||||
@@ -1,14 +1,52 @@
|
||||
from __future__ import annotations
|
||||
from typing import List, Dict, Tuple, AsyncIterator
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import os
|
||||
import struct
|
||||
from models import StorageAdapter
|
||||
from telethon import TelegramClient
|
||||
from telethon.crypto import AuthKey
|
||||
from telethon.sessions import StringSession
|
||||
from telethon.tl import types
|
||||
import socks
|
||||
|
||||
_SESSION_LOCKS: Dict[str, asyncio.Lock] = {}
|
||||
|
||||
|
||||
def _get_session_lock(session_string: str) -> asyncio.Lock:
|
||||
lock = _SESSION_LOCKS.get(session_string)
|
||||
if lock is None:
|
||||
lock = asyncio.Lock()
|
||||
_SESSION_LOCKS[session_string] = lock
|
||||
return lock
|
||||
|
||||
|
||||
class _NamedFile:
|
||||
def __init__(self, file_obj, name: str):
|
||||
self._file = file_obj
|
||||
self.name = name
|
||||
|
||||
def read(self, *args, **kwargs):
|
||||
return self._file.read(*args, **kwargs)
|
||||
|
||||
def seek(self, *args, **kwargs):
|
||||
return self._file.seek(*args, **kwargs)
|
||||
|
||||
def tell(self):
|
||||
return self._file.tell()
|
||||
|
||||
def seekable(self):
|
||||
return self._file.seekable()
|
||||
|
||||
def close(self):
|
||||
return self._file.close()
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._file, name)
|
||||
|
||||
# 适配器类型标识
|
||||
ADAPTER_TYPE = "Telegram"
|
||||
ADAPTER_TYPE = "telegram"
|
||||
|
||||
# 适配器配置项定义
|
||||
CONFIG_SCHEMA = [
|
||||
@@ -55,9 +93,93 @@ class TelegramAdapter:
|
||||
if not all([self.api_id, self.api_hash, self.session_string, self.chat_id]):
|
||||
raise ValueError("Telegram 适配器需要 api_id, api_hash, session_string 和 chat_id")
|
||||
|
||||
@staticmethod
|
||||
def _parse_legacy_session_string(value: str) -> StringSession:
|
||||
"""
|
||||
兼容旧版 session_string 格式:
|
||||
- version(1B char) + base64(data)
|
||||
- data: dc_id(1B) + ip_len(2B) + ip(ASCII, ip_len bytes) + port(2B) + auth_key(256B)
|
||||
"""
|
||||
s = (value or "").strip()
|
||||
if not s:
|
||||
raise ValueError("session_string 为空")
|
||||
|
||||
body = s[1:] if s.startswith("1") else s
|
||||
raw = base64.urlsafe_b64decode(body)
|
||||
if len(raw) < 1 + 2 + 2 + 256:
|
||||
raise ValueError("legacy session 数据长度不足")
|
||||
|
||||
dc_id = raw[0]
|
||||
ip_len = struct.unpack(">H", raw[1:3])[0]
|
||||
expected_len = 1 + 2 + ip_len + 2 + 256
|
||||
if len(raw) != expected_len:
|
||||
raise ValueError("legacy session 数据长度不匹配")
|
||||
|
||||
ip_start = 3
|
||||
ip_end = ip_start + ip_len
|
||||
ip = raw[ip_start:ip_end].decode("utf-8")
|
||||
port = struct.unpack(">H", raw[ip_end : ip_end + 2])[0]
|
||||
key = raw[ip_end + 2 : ip_end + 2 + 256]
|
||||
|
||||
sess = StringSession()
|
||||
sess.set_dc(dc_id, ip, port)
|
||||
sess.auth_key = AuthKey(key)
|
||||
return sess
|
||||
|
||||
@staticmethod
|
||||
def _pick_photo_thumb(thumbs: list | None):
|
||||
if not thumbs:
|
||||
return None
|
||||
|
||||
cached = []
|
||||
others = []
|
||||
for t in thumbs:
|
||||
if isinstance(t, (types.PhotoCachedSize, types.PhotoStrippedSize)):
|
||||
cached.append(t)
|
||||
elif isinstance(t, (types.PhotoSize, types.PhotoSizeProgressive)):
|
||||
if not isinstance(t, types.PhotoSizeEmpty):
|
||||
others.append(t)
|
||||
|
||||
if cached:
|
||||
cached.sort(key=lambda x: len(getattr(x, "bytes", b"") or b""))
|
||||
return cached[-1]
|
||||
|
||||
if others:
|
||||
def _sz(x):
|
||||
if isinstance(x, types.PhotoSizeProgressive):
|
||||
return max(x.sizes or [0])
|
||||
return int(getattr(x, "size", 0) or 0)
|
||||
|
||||
others.sort(key=_sz)
|
||||
return others[-1]
|
||||
|
||||
return None
|
||||
|
||||
def _build_session(self) -> StringSession:
|
||||
s = (self.session_string or "").strip()
|
||||
if not s:
|
||||
raise ValueError("Telegram 适配器 session_string 为空")
|
||||
|
||||
try:
|
||||
return StringSession(s)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 少数工具可能去掉了 version 前缀,这里做一次兼容
|
||||
if not s.startswith("1"):
|
||||
try:
|
||||
return StringSession("1" + s)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
return self._parse_legacy_session_string(s)
|
||||
except Exception as exc:
|
||||
raise ValueError("Telegram session_string 无效,请使用 Telethon StringSession 重新生成") from exc
|
||||
|
||||
def _get_client(self) -> TelegramClient:
|
||||
"""创建一个新的 TelegramClient 实例"""
|
||||
return TelegramClient(StringSession(self.session_string), self.api_id, self.api_hash, proxy=self.proxy)
|
||||
return TelegramClient(self._build_session(), self.api_id, self.api_hash, proxy=self.proxy)
|
||||
|
||||
def get_effective_root(self, sub_path: str | None) -> str:
|
||||
return ""
|
||||
@@ -74,33 +196,32 @@ class TelegramAdapter:
|
||||
for message in messages:
|
||||
if not message:
|
||||
continue
|
||||
|
||||
|
||||
media = message.document or message.video or message.photo
|
||||
if not media:
|
||||
continue
|
||||
|
||||
filename = None
|
||||
size = 0
|
||||
|
||||
if message.photo:
|
||||
photo_size = message.photo.sizes[-1]
|
||||
size = photo_size.size if hasattr(photo_size, 'size') else 0
|
||||
filename = f"photo_{message.id}.jpg"
|
||||
file_meta = message.file
|
||||
if not file_meta:
|
||||
continue
|
||||
|
||||
elif message.document or message.video:
|
||||
size = media.size
|
||||
if hasattr(media, 'attributes'):
|
||||
for attr in media.attributes:
|
||||
if hasattr(attr, 'file_name') and attr.file_name:
|
||||
filename = attr.file_name
|
||||
break
|
||||
|
||||
filename = file_meta.name
|
||||
if not filename:
|
||||
if message.text and '.' in message.text and len(message.text) < 256 and '\n' not in message.text:
|
||||
filename = message.text
|
||||
|
||||
if not filename:
|
||||
filename = f"unknown_{message.id}"
|
||||
else:
|
||||
filename = f"unknown_{message.id}"
|
||||
|
||||
size = file_meta.size
|
||||
if size is None:
|
||||
# 兼容缺失 size 的情况
|
||||
if hasattr(media, "size") and media.size is not None:
|
||||
size = media.size
|
||||
elif message.photo and getattr(message.photo, "sizes", None):
|
||||
photo_size = message.photo.sizes[-1]
|
||||
size = getattr(photo_size, "size", 0) or 0
|
||||
else:
|
||||
size = 0
|
||||
|
||||
entries.append({
|
||||
"name": f"{message.id}_{filename}",
|
||||
@@ -166,7 +287,48 @@ class TelegramAdapter:
|
||||
|
||||
try:
|
||||
await client.connect()
|
||||
await client.send_file(self.chat_id, file_like, caption=file_like.name)
|
||||
sent = await client.send_file(self.chat_id, file_like, caption=file_like.name)
|
||||
message = sent[0] if isinstance(sent, list) and sent else sent
|
||||
actual_rel = rel
|
||||
if message:
|
||||
stored_name = file_like.name
|
||||
file_meta = getattr(message, "file", None)
|
||||
if file_meta and getattr(file_meta, "name", None):
|
||||
stored_name = file_meta.name
|
||||
if getattr(message, "id", None) is not None:
|
||||
actual_rel = f"{message.id}_{stored_name}"
|
||||
return {"rel": actual_rel, "size": len(data)}
|
||||
finally:
|
||||
if client.is_connected():
|
||||
await client.disconnect()
|
||||
|
||||
async def write_upload_file(self, root: str, rel: str, file_obj, filename: str | None, file_size: int | None = None, content_type: str | None = None):
|
||||
client = self._get_client()
|
||||
name = filename or os.path.basename(rel) or "file"
|
||||
file_like = _NamedFile(file_obj, name)
|
||||
|
||||
try:
|
||||
await client.connect()
|
||||
sent = await client.send_file(
|
||||
self.chat_id,
|
||||
file_like,
|
||||
caption=file_like.name,
|
||||
file_size=file_size,
|
||||
mime_type=content_type,
|
||||
)
|
||||
message = sent[0] if isinstance(sent, list) and sent else sent
|
||||
actual_rel = rel
|
||||
size = file_size or 0
|
||||
if message:
|
||||
stored_name = file_like.name
|
||||
file_meta = getattr(message, "file", None)
|
||||
if file_meta and getattr(file_meta, "name", None):
|
||||
stored_name = file_meta.name
|
||||
if getattr(message, "id", None) is not None:
|
||||
actual_rel = f"{message.id}_{stored_name}"
|
||||
if file_meta and getattr(file_meta, "size", None):
|
||||
size = int(file_meta.size)
|
||||
return {"rel": actual_rel, "size": size}
|
||||
finally:
|
||||
if client.is_connected():
|
||||
await client.disconnect()
|
||||
@@ -176,8 +338,9 @@ class TelegramAdapter:
|
||||
client = self._get_client()
|
||||
filename = os.path.basename(rel) or "file"
|
||||
import tempfile
|
||||
temp_dir = tempfile.gettempdir()
|
||||
temp_path = os.path.join(temp_dir, filename)
|
||||
suffix = os.path.splitext(filename)[1]
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tf:
|
||||
temp_path = tf.name
|
||||
|
||||
total_size = 0
|
||||
try:
|
||||
@@ -188,18 +351,62 @@ class TelegramAdapter:
|
||||
total_size += len(chunk)
|
||||
|
||||
await client.connect()
|
||||
await client.send_file(self.chat_id, temp_path, caption=filename)
|
||||
sent = await client.send_file(self.chat_id, temp_path, caption=filename)
|
||||
message = sent[0] if isinstance(sent, list) and sent else sent
|
||||
actual_rel = rel
|
||||
if message:
|
||||
stored_name = filename
|
||||
file_meta = getattr(message, "file", None)
|
||||
if file_meta and getattr(file_meta, "name", None):
|
||||
stored_name = file_meta.name
|
||||
if getattr(message, "id", None) is not None:
|
||||
actual_rel = f"{message.id}_{stored_name}"
|
||||
|
||||
finally:
|
||||
if os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
if client.is_connected():
|
||||
await client.disconnect()
|
||||
return total_size
|
||||
return {"rel": actual_rel, "size": total_size}
|
||||
|
||||
async def mkdir(self, root: str, rel: str):
|
||||
raise NotImplementedError("Telegram 适配器不支持创建目录。")
|
||||
|
||||
async def get_thumbnail(self, root: str, rel: str, size: str = "medium"):
|
||||
try:
|
||||
message_id_str, _ = rel.split('_', 1)
|
||||
message_id = int(message_id_str)
|
||||
except (ValueError, IndexError):
|
||||
return None
|
||||
|
||||
client = self._get_client()
|
||||
try:
|
||||
await client.connect()
|
||||
message = await client.get_messages(self.chat_id, ids=message_id)
|
||||
if not message:
|
||||
return None
|
||||
|
||||
doc = message.document or message.video
|
||||
thumbs = None
|
||||
if doc and getattr(doc, "thumbs", None):
|
||||
thumbs = list(doc.thumbs or [])
|
||||
elif message.photo and getattr(message.photo, "sizes", None):
|
||||
thumbs = list(message.photo.sizes or [])
|
||||
|
||||
thumb = self._pick_photo_thumb(thumbs)
|
||||
if not thumb:
|
||||
return None
|
||||
|
||||
result = await client.download_media(message, bytes, thumb=thumb)
|
||||
if isinstance(result, (bytes, bytearray)):
|
||||
return bytes(result)
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
finally:
|
||||
if client.is_connected():
|
||||
await client.disconnect()
|
||||
|
||||
async def delete(self, root: str, rel: str):
|
||||
"""删除一个文件 (即一条消息)"""
|
||||
try:
|
||||
@@ -238,6 +445,8 @@ class TelegramAdapter:
|
||||
raise HTTPException(status_code=400, detail=f"无效的文件路径格式: {rel}")
|
||||
|
||||
client = self._get_client()
|
||||
lock = _get_session_lock(self.session_string)
|
||||
await lock.acquire()
|
||||
|
||||
try:
|
||||
await client.connect()
|
||||
@@ -246,13 +455,27 @@ class TelegramAdapter:
|
||||
if not message or not media:
|
||||
raise FileNotFoundError(f"在频道 {self.chat_id} 中未找到消息ID为 {message_id} 的文件")
|
||||
|
||||
if message.photo:
|
||||
photo_size = media.sizes[-1]
|
||||
file_size = photo_size.size if hasattr(photo_size, 'size') else 0
|
||||
mime_type = "image/jpeg"
|
||||
else:
|
||||
file_size = media.size
|
||||
mime_type = media.mime_type or "application/octet-stream"
|
||||
file_meta = message.file
|
||||
file_size = file_meta.size if file_meta and file_meta.size is not None else None
|
||||
if file_size is None:
|
||||
if hasattr(media, "size") and media.size is not None:
|
||||
file_size = media.size
|
||||
elif message.photo and getattr(message.photo, "sizes", None):
|
||||
photo_size = message.photo.sizes[-1]
|
||||
file_size = getattr(photo_size, "size", 0) or 0
|
||||
else:
|
||||
file_size = 0
|
||||
|
||||
mime_type = None
|
||||
if file_meta and getattr(file_meta, "mime_type", None):
|
||||
mime_type = file_meta.mime_type
|
||||
if not mime_type:
|
||||
if hasattr(media, "mime_type") and media.mime_type:
|
||||
mime_type = media.mime_type
|
||||
elif message.photo:
|
||||
mime_type = "image/jpeg"
|
||||
else:
|
||||
mime_type = "application/octet-stream"
|
||||
|
||||
start = 0
|
||||
end = file_size - 1
|
||||
@@ -261,7 +484,6 @@ class TelegramAdapter:
|
||||
headers = {
|
||||
"Accept-Ranges": "bytes",
|
||||
"Content-Type": mime_type,
|
||||
"Content-Length": str(file_size),
|
||||
}
|
||||
|
||||
if range_header:
|
||||
@@ -273,7 +495,6 @@ class TelegramAdapter:
|
||||
if start >= file_size or end >= file_size or start > end:
|
||||
raise HTTPException(status_code=416, detail="Requested Range Not Satisfiable")
|
||||
status = 206
|
||||
headers["Content-Length"] = str(end - start + 1)
|
||||
headers["Content-Range"] = f"bytes {start}-{end}/{file_size}"
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail="Invalid Range header")
|
||||
@@ -292,18 +513,28 @@ class TelegramAdapter:
|
||||
if downloaded >= limit:
|
||||
break
|
||||
finally:
|
||||
if client.is_connected():
|
||||
await client.disconnect()
|
||||
try:
|
||||
if client.is_connected():
|
||||
await client.disconnect()
|
||||
finally:
|
||||
lock.release()
|
||||
|
||||
return StreamingResponse(iterator(), status_code=status, headers=headers)
|
||||
|
||||
except HTTPException:
|
||||
if client.is_connected():
|
||||
await client.disconnect()
|
||||
lock.release()
|
||||
raise
|
||||
except FileNotFoundError as e:
|
||||
if client.is_connected():
|
||||
await client.disconnect()
|
||||
lock.release()
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
if client.is_connected():
|
||||
await client.disconnect()
|
||||
lock.release()
|
||||
raise HTTPException(status_code=500, detail=f"Streaming failed: {str(e)}")
|
||||
|
||||
async def stat_file(self, root: str, rel: str):
|
||||
@@ -321,11 +552,16 @@ class TelegramAdapter:
|
||||
if not message or not media:
|
||||
raise FileNotFoundError(f"在频道 {self.chat_id} 中未找到消息ID为 {message_id} 的文件")
|
||||
|
||||
if message.photo:
|
||||
photo_size = media.sizes[-1]
|
||||
size = photo_size.size if hasattr(photo_size, 'size') else 0
|
||||
else:
|
||||
size = media.size
|
||||
file_meta = message.file
|
||||
size = file_meta.size if file_meta and file_meta.size is not None else None
|
||||
if size is None:
|
||||
if hasattr(media, "size") and media.size is not None:
|
||||
size = media.size
|
||||
elif message.photo and getattr(message.photo, "sizes", None):
|
||||
photo_size = message.photo.sizes[-1]
|
||||
size = getattr(photo_size, "size", 0) or 0
|
||||
else:
|
||||
size = 0
|
||||
|
||||
return {
|
||||
"name": rel,
|
||||
@@ -339,4 +575,4 @@ class TelegramAdapter:
|
||||
await client.disconnect()
|
||||
|
||||
def ADAPTER_FACTORY(rec: StorageAdapter) -> TelegramAdapter:
|
||||
return TelegramAdapter(rec)
|
||||
return TelegramAdapter(rec)
|
||||
@@ -1,4 +1,3 @@
|
||||
from __future__ import annotations
|
||||
from typing import List, Dict, Optional, Tuple, AsyncIterator
|
||||
import httpx
|
||||
from urllib.parse import urljoin, quote
|
||||
@@ -9,7 +8,6 @@ import mimetypes
|
||||
import logging
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import StreamingResponse, Response
|
||||
from services.logging import LogService
|
||||
|
||||
NS = {"d": "DAV:"}
|
||||
|
||||
@@ -148,15 +146,6 @@ class WebDAVAdapter:
|
||||
async with self._client() as client:
|
||||
resp = await client.put(url, content=data)
|
||||
resp.raise_for_status()
|
||||
await LogService.info(
|
||||
"adapter:webdav",
|
||||
f"Wrote file to {rel}",
|
||||
details={
|
||||
"adapter_id": self.record.id,
|
||||
"url": url,
|
||||
"size": len(data),
|
||||
},
|
||||
)
|
||||
|
||||
async def mkdir(self, root: str, rel: str):
|
||||
url = self._build_url(rel.rstrip('/') + '/')
|
||||
@@ -164,11 +153,6 @@ class WebDAVAdapter:
|
||||
resp = await client.request("MKCOL", url)
|
||||
if resp.status_code not in (201, 405):
|
||||
resp.raise_for_status()
|
||||
await LogService.info(
|
||||
"adapter:webdav",
|
||||
f"Created directory {rel}",
|
||||
details={"adapter_id": self.record.id, "url": url},
|
||||
)
|
||||
|
||||
async def delete(self, root: str, rel: str):
|
||||
url = self._build_url(rel)
|
||||
@@ -176,11 +160,6 @@ class WebDAVAdapter:
|
||||
resp = await client.delete(url)
|
||||
if resp.status_code not in (204, 200, 404):
|
||||
resp.raise_for_status()
|
||||
await LogService.info(
|
||||
"adapter:webdav",
|
||||
f"Deleted {rel}",
|
||||
details={"adapter_id": self.record.id, "url": url},
|
||||
)
|
||||
|
||||
async def move(self, root: str, src_rel: str, dst_rel: str):
|
||||
src_url = self._build_url(src_rel)
|
||||
@@ -188,15 +167,6 @@ class WebDAVAdapter:
|
||||
async with self._client() as client:
|
||||
resp = await client.request("MOVE", src_url, headers={"Destination": dst_url})
|
||||
resp.raise_for_status()
|
||||
await LogService.info(
|
||||
"adapter:webdav",
|
||||
f"Moved {src_rel} to {dst_rel}",
|
||||
details={
|
||||
"adapter_id": self.record.id,
|
||||
"src_url": src_url,
|
||||
"dst_url": dst_url,
|
||||
},
|
||||
)
|
||||
|
||||
async def rename(self, root: str, src_rel: str, dst_rel: str):
|
||||
src_url = self._build_url(src_rel)
|
||||
@@ -204,15 +174,6 @@ class WebDAVAdapter:
|
||||
async with self._client() as client:
|
||||
resp = await client.request("MOVE", src_url, headers={"Destination": dst_url})
|
||||
resp.raise_for_status()
|
||||
await LogService.info(
|
||||
"adapter:webdav",
|
||||
f"Renamed {src_rel} to {dst_rel}",
|
||||
details={
|
||||
"adapter_id": self.record.id,
|
||||
"src_url": src_url,
|
||||
"dst_url": dst_url,
|
||||
},
|
||||
)
|
||||
|
||||
async def get_file_size(self, root: str, rel: str) -> int:
|
||||
"""获取文件大小"""
|
||||
@@ -455,8 +416,16 @@ class WebDAVAdapter:
|
||||
info["type"] = "dir" if is_dir else "file"
|
||||
if size_el is not None and size_el.text and size_el.text.isdigit():
|
||||
info["size"] = int(size_el.text)
|
||||
elif info["size"] is None:
|
||||
info["size"] = 0
|
||||
if lm_el is not None and lm_el.text:
|
||||
info["mtime"] = lm_el.text
|
||||
from email.utils import parsedate_to_datetime
|
||||
try:
|
||||
info["mtime"] = int(parsedate_to_datetime(lm_el.text).timestamp())
|
||||
except Exception:
|
||||
info["mtime"] = 0
|
||||
elif info["mtime"] is None:
|
||||
info["mtime"] = 0
|
||||
# exif信息
|
||||
exif = None
|
||||
if not info["is_dir"]:
|
||||
@@ -510,15 +479,6 @@ class WebDAVAdapter:
|
||||
if resp.status_code == 404:
|
||||
raise FileNotFoundError(src_rel)
|
||||
resp.raise_for_status()
|
||||
await LogService.info(
|
||||
"adapter:webdav",
|
||||
f"Copied {src_rel} to {dst_rel}",
|
||||
details={
|
||||
"adapter_id": self.record.id,
|
||||
"src_url": src_url,
|
||||
"dst_url": dst_url,
|
||||
},
|
||||
)
|
||||
|
||||
ADAPTER_TYPE = "webdav"
|
||||
CONFIG_SCHEMA = [
|
||||
@@ -1,20 +1,28 @@
|
||||
from typing import Dict, Callable
|
||||
import pkgutil
|
||||
import inspect
|
||||
import pkgutil
|
||||
from importlib import import_module
|
||||
from typing import Callable, Dict
|
||||
|
||||
from .base import BaseAdapter
|
||||
from models import StorageAdapter
|
||||
from .providers.base import BaseAdapter
|
||||
|
||||
AdapterFactory = Callable[[StorageAdapter], object]
|
||||
AdapterFactory = Callable[[StorageAdapter], BaseAdapter]
|
||||
|
||||
TYPE_MAP: Dict[str, AdapterFactory] = {}
|
||||
CONFIG_SCHEMAS: Dict[str, list] = {}
|
||||
|
||||
|
||||
def normalize_adapter_type(value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
normalized = str(value).strip().lower()
|
||||
return normalized or None
|
||||
|
||||
|
||||
def discover_adapters():
|
||||
"""扫描 services.adapters 包, 自动注册适配器类型、工厂与配置 schema。"""
|
||||
from .. import adapters as adapters_pkg
|
||||
"""扫描 domain.adapters.providers 包, 自动注册适配器类型、工厂与配置 schema。"""
|
||||
from . import providers as adapters_pkg
|
||||
|
||||
TYPE_MAP.clear()
|
||||
CONFIG_SCHEMAS.clear()
|
||||
for modinfo in pkgutil.iter_modules(adapters_pkg.__path__):
|
||||
@@ -25,7 +33,28 @@ def discover_adapters():
|
||||
module = import_module(full_name)
|
||||
except Exception:
|
||||
continue
|
||||
adapter_type = getattr(module, "ADAPTER_TYPE", None)
|
||||
|
||||
adapter_types = getattr(module, "ADAPTER_TYPES", None)
|
||||
if isinstance(adapter_types, dict):
|
||||
default_schema = getattr(module, "CONFIG_SCHEMA", None)
|
||||
schema_map = getattr(module, "CONFIG_SCHEMA_MAP", None)
|
||||
if not isinstance(schema_map, dict):
|
||||
schema_map = None
|
||||
|
||||
for adapter_type, factory in adapter_types.items():
|
||||
normalized_type = normalize_adapter_type(adapter_type)
|
||||
if not normalized_type:
|
||||
continue
|
||||
if not callable(factory):
|
||||
continue
|
||||
TYPE_MAP[normalized_type] = factory
|
||||
|
||||
schema = schema_map.get(normalized_type) if schema_map else default_schema
|
||||
if isinstance(schema, list):
|
||||
CONFIG_SCHEMAS[normalized_type] = schema
|
||||
continue
|
||||
|
||||
adapter_type = normalize_adapter_type(getattr(module, "ADAPTER_TYPE", None))
|
||||
schema = getattr(module, "CONFIG_SCHEMA", None)
|
||||
factory = getattr(module, "ADAPTER_FACTORY", None)
|
||||
|
||||
@@ -57,22 +86,31 @@ def get_config_schema(adapter_type: str):
|
||||
|
||||
class RuntimeRegistry:
|
||||
def __init__(self):
|
||||
self._instances: Dict[int, object] = {}
|
||||
self._instances: Dict[int, BaseAdapter] = {}
|
||||
|
||||
async def refresh(self):
|
||||
discover_adapters()
|
||||
self._instances.clear()
|
||||
adapters = await StorageAdapter.filter(enabled=True)
|
||||
for rec in adapters:
|
||||
factory = TYPE_MAP.get(rec.type)
|
||||
normalized_type = normalize_adapter_type(rec.type)
|
||||
if not normalized_type:
|
||||
continue
|
||||
if normalized_type != rec.type:
|
||||
rec.type = normalized_type
|
||||
try:
|
||||
await rec.save(update_fields=["type"])
|
||||
except Exception:
|
||||
continue
|
||||
factory = TYPE_MAP.get(normalized_type)
|
||||
if not factory:
|
||||
continue
|
||||
try:
|
||||
self._instances[rec.id] = factory(rec)
|
||||
except Exception:
|
||||
continue
|
||||
continue
|
||||
|
||||
def get(self, adapter_id: int):
|
||||
def get(self, adapter_id: int) -> BaseAdapter | None:
|
||||
return self._instances.get(adapter_id)
|
||||
|
||||
def snapshot(self) -> Dict[int, BaseAdapter]:
|
||||
@@ -88,11 +126,22 @@ class RuntimeRegistry:
|
||||
if not rec.enabled:
|
||||
self.remove(rec.id)
|
||||
return
|
||||
|
||||
factory = TYPE_MAP.get(rec.type)
|
||||
|
||||
normalized_type = normalize_adapter_type(rec.type)
|
||||
if not normalized_type:
|
||||
self.remove(rec.id)
|
||||
return
|
||||
if normalized_type != rec.type:
|
||||
rec.type = normalized_type
|
||||
try:
|
||||
await rec.save(update_fields=["type"])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
factory = TYPE_MAP.get(normalized_type)
|
||||
if not factory:
|
||||
discover_adapters()
|
||||
factory = TYPE_MAP.get(rec.type)
|
||||
factory = TYPE_MAP.get(normalized_type)
|
||||
if not factory:
|
||||
return
|
||||
|
||||
116
domain/adapters/service.py
Normal file
116
domain/adapters/service.py
Normal file
@@ -0,0 +1,116 @@
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from domain.auth import User
|
||||
from .registry import (
|
||||
get_config_schemas,
|
||||
normalize_adapter_type,
|
||||
runtime_registry,
|
||||
)
|
||||
from .types import AdapterCreate, AdapterOut
|
||||
from models import StorageAdapter
|
||||
|
||||
|
||||
class AdapterService:
|
||||
@classmethod
|
||||
def _validate_and_normalize_config(cls, adapter_type: str, cfg):
|
||||
schemas = get_config_schemas()
|
||||
adapter_type = normalize_adapter_type(adapter_type)
|
||||
if not adapter_type:
|
||||
raise HTTPException(400, detail="不支持的适配器类型")
|
||||
if not isinstance(cfg, dict):
|
||||
raise HTTPException(400, detail="config 必须是对象")
|
||||
schema = schemas.get(adapter_type)
|
||||
if not schema:
|
||||
raise HTTPException(400, detail=f"不支持的适配器类型: {adapter_type}")
|
||||
out = {}
|
||||
missing = []
|
||||
for f in schema:
|
||||
k = f["key"]
|
||||
if k in cfg and cfg[k] not in (None, ""):
|
||||
out[k] = cfg[k]
|
||||
elif "default" in f:
|
||||
out[k] = f["default"]
|
||||
elif f.get("required"):
|
||||
missing.append(k)
|
||||
if missing:
|
||||
raise HTTPException(400, detail="缺少必填配置字段: " + ", ".join(missing))
|
||||
if adapter_type in ("alist", "openlist"):
|
||||
username = out.get("username")
|
||||
password = out.get("password")
|
||||
if (username and not password) or (password and not username):
|
||||
raise HTTPException(400, detail="用户名和密码必须同时填写或同时留空")
|
||||
return out
|
||||
|
||||
@classmethod
|
||||
async def create_adapter(cls, data: AdapterCreate, current_user: Optional[User]):
|
||||
norm_path = AdapterCreate.normalize_mount_path(data.path)
|
||||
exists = await StorageAdapter.get_or_none(path=norm_path)
|
||||
if exists:
|
||||
raise HTTPException(400, detail="Mount path already exists")
|
||||
|
||||
adapter_fields = {
|
||||
"name": data.name,
|
||||
"type": data.type,
|
||||
"config": cls._validate_and_normalize_config(data.type, data.config or {}),
|
||||
"enabled": data.enabled,
|
||||
"path": norm_path,
|
||||
"sub_path": data.sub_path,
|
||||
}
|
||||
|
||||
rec = await StorageAdapter.create(**adapter_fields)
|
||||
await runtime_registry.upsert(rec)
|
||||
return AdapterOut.model_validate(rec)
|
||||
|
||||
@classmethod
|
||||
async def list_adapters(cls):
|
||||
adapters = await StorageAdapter.all()
|
||||
return [AdapterOut.model_validate(a) for a in adapters]
|
||||
|
||||
@classmethod
|
||||
async def available_adapter_types(cls):
|
||||
data = []
|
||||
for adapter_type, fields in get_config_schemas().items():
|
||||
data.append({
|
||||
"type": adapter_type,
|
||||
"config_schema": fields,
|
||||
})
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
async def get_adapter(cls, adapter_id: int):
|
||||
rec = await StorageAdapter.get_or_none(id=adapter_id)
|
||||
if not rec:
|
||||
raise HTTPException(404, detail="Not found")
|
||||
return AdapterOut.model_validate(rec)
|
||||
|
||||
@classmethod
|
||||
async def update_adapter(cls, adapter_id: int, data: AdapterCreate, current_user: Optional[User]):
|
||||
rec = await StorageAdapter.get_or_none(id=adapter_id)
|
||||
if not rec:
|
||||
raise HTTPException(404, detail="Not found")
|
||||
|
||||
norm_path = AdapterCreate.normalize_mount_path(data.path)
|
||||
existing = await StorageAdapter.get_or_none(path=norm_path)
|
||||
if existing and existing.id != adapter_id:
|
||||
raise HTTPException(400, detail="Mount path already exists")
|
||||
|
||||
rec.name = data.name
|
||||
rec.type = data.type
|
||||
rec.config = cls._validate_and_normalize_config(data.type, data.config or {})
|
||||
rec.enabled = data.enabled
|
||||
rec.path = norm_path
|
||||
rec.sub_path = data.sub_path
|
||||
await rec.save()
|
||||
|
||||
await runtime_registry.upsert(rec)
|
||||
return AdapterOut.model_validate(rec)
|
||||
|
||||
@classmethod
|
||||
async def delete_adapter(cls, adapter_id: int, current_user: Optional[User]):
|
||||
deleted = await StorageAdapter.filter(id=adapter_id).delete()
|
||||
if not deleted:
|
||||
raise HTTPException(404, detail="Not found")
|
||||
runtime_registry.remove(adapter_id)
|
||||
return {"deleted": True}
|
||||
@@ -1,15 +1,29 @@
|
||||
import re
|
||||
from typing import Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
class AdapterBase(BaseModel):
|
||||
name: str
|
||||
type: str = Field(pattern=r"^[a-zA-Z0-9_]+$")
|
||||
type: str = Field(pattern=r"^[a-z0-9_]+$")
|
||||
config: Dict = Field(default_factory=dict)
|
||||
enabled: bool = True
|
||||
path: str = None
|
||||
sub_path: Optional[str] = None
|
||||
|
||||
@field_validator("type", mode="before")
|
||||
@classmethod
|
||||
def _normalize_type(cls, v: str):
|
||||
if not isinstance(v, str):
|
||||
raise ValueError("type required")
|
||||
normalized = v.strip().lower()
|
||||
if not normalized:
|
||||
raise ValueError("type required")
|
||||
if not re.fullmatch(r"[a-z0-9_]+", normalized):
|
||||
raise ValueError("type must be lowercase alphanumeric or underscore")
|
||||
return normalized
|
||||
|
||||
|
||||
class AdapterCreate(AdapterBase):
|
||||
@staticmethod
|
||||
9
domain/agent/__init__.py
Normal file
9
domain/agent/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from .service import AgentService
|
||||
from .types import AgentChatContext, AgentChatRequest, PendingToolCall
|
||||
|
||||
__all__ = [
|
||||
"AgentService",
|
||||
"AgentChatContext",
|
||||
"AgentChatRequest",
|
||||
"PendingToolCall",
|
||||
]
|
||||
38
domain/agent/api.py
Normal file
38
domain/agent/api.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from api.response import success
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.auth import User, get_current_active_user
|
||||
from .service import AgentService
|
||||
from .types import AgentChatRequest
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/agent", tags=["agent"])
|
||||
|
||||
|
||||
@router.post("/chat")
|
||||
@audit(action=AuditAction.CREATE, description="Agent 对话", body_fields=["auto_execute"])
|
||||
async def chat(
|
||||
request: Request,
|
||||
payload: AgentChatRequest,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
data = await AgentService.chat(payload, current_user)
|
||||
return success(data)
|
||||
|
||||
|
||||
@router.post("/chat/stream")
|
||||
@audit(action=AuditAction.CREATE, description="Agent 对话(SSE)", body_fields=["auto_execute"])
|
||||
async def chat_stream(
|
||||
request: Request,
|
||||
payload: AgentChatRequest,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
return StreamingResponse(
|
||||
AgentService.chat_stream(payload, current_user),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache"},
|
||||
)
|
||||
472
domain/agent/service.py
Normal file
472
domain/agent/service.py
Normal file
@@ -0,0 +1,472 @@
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import httpx
|
||||
from fastapi import HTTPException
|
||||
|
||||
from domain.ai import AIProviderService, MissingModelError, chat_completion, chat_completion_stream
|
||||
from domain.auth import User
|
||||
from .tools import get_tool, openai_tools, tool_result_to_content
|
||||
from .types import AgentChatRequest, PendingToolCall
|
||||
|
||||
|
||||
def _normalize_path(p: Optional[str]) -> Optional[str]:
|
||||
if not p:
|
||||
return None
|
||||
s = str(p).strip()
|
||||
if not s:
|
||||
return None
|
||||
s = s.replace("\\", "/")
|
||||
if not s.startswith("/"):
|
||||
s = "/" + s
|
||||
s = s.rstrip("/") or "/"
|
||||
return s
|
||||
|
||||
|
||||
def _build_system_prompt(current_path: Optional[str]) -> str:
|
||||
lines = [
|
||||
"你是 Foxel 的 AI 助手。",
|
||||
"你可以通过工具对文件/目录进行查询、读写、移动、复制、删除,以及运行处理器(processor)。",
|
||||
"",
|
||||
"可用工具:",
|
||||
"- time:获取服务器当前时间(精确到秒,英文星期),支持 year/month/day/hour/minute/second 偏移。",
|
||||
"- web_fetch:抓取网页(HTTP 请求),支持 GET/POST/PUT/PATCH/DELETE/HEAD/OPTIONS,返回状态/标题/正文/链接等。",
|
||||
"- vfs_list_dir:浏览目录(列出 entries + pagination)。",
|
||||
"- vfs_stat:查看文件/目录信息。",
|
||||
"- vfs_read_text:读取文本文件内容(不支持二进制)。",
|
||||
"- vfs_search:搜索文件(vector/filename)。",
|
||||
"- vfs_write_text:写入文本文件内容(覆盖)。",
|
||||
"- vfs_mkdir:创建目录。",
|
||||
"- vfs_delete:删除文件或目录。",
|
||||
"- vfs_move:移动路径。",
|
||||
"- vfs_copy:复制路径。",
|
||||
"- vfs_rename:重命名路径。",
|
||||
"- processors_list:获取可用处理器列表(含 type/name/config_schema/produces_file/supports_directory)。",
|
||||
"- processors_run:运行处理器处理文件或目录(会返回 task_id 或 task_ids)。",
|
||||
"",
|
||||
"规则:",
|
||||
"1) 读操作(web_fetch/vfs_list_dir/vfs_stat/vfs_read_text/vfs_search)可直接调用工具。",
|
||||
"2) 写/改/删操作(vfs_write_text/vfs_mkdir/vfs_delete/vfs_move/vfs_copy/vfs_rename/processors_run)默认需要用户确认;只有在开启自动执行时才应直接执行。",
|
||||
"3) 用户未给出明确路径时先追问;若提供了“当前文件管理目录”,可以基于它把相对描述补全为绝对路径(以 / 开头)。",
|
||||
"4) 修改文件内容:先读取(vfs_read_text)→给出改动点→确认后再写入(vfs_write_text)。",
|
||||
"5) processors_run 返回任务 id 后,说明任务已提交,可在任务队列查看进度。",
|
||||
"6) 回答语言跟随用户;用户用英文则用英文,用户用中文则用中文。回答尽量简洁。",
|
||||
]
|
||||
if current_path:
|
||||
lines.append("")
|
||||
lines.append(f"当前文件管理目录:{current_path}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _ensure_tool_call_ids(message: Dict[str, Any]) -> Dict[str, Any]:
|
||||
tool_calls = message.get("tool_calls")
|
||||
if not isinstance(tool_calls, list):
|
||||
return message
|
||||
|
||||
changed = False
|
||||
for idx, call in enumerate(tool_calls):
|
||||
if not isinstance(call, dict):
|
||||
continue
|
||||
call_id = call.get("id")
|
||||
if isinstance(call_id, str) and call_id.strip():
|
||||
continue
|
||||
call["id"] = f"call_{idx}"
|
||||
changed = True
|
||||
|
||||
if changed:
|
||||
message["tool_calls"] = tool_calls
|
||||
return message
|
||||
|
||||
|
||||
def _extract_pending(tool_call: Dict[str, Any], requires_confirmation: bool) -> PendingToolCall:
|
||||
call_id = str(tool_call.get("id") or "")
|
||||
fn = tool_call.get("function") or {}
|
||||
name = str((fn.get("name") if isinstance(fn, dict) else None) or "")
|
||||
raw_args = fn.get("arguments") if isinstance(fn, dict) else None
|
||||
arguments: Dict[str, Any] = {}
|
||||
if isinstance(raw_args, str) and raw_args.strip():
|
||||
try:
|
||||
parsed = json.loads(raw_args)
|
||||
if isinstance(parsed, dict):
|
||||
arguments = parsed
|
||||
except json.JSONDecodeError:
|
||||
arguments = {}
|
||||
return PendingToolCall(
|
||||
id=call_id,
|
||||
name=name,
|
||||
arguments=arguments,
|
||||
requires_confirmation=requires_confirmation,
|
||||
)
|
||||
|
||||
|
||||
def _find_last_assistant_tool_calls(messages: List[Dict[str, Any]]) -> Tuple[int, Dict[str, Any]]:
|
||||
for idx in range(len(messages) - 1, -1, -1):
|
||||
msg = messages[idx]
|
||||
if not isinstance(msg, dict):
|
||||
continue
|
||||
if msg.get("role") != "assistant":
|
||||
continue
|
||||
tool_calls = msg.get("tool_calls")
|
||||
if isinstance(tool_calls, list) and tool_calls:
|
||||
return idx, msg
|
||||
raise HTTPException(status_code=400, detail="没有可确认的待执行操作")
|
||||
|
||||
|
||||
def _existing_tool_result_ids(messages: List[Dict[str, Any]]) -> set[str]:
|
||||
ids: set[str] = set()
|
||||
for msg in messages:
|
||||
if not isinstance(msg, dict):
|
||||
continue
|
||||
if msg.get("role") != "tool":
|
||||
continue
|
||||
tool_call_id = msg.get("tool_call_id")
|
||||
if isinstance(tool_call_id, str) and tool_call_id.strip():
|
||||
ids.add(tool_call_id)
|
||||
return ids
|
||||
|
||||
|
||||
async def _choose_chat_ability() -> str:
|
||||
tools_model = await AIProviderService.get_default_model("tools")
|
||||
return "tools" if tools_model else "chat"
|
||||
|
||||
|
||||
def _sse(event: str, data: Any) -> bytes:
|
||||
payload = json.dumps(data, ensure_ascii=False, separators=(",", ":"))
|
||||
return f"event: {event}\ndata: {payload}\n\n".encode("utf-8")
|
||||
|
||||
|
||||
def _format_exc(exc: BaseException) -> str:
|
||||
text = str(exc)
|
||||
return text if text else exc.__class__.__name__
|
||||
|
||||
|
||||
class AgentService:
|
||||
@classmethod
|
||||
async def chat(cls, req: AgentChatRequest, user: Optional[User]) -> Dict[str, Any]:
|
||||
history: List[Dict[str, Any]] = list(req.messages or [])
|
||||
current_path = _normalize_path(req.context.current_path if req.context else None)
|
||||
|
||||
system_prompt = _build_system_prompt(current_path)
|
||||
internal_messages: List[Dict[str, Any]] = [{"role": "system", "content": system_prompt}] + history
|
||||
|
||||
new_messages: List[Dict[str, Any]] = []
|
||||
pending: List[PendingToolCall] = []
|
||||
|
||||
approved_ids = {i for i in (req.approved_tool_call_ids or []) if isinstance(i, str) and i.strip()}
|
||||
rejected_ids = {i for i in (req.rejected_tool_call_ids or []) if isinstance(i, str) and i.strip()}
|
||||
|
||||
if approved_ids or rejected_ids:
|
||||
_, last_call_msg = _find_last_assistant_tool_calls(internal_messages)
|
||||
last_call_msg = _ensure_tool_call_ids(last_call_msg)
|
||||
tool_calls = last_call_msg.get("tool_calls") or []
|
||||
call_map: Dict[str, Dict[str, Any]] = {
|
||||
str(c.get("id")): c
|
||||
for c in tool_calls
|
||||
if isinstance(c, dict) and isinstance(c.get("id"), str)
|
||||
}
|
||||
|
||||
existing_ids = _existing_tool_result_ids(internal_messages)
|
||||
for call_id in approved_ids | rejected_ids:
|
||||
if call_id in existing_ids:
|
||||
continue
|
||||
tool_call = call_map.get(call_id)
|
||||
if not tool_call:
|
||||
continue
|
||||
fn = tool_call.get("function") or {}
|
||||
name = fn.get("name") if isinstance(fn, dict) else None
|
||||
args_raw = fn.get("arguments") if isinstance(fn, dict) else None
|
||||
args: Dict[str, Any] = {}
|
||||
if isinstance(args_raw, str) and args_raw.strip():
|
||||
try:
|
||||
parsed = json.loads(args_raw)
|
||||
if isinstance(parsed, dict):
|
||||
args = parsed
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
|
||||
spec = get_tool(str(name or ""))
|
||||
if call_id in rejected_ids:
|
||||
content = tool_result_to_content({"canceled": True, "reason": "user_rejected"})
|
||||
tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content}
|
||||
internal_messages.append(tool_msg)
|
||||
new_messages.append(tool_msg)
|
||||
continue
|
||||
|
||||
if not spec:
|
||||
content = tool_result_to_content({"error": f"unknown_tool: {name}"})
|
||||
tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content}
|
||||
internal_messages.append(tool_msg)
|
||||
new_messages.append(tool_msg)
|
||||
continue
|
||||
|
||||
try:
|
||||
result = await spec.handler(args)
|
||||
content = tool_result_to_content(result)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
content = tool_result_to_content({"error": str(exc)})
|
||||
tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content}
|
||||
internal_messages.append(tool_msg)
|
||||
new_messages.append(tool_msg)
|
||||
|
||||
tools_schema = openai_tools()
|
||||
ability = await _choose_chat_ability()
|
||||
max_loops = 4
|
||||
|
||||
for _ in range(max_loops):
|
||||
try:
|
||||
assistant = await chat_completion(
|
||||
internal_messages,
|
||||
ability=ability,
|
||||
tools=tools_schema,
|
||||
tool_choice="auto",
|
||||
timeout=60.0,
|
||||
)
|
||||
except MissingModelError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
except httpx.HTTPStatusError as exc:
|
||||
raise HTTPException(status_code=502, detail=f"对话请求失败: {exc}") from exc
|
||||
except httpx.RequestError as exc:
|
||||
raise HTTPException(status_code=502, detail=f"对话请求异常: {exc}") from exc
|
||||
|
||||
assistant = _ensure_tool_call_ids(assistant)
|
||||
internal_messages.append(assistant)
|
||||
new_messages.append(assistant)
|
||||
|
||||
tool_calls = assistant.get("tool_calls")
|
||||
if not isinstance(tool_calls, list) or not tool_calls:
|
||||
break
|
||||
|
||||
pending = []
|
||||
for call in tool_calls:
|
||||
if not isinstance(call, dict):
|
||||
continue
|
||||
call_id = str(call.get("id") or "")
|
||||
fn = call.get("function") or {}
|
||||
name = fn.get("name") if isinstance(fn, dict) else None
|
||||
args_raw = fn.get("arguments") if isinstance(fn, dict) else None
|
||||
args: Dict[str, Any] = {}
|
||||
if isinstance(args_raw, str) and args_raw.strip():
|
||||
try:
|
||||
parsed = json.loads(args_raw)
|
||||
if isinstance(parsed, dict):
|
||||
args = parsed
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
|
||||
spec = get_tool(str(name or ""))
|
||||
if not spec:
|
||||
content = tool_result_to_content({"error": f"unknown_tool: {name}"})
|
||||
tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content}
|
||||
internal_messages.append(tool_msg)
|
||||
new_messages.append(tool_msg)
|
||||
continue
|
||||
|
||||
if spec.requires_confirmation and not req.auto_execute:
|
||||
pending.append(_extract_pending(call, True))
|
||||
continue
|
||||
|
||||
try:
|
||||
result = await spec.handler(args)
|
||||
content = tool_result_to_content(result)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
content = tool_result_to_content({"error": str(exc)})
|
||||
tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content}
|
||||
internal_messages.append(tool_msg)
|
||||
new_messages.append(tool_msg)
|
||||
|
||||
if pending:
|
||||
break
|
||||
|
||||
payload: Dict[str, Any] = {"messages": new_messages}
|
||||
if pending:
|
||||
payload["pending_tool_calls"] = [p.model_dump() for p in pending]
|
||||
return payload
|
||||
|
||||
@classmethod
|
||||
async def chat_stream(cls, req: AgentChatRequest, user: Optional[User]):
|
||||
history: List[Dict[str, Any]] = list(req.messages or [])
|
||||
current_path = _normalize_path(req.context.current_path if req.context else None)
|
||||
|
||||
system_prompt = _build_system_prompt(current_path)
|
||||
internal_messages: List[Dict[str, Any]] = [{"role": "system", "content": system_prompt}] + history
|
||||
|
||||
new_messages: List[Dict[str, Any]] = []
|
||||
pending: List[PendingToolCall] = []
|
||||
|
||||
approved_ids = {i for i in (req.approved_tool_call_ids or []) if isinstance(i, str) and i.strip()}
|
||||
rejected_ids = {i for i in (req.rejected_tool_call_ids or []) if isinstance(i, str) and i.strip()}
|
||||
|
||||
try:
|
||||
if approved_ids or rejected_ids:
|
||||
_, last_call_msg = _find_last_assistant_tool_calls(internal_messages)
|
||||
last_call_msg = _ensure_tool_call_ids(last_call_msg)
|
||||
tool_calls = last_call_msg.get("tool_calls") or []
|
||||
call_map: Dict[str, Dict[str, Any]] = {
|
||||
str(c.get("id")): c
|
||||
for c in tool_calls
|
||||
if isinstance(c, dict) and isinstance(c.get("id"), str)
|
||||
}
|
||||
|
||||
existing_ids = _existing_tool_result_ids(internal_messages)
|
||||
for call_id in approved_ids | rejected_ids:
|
||||
if call_id in existing_ids:
|
||||
continue
|
||||
tool_call = call_map.get(call_id)
|
||||
if not tool_call:
|
||||
continue
|
||||
fn = tool_call.get("function") or {}
|
||||
name = fn.get("name") if isinstance(fn, dict) else None
|
||||
args_raw = fn.get("arguments") if isinstance(fn, dict) else None
|
||||
args: Dict[str, Any] = {}
|
||||
if isinstance(args_raw, str) and args_raw.strip():
|
||||
try:
|
||||
parsed = json.loads(args_raw)
|
||||
if isinstance(parsed, dict):
|
||||
args = parsed
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
|
||||
spec = get_tool(str(name or ""))
|
||||
if call_id in rejected_ids:
|
||||
content = tool_result_to_content({"canceled": True, "reason": "user_rejected"})
|
||||
tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content}
|
||||
internal_messages.append(tool_msg)
|
||||
new_messages.append(tool_msg)
|
||||
yield _sse("tool_end", {"tool_call_id": call_id, "name": str(name or ""), "message": tool_msg})
|
||||
continue
|
||||
|
||||
if not spec:
|
||||
content = tool_result_to_content({"error": f"unknown_tool: {name}"})
|
||||
tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content}
|
||||
internal_messages.append(tool_msg)
|
||||
new_messages.append(tool_msg)
|
||||
yield _sse("tool_end", {"tool_call_id": call_id, "name": str(name or ""), "message": tool_msg})
|
||||
continue
|
||||
|
||||
yield _sse("tool_start", {"tool_call_id": call_id, "name": spec.name})
|
||||
try:
|
||||
result = await spec.handler(args)
|
||||
content = tool_result_to_content(result)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
content = tool_result_to_content({"error": str(exc)})
|
||||
tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content}
|
||||
internal_messages.append(tool_msg)
|
||||
new_messages.append(tool_msg)
|
||||
yield _sse("tool_end", {"tool_call_id": call_id, "name": spec.name, "message": tool_msg})
|
||||
|
||||
tools_schema = openai_tools()
|
||||
ability = await _choose_chat_ability()
|
||||
max_loops = 4
|
||||
|
||||
for _ in range(max_loops):
|
||||
assistant_event_id = uuid.uuid4().hex
|
||||
yield _sse("assistant_start", {"id": assistant_event_id})
|
||||
|
||||
assistant_message: Dict[str, Any] | None = None
|
||||
try:
|
||||
async for event in chat_completion_stream(
|
||||
internal_messages,
|
||||
ability=ability,
|
||||
tools=tools_schema,
|
||||
tool_choice="auto",
|
||||
timeout=60.0,
|
||||
):
|
||||
if event.get("type") == "delta":
|
||||
delta = event.get("delta")
|
||||
if isinstance(delta, str) and delta:
|
||||
yield _sse("assistant_delta", {"id": assistant_event_id, "delta": delta})
|
||||
elif event.get("type") == "message":
|
||||
msg = event.get("message")
|
||||
if isinstance(msg, dict):
|
||||
assistant_message = msg
|
||||
except MissingModelError as exc:
|
||||
raise HTTPException(status_code=400, detail=_format_exc(exc)) from exc
|
||||
except httpx.HTTPStatusError as exc:
|
||||
raise HTTPException(status_code=502, detail=f"对话请求失败: {_format_exc(exc)}") from exc
|
||||
except httpx.RequestError as exc:
|
||||
raise HTTPException(status_code=502, detail=f"对话请求异常: {_format_exc(exc)}") from exc
|
||||
|
||||
if not assistant_message:
|
||||
assistant_message = {"role": "assistant", "content": ""}
|
||||
|
||||
assistant_message = _ensure_tool_call_ids(assistant_message)
|
||||
internal_messages.append(assistant_message)
|
||||
new_messages.append(assistant_message)
|
||||
yield _sse("assistant_end", {"id": assistant_event_id, "message": assistant_message})
|
||||
|
||||
tool_calls = assistant_message.get("tool_calls")
|
||||
if not isinstance(tool_calls, list) or not tool_calls:
|
||||
break
|
||||
|
||||
pending = []
|
||||
for call in tool_calls:
|
||||
if not isinstance(call, dict):
|
||||
continue
|
||||
call_id = str(call.get("id") or "")
|
||||
fn = call.get("function") or {}
|
||||
name = fn.get("name") if isinstance(fn, dict) else None
|
||||
args_raw = fn.get("arguments") if isinstance(fn, dict) else None
|
||||
args: Dict[str, Any] = {}
|
||||
if isinstance(args_raw, str) and args_raw.strip():
|
||||
try:
|
||||
parsed = json.loads(args_raw)
|
||||
if isinstance(parsed, dict):
|
||||
args = parsed
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
|
||||
spec = get_tool(str(name or ""))
|
||||
if not spec:
|
||||
content = tool_result_to_content({"error": f"unknown_tool: {name}"})
|
||||
tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content}
|
||||
internal_messages.append(tool_msg)
|
||||
new_messages.append(tool_msg)
|
||||
yield _sse("tool_end", {"tool_call_id": call_id, "name": str(name or ""), "message": tool_msg})
|
||||
continue
|
||||
|
||||
if spec.requires_confirmation and not req.auto_execute:
|
||||
pending.append(_extract_pending(call, True))
|
||||
continue
|
||||
|
||||
yield _sse("tool_start", {"tool_call_id": call_id, "name": spec.name})
|
||||
try:
|
||||
result = await spec.handler(args)
|
||||
content = tool_result_to_content(result)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
content = tool_result_to_content({"error": str(exc)})
|
||||
tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content}
|
||||
internal_messages.append(tool_msg)
|
||||
new_messages.append(tool_msg)
|
||||
yield _sse("tool_end", {"tool_call_id": call_id, "name": spec.name, "message": tool_msg})
|
||||
|
||||
if pending:
|
||||
yield _sse("pending", {"pending_tool_calls": [p.model_dump() for p in pending]})
|
||||
break
|
||||
|
||||
payload: Dict[str, Any] = {"messages": new_messages}
|
||||
if pending:
|
||||
payload["pending_tool_calls"] = [p.model_dump() for p in pending]
|
||||
yield _sse("done", payload)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except HTTPException as exc:
|
||||
detail = exc.detail
|
||||
content = detail if isinstance(detail, str) else str(detail)
|
||||
if not content.strip():
|
||||
content = f"请求失败({exc.status_code})"
|
||||
new_messages.append({"role": "assistant", "content": content})
|
||||
payload: Dict[str, Any] = {"messages": new_messages}
|
||||
if pending:
|
||||
payload["pending_tool_calls"] = [p.model_dump() for p in pending]
|
||||
yield _sse("done", payload)
|
||||
return
|
||||
except Exception as exc: # noqa: BLE001
|
||||
new_messages.append({"role": "assistant", "content": f"服务端异常: {_format_exc(exc)}"})
|
||||
payload: Dict[str, Any] = {"messages": new_messages}
|
||||
if pending:
|
||||
payload["pending_tool_calls"] = [p.model_dump() for p in pending]
|
||||
yield _sse("done", payload)
|
||||
return
|
||||
37
domain/agent/tools/__init__.py
Normal file
37
domain/agent/tools/__init__.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from .base import ToolSpec, tool_result_to_content
|
||||
from .processors import TOOLS as PROCESSOR_TOOLS
|
||||
from .time import TOOLS as TIME_TOOLS
|
||||
from .vfs import TOOLS as VFS_TOOLS
|
||||
from .web_fetch import TOOLS as WEB_FETCH_TOOLS
|
||||
|
||||
TOOLS: Dict[str, ToolSpec] = {}
|
||||
for group in (TIME_TOOLS, WEB_FETCH_TOOLS, PROCESSOR_TOOLS, VFS_TOOLS):
|
||||
TOOLS.update(group)
|
||||
|
||||
|
||||
def get_tool(name: str) -> Optional[ToolSpec]:
|
||||
return TOOLS.get(name)
|
||||
|
||||
|
||||
def openai_tools() -> List[Dict[str, Any]]:
|
||||
out: List[Dict[str, Any]] = []
|
||||
for spec in TOOLS.values():
|
||||
out.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": spec.name,
|
||||
"description": spec.description,
|
||||
"parameters": spec.parameters,
|
||||
},
|
||||
})
|
||||
return out
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ToolSpec",
|
||||
"get_tool",
|
||||
"openai_tools",
|
||||
"tool_result_to_content",
|
||||
]
|
||||
149
domain/agent/tools/base.py
Normal file
149
domain/agent/tools/base.py
Normal file
@@ -0,0 +1,149 @@
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ToolSpec:
|
||||
name: str
|
||||
description: str
|
||||
parameters: Dict[str, Any]
|
||||
requires_confirmation: bool
|
||||
handler: Callable[[Dict[str, Any]], Awaitable[Any]]
|
||||
|
||||
|
||||
def _stringify_value(value: Any) -> str:
|
||||
if value is None:
|
||||
return ""
|
||||
if isinstance(value, bool):
|
||||
return "true" if value else "false"
|
||||
if isinstance(value, (int, float)):
|
||||
return str(value)
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
try:
|
||||
return json.dumps(value, ensure_ascii=False)
|
||||
except TypeError:
|
||||
return str(value)
|
||||
|
||||
|
||||
def _list_to_view_items(items: List[Any]) -> List[Any]:
|
||||
normalized: List[Any] = []
|
||||
for item in items:
|
||||
if isinstance(item, dict):
|
||||
normalized.append({str(k): _stringify_value(v) for k, v in item.items()})
|
||||
else:
|
||||
normalized.append(_stringify_value(item))
|
||||
return normalized
|
||||
|
||||
|
||||
def _dict_to_kv_items(data: Dict[str, Any]) -> List[Dict[str, str]]:
|
||||
return [{"key": str(k), "value": _stringify_value(v)} for k, v in data.items()]
|
||||
|
||||
|
||||
def _first_list_field(data: Dict[str, Any]) -> tuple[Optional[str], Optional[List[Any]]]:
|
||||
for key, value in data.items():
|
||||
if isinstance(value, list):
|
||||
return str(key), value
|
||||
return None, None
|
||||
|
||||
|
||||
def _build_view(data: Any) -> Dict[str, Any]:
|
||||
if data is None:
|
||||
return {"type": "kv", "items": []}
|
||||
if isinstance(data, str):
|
||||
return {"type": "text", "text": data}
|
||||
if isinstance(data, list):
|
||||
return {"type": "list", "items": _list_to_view_items(data)}
|
||||
if isinstance(data, dict):
|
||||
content = data.get("content")
|
||||
if isinstance(content, str):
|
||||
meta = {k: _stringify_value(v) for k, v in data.items() if k != "content"}
|
||||
view: Dict[str, Any] = {"type": "text", "text": content}
|
||||
if meta:
|
||||
view["meta"] = meta
|
||||
return view
|
||||
list_key, list_val = _first_list_field(data)
|
||||
if list_key and isinstance(list_val, list):
|
||||
meta = {k: _stringify_value(v) for k, v in data.items() if k != list_key}
|
||||
view = {"type": "list", "title": list_key, "items": _list_to_view_items(list_val)}
|
||||
if meta:
|
||||
view["meta"] = meta
|
||||
return view
|
||||
return {"type": "kv", "items": _dict_to_kv_items(data)}
|
||||
return {"type": "text", "text": _stringify_value(data)}
|
||||
|
||||
|
||||
def _build_summary(view: Dict[str, Any]) -> str:
|
||||
view_type = str(view.get("type") or "")
|
||||
if view_type == "text":
|
||||
text = view.get("text")
|
||||
size = len(text) if isinstance(text, str) else 0
|
||||
return f"chars: {size}" if size else "text"
|
||||
if view_type == "list":
|
||||
items = view.get("items")
|
||||
count = len(items) if isinstance(items, list) else 0
|
||||
title = str(view.get("title") or "items")
|
||||
return f"{title}: {count}"
|
||||
if view_type == "kv":
|
||||
items = view.get("items")
|
||||
count = len(items) if isinstance(items, list) else 0
|
||||
return f"fields: {count}"
|
||||
if view_type == "error":
|
||||
return str(view.get("message") or "error")
|
||||
return ""
|
||||
|
||||
|
||||
def _build_error_payload(code: str, message: str, detail: Any = None) -> Dict[str, Any]:
|
||||
summary = "Canceled" if code == "canceled" else message or "error"
|
||||
view = {"type": "error", "message": summary}
|
||||
payload: Dict[str, Any] = {
|
||||
"ok": False,
|
||||
"summary": summary,
|
||||
"view": view,
|
||||
"error": {
|
||||
"code": code,
|
||||
"message": message,
|
||||
},
|
||||
}
|
||||
if detail is not None:
|
||||
payload["error"]["detail"] = detail
|
||||
return payload
|
||||
|
||||
|
||||
def _normalize_tool_result(result: Any) -> Dict[str, Any]:
|
||||
if isinstance(result, dict) and "ok" in result:
|
||||
payload = dict(result)
|
||||
if payload.get("ok") is False:
|
||||
error = payload.get("error")
|
||||
message = _stringify_value(error.get("message") if isinstance(error, dict) else error)
|
||||
payload.setdefault("summary", message or "error")
|
||||
payload.setdefault("view", {"type": "error", "message": payload["summary"]})
|
||||
return payload
|
||||
data = payload.get("data")
|
||||
if payload.get("view") is None:
|
||||
payload["view"] = _build_view(data)
|
||||
if not payload.get("summary"):
|
||||
payload["summary"] = _build_summary(payload["view"])
|
||||
return payload
|
||||
|
||||
if isinstance(result, dict) and result.get("canceled"):
|
||||
reason = _stringify_value(result.get("reason") or "canceled")
|
||||
return _build_error_payload("canceled", reason, detail=result)
|
||||
|
||||
if isinstance(result, dict) and "error" in result:
|
||||
error = result.get("error")
|
||||
message = _stringify_value(error.get("message") if isinstance(error, dict) else error)
|
||||
return _build_error_payload("error", message, detail=error)
|
||||
|
||||
view = _build_view(result)
|
||||
summary = _build_summary(view)
|
||||
return {"ok": True, "summary": summary, "view": view, "data": result}
|
||||
|
||||
|
||||
def tool_result_to_content(result: Any) -> str:
|
||||
payload = _normalize_tool_result(result)
|
||||
try:
|
||||
return json.dumps(payload, ensure_ascii=False, default=str)
|
||||
except TypeError:
|
||||
return json.dumps({"ok": False, "summary": "error", "view": {"type": "error", "message": "error"}}, ensure_ascii=False)
|
||||
96
domain/agent/tools/processors.py
Normal file
96
domain/agent/tools/processors.py
Normal file
@@ -0,0 +1,96 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from domain.processors import ProcessDirectoryRequest, ProcessRequest, ProcessorService
|
||||
from domain.virtual_fs import VirtualFSService
|
||||
|
||||
from .base import ToolSpec
|
||||
|
||||
|
||||
async def _processors_list(_: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return {"processors": ProcessorService.list_processors()}
|
||||
|
||||
|
||||
async def _processors_run(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
path = str(args.get("path") or "")
|
||||
processor_type = str(args.get("processor_type") or "")
|
||||
config = args.get("config")
|
||||
if not isinstance(config, dict):
|
||||
config = {}
|
||||
|
||||
save_to = args.get("save_to")
|
||||
save_to = str(save_to) if isinstance(save_to, str) and save_to.strip() else None
|
||||
|
||||
max_depth = args.get("max_depth")
|
||||
max_depth_value: Optional[int] = None
|
||||
if max_depth is not None:
|
||||
try:
|
||||
max_depth_value = int(max_depth)
|
||||
except (TypeError, ValueError):
|
||||
max_depth_value = None
|
||||
|
||||
suffix = args.get("suffix")
|
||||
suffix_value = str(suffix) if isinstance(suffix, str) and suffix.strip() else None
|
||||
|
||||
overwrite_value = args.get("overwrite")
|
||||
overwrite = bool(overwrite_value) if overwrite_value is not None else None
|
||||
|
||||
is_dir = await VirtualFSService.path_is_directory(path)
|
||||
if is_dir and (max_depth_value is not None or suffix_value is not None):
|
||||
req = ProcessDirectoryRequest(
|
||||
path=path,
|
||||
processor_type=processor_type,
|
||||
config=config,
|
||||
overwrite=True if overwrite is None else overwrite,
|
||||
max_depth=max_depth_value,
|
||||
suffix=suffix_value,
|
||||
)
|
||||
result = await ProcessorService.process_directory(req)
|
||||
return {"mode": "directory", **result}
|
||||
|
||||
req = ProcessRequest(
|
||||
path=path,
|
||||
processor_type=processor_type,
|
||||
config=config,
|
||||
save_to=save_to,
|
||||
overwrite=False if overwrite is None else overwrite,
|
||||
)
|
||||
result = await ProcessorService.process_file(req)
|
||||
return {"mode": "file", **result}
|
||||
|
||||
|
||||
TOOLS: Dict[str, ToolSpec] = {
|
||||
"processors_list": ToolSpec(
|
||||
name="processors_list",
|
||||
description="获取可用处理器列表(type/name/config_schema 等)。",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"additionalProperties": False,
|
||||
},
|
||||
requires_confirmation=False,
|
||||
handler=_processors_list,
|
||||
),
|
||||
"processors_run": ToolSpec(
|
||||
name="processors_run",
|
||||
description=(
|
||||
"运行处理器处理文件或目录。"
|
||||
" 对目录可选 max_depth/suffix;对文件可选 overwrite/save_to。"
|
||||
" 返回任务 id(去任务队列查看进度)。"
|
||||
),
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "文件或目录路径(绝对路径,如 /foo/bar)"},
|
||||
"processor_type": {"type": "string", "description": "处理器类型(例如 image_watermark)"},
|
||||
"config": {"type": "object", "description": "处理器配置,按 processors_list 返回的 config_schema 填写"},
|
||||
"overwrite": {"type": "boolean", "description": "是否覆盖原文件/目录内文件"},
|
||||
"save_to": {"type": "string", "description": "保存到指定路径(仅文件模式,且 overwrite=false 时使用)"},
|
||||
"max_depth": {"type": "integer", "description": "目录遍历深度(仅目录模式)"},
|
||||
"suffix": {"type": "string", "description": "目录批处理时的输出后缀(仅 produces_file 且 overwrite=false)"},
|
||||
},
|
||||
"required": ["path", "processor_type"],
|
||||
},
|
||||
requires_confirmation=True,
|
||||
handler=_processors_run,
|
||||
),
|
||||
}
|
||||
92
domain/agent/tools/time.py
Normal file
92
domain/agent/tools/time.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import calendar
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict
|
||||
|
||||
from .base import ToolSpec
|
||||
|
||||
|
||||
def _parse_offset(args: Dict[str, Any], key: str) -> int:
|
||||
value = args.get(key)
|
||||
if value is None:
|
||||
return 0
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return 0
|
||||
|
||||
|
||||
def _add_months(dt: datetime, months: int) -> datetime:
|
||||
if months == 0:
|
||||
return dt
|
||||
total = dt.year * 12 + (dt.month - 1) + months
|
||||
year = total // 12
|
||||
month = total % 12 + 1
|
||||
last_day = calendar.monthrange(year, month)[1]
|
||||
day = min(dt.day, last_day)
|
||||
return dt.replace(year=year, month=month, day=day)
|
||||
|
||||
|
||||
async def _time(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
now = datetime.now()
|
||||
year_offset = _parse_offset(args, "year")
|
||||
month_offset = _parse_offset(args, "month")
|
||||
day_offset = _parse_offset(args, "day")
|
||||
hour_offset = _parse_offset(args, "hour")
|
||||
minute_offset = _parse_offset(args, "minute")
|
||||
second_offset = _parse_offset(args, "second")
|
||||
|
||||
dt = _add_months(now, year_offset * 12 + month_offset)
|
||||
dt = dt + timedelta(days=day_offset, hours=hour_offset, minutes=minute_offset, seconds=second_offset)
|
||||
|
||||
weekday_names = [
|
||||
"Monday",
|
||||
"Tuesday",
|
||||
"Wednesday",
|
||||
"Thursday",
|
||||
"Friday",
|
||||
"Saturday",
|
||||
"Sunday",
|
||||
]
|
||||
weekday = weekday_names[dt.weekday()]
|
||||
dt_str = dt.strftime("%Y-%m-%d %H:%M:%S")
|
||||
return {
|
||||
"ok": True,
|
||||
"summary": f"{dt_str} · {weekday}",
|
||||
"data": {
|
||||
"datetime": dt_str,
|
||||
"weekday": weekday,
|
||||
"offset": {
|
||||
"year": year_offset,
|
||||
"month": month_offset,
|
||||
"day": day_offset,
|
||||
"hour": hour_offset,
|
||||
"minute": minute_offset,
|
||||
"second": second_offset,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
TOOLS: Dict[str, ToolSpec] = {
|
||||
"time": ToolSpec(
|
||||
name="time",
|
||||
description=(
|
||||
"获取服务器当前时间(精确到秒,含英文星期)。"
|
||||
" 支持 year/month/day/hour/minute/second 偏移(可为负数)。"
|
||||
),
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"year": {"type": "integer", "description": "年偏移(可为负数)"},
|
||||
"month": {"type": "integer", "description": "月偏移(可为负数)"},
|
||||
"day": {"type": "integer", "description": "日偏移(可为负数)"},
|
||||
"hour": {"type": "integer", "description": "时偏移(可为负数)"},
|
||||
"minute": {"type": "integer", "description": "分偏移(可为负数)"},
|
||||
"second": {"type": "integer", "description": "秒偏移(可为负数)"},
|
||||
},
|
||||
"additionalProperties": False,
|
||||
},
|
||||
requires_confirmation=False,
|
||||
handler=_time,
|
||||
),
|
||||
}
|
||||
287
domain/agent/tools/vfs.py
Normal file
287
domain/agent/tools/vfs.py
Normal file
@@ -0,0 +1,287 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from domain.virtual_fs import VirtualFSService
|
||||
from domain.virtual_fs.search import VirtualFSSearchService
|
||||
|
||||
from .base import ToolSpec
|
||||
|
||||
|
||||
def _normalize_vfs_path(value: Any) -> str:
|
||||
s = str(value or "").strip().replace("\\", "/")
|
||||
if not s:
|
||||
return ""
|
||||
if not s.startswith("/"):
|
||||
s = "/" + s
|
||||
s = s.rstrip("/") or "/"
|
||||
return s
|
||||
|
||||
|
||||
def _require_vfs_path(value: Any, field: str) -> str:
|
||||
path = _normalize_vfs_path(value)
|
||||
if not path:
|
||||
raise ValueError(f"missing_{field}")
|
||||
return path
|
||||
|
||||
|
||||
async def _vfs_list_dir(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
path = _normalize_vfs_path(args.get("path") or "/") or "/"
|
||||
page = int(args.get("page") or 1)
|
||||
page_size = int(args.get("page_size") or 50)
|
||||
sort_by = str(args.get("sort_by") or "name")
|
||||
sort_order = str(args.get("sort_order") or "asc")
|
||||
return await VirtualFSService.list_directory(path, page, page_size, sort_by, sort_order)
|
||||
|
||||
|
||||
async def _vfs_stat(args: Dict[str, Any]) -> Any:
|
||||
path = _require_vfs_path(args.get("path"), "path")
|
||||
return await VirtualFSService.stat(path)
|
||||
|
||||
|
||||
async def _vfs_read_text(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
path = _require_vfs_path(args.get("path"), "path")
|
||||
encoding = str(args.get("encoding") or "utf-8")
|
||||
max_chars = int(args.get("max_chars") or 8000)
|
||||
|
||||
data = await VirtualFSService.read_file(path)
|
||||
if isinstance(data, (bytes, bytearray)):
|
||||
try:
|
||||
text = bytes(data).decode(encoding)
|
||||
except UnicodeDecodeError:
|
||||
return {"error": "binary_or_invalid_text", "path": path}
|
||||
elif isinstance(data, str):
|
||||
text = data
|
||||
else:
|
||||
text = str(data)
|
||||
|
||||
original_len = len(text)
|
||||
truncated = original_len > max_chars
|
||||
if truncated:
|
||||
text = text[:max_chars]
|
||||
return {
|
||||
"path": path,
|
||||
"encoding": encoding,
|
||||
"content": text,
|
||||
"truncated": truncated,
|
||||
"length": original_len,
|
||||
}
|
||||
|
||||
|
||||
async def _vfs_write_text(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
path = _require_vfs_path(args.get("path"), "path")
|
||||
if path == "/":
|
||||
raise ValueError("invalid_path")
|
||||
encoding = str(args.get("encoding") or "utf-8")
|
||||
content = str(args.get("content") or "")
|
||||
data = content.encode(encoding)
|
||||
await VirtualFSService.write_file(path, data)
|
||||
return {"written": True, "path": path, "encoding": encoding, "bytes": len(data)}
|
||||
|
||||
|
||||
async def _vfs_mkdir(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
path = _require_vfs_path(args.get("path"), "path")
|
||||
return await VirtualFSService.mkdir(path)
|
||||
|
||||
|
||||
async def _vfs_delete(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
path = _require_vfs_path(args.get("path"), "path")
|
||||
return await VirtualFSService.delete(path)
|
||||
|
||||
|
||||
async def _vfs_move(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
src = _require_vfs_path(args.get("src"), "src")
|
||||
dst = _require_vfs_path(args.get("dst"), "dst")
|
||||
if src == "/" or dst == "/":
|
||||
raise ValueError("invalid_path")
|
||||
overwrite = bool(args.get("overwrite") or False)
|
||||
return await VirtualFSService.move(src, dst, overwrite)
|
||||
|
||||
|
||||
async def _vfs_copy(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
src = _require_vfs_path(args.get("src"), "src")
|
||||
dst = _require_vfs_path(args.get("dst"), "dst")
|
||||
if src == "/" or dst == "/":
|
||||
raise ValueError("invalid_path")
|
||||
overwrite = bool(args.get("overwrite") or False)
|
||||
return await VirtualFSService.copy(src, dst, overwrite)
|
||||
|
||||
|
||||
async def _vfs_rename(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
src = _require_vfs_path(args.get("src"), "src")
|
||||
dst = _require_vfs_path(args.get("dst"), "dst")
|
||||
if src == "/" or dst == "/":
|
||||
raise ValueError("invalid_path")
|
||||
overwrite = bool(args.get("overwrite") or False)
|
||||
return await VirtualFSService.rename(src, dst, overwrite)
|
||||
|
||||
|
||||
async def _vfs_search(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
q = str(args.get("q") or "").strip()
|
||||
if not q:
|
||||
raise ValueError("missing_q")
|
||||
mode = str(args.get("mode") or "vector")
|
||||
top_k = int(args.get("top_k") or 10)
|
||||
page = int(args.get("page") or 1)
|
||||
page_size = int(args.get("page_size") or 10)
|
||||
return await VirtualFSSearchService.search(q, top_k, mode, page, page_size)
|
||||
|
||||
|
||||
TOOLS: Dict[str, ToolSpec] = {
|
||||
"vfs_list_dir": ToolSpec(
|
||||
name="vfs_list_dir",
|
||||
description="浏览目录(列出 entries + pagination)。",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "目录路径(绝对路径,如 /foo/bar)"},
|
||||
"page": {"type": "integer", "description": "页码(从 1 开始)"},
|
||||
"page_size": {"type": "integer", "description": "每页条数"},
|
||||
"sort_by": {"type": "string", "description": "排序字段:name/size/mtime"},
|
||||
"sort_order": {"type": "string", "description": "排序顺序:asc/desc"},
|
||||
},
|
||||
"required": ["path"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
requires_confirmation=False,
|
||||
handler=_vfs_list_dir,
|
||||
),
|
||||
"vfs_stat": ToolSpec(
|
||||
name="vfs_stat",
|
||||
description="查看文件/目录信息(size/mtime/is_dir/has_thumbnail/vector_index 等)。",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "路径(绝对路径,如 /foo/bar.txt)"},
|
||||
},
|
||||
"required": ["path"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
requires_confirmation=False,
|
||||
handler=_vfs_stat,
|
||||
),
|
||||
"vfs_read_text": ToolSpec(
|
||||
name="vfs_read_text",
|
||||
description="读取文本文件内容(解码失败视为二进制,返回 error)。",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "文件路径(绝对路径,如 /foo/bar.md)"},
|
||||
"encoding": {"type": "string", "description": "文本编码(默认 utf-8)"},
|
||||
"max_chars": {"type": "integer", "description": "最多返回的字符数(默认 8000)"},
|
||||
},
|
||||
"required": ["path"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
requires_confirmation=False,
|
||||
handler=_vfs_read_text,
|
||||
),
|
||||
"vfs_write_text": ToolSpec(
|
||||
name="vfs_write_text",
|
||||
description="写入文本文件内容(会覆盖目标文件)。",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "文件路径(绝对路径,如 /foo/bar.md)"},
|
||||
"content": {"type": "string", "description": "要写入的文本内容"},
|
||||
"encoding": {"type": "string", "description": "文本编码(默认 utf-8)"},
|
||||
},
|
||||
"required": ["path", "content"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
requires_confirmation=True,
|
||||
handler=_vfs_write_text,
|
||||
),
|
||||
"vfs_mkdir": ToolSpec(
|
||||
name="vfs_mkdir",
|
||||
description="创建目录。",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "目录路径(绝对路径,如 /foo/bar)"},
|
||||
},
|
||||
"required": ["path"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
requires_confirmation=True,
|
||||
handler=_vfs_mkdir,
|
||||
),
|
||||
"vfs_delete": ToolSpec(
|
||||
name="vfs_delete",
|
||||
description="删除文件或目录(由底层适配器决定是否递归)。",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "路径(绝对路径,如 /foo/bar 或 /foo/bar.txt)"},
|
||||
},
|
||||
"required": ["path"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
requires_confirmation=True,
|
||||
handler=_vfs_delete,
|
||||
),
|
||||
"vfs_move": ToolSpec(
|
||||
name="vfs_move",
|
||||
description="移动路径(可能进入任务队列)。",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"src": {"type": "string", "description": "源路径(绝对路径)"},
|
||||
"dst": {"type": "string", "description": "目标路径(绝对路径)"},
|
||||
"overwrite": {"type": "boolean", "description": "是否允许覆盖已存在目标(默认 false)"},
|
||||
},
|
||||
"required": ["src", "dst"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
requires_confirmation=True,
|
||||
handler=_vfs_move,
|
||||
),
|
||||
"vfs_copy": ToolSpec(
|
||||
name="vfs_copy",
|
||||
description="复制路径(可能进入任务队列)。",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"src": {"type": "string", "description": "源路径(绝对路径)"},
|
||||
"dst": {"type": "string", "description": "目标路径(绝对路径)"},
|
||||
"overwrite": {"type": "boolean", "description": "是否覆盖已存在目标(默认 false)"},
|
||||
},
|
||||
"required": ["src", "dst"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
requires_confirmation=True,
|
||||
handler=_vfs_copy,
|
||||
),
|
||||
"vfs_rename": ToolSpec(
|
||||
name="vfs_rename",
|
||||
description="重命名路径(本质是同目录 move)。",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"src": {"type": "string", "description": "源路径(绝对路径)"},
|
||||
"dst": {"type": "string", "description": "目标路径(绝对路径)"},
|
||||
"overwrite": {"type": "boolean", "description": "是否允许覆盖已存在目标(默认 false)"},
|
||||
},
|
||||
"required": ["src", "dst"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
requires_confirmation=True,
|
||||
handler=_vfs_rename,
|
||||
),
|
||||
"vfs_search": ToolSpec(
|
||||
name="vfs_search",
|
||||
description="搜索文件(mode=vector 或 filename)。",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"q": {"type": "string", "description": "搜索关键词"},
|
||||
"mode": {"type": "string", "description": "搜索模式:vector/filename(默认 vector)"},
|
||||
"top_k": {"type": "integer", "description": "返回数量(vector 模式使用,默认 10)"},
|
||||
"page": {"type": "integer", "description": "页码(filename 模式使用,默认 1)"},
|
||||
"page_size": {"type": "integer", "description": "分页大小(filename 模式使用,默认 10)"},
|
||||
},
|
||||
"required": ["q"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
requires_confirmation=False,
|
||||
handler=_vfs_search,
|
||||
),
|
||||
}
|
||||
182
domain/agent/tools/web_fetch.py
Normal file
182
domain/agent/tools/web_fetch.py
Normal file
@@ -0,0 +1,182 @@
|
||||
from html.parser import HTMLParser
|
||||
from typing import Any, Dict, List
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import httpx
|
||||
|
||||
from .base import ToolSpec
|
||||
|
||||
|
||||
class _HtmlTextExtractor(HTMLParser):
|
||||
def __init__(self, base_url: str):
|
||||
super().__init__()
|
||||
self.base_url = base_url
|
||||
self.links: List[str] = []
|
||||
self._link_set: set[str] = set()
|
||||
self._title_parts: List[str] = []
|
||||
self._text_parts: List[str] = []
|
||||
self._in_title = False
|
||||
self._skip_text = False
|
||||
|
||||
def handle_starttag(self, tag: str, attrs: List[tuple[str, str | None]]):
|
||||
tag = tag.lower()
|
||||
if tag == "title":
|
||||
self._in_title = True
|
||||
if tag in ("script", "style", "noscript"):
|
||||
self._skip_text = True
|
||||
if tag != "a":
|
||||
return
|
||||
href = ""
|
||||
for key, value in attrs:
|
||||
if key.lower() == "href":
|
||||
href = str(value or "").strip()
|
||||
break
|
||||
if not href or href.startswith("#"):
|
||||
return
|
||||
lower = href.lower()
|
||||
if lower.startswith(("javascript:", "mailto:", "tel:", "data:")):
|
||||
return
|
||||
resolved = urljoin(self.base_url, href)
|
||||
if resolved in self._link_set:
|
||||
return
|
||||
self._link_set.add(resolved)
|
||||
self.links.append(resolved)
|
||||
|
||||
def handle_endtag(self, tag: str):
|
||||
tag = tag.lower()
|
||||
if tag == "title":
|
||||
self._in_title = False
|
||||
if tag in ("script", "style", "noscript"):
|
||||
self._skip_text = False
|
||||
|
||||
def handle_data(self, data: str):
|
||||
if not data:
|
||||
return
|
||||
if self._in_title:
|
||||
self._title_parts.append(data)
|
||||
if self._skip_text:
|
||||
return
|
||||
if data.strip():
|
||||
self._text_parts.append(data)
|
||||
|
||||
@property
|
||||
def title(self) -> str:
|
||||
return " ".join(part.strip() for part in self._title_parts if part and part.strip()).strip()
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
if not self._text_parts:
|
||||
return ""
|
||||
text = " ".join(part.strip() for part in self._text_parts if part and part.strip())
|
||||
return " ".join(text.split())
|
||||
|
||||
|
||||
async def _web_fetch(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
url = str(args.get("url") or "").strip()
|
||||
if not url:
|
||||
raise ValueError("missing_url")
|
||||
|
||||
method = str(args.get("method") or "GET").upper()
|
||||
allowed_methods = {"GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"}
|
||||
if method not in allowed_methods:
|
||||
raise ValueError("invalid_method")
|
||||
|
||||
headers_raw = args.get("headers")
|
||||
headers = {str(k): str(v) for k, v in headers_raw.items() if v is not None} if isinstance(headers_raw, dict) else None
|
||||
params_raw = args.get("params")
|
||||
params = {str(k): str(v) for k, v in params_raw.items() if v is not None} if isinstance(params_raw, dict) else None
|
||||
json_body = args.get("json") if "json" in args else None
|
||||
body = args.get("body")
|
||||
|
||||
request_kwargs: Dict[str, Any] = {}
|
||||
if headers:
|
||||
request_kwargs["headers"] = headers
|
||||
if params:
|
||||
request_kwargs["params"] = params
|
||||
if json_body is not None:
|
||||
request_kwargs["json"] = json_body
|
||||
elif body is not None:
|
||||
request_kwargs["content"] = str(body)
|
||||
|
||||
async with httpx.AsyncClient(timeout=20.0, follow_redirects=True) as client:
|
||||
resp = await client.request(method, url, **request_kwargs)
|
||||
|
||||
content_type = resp.headers.get("content-type") or ""
|
||||
text = resp.text or ""
|
||||
is_html = "html" in content_type.lower()
|
||||
if not is_html:
|
||||
probe = text.lstrip()[:200].lower()
|
||||
if "<html" in probe or "<!doctype html" in probe:
|
||||
is_html = True
|
||||
|
||||
html = ""
|
||||
title = ""
|
||||
links: List[str] = []
|
||||
extracted_text = text
|
||||
|
||||
if is_html and text:
|
||||
html = text
|
||||
parser = _HtmlTextExtractor(str(resp.url))
|
||||
parser.feed(text)
|
||||
title = parser.title
|
||||
links = parser.links
|
||||
extracted_text = parser.text
|
||||
|
||||
data = {
|
||||
"url": url,
|
||||
"method": method,
|
||||
"final_url": str(resp.url),
|
||||
"status_code": resp.status_code,
|
||||
"content_type": content_type,
|
||||
"title": title,
|
||||
"html": html,
|
||||
"text": extracted_text,
|
||||
"links": links,
|
||||
}
|
||||
|
||||
summary_parts = [method, str(resp.status_code)]
|
||||
if title:
|
||||
summary_parts.append(title)
|
||||
summary_parts.append(f"{len(links)} links")
|
||||
summary = " · ".join(summary_parts)
|
||||
|
||||
view = {
|
||||
"type": "text",
|
||||
"text": extracted_text,
|
||||
"meta": {
|
||||
"url": url,
|
||||
"final_url": str(resp.url),
|
||||
"status_code": resp.status_code,
|
||||
"content_type": content_type,
|
||||
"title": title,
|
||||
"method": method,
|
||||
"links": len(links),
|
||||
},
|
||||
}
|
||||
return {"ok": True, "summary": summary, "view": view, "data": data}
|
||||
|
||||
|
||||
TOOLS: Dict[str, ToolSpec] = {
|
||||
"web_fetch": ToolSpec(
|
||||
name="web_fetch",
|
||||
description=(
|
||||
"抓取网页内容,返回状态、标题、正文、HTML、链接等信息。"
|
||||
" 支持 GET/POST/PUT/PATCH/DELETE/HEAD/OPTIONS。"
|
||||
),
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {"type": "string", "description": "目标 URL"},
|
||||
"method": {"type": "string", "description": "请求方法(默认 GET)"},
|
||||
"headers": {"type": "object", "description": "请求头", "additionalProperties": {"type": "string"}},
|
||||
"params": {"type": "object", "description": "查询参数", "additionalProperties": {"type": "string"}},
|
||||
"json": {"type": "object", "description": "JSON 请求体"},
|
||||
"body": {"type": "string", "description": "原始请求体"},
|
||||
},
|
||||
"required": ["url"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
requires_confirmation=False,
|
||||
handler=_web_fetch,
|
||||
),
|
||||
}
|
||||
23
domain/agent/types.py
Normal file
23
domain/agent/types.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class AgentChatContext(BaseModel):
|
||||
current_path: Optional[str] = None
|
||||
|
||||
|
||||
class AgentChatRequest(BaseModel):
|
||||
messages: List[Dict[str, Any]] = Field(default_factory=list)
|
||||
auto_execute: bool = False
|
||||
approved_tool_call_ids: List[str] = Field(default_factory=list)
|
||||
rejected_tool_call_ids: List[str] = Field(default_factory=list)
|
||||
context: Optional[AgentChatContext] = None
|
||||
|
||||
|
||||
class PendingToolCall(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
arguments: Dict[str, Any] = Field(default_factory=dict)
|
||||
requires_confirmation: bool = True
|
||||
|
||||
67
domain/ai/__init__.py
Normal file
67
domain/ai/__init__.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from .inference import (
|
||||
MissingModelError,
|
||||
chat_completion,
|
||||
chat_completion_stream,
|
||||
describe_image_base64,
|
||||
get_text_embedding,
|
||||
provider_service,
|
||||
rerank_texts,
|
||||
)
|
||||
from .service import (
|
||||
AIProviderService,
|
||||
FILE_COLLECTION_NAME,
|
||||
VECTOR_COLLECTION_NAME,
|
||||
DEFAULT_VECTOR_DIMENSION,
|
||||
VectorDBConfigManager,
|
||||
VectorDBService,
|
||||
)
|
||||
from .types import (
|
||||
ABILITIES,
|
||||
AIDefaultsUpdate,
|
||||
AIModelCreate,
|
||||
AIModelUpdate,
|
||||
AIProviderCreate,
|
||||
AIProviderUpdate,
|
||||
VectorDBConfigPayload,
|
||||
normalize_capabilities,
|
||||
)
|
||||
from .vector_providers import (
|
||||
BaseVectorProvider,
|
||||
MilvusLiteProvider,
|
||||
MilvusServerProvider,
|
||||
QdrantProvider,
|
||||
get_provider_class,
|
||||
get_provider_entry,
|
||||
list_providers,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"MissingModelError",
|
||||
"chat_completion",
|
||||
"chat_completion_stream",
|
||||
"describe_image_base64",
|
||||
"get_text_embedding",
|
||||
"provider_service",
|
||||
"rerank_texts",
|
||||
"AIProviderService",
|
||||
"VectorDBService",
|
||||
"VectorDBConfigManager",
|
||||
"DEFAULT_VECTOR_DIMENSION",
|
||||
"VECTOR_COLLECTION_NAME",
|
||||
"FILE_COLLECTION_NAME",
|
||||
"BaseVectorProvider",
|
||||
"MilvusLiteProvider",
|
||||
"MilvusServerProvider",
|
||||
"QdrantProvider",
|
||||
"list_providers",
|
||||
"get_provider_entry",
|
||||
"get_provider_class",
|
||||
"ABILITIES",
|
||||
"normalize_capabilities",
|
||||
"AIDefaultsUpdate",
|
||||
"AIModelCreate",
|
||||
"AIModelUpdate",
|
||||
"AIProviderCreate",
|
||||
"AIProviderUpdate",
|
||||
"VectorDBConfigPayload",
|
||||
]
|
||||
304
domain/ai/api.py
Normal file
304
domain/ai/api.py
Normal file
@@ -0,0 +1,304 @@
|
||||
from typing import Annotated, Dict, Optional
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Depends, HTTPException, Path, Request
|
||||
|
||||
from api.response import success
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.auth import User, get_current_active_user
|
||||
from .service import AIProviderService, VectorDBConfigManager, VectorDBService
|
||||
from .types import (
|
||||
AIDefaultsUpdate,
|
||||
AIModelCreate,
|
||||
AIModelUpdate,
|
||||
AIProviderCreate,
|
||||
AIProviderUpdate,
|
||||
VectorDBConfigPayload,
|
||||
)
|
||||
from .vector_providers import get_provider_class, get_provider_entry, list_providers
|
||||
|
||||
router_ai = APIRouter(prefix="/api/ai", tags=["ai"])
|
||||
router_vector_db = APIRouter(prefix="/api/vector-db", tags=["vector-db"])
|
||||
|
||||
|
||||
@audit(action=AuditAction.READ, description="获取 AI 提供商列表")
|
||||
@router_ai.get("/providers")
|
||||
async def list_providers_endpoint(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
providers = await AIProviderService.list_providers()
|
||||
return success({"providers": providers})
|
||||
|
||||
|
||||
@audit(
|
||||
action=AuditAction.CREATE,
|
||||
description="创建 AI 提供商",
|
||||
body_fields=["name", "identifier", "provider_type", "api_format", "base_url", "logo_url"],
|
||||
redact_fields=["api_key"],
|
||||
)
|
||||
@router_ai.post("/providers")
|
||||
async def create_provider(
|
||||
request: Request,
|
||||
payload: AIProviderCreate,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
provider = await AIProviderService.create_provider(payload.dict())
|
||||
return success(provider)
|
||||
|
||||
|
||||
@audit(action=AuditAction.READ, description="获取 AI 提供商详情")
|
||||
@router_ai.get("/providers/{provider_id}")
|
||||
async def get_provider(
|
||||
request: Request,
|
||||
provider_id: Annotated[int, Path(..., gt=0)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
provider = await AIProviderService.get_provider(provider_id, with_models=True)
|
||||
return success(provider)
|
||||
|
||||
|
||||
@audit(
|
||||
action=AuditAction.UPDATE,
|
||||
description="更新 AI 提供商",
|
||||
body_fields=["name", "provider_type", "api_format", "base_url", "logo_url", "api_key"],
|
||||
redact_fields=["api_key"],
|
||||
)
|
||||
@router_ai.put("/providers/{provider_id}")
|
||||
async def update_provider(
|
||||
request: Request,
|
||||
provider_id: Annotated[int, Path(..., gt=0)],
|
||||
payload: AIProviderUpdate,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
data = {k: v for k, v in payload.dict().items() if v is not None}
|
||||
if not data:
|
||||
raise HTTPException(status_code=400, detail="No fields to update")
|
||||
provider = await AIProviderService.update_provider(provider_id, data)
|
||||
return success(provider)
|
||||
|
||||
|
||||
@audit(action=AuditAction.DELETE, description="删除 AI 提供商")
|
||||
@router_ai.delete("/providers/{provider_id}")
|
||||
async def delete_provider(
|
||||
request: Request,
|
||||
provider_id: Annotated[int, Path(..., gt=0)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
await AIProviderService.delete_provider(provider_id)
|
||||
return success({"id": provider_id})
|
||||
|
||||
|
||||
@audit(action=AuditAction.UPDATE, description="同步模型列表")
|
||||
@router_ai.post("/providers/{provider_id}/sync-models")
|
||||
async def sync_models(
|
||||
request: Request,
|
||||
provider_id: Annotated[int, Path(..., gt=0)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
try:
|
||||
result = await AIProviderService.sync_models(provider_id)
|
||||
except (httpx.RequestError, httpx.HTTPStatusError) as exc:
|
||||
raise HTTPException(status_code=502, detail=f"Failed to synchronize models: {exc}") from exc
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
return success(result)
|
||||
|
||||
|
||||
@audit(action=AuditAction.READ, description="获取远程模型列表")
|
||||
@router_ai.get("/providers/{provider_id}/remote-models")
|
||||
async def fetch_remote_models(
|
||||
request: Request,
|
||||
provider_id: Annotated[int, Path(..., gt=0)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
try:
|
||||
models = await AIProviderService.fetch_remote_models(provider_id)
|
||||
except (httpx.RequestError, httpx.HTTPStatusError) as exc:
|
||||
raise HTTPException(status_code=502, detail=f"Failed to pull models: {exc}") from exc
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
return success({"models": models})
|
||||
|
||||
|
||||
@audit(action=AuditAction.READ, description="获取模型列表")
|
||||
@router_ai.get("/providers/{provider_id}/models")
|
||||
async def list_models(
|
||||
request: Request,
|
||||
provider_id: Annotated[int, Path(..., gt=0)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
models = await AIProviderService.list_models(provider_id)
|
||||
return success({"models": models})
|
||||
|
||||
|
||||
@audit(
|
||||
action=AuditAction.CREATE,
|
||||
description="创建模型",
|
||||
body_fields=["name", "display_name", "capabilities", "context_window", "embedding_dimensions"],
|
||||
)
|
||||
@router_ai.post("/providers/{provider_id}/models")
|
||||
async def create_model(
|
||||
request: Request,
|
||||
provider_id: Annotated[int, Path(..., gt=0)],
|
||||
payload: AIModelCreate,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
model = await AIProviderService.create_model(provider_id, payload.dict())
|
||||
return success(model)
|
||||
|
||||
|
||||
@audit(
|
||||
action=AuditAction.UPDATE,
|
||||
description="更新模型",
|
||||
body_fields=["display_name", "description", "capabilities", "context_window", "embedding_dimensions"],
|
||||
)
|
||||
@router_ai.put("/models/{model_id}")
|
||||
async def update_model(
|
||||
request: Request,
|
||||
model_id: Annotated[int, Path(..., gt=0)],
|
||||
payload: AIModelUpdate,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
data = {k: v for k, v in payload.dict().items() if v is not None}
|
||||
if not data:
|
||||
raise HTTPException(status_code=400, detail="No fields to update")
|
||||
model = await AIProviderService.update_model(model_id, data)
|
||||
return success(model)
|
||||
|
||||
|
||||
@audit(action=AuditAction.DELETE, description="删除模型")
|
||||
@router_ai.delete("/models/{model_id}")
|
||||
async def delete_model(
|
||||
request: Request,
|
||||
model_id: Annotated[int, Path(..., gt=0)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
await AIProviderService.delete_model(model_id)
|
||||
return success({"id": model_id})
|
||||
|
||||
|
||||
def _get_embedding_dimension(entry: Optional[Dict]) -> Optional[int]:
|
||||
if not entry:
|
||||
return None
|
||||
value = entry.get("embedding_dimensions")
|
||||
return int(value) if value is not None else None
|
||||
|
||||
|
||||
@audit(action=AuditAction.READ, description="获取默认模型")
|
||||
@router_ai.get("/defaults")
|
||||
async def get_defaults(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
defaults = await AIProviderService.get_default_models()
|
||||
return success(defaults)
|
||||
|
||||
|
||||
@audit(
|
||||
action=AuditAction.UPDATE,
|
||||
description="更新默认模型",
|
||||
body_fields=["chat", "vision", "embedding", "rerank", "voice", "tools"],
|
||||
)
|
||||
@router_ai.put("/defaults")
|
||||
async def update_defaults(
|
||||
request: Request,
|
||||
payload: AIDefaultsUpdate,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
previous = await AIProviderService.get_default_models()
|
||||
try:
|
||||
updated = await AIProviderService.set_default_models(payload.as_mapping())
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
prev_dim = _get_embedding_dimension(previous.get("embedding"))
|
||||
next_dim = _get_embedding_dimension(updated.get("embedding"))
|
||||
|
||||
if prev_dim and next_dim and prev_dim != next_dim:
|
||||
try:
|
||||
await VectorDBService().clear_all_data()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
raise HTTPException(status_code=500, detail=f"Failed to clear vector database: {exc}") from exc
|
||||
|
||||
return success(updated)
|
||||
|
||||
|
||||
@audit(action=AuditAction.UPDATE, description="清空向量数据库")
|
||||
@router_vector_db.post("/clear-all", summary="清空向量数据库")
|
||||
async def clear_vector_db(request: Request, user: User = Depends(get_current_active_user)):
|
||||
try:
|
||||
service = VectorDBService()
|
||||
await service.clear_all_data()
|
||||
return success(msg="向量数据库已清空")
|
||||
except Exception as e: # noqa: BLE001
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@audit(action=AuditAction.READ, description="获取向量数据库统计")
|
||||
@router_vector_db.get("/stats", summary="获取向量数据库统计")
|
||||
async def get_vector_db_stats(request: Request, user: User = Depends(get_current_active_user)):
|
||||
try:
|
||||
service = VectorDBService()
|
||||
data = await service.get_all_stats()
|
||||
return success(data=data)
|
||||
except Exception as e: # noqa: BLE001
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@audit(action=AuditAction.READ, description="获取向量数据库提供者列表")
|
||||
@router_vector_db.get("/providers", summary="列出可用向量数据库提供者")
|
||||
async def list_vector_providers(request: Request):
|
||||
return success(list_providers())
|
||||
|
||||
|
||||
@audit(action=AuditAction.READ, description="获取向量数据库配置")
|
||||
@router_vector_db.get("/config", summary="获取当前向量数据库配置")
|
||||
async def get_vector_db_config(request: Request, user: User = Depends(get_current_active_user)):
|
||||
service = VectorDBService()
|
||||
data = await service.current_provider()
|
||||
return success(data)
|
||||
|
||||
|
||||
@audit(action=AuditAction.UPDATE, description="更新向量数据库配置", body_fields=["type"])
|
||||
@router_vector_db.post("/config", summary="更新向量数据库配置")
|
||||
async def update_vector_db_config(
|
||||
request: Request, payload: VectorDBConfigPayload, user: User = Depends(get_current_active_user)
|
||||
):
|
||||
entry = get_provider_entry(payload.type)
|
||||
if not entry:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"未知的向量数据库类型: {payload.type}")
|
||||
if not entry.get("enabled", True):
|
||||
raise HTTPException(status_code=400, detail="该向量数据库类型暂不可用")
|
||||
|
||||
provider_cls = get_provider_class(payload.type)
|
||||
if not provider_cls:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"未找到类型 {payload.type} 对应的实现")
|
||||
|
||||
test_provider = provider_cls(payload.config)
|
||||
try:
|
||||
await test_provider.initialize()
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc))
|
||||
finally:
|
||||
client = getattr(test_provider, "client", None)
|
||||
close_fn = getattr(client, "close", None)
|
||||
if callable(close_fn):
|
||||
try:
|
||||
close_fn()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await VectorDBConfigManager.save_config(payload.type, payload.config)
|
||||
service = VectorDBService()
|
||||
await service.reload()
|
||||
config_data = await service.current_provider()
|
||||
stats = await service.get_all_stats()
|
||||
return success({"config": config_data, "stats": stats})
|
||||
|
||||
|
||||
__all__ = ["router_ai", "router_vector_db"]
|
||||
1163
domain/ai/inference.py
Normal file
1163
domain/ai/inference.py
Normal file
File diff suppressed because it is too large
Load Diff
501
domain/ai/service.py
Normal file
501
domain/ai/service.py
Normal file
@@ -0,0 +1,501 @@
|
||||
import asyncio
|
||||
import json
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import httpx
|
||||
from tortoise.exceptions import DoesNotExist
|
||||
from tortoise.transactions import in_transaction
|
||||
|
||||
from domain.config import ConfigService
|
||||
from models.database import AIDefaultModel, AIModel, AIProvider
|
||||
|
||||
from .types import ABILITIES, normalize_capabilities
|
||||
from .vector_providers import (
|
||||
BaseVectorProvider,
|
||||
get_provider_class,
|
||||
get_provider_entry,
|
||||
list_providers,
|
||||
)
|
||||
|
||||
DEFAULT_VECTOR_DIMENSION = 4096
|
||||
VECTOR_COLLECTION_NAME = "vector_collection"
|
||||
FILE_COLLECTION_NAME = "file_collection"
|
||||
|
||||
OPENAI_EMBEDDING_DIMS = {
|
||||
"text-embedding-3-large": 3072,
|
||||
"text-embedding-3-small": 1536,
|
||||
"text-embedding-ada-002": 1536,
|
||||
}
|
||||
|
||||
|
||||
class VectorDBConfigManager:
|
||||
TYPE_KEY = "VECTOR_DB_TYPE"
|
||||
CONFIG_KEY = "VECTOR_DB_CONFIG"
|
||||
DEFAULT_TYPE = "milvus_lite"
|
||||
|
||||
@classmethod
|
||||
async def load_config(cls) -> Tuple[str, Dict[str, Any]]:
|
||||
raw_type = await ConfigService.get(cls.TYPE_KEY, cls.DEFAULT_TYPE)
|
||||
provider_type = str(raw_type or cls.DEFAULT_TYPE)
|
||||
|
||||
raw_config = await ConfigService.get(cls.CONFIG_KEY)
|
||||
config_dict: Dict[str, Any] = {}
|
||||
if isinstance(raw_config, str) and raw_config:
|
||||
try:
|
||||
config_dict = json.loads(raw_config)
|
||||
except json.JSONDecodeError:
|
||||
config_dict = {}
|
||||
elif isinstance(raw_config, dict):
|
||||
config_dict = raw_config
|
||||
return provider_type, config_dict
|
||||
|
||||
@classmethod
|
||||
async def save_config(cls, provider_type: str, config: Dict[str, Any]) -> None:
|
||||
await ConfigService.set(cls.TYPE_KEY, provider_type)
|
||||
await ConfigService.set(cls.CONFIG_KEY, json.dumps(config or {}))
|
||||
|
||||
@classmethod
|
||||
async def get_type(cls) -> str:
|
||||
provider_type, _ = await cls.load_config()
|
||||
return provider_type
|
||||
|
||||
@classmethod
|
||||
async def get_config(cls) -> Dict[str, Any]:
|
||||
_, config = await cls.load_config()
|
||||
return config
|
||||
|
||||
|
||||
def _normalize_embedding_dim(value: Any) -> Optional[int]:
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
casted = int(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
return casted if casted > 0 else None
|
||||
|
||||
|
||||
def _apply_embedding_dim_to_metadata(
|
||||
data: Dict[str, Any],
|
||||
embedding_dim: Optional[int],
|
||||
base_metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
source = base_metadata if isinstance(base_metadata, dict) else {}
|
||||
metadata: Dict[str, Any] = dict(source)
|
||||
override = data.get("metadata")
|
||||
if isinstance(override, dict) and override:
|
||||
metadata.update(override)
|
||||
if embedding_dim is None:
|
||||
metadata.pop("embedding_dimensions", None)
|
||||
else:
|
||||
metadata["embedding_dimensions"] = embedding_dim
|
||||
data["metadata"] = metadata or None
|
||||
return data
|
||||
|
||||
|
||||
def infer_openai_capabilities(model_id: str) -> Tuple[List[str], Optional[int]]:
|
||||
lower = model_id.lower()
|
||||
caps = set()
|
||||
|
||||
if any(keyword in lower for keyword in ["gpt", "chat", "turbo", "o1", "sonnet", "haiku", "thinking"]):
|
||||
caps.update({"chat", "tools"})
|
||||
|
||||
if any(keyword in lower for keyword in ["vision", "gpt-4o", "gpt-4.1", "o1", "vision-preview", "omni"]):
|
||||
caps.add("vision")
|
||||
|
||||
if any(keyword in lower for keyword in ["embed", "embedding"]):
|
||||
caps.add("embedding")
|
||||
|
||||
if "rerank" in lower or "re-rank" in lower:
|
||||
caps.add("rerank")
|
||||
|
||||
if any(keyword in lower for keyword in ["tts", "speech", "audio"]):
|
||||
caps.add("voice")
|
||||
|
||||
embedding_dim = OPENAI_EMBEDDING_DIMS.get(model_id)
|
||||
return normalize_capabilities(caps), embedding_dim
|
||||
|
||||
|
||||
def infer_gemini_capabilities(methods: Iterable[str]) -> List[str]:
|
||||
caps = set()
|
||||
for method in methods:
|
||||
m = method.lower()
|
||||
if m in {"generatecontent", "counttokens"}:
|
||||
caps.update({"chat", "tools", "vision"})
|
||||
if m == "embedcontent":
|
||||
caps.add("embedding")
|
||||
if m in {"generatespeech", "audiogeneration"}:
|
||||
caps.add("voice")
|
||||
if m == "rerank":
|
||||
caps.add("rerank")
|
||||
return normalize_capabilities(caps)
|
||||
|
||||
|
||||
def serialize_provider(provider: AIProvider) -> Dict[str, Any]:
|
||||
return {
|
||||
"id": provider.id,
|
||||
"name": provider.name,
|
||||
"identifier": provider.identifier,
|
||||
"provider_type": provider.provider_type,
|
||||
"api_format": provider.api_format,
|
||||
"base_url": provider.base_url,
|
||||
"has_api_key": bool(provider.api_key),
|
||||
"logo_url": provider.logo_url,
|
||||
"extra_config": provider.extra_config or {},
|
||||
"created_at": provider.created_at,
|
||||
"updated_at": provider.updated_at,
|
||||
}
|
||||
|
||||
|
||||
def model_to_dict(model: AIModel, provider: Optional[AIProvider] = None) -> Dict[str, Any]:
|
||||
provider_obj = provider or getattr(model, "provider", None)
|
||||
provider_data = serialize_provider(provider_obj) if provider_obj else None
|
||||
return {
|
||||
"id": model.id,
|
||||
"provider_id": model.provider_id,
|
||||
"name": model.name,
|
||||
"display_name": model.display_name,
|
||||
"description": model.description,
|
||||
"capabilities": normalize_capabilities(model.capabilities),
|
||||
"context_window": model.context_window,
|
||||
"embedding_dimensions": model.embedding_dimensions,
|
||||
"metadata": model.metadata or {},
|
||||
"created_at": model.created_at,
|
||||
"updated_at": model.updated_at,
|
||||
"provider": provider_data,
|
||||
}
|
||||
|
||||
|
||||
def provider_to_dict(provider: AIProvider, models: Optional[List[AIModel]] = None) -> Dict[str, Any]:
|
||||
data = serialize_provider(provider)
|
||||
if models is not None:
|
||||
data["models"] = [model_to_dict(m, provider=provider) for m in models]
|
||||
return data
|
||||
|
||||
|
||||
class AIProviderService:
|
||||
@classmethod
|
||||
async def list_providers(cls) -> List[Dict[str, Any]]:
|
||||
providers = await AIProvider.all().order_by("id").prefetch_related("models")
|
||||
return [provider_to_dict(p, models=list(p.models)) for p in providers]
|
||||
|
||||
@classmethod
|
||||
async def get_provider(cls, provider_id: int, with_models: bool = False) -> Dict[str, Any]:
|
||||
if with_models:
|
||||
provider = await AIProvider.get(id=provider_id)
|
||||
models = await provider.models.all()
|
||||
return provider_to_dict(provider, models=models)
|
||||
provider = await AIProvider.get(id=provider_id)
|
||||
return provider_to_dict(provider)
|
||||
|
||||
@classmethod
|
||||
async def create_provider(cls, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
data = payload.copy()
|
||||
data.setdefault("extra_config", {})
|
||||
provider = await AIProvider.create(**data)
|
||||
return provider_to_dict(provider)
|
||||
|
||||
@classmethod
|
||||
async def update_provider(cls, provider_id: int, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
provider = await AIProvider.get(id=provider_id)
|
||||
for field, value in payload.items():
|
||||
setattr(provider, field, value)
|
||||
await provider.save()
|
||||
return provider_to_dict(provider)
|
||||
|
||||
@classmethod
|
||||
async def delete_provider(cls, provider_id: int) -> None:
|
||||
await AIProvider.filter(id=provider_id).delete()
|
||||
|
||||
@classmethod
|
||||
async def list_models(cls, provider_id: int) -> List[Dict[str, Any]]:
|
||||
models = await AIModel.filter(provider_id=provider_id).order_by("id").prefetch_related("provider")
|
||||
return [model_to_dict(m) for m in models]
|
||||
|
||||
@classmethod
|
||||
async def create_model(cls, provider_id: int, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
data = payload.copy()
|
||||
data["provider_id"] = provider_id
|
||||
data["capabilities"] = normalize_capabilities(data.get("capabilities"))
|
||||
embedding_dim = _normalize_embedding_dim(data.pop("embedding_dimensions", None))
|
||||
data = _apply_embedding_dim_to_metadata(data, embedding_dim)
|
||||
model = await AIModel.create(**data)
|
||||
await model.fetch_related("provider")
|
||||
return model_to_dict(model)
|
||||
|
||||
@classmethod
|
||||
async def update_model(cls, model_id: int, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
model = await AIModel.get(id=model_id)
|
||||
data = payload.copy()
|
||||
if "capabilities" in data:
|
||||
data["capabilities"] = normalize_capabilities(data.get("capabilities"))
|
||||
embedding_dim = None
|
||||
if "embedding_dimensions" in data:
|
||||
embedding_dim = _normalize_embedding_dim(data.pop("embedding_dimensions", None))
|
||||
_apply_embedding_dim_to_metadata(data, embedding_dim, base_metadata=model.metadata)
|
||||
for field, value in data.items():
|
||||
setattr(model, field, value)
|
||||
if embedding_dim is not None or ("embedding_dimensions" in payload and embedding_dim is None):
|
||||
model.embedding_dimensions = embedding_dim
|
||||
await model.save()
|
||||
await model.fetch_related("provider")
|
||||
return model_to_dict(model)
|
||||
|
||||
@classmethod
|
||||
async def delete_model(cls, model_id: int) -> None:
|
||||
await AIModel.filter(id=model_id).delete()
|
||||
|
||||
@classmethod
|
||||
async def fetch_remote_models(cls, provider_id: int) -> List[Dict[str, Any]]:
|
||||
provider = await AIProvider.get(id=provider_id)
|
||||
return await cls._get_remote_models(provider)
|
||||
|
||||
@classmethod
|
||||
async def _get_remote_models(cls, provider: AIProvider) -> List[Dict[str, Any]]:
|
||||
if not provider.base_url:
|
||||
raise ValueError("Provider base_url is required for syncing models")
|
||||
|
||||
fmt = (provider.api_format or "").lower()
|
||||
if fmt not in {"openai", "gemini"}:
|
||||
raise ValueError(f"Unsupported api_format '{provider.api_format}' for syncing models")
|
||||
|
||||
if fmt == "openai":
|
||||
return await cls._fetch_openai_models(provider)
|
||||
return await cls._fetch_gemini_models(provider)
|
||||
|
||||
@classmethod
|
||||
async def sync_models(cls, provider_id: int) -> Dict[str, int]:
|
||||
provider = await AIProvider.get(id=provider_id)
|
||||
remote_models = await cls._get_remote_models(provider)
|
||||
|
||||
created = 0
|
||||
updated = 0
|
||||
for entry in remote_models:
|
||||
defaults = entry.copy()
|
||||
model_id = defaults.pop("name")
|
||||
defaults["capabilities"] = normalize_capabilities(defaults.get("capabilities"))
|
||||
embedding_dim = _normalize_embedding_dim(defaults.pop("embedding_dimensions", None))
|
||||
defaults = _apply_embedding_dim_to_metadata(defaults, embedding_dim)
|
||||
obj, is_created = await AIModel.get_or_create(
|
||||
provider_id=provider.id,
|
||||
name=model_id,
|
||||
defaults=defaults,
|
||||
)
|
||||
if is_created:
|
||||
created += 1
|
||||
continue
|
||||
for field, value in defaults.items():
|
||||
setattr(obj, field, value)
|
||||
if embedding_dim is not None or ("embedding_dimensions" in entry and embedding_dim is None):
|
||||
obj.embedding_dimensions = embedding_dim
|
||||
await obj.save()
|
||||
updated += 1
|
||||
|
||||
return {"created": created, "updated": updated}
|
||||
|
||||
@classmethod
|
||||
async def get_default_models(cls) -> Dict[str, Optional[Dict[str, Any]]]:
|
||||
defaults = await AIDefaultModel.all().prefetch_related("model__provider")
|
||||
result: Dict[str, Optional[Dict[str, Any]]] = {ability: None for ability in ABILITIES}
|
||||
for item in defaults:
|
||||
result[item.ability] = model_to_dict(item.model, provider=item.model.provider) # type: ignore[attr-defined]
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
async def set_default_models(cls, mapping: Dict[str, Optional[int]]) -> Dict[str, Optional[Dict[str, Any]]]:
|
||||
normalized = {ability: mapping.get(ability) for ability in ABILITIES}
|
||||
async with in_transaction() as connection:
|
||||
for ability, model_id in normalized.items():
|
||||
record = await AIDefaultModel.get_or_none(ability=ability)
|
||||
if model_id:
|
||||
try:
|
||||
model = await AIModel.get(id=model_id)
|
||||
except DoesNotExist:
|
||||
raise ValueError(f"Model {model_id} not found")
|
||||
if record:
|
||||
record.model_id = model_id
|
||||
await record.save(using_db=connection)
|
||||
else:
|
||||
await AIDefaultModel.create(ability=ability, model_id=model_id)
|
||||
elif record:
|
||||
await record.delete(using_db=connection)
|
||||
return await cls.get_default_models()
|
||||
|
||||
@classmethod
|
||||
async def get_default_model(cls, ability: str) -> Optional[AIModel]:
|
||||
ability_key = ability.lower()
|
||||
if ability_key not in ABILITIES:
|
||||
return None
|
||||
record = await AIDefaultModel.get_or_none(ability=ability_key)
|
||||
if not record:
|
||||
return None
|
||||
model = await AIModel.get_or_none(id=record.model_id)
|
||||
if model:
|
||||
await model.fetch_related("provider")
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
async def _fetch_openai_models(cls, provider: AIProvider) -> List[Dict[str, Any]]:
|
||||
base_url = provider.base_url.rstrip("/")
|
||||
url = f"{base_url}/models"
|
||||
headers = {}
|
||||
if provider.api_key:
|
||||
headers["Authorization"] = f"Bearer {provider.api_key}"
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
payload = response.json()
|
||||
|
||||
data = payload.get("data", [])
|
||||
entries: List[Dict[str, Any]] = []
|
||||
for item in data:
|
||||
model_id = item.get("id")
|
||||
if not model_id:
|
||||
continue
|
||||
capabilities, embedding_dim = infer_openai_capabilities(model_id)
|
||||
entries.append({
|
||||
"name": model_id,
|
||||
"display_name": item.get("display_name"),
|
||||
"description": item.get("description"),
|
||||
"capabilities": capabilities,
|
||||
"context_window": item.get("context_window"),
|
||||
"embedding_dimensions": embedding_dim,
|
||||
"metadata": item,
|
||||
})
|
||||
return entries
|
||||
|
||||
@classmethod
|
||||
async def _fetch_gemini_models(cls, provider: AIProvider) -> List[Dict[str, Any]]:
|
||||
base_url = provider.base_url.rstrip("/")
|
||||
suffix = "/models"
|
||||
if provider.api_key:
|
||||
suffix += f"?key={provider.api_key}"
|
||||
url = f"{base_url}{suffix}"
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
payload = response.json()
|
||||
|
||||
data = payload.get("models", [])
|
||||
entries: List[Dict[str, Any]] = []
|
||||
for item in data:
|
||||
model_id = item.get("name")
|
||||
if not model_id:
|
||||
continue
|
||||
methods = item.get("supportedGenerationMethods") or []
|
||||
capabilities = infer_gemini_capabilities(methods)
|
||||
entries.append({
|
||||
"name": model_id,
|
||||
"display_name": item.get("displayName"),
|
||||
"description": item.get("description"),
|
||||
"capabilities": capabilities,
|
||||
"context_window": item.get("inputTokenLimit"),
|
||||
"embedding_dimensions": item.get("embeddingDimensions"),
|
||||
"metadata": item,
|
||||
})
|
||||
return entries
|
||||
|
||||
|
||||
class VectorDBService:
|
||||
_instance: Optional["VectorDBService"] = None
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(self, "_provider"):
|
||||
self._provider: Optional[BaseVectorProvider] = None
|
||||
self._provider_type: Optional[str] = None
|
||||
self._provider_config: Dict[str, Any] | None = None
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def _ensure_provider(self) -> BaseVectorProvider:
|
||||
if self._provider is None:
|
||||
await self.reload()
|
||||
assert self._provider is not None
|
||||
return self._provider
|
||||
|
||||
async def reload(self) -> BaseVectorProvider:
|
||||
async with self._lock:
|
||||
provider_type, provider_config = await VectorDBConfigManager.load_config()
|
||||
normalized_config = dict(provider_config or {})
|
||||
if (
|
||||
self._provider
|
||||
and self._provider_type == provider_type
|
||||
and self._provider_config == normalized_config
|
||||
):
|
||||
return self._provider
|
||||
|
||||
entry = get_provider_entry(provider_type)
|
||||
if not entry:
|
||||
raise RuntimeError(f"Unknown vector database provider: {provider_type}")
|
||||
if not entry.get("enabled", True):
|
||||
raise RuntimeError(f"Vector database provider '{provider_type}' is disabled")
|
||||
|
||||
provider_cls = get_provider_class(provider_type)
|
||||
if not provider_cls:
|
||||
raise RuntimeError(f"Provider class not found for '{provider_type}'")
|
||||
|
||||
provider = provider_cls(provider_config)
|
||||
await provider.initialize()
|
||||
|
||||
self._provider = provider
|
||||
self._provider_type = provider_type
|
||||
self._provider_config = normalized_config
|
||||
return provider
|
||||
|
||||
async def ensure_collection(self, collection_name: str, vector: bool = True, dim: int = DEFAULT_VECTOR_DIMENSION) -> None:
|
||||
provider = await self._ensure_provider()
|
||||
provider.ensure_collection(collection_name, vector, dim)
|
||||
|
||||
async def upsert_vector(self, collection_name: str, data: Dict[str, Any]) -> None:
|
||||
provider = await self._ensure_provider()
|
||||
provider.upsert_vector(collection_name, data)
|
||||
|
||||
async def delete_vector(self, collection_name: str, path: str) -> None:
|
||||
provider = await self._ensure_provider()
|
||||
provider.delete_vector(collection_name, path)
|
||||
|
||||
async def search_vectors(self, collection_name: str, query_embedding, top_k: int = 5):
|
||||
provider = await self._ensure_provider()
|
||||
return provider.search_vectors(collection_name, query_embedding, top_k)
|
||||
|
||||
async def search_by_path(self, collection_name: str, query_path: str, top_k: int = 20):
|
||||
provider = await self._ensure_provider()
|
||||
return provider.search_by_path(collection_name, query_path, top_k)
|
||||
|
||||
async def get_all_stats(self) -> Dict[str, Any]:
|
||||
provider = await self._ensure_provider()
|
||||
return provider.get_all_stats()
|
||||
|
||||
async def clear_all_data(self) -> None:
|
||||
provider = await self._ensure_provider()
|
||||
provider.clear_all_data()
|
||||
|
||||
async def current_provider(self) -> Dict[str, Any]:
|
||||
provider_type, provider_config = await VectorDBConfigManager.load_config()
|
||||
entry = get_provider_entry(provider_type) or {}
|
||||
return {
|
||||
"type": provider_type,
|
||||
"config": provider_config,
|
||||
"label": entry.get("label"),
|
||||
"enabled": entry.get("enabled", True),
|
||||
}
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AIProviderService",
|
||||
"VectorDBService",
|
||||
"VectorDBConfigManager",
|
||||
"DEFAULT_VECTOR_DIMENSION",
|
||||
"list_providers",
|
||||
"get_provider_entry",
|
||||
"get_provider_class",
|
||||
"normalize_capabilities",
|
||||
"ABILITIES",
|
||||
]
|
||||
121
domain/ai/types.py
Normal file
121
domain/ai/types.py
Normal file
@@ -0,0 +1,121 @@
|
||||
from typing import Any, Dict, Iterable, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
ABILITIES = ["chat", "vision", "embedding", "rerank", "voice", "tools"]
|
||||
|
||||
|
||||
def normalize_capabilities(items: Optional[Iterable[str]]) -> List[str]:
|
||||
if not items:
|
||||
return []
|
||||
normalized: List[str] = []
|
||||
for cap in items:
|
||||
key = str(cap).strip().lower()
|
||||
if key in ABILITIES and key not in normalized:
|
||||
normalized.append(key)
|
||||
return normalized
|
||||
|
||||
|
||||
class AIProviderBase(BaseModel):
|
||||
name: str
|
||||
identifier: str = Field(..., pattern=r"^[a-z0-9_\-\.]+$")
|
||||
provider_type: Optional[str] = None
|
||||
api_format: str
|
||||
base_url: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
logo_url: Optional[str] = None
|
||||
extra_config: Optional[dict] = None
|
||||
|
||||
@field_validator("api_format")
|
||||
@classmethod
|
||||
def normalize_format(cls, value: str) -> str:
|
||||
fmt = value.lower()
|
||||
if fmt not in {"openai", "gemini", "anthropic", "ollama"}:
|
||||
raise ValueError("api_format must be 'openai', 'gemini', 'anthropic', or 'ollama'")
|
||||
return fmt
|
||||
|
||||
|
||||
class AIProviderCreate(AIProviderBase):
|
||||
pass
|
||||
|
||||
|
||||
class AIProviderUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
provider_type: Optional[str] = None
|
||||
api_format: Optional[str] = None
|
||||
base_url: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
logo_url: Optional[str] = None
|
||||
extra_config: Optional[dict] = None
|
||||
|
||||
@field_validator("api_format")
|
||||
@classmethod
|
||||
def normalize_format(cls, value: Optional[str]) -> Optional[str]:
|
||||
if value is None:
|
||||
return value
|
||||
fmt = value.lower()
|
||||
if fmt not in {"openai", "gemini", "anthropic", "ollama"}:
|
||||
raise ValueError("api_format must be 'openai', 'gemini', 'anthropic', or 'ollama'")
|
||||
return fmt
|
||||
|
||||
|
||||
class AIModelBase(BaseModel):
|
||||
name: str
|
||||
display_name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
capabilities: Optional[List[str]] = None
|
||||
context_window: Optional[int] = None
|
||||
embedding_dimensions: Optional[int] = None
|
||||
metadata: Optional[dict] = None
|
||||
|
||||
@field_validator("capabilities")
|
||||
@classmethod
|
||||
def validate_capabilities(cls, items: Optional[List[str]]) -> Optional[List[str]]:
|
||||
if items is None:
|
||||
return None
|
||||
normalized = normalize_capabilities(items)
|
||||
invalid = set(items) - set(normalized)
|
||||
if invalid:
|
||||
raise ValueError(f"Unsupported capabilities: {', '.join(invalid)}")
|
||||
return normalized
|
||||
|
||||
|
||||
class AIModelCreate(AIModelBase):
|
||||
pass
|
||||
|
||||
|
||||
class AIModelUpdate(BaseModel):
|
||||
display_name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
capabilities: Optional[List[str]] = None
|
||||
context_window: Optional[int] = None
|
||||
embedding_dimensions: Optional[int] = None
|
||||
metadata: Optional[dict] = None
|
||||
|
||||
@field_validator("capabilities")
|
||||
@classmethod
|
||||
def validate_capabilities(cls, items: Optional[List[str]]) -> Optional[List[str]]:
|
||||
if items is None:
|
||||
return None
|
||||
normalized = normalize_capabilities(items)
|
||||
invalid = set(items) - set(normalized)
|
||||
if invalid:
|
||||
raise ValueError(f"Unsupported capabilities: {', '.join(invalid)}")
|
||||
return normalized
|
||||
|
||||
|
||||
class AIDefaultsUpdate(BaseModel):
|
||||
chat: Optional[int] = None
|
||||
vision: Optional[int] = None
|
||||
embedding: Optional[int] = None
|
||||
rerank: Optional[int] = None
|
||||
voice: Optional[int] = None
|
||||
tools: Optional[int] = None
|
||||
|
||||
def as_mapping(self) -> Dict[str, Optional[int]]:
|
||||
return {ability: getattr(self, ability) for ability in ABILITIES}
|
||||
|
||||
|
||||
class VectorDBConfigPayload(BaseModel):
|
||||
type: str = Field(..., description="向量数据库提供者类型")
|
||||
config: Dict[str, Any] = Field(default_factory=dict, description="提供者配置参数")
|
||||
@@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, List, Type
|
||||
|
||||
from .base import BaseVectorProvider
|
||||
@@ -54,3 +52,14 @@ def get_provider_class(provider_type: str) -> Type[BaseVectorProvider] | None:
|
||||
if not entry:
|
||||
return None
|
||||
return entry.get("class") # type: ignore[return-value]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BaseVectorProvider",
|
||||
"MilvusLiteProvider",
|
||||
"MilvusServerProvider",
|
||||
"QdrantProvider",
|
||||
"list_providers",
|
||||
"get_provider_entry",
|
||||
"get_provider_class",
|
||||
]
|
||||
@@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
@@ -39,6 +37,35 @@ class MilvusLiteProvider(BaseVectorProvider):
|
||||
raise RuntimeError("Milvus Lite client is not initialized")
|
||||
return self.client
|
||||
|
||||
@staticmethod
|
||||
def _extract_hit_payload(hit: Any) -> tuple[Any, Any, Dict[str, Any]]:
|
||||
hit_id = getattr(hit, "id", None)
|
||||
distance = getattr(hit, "distance", None)
|
||||
payload: Dict[str, Any] = {}
|
||||
|
||||
raw: Dict[str, Any] | None = None
|
||||
if hasattr(hit, "entity"):
|
||||
raw_entity = getattr(hit, "entity")
|
||||
if hasattr(raw_entity, "to_dict"):
|
||||
raw = dict(raw_entity.to_dict())
|
||||
else:
|
||||
raw = dict(raw_entity)
|
||||
elif isinstance(hit, dict):
|
||||
raw = dict(hit)
|
||||
|
||||
if raw:
|
||||
hit_id = hit_id or raw.get("id")
|
||||
distance = distance if distance is not None else raw.get("distance")
|
||||
inner = raw.get("entity")
|
||||
if isinstance(inner, dict):
|
||||
payload = dict(inner)
|
||||
else:
|
||||
payload = {k: v for k, v in raw.items() if k not in {"id", "distance", "entity"}}
|
||||
|
||||
payload.setdefault("path", payload.get("source_path"))
|
||||
payload.setdefault("source_path", payload.get("path"))
|
||||
return hit_id, distance, payload
|
||||
|
||||
@staticmethod
|
||||
def _to_int(value: Any) -> int:
|
||||
try:
|
||||
@@ -50,15 +77,20 @@ class MilvusLiteProvider(BaseVectorProvider):
|
||||
client = self._get_client()
|
||||
if client.has_collection(collection_name):
|
||||
return
|
||||
common_fields = [
|
||||
FieldSchema(name="path", dtype=DataType.VARCHAR, max_length=512, is_primary=True, auto_id=False),
|
||||
FieldSchema(name="source_path", dtype=DataType.VARCHAR, max_length=512, is_primary=False, auto_id=False),
|
||||
]
|
||||
|
||||
if vector:
|
||||
vector_dim = dim if isinstance(dim, int) and dim > 0 else 0
|
||||
if vector_dim <= 0:
|
||||
vector_dim = 4096
|
||||
fields = [
|
||||
FieldSchema(name="path", dtype=DataType.VARCHAR, max_length=512, is_primary=True, auto_id=False),
|
||||
*common_fields,
|
||||
FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=vector_dim),
|
||||
]
|
||||
schema = CollectionSchema(fields, description="Image vector collection")
|
||||
schema = CollectionSchema(fields, description="Vector collection", enable_dynamic_field=True)
|
||||
client.create_collection(collection_name, schema=schema)
|
||||
index_params = MilvusClient.prepare_index_params()
|
||||
index_params.add_index(
|
||||
@@ -70,38 +102,86 @@ class MilvusLiteProvider(BaseVectorProvider):
|
||||
)
|
||||
client.create_index(collection_name, index_params=index_params)
|
||||
else:
|
||||
fields = [
|
||||
FieldSchema(name="path", dtype=DataType.VARCHAR, max_length=512, is_primary=True, auto_id=False),
|
||||
]
|
||||
schema = CollectionSchema(fields, description="Simple file index")
|
||||
schema = CollectionSchema(common_fields, description="Simple file index", enable_dynamic_field=True)
|
||||
client.create_collection(collection_name, schema=schema)
|
||||
|
||||
def upsert_vector(self, collection_name: str, data: Dict[str, Any]) -> None:
|
||||
self._get_client().upsert(collection_name, data)
|
||||
payload = dict(data)
|
||||
payload.setdefault("source_path", payload.get("path"))
|
||||
payload.setdefault("vector_id", payload.get("path"))
|
||||
self._get_client().upsert(collection_name, data=[payload])
|
||||
|
||||
def delete_vector(self, collection_name: str, path: str) -> None:
|
||||
self._get_client().delete(collection_name, ids=[path])
|
||||
client = self._get_client()
|
||||
escaped = path.replace('"', '\\"')
|
||||
client.delete(collection_name, filter=f'source_path == "{escaped}"')
|
||||
|
||||
def search_vectors(self, collection_name: str, query_embedding, top_k: int):
|
||||
search_params = {"metric_type": "COSINE"}
|
||||
return self._get_client().search(
|
||||
output_fields = [
|
||||
"path",
|
||||
"source_path",
|
||||
"chunk_id",
|
||||
"mime",
|
||||
"text",
|
||||
"start_offset",
|
||||
"end_offset",
|
||||
"type",
|
||||
"name",
|
||||
]
|
||||
raw_results = self._get_client().search(
|
||||
collection_name,
|
||||
data=[query_embedding],
|
||||
anns_field="embedding",
|
||||
search_params=search_params,
|
||||
limit=top_k,
|
||||
output_fields=["path"],
|
||||
output_fields=output_fields,
|
||||
)
|
||||
formatted: List[List[Dict[str, Any]]] = []
|
||||
for hits in raw_results:
|
||||
bucket: List[Dict[str, Any]] = []
|
||||
for hit in hits:
|
||||
hit_id, distance, entity = self._extract_hit_payload(hit)
|
||||
bucket.append({
|
||||
"id": hit_id,
|
||||
"distance": distance,
|
||||
"entity": entity,
|
||||
})
|
||||
formatted.append(bucket)
|
||||
return formatted
|
||||
|
||||
def search_by_path(self, collection_name: str, query_path: str, top_k: int):
|
||||
filter_expr = f"path like '%{query_path}%'" if query_path else "path like '%%'"
|
||||
if query_path:
|
||||
escaped = query_path.replace('"', '\\"')
|
||||
filter_expr = f'source_path like \"%{escaped}%\"'
|
||||
else:
|
||||
filter_expr = "source_path like '%%'"
|
||||
results = self._get_client().query(
|
||||
collection_name,
|
||||
filter=filter_expr,
|
||||
limit=top_k,
|
||||
output_fields=["path"],
|
||||
output_fields=[
|
||||
"path",
|
||||
"source_path",
|
||||
"chunk_id",
|
||||
"mime",
|
||||
"text",
|
||||
"start_offset",
|
||||
"end_offset",
|
||||
"type",
|
||||
"name",
|
||||
],
|
||||
)
|
||||
return [[{"id": r["path"], "distance": 1.0, "entity": {"path": r["path"]}} for r in results]]
|
||||
formatted = []
|
||||
for row in results:
|
||||
entity = dict(row)
|
||||
entity.setdefault("path", entity.get("source_path"))
|
||||
formatted.append({
|
||||
"id": entity.get("path"),
|
||||
"distance": 1.0,
|
||||
"entity": entity,
|
||||
})
|
||||
return [formatted]
|
||||
|
||||
def get_all_stats(self) -> Dict[str, Any]:
|
||||
client = self._get_client()
|
||||
@@ -150,7 +230,7 @@ class MilvusLiteProvider(BaseVectorProvider):
|
||||
|
||||
for index_name in index_names:
|
||||
try:
|
||||
detail = client.describe_index(name, index_name) or {}
|
||||
detail = client.describe_index(name) or {}
|
||||
except Exception:
|
||||
detail = {}
|
||||
indexes.append(
|
||||
@@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pymilvus import CollectionSchema, DataType, FieldSchema, MilvusClient
|
||||
@@ -47,6 +45,35 @@ class MilvusServerProvider(BaseVectorProvider):
|
||||
raise RuntimeError("Milvus Server client is not initialized")
|
||||
return self.client
|
||||
|
||||
@staticmethod
|
||||
def _extract_hit_payload(hit: Any) -> tuple[Any, Any, Dict[str, Any]]:
|
||||
hit_id = getattr(hit, "id", None)
|
||||
distance = getattr(hit, "distance", None)
|
||||
payload: Dict[str, Any] = {}
|
||||
|
||||
raw: Dict[str, Any] | None = None
|
||||
if hasattr(hit, "entity"):
|
||||
raw_entity = getattr(hit, "entity")
|
||||
if hasattr(raw_entity, "to_dict"):
|
||||
raw = dict(raw_entity.to_dict())
|
||||
else:
|
||||
raw = dict(raw_entity)
|
||||
elif isinstance(hit, dict):
|
||||
raw = dict(hit)
|
||||
|
||||
if raw:
|
||||
hit_id = hit_id or raw.get("id")
|
||||
distance = distance if distance is not None else raw.get("distance")
|
||||
inner = raw.get("entity")
|
||||
if isinstance(inner, dict):
|
||||
payload = dict(inner)
|
||||
else:
|
||||
payload = {k: v for k, v in raw.items() if k not in {"id", "distance", "entity"}}
|
||||
|
||||
payload.setdefault("path", payload.get("source_path"))
|
||||
payload.setdefault("source_path", payload.get("path"))
|
||||
return hit_id, distance, payload
|
||||
|
||||
@staticmethod
|
||||
def _to_int(value: Any) -> int:
|
||||
try:
|
||||
@@ -58,15 +85,19 @@ class MilvusServerProvider(BaseVectorProvider):
|
||||
client = self._get_client()
|
||||
if client.has_collection(collection_name):
|
||||
return
|
||||
common_fields = [
|
||||
FieldSchema(name="path", dtype=DataType.VARCHAR, max_length=512, is_primary=True, auto_id=False),
|
||||
FieldSchema(name="source_path", dtype=DataType.VARCHAR, max_length=512, is_primary=False, auto_id=False),
|
||||
]
|
||||
if vector:
|
||||
vector_dim = dim if isinstance(dim, int) and dim > 0 else 0
|
||||
if vector_dim <= 0:
|
||||
vector_dim = 4096
|
||||
fields = [
|
||||
FieldSchema(name="path", dtype=DataType.VARCHAR, max_length=512, is_primary=True, auto_id=False),
|
||||
*common_fields,
|
||||
FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=vector_dim),
|
||||
]
|
||||
schema = CollectionSchema(fields, description="Image vector collection")
|
||||
schema = CollectionSchema(fields, description="Vector collection", enable_dynamic_field=True)
|
||||
client.create_collection(collection_name, schema=schema)
|
||||
index_params = MilvusClient.prepare_index_params()
|
||||
index_params.add_index(
|
||||
@@ -78,38 +109,86 @@ class MilvusServerProvider(BaseVectorProvider):
|
||||
)
|
||||
client.create_index(collection_name, index_params=index_params)
|
||||
else:
|
||||
fields = [
|
||||
FieldSchema(name="path", dtype=DataType.VARCHAR, max_length=512, is_primary=True, auto_id=False),
|
||||
]
|
||||
schema = CollectionSchema(fields, description="Simple file index")
|
||||
schema = CollectionSchema(common_fields, description="Simple file index", enable_dynamic_field=True)
|
||||
client.create_collection(collection_name, schema=schema)
|
||||
|
||||
def upsert_vector(self, collection_name: str, data: Dict[str, Any]) -> None:
|
||||
self._get_client().upsert(collection_name, data)
|
||||
payload = dict(data)
|
||||
payload.setdefault("source_path", payload.get("path"))
|
||||
payload.setdefault("vector_id", payload.get("path"))
|
||||
self._get_client().upsert(collection_name, data=[payload])
|
||||
|
||||
def delete_vector(self, collection_name: str, path: str) -> None:
|
||||
self._get_client().delete(collection_name, ids=[path])
|
||||
client = self._get_client()
|
||||
escaped = path.replace('"', '\\"')
|
||||
client.delete(collection_name, filter=f'source_path == "{escaped}"')
|
||||
|
||||
def search_vectors(self, collection_name: str, query_embedding, top_k: int):
|
||||
search_params = {"metric_type": "COSINE"}
|
||||
return self._get_client().search(
|
||||
output_fields = [
|
||||
"path",
|
||||
"source_path",
|
||||
"chunk_id",
|
||||
"mime",
|
||||
"text",
|
||||
"start_offset",
|
||||
"end_offset",
|
||||
"type",
|
||||
"name",
|
||||
]
|
||||
raw_results = self._get_client().search(
|
||||
collection_name,
|
||||
data=[query_embedding],
|
||||
anns_field="embedding",
|
||||
search_params=search_params,
|
||||
limit=top_k,
|
||||
output_fields=["path"],
|
||||
output_fields=output_fields,
|
||||
)
|
||||
formatted: List[List[Dict[str, Any]]] = []
|
||||
for hits in raw_results:
|
||||
bucket: List[Dict[str, Any]] = []
|
||||
for hit in hits:
|
||||
hit_id, distance, entity = self._extract_hit_payload(hit)
|
||||
bucket.append({
|
||||
"id": hit_id,
|
||||
"distance": distance,
|
||||
"entity": entity,
|
||||
})
|
||||
formatted.append(bucket)
|
||||
return formatted
|
||||
|
||||
def search_by_path(self, collection_name: str, query_path: str, top_k: int):
|
||||
filter_expr = f"path like '%{query_path}%'" if query_path else "path like '%%'"
|
||||
if query_path:
|
||||
escaped = query_path.replace('"', '\\"')
|
||||
filter_expr = f'source_path like \"%{escaped}%\"'
|
||||
else:
|
||||
filter_expr = "source_path like '%%'"
|
||||
results = self._get_client().query(
|
||||
collection_name,
|
||||
filter=filter_expr,
|
||||
limit=top_k,
|
||||
output_fields=["path"],
|
||||
output_fields=[
|
||||
"path",
|
||||
"source_path",
|
||||
"chunk_id",
|
||||
"mime",
|
||||
"text",
|
||||
"start_offset",
|
||||
"end_offset",
|
||||
"type",
|
||||
"name",
|
||||
],
|
||||
)
|
||||
return [[{"id": r["path"], "distance": 1.0, "entity": {"path": r["path"]}} for r in results]]
|
||||
formatted = []
|
||||
for row in results:
|
||||
entity = dict(row)
|
||||
entity.setdefault("path", entity.get("source_path"))
|
||||
formatted.append({
|
||||
"id": entity.get("path"),
|
||||
"distance": 1.0,
|
||||
"entity": entity,
|
||||
})
|
||||
return [formatted]
|
||||
|
||||
def get_all_stats(self) -> Dict[str, Any]:
|
||||
client = self._get_client()
|
||||
@@ -158,7 +237,7 @@ class MilvusServerProvider(BaseVectorProvider):
|
||||
|
||||
for index_name in index_names:
|
||||
try:
|
||||
detail = client.describe_index(name, index_name) or {}
|
||||
detail = client.describe_index(name) or {}
|
||||
except Exception:
|
||||
detail = {}
|
||||
indexes.append(
|
||||
@@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional, Sequence
|
||||
from uuid import NAMESPACE_URL, uuid5
|
||||
|
||||
@@ -42,7 +40,6 @@ class QdrantProvider(BaseVectorProvider):
|
||||
api_key = (self.config.get("api_key") or None) or None
|
||||
try:
|
||||
client = QdrantClient(url=url, api_key=api_key)
|
||||
# 简单连通性校验
|
||||
client.get_collections()
|
||||
self.client = client
|
||||
except Exception as exc: # pragma: no cover - 依赖外部服务
|
||||
@@ -58,29 +55,58 @@ class QdrantProvider(BaseVectorProvider):
|
||||
size = dim if vector and isinstance(dim, int) and dim > 0 else 1
|
||||
return qmodels.VectorParams(size=size, distance=qmodels.Distance.COSINE)
|
||||
|
||||
def _ensure_payload_indexes(self, client: QdrantClient, collection_name: str) -> None:
|
||||
for field in ("path", "source_path"):
|
||||
try:
|
||||
client.create_payload_index(
|
||||
collection_name=collection_name,
|
||||
field_name=field,
|
||||
field_schema="keyword",
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - 依赖外部服务
|
||||
message = str(exc).lower()
|
||||
if "already exists" in message or "index exists" in message:
|
||||
continue
|
||||
raise
|
||||
|
||||
def ensure_collection(self, collection_name: str, vector: bool, dim: int) -> None:
|
||||
client = self._get_client()
|
||||
try:
|
||||
if client.collection_exists(collection_name):
|
||||
return
|
||||
exists = client.collection_exists(collection_name)
|
||||
except Exception as exc: # pragma: no cover - 依赖外部服务
|
||||
raise RuntimeError(f"Failed to check Qdrant collection '{collection_name}': {exc}") from exc
|
||||
|
||||
if exists:
|
||||
try:
|
||||
self._ensure_payload_indexes(client, collection_name)
|
||||
except Exception:
|
||||
pass
|
||||
return
|
||||
|
||||
vectors_config = self._vector_params(vector, dim)
|
||||
try:
|
||||
client.create_collection(collection_name=collection_name, vectors_config=vectors_config)
|
||||
except Exception as exc: # pragma: no cover
|
||||
if "already exists" in str(exc).lower():
|
||||
try:
|
||||
self._ensure_payload_indexes(client, collection_name)
|
||||
except Exception:
|
||||
pass
|
||||
return
|
||||
raise RuntimeError(f"Failed to create Qdrant collection '{collection_name}': {exc}") from exc
|
||||
|
||||
try:
|
||||
self._ensure_payload_indexes(client, collection_name)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _point_id(path: str) -> str:
|
||||
return str(uuid5(NAMESPACE_URL, path))
|
||||
def _point_id(uid: str) -> str:
|
||||
return str(uuid5(NAMESPACE_URL, uid))
|
||||
|
||||
def _prepare_point(self, data: Dict[str, Any]) -> qmodels.PointStruct:
|
||||
path = data.get("path")
|
||||
if not path:
|
||||
uid = data.get("path")
|
||||
if not uid:
|
||||
raise ValueError("Qdrant upsert requires 'path' in data")
|
||||
|
||||
embedding = data.get("embedding")
|
||||
@@ -89,8 +115,11 @@ class QdrantProvider(BaseVectorProvider):
|
||||
else:
|
||||
vector = [float(x) for x in embedding]
|
||||
|
||||
payload = {"path": path}
|
||||
return qmodels.PointStruct(id=self._point_id(path), vector=vector, payload=payload)
|
||||
payload = {k: v for k, v in data.items() if k != "embedding"}
|
||||
payload.setdefault("vector_id", uid)
|
||||
source_path = payload.get("source_path") or payload.get("path")
|
||||
payload["path"] = source_path
|
||||
return qmodels.PointStruct(id=self._point_id(str(uid)), vector=vector, payload=payload)
|
||||
|
||||
def upsert_vector(self, collection_name: str, data: Dict[str, Any]) -> None:
|
||||
client = self._get_client()
|
||||
@@ -99,7 +128,12 @@ class QdrantProvider(BaseVectorProvider):
|
||||
|
||||
def delete_vector(self, collection_name: str, path: str) -> None:
|
||||
client = self._get_client()
|
||||
selector = qmodels.PointIdsList(points=[self._point_id(path)])
|
||||
condition = qmodels.FieldCondition(
|
||||
key="path",
|
||||
match=qmodels.MatchValue(value=path),
|
||||
)
|
||||
flt = qmodels.Filter(must=[condition])
|
||||
selector = qmodels.FilterSelector(filter=flt)
|
||||
client.delete(collection_name=collection_name, points_selector=selector, wait=True)
|
||||
|
||||
def _format_search_results(self, points: Sequence[qmodels.ScoredPoint]):
|
||||
@@ -107,7 +141,7 @@ class QdrantProvider(BaseVectorProvider):
|
||||
{
|
||||
"id": point.id,
|
||||
"distance": point.score,
|
||||
"entity": {"path": (point.payload or {}).get("path")},
|
||||
"entity": point.payload or {},
|
||||
}
|
||||
for point in points
|
||||
]
|
||||
@@ -141,11 +175,11 @@ class QdrantProvider(BaseVectorProvider):
|
||||
break
|
||||
|
||||
for record in records:
|
||||
path = (record.payload or {}).get("path")
|
||||
if query_path and path:
|
||||
if query_path not in path:
|
||||
continue
|
||||
results.append({"id": record.id, "distance": 1.0, "entity": {"path": path}})
|
||||
payload = record.payload or {}
|
||||
path = payload.get("path")
|
||||
if query_path and path and query_path not in path:
|
||||
continue
|
||||
results.append({"id": record.id, "distance": 1.0, "entity": payload})
|
||||
if len(results) >= top_k:
|
||||
break
|
||||
|
||||
4
domain/audit/__init__.py
Normal file
4
domain/audit/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .decorator import audit
|
||||
from .types import AuditAction
|
||||
|
||||
__all__ = ["audit", "AuditAction"]
|
||||
69
domain/audit/api.py
Normal file
69
domain/audit/api.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from datetime import datetime, timezone
|
||||
from typing import Annotated, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
|
||||
from api import response
|
||||
from domain.auth import User, get_current_active_user
|
||||
from domain.permission import require_system_permission
|
||||
from domain.permission.types import SystemPermission
|
||||
from .service import AuditService
|
||||
from .types import AuditAction
|
||||
|
||||
CurrentUser = Annotated[User, Depends(get_current_active_user)]
|
||||
|
||||
router = APIRouter(prefix="/api/audit", tags=["Audit"])
|
||||
|
||||
|
||||
def _parse_iso(value: Optional[str], field: str):
|
||||
if not value:
|
||||
return None
|
||||
try:
|
||||
normalized = value.replace("Z", "+00:00")
|
||||
dt = datetime.fromisoformat(normalized)
|
||||
if dt.tzinfo:
|
||||
dt = dt.astimezone(timezone.utc).replace(tzinfo=None)
|
||||
return dt
|
||||
except ValueError as exc: # noqa: BLE001
|
||||
raise HTTPException(status_code=400, detail=f"invalid {field}") from exc
|
||||
|
||||
|
||||
@router.get("/logs")
|
||||
@require_system_permission(SystemPermission.AUDIT_VIEW)
|
||||
async def list_audit_logs(
|
||||
current_user: CurrentUser,
|
||||
page_num: int = Query(1, ge=1, alias="page", description="页码"),
|
||||
page_size: int = Query(20, ge=1, le=200, description="每页条数"),
|
||||
action: AuditAction | None = Query(None, description="操作类型"),
|
||||
success: bool | None = Query(None, description="是否成功"),
|
||||
username: str | None = Query(None, description="用户名模糊匹配"),
|
||||
path: str | None = Query(None, description="路径模糊匹配"),
|
||||
start_time: str | None = Query(None, description="开始时间 (ISO 8601)"),
|
||||
end_time: str | None = Query(None, description="结束时间 (ISO 8601)"),
|
||||
):
|
||||
start_dt = _parse_iso(start_time, "start_time")
|
||||
end_dt = _parse_iso(end_time, "end_time")
|
||||
items, total = await AuditService.list_logs(
|
||||
page=page_num,
|
||||
page_size=page_size,
|
||||
action=str(action) if action else None,
|
||||
success=success,
|
||||
username=username,
|
||||
path=path,
|
||||
start_time=start_dt,
|
||||
end_time=end_dt,
|
||||
)
|
||||
return response.success(response.page(items, total, page_num, page_size))
|
||||
|
||||
|
||||
@router.delete("/logs")
|
||||
@require_system_permission(SystemPermission.AUDIT_VIEW)
|
||||
async def clear_audit_logs(
|
||||
current_user: CurrentUser,
|
||||
start_time: str | None = Query(None, description="开始时间 (ISO 8601)"),
|
||||
end_time: str | None = Query(None, description="结束时间 (ISO 8601)"),
|
||||
):
|
||||
start_dt = _parse_iso(start_time, "start_time")
|
||||
end_dt = _parse_iso(end_time, "end_time")
|
||||
deleted_count = await AuditService.clear_logs(start_time=start_dt, end_time=end_dt)
|
||||
return response.success({"deleted_count": deleted_count})
|
||||
204
domain/audit/decorator.py
Normal file
204
domain/audit/decorator.py
Normal file
@@ -0,0 +1,204 @@
|
||||
import inspect
|
||||
import time
|
||||
from functools import wraps
|
||||
from typing import Any, Dict, Mapping, Optional
|
||||
|
||||
import jwt
|
||||
from fastapi import Request
|
||||
from jwt.exceptions import InvalidTokenError
|
||||
|
||||
from domain.auth import ALGORITHM
|
||||
from domain.config import ConfigService
|
||||
from models.database import UserAccount
|
||||
from .service import AuditService
|
||||
from .types import AuditAction
|
||||
|
||||
|
||||
def _extract_request(bound_args: Mapping[str, Any]) -> Request | None:
|
||||
for value in bound_args.values():
|
||||
if isinstance(value, Request):
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
async def _resolve_user(request: Request | None, user_obj: Any | None) -> tuple[Optional[int], Optional[str]]:
|
||||
user_id: int | None = None
|
||||
username: str | None = None
|
||||
|
||||
if request:
|
||||
auth_header = request.headers.get("authorization") or request.headers.get("Authorization")
|
||||
if auth_header and auth_header.lower().startswith("bearer "):
|
||||
token = auth_header.split(" ", 1)[1]
|
||||
try:
|
||||
payload = jwt.decode(token, await ConfigService.get_secret_key("SECRET_KEY"), algorithms=[ALGORITHM])
|
||||
username = payload.get("sub") or payload.get("username")
|
||||
if username:
|
||||
user = await UserAccount.get_or_none(username=username)
|
||||
user_id = user.id if user else None
|
||||
except (InvalidTokenError, Exception):
|
||||
pass
|
||||
|
||||
if user_id is None and username is None and user_obj is not None:
|
||||
user_id = getattr(user_obj, "id", None) or getattr(user_obj, "user_id", None)
|
||||
username = getattr(user_obj, "username", None) or getattr(user_obj, "name", None)
|
||||
if isinstance(user_obj, dict):
|
||||
user_id = user_obj.get("id", user_obj.get("user_id", user_id))
|
||||
username = user_obj.get("username", user_obj.get("name", username))
|
||||
|
||||
return user_id, username
|
||||
|
||||
|
||||
def _extract_body_fields(bound_args: Mapping[str, Any], body_fields: list[str] | None, redact_fields: list[str] | None):
|
||||
if not body_fields:
|
||||
return None
|
||||
body: Dict[str, Any] = {}
|
||||
redacts = set(redact_fields or [])
|
||||
for value in bound_args.values():
|
||||
data: Optional[Dict[str, Any]] = None
|
||||
if hasattr(value, "model_dump"):
|
||||
try:
|
||||
data = value.model_dump()
|
||||
except Exception:
|
||||
data = None
|
||||
elif hasattr(value, "dict"):
|
||||
try:
|
||||
data = value.dict()
|
||||
except Exception:
|
||||
data = None
|
||||
elif isinstance(value, dict):
|
||||
data = value
|
||||
elif hasattr(value, "__dict__"):
|
||||
data = dict(value.__dict__)
|
||||
if not isinstance(data, dict):
|
||||
continue
|
||||
for field in body_fields:
|
||||
if field in data and field not in body:
|
||||
body[field] = data[field]
|
||||
if not body:
|
||||
return None
|
||||
for field in redacts:
|
||||
if field in body:
|
||||
body[field] = "<redacted>"
|
||||
return body
|
||||
|
||||
|
||||
def _build_request_params(request: Request | None) -> Dict[str, Any] | None:
|
||||
if not request:
|
||||
return None
|
||||
params: Dict[str, Any] = {}
|
||||
query = dict(request.query_params)
|
||||
if query:
|
||||
params["query"] = query
|
||||
path_params = dict(request.path_params or {})
|
||||
if path_params:
|
||||
params["path"] = path_params
|
||||
return params or None
|
||||
|
||||
|
||||
def _get_client_ip(request: Request | None) -> str | None:
|
||||
if not request:
|
||||
return None
|
||||
cf_connecting_ip = request.headers.get("cf-connecting-ip") or request.headers.get("CF-Connecting-IP")
|
||||
if cf_connecting_ip:
|
||||
ip = cf_connecting_ip.strip()
|
||||
if ip:
|
||||
return ip
|
||||
x_real_ip = request.headers.get("x-real-ip") or request.headers.get("X-Real-IP")
|
||||
if x_real_ip:
|
||||
ip = x_real_ip.strip()
|
||||
if ip:
|
||||
return ip
|
||||
x_forwarded_for = request.headers.get("x-forwarded-for") or request.headers.get("X-Forwarded-For")
|
||||
if x_forwarded_for:
|
||||
for part in x_forwarded_for.split(","):
|
||||
ip = part.strip()
|
||||
if ip and ip.lower() != "unknown":
|
||||
return ip
|
||||
return request.client.host if request.client else None
|
||||
|
||||
|
||||
def _status_code_from_response(response: Any) -> int:
|
||||
if hasattr(response, "status_code"):
|
||||
try:
|
||||
return int(getattr(response, "status_code"))
|
||||
except Exception:
|
||||
pass
|
||||
return 200
|
||||
|
||||
|
||||
def audit(
|
||||
*,
|
||||
action: AuditAction,
|
||||
description: str | None = None,
|
||||
body_fields: list[str] | None = None,
|
||||
redact_fields: list[str] | None = None,
|
||||
user_kw: str = "current_user",
|
||||
):
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
bound = inspect.signature(func).bind_partial(*args, **kwargs)
|
||||
bound.apply_defaults()
|
||||
request = _extract_request(bound.arguments)
|
||||
start = time.perf_counter()
|
||||
user_info = bound.arguments.get(user_kw)
|
||||
user_id, username = await _resolve_user(request, user_info)
|
||||
request_params = _build_request_params(request)
|
||||
request_body = _extract_body_fields(bound.arguments, body_fields, redact_fields)
|
||||
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
if inspect.isawaitable(result):
|
||||
result = await result
|
||||
status_code = _status_code_from_response(result)
|
||||
success = True
|
||||
error = None
|
||||
except Exception as exc: # noqa: BLE001
|
||||
status_code = getattr(exc, "status_code", 500)
|
||||
success = False
|
||||
error = str(exc)
|
||||
duration_ms = round((time.perf_counter() - start) * 1000, 2)
|
||||
try:
|
||||
await AuditService.log(
|
||||
action=action,
|
||||
description=description,
|
||||
user_id=user_id,
|
||||
username=username,
|
||||
client_ip=_get_client_ip(request),
|
||||
method=request.method if request else "",
|
||||
path=request.url.path if request else func.__name__,
|
||||
status_code=status_code,
|
||||
duration_ms=duration_ms,
|
||||
success=success,
|
||||
request_params=request_params,
|
||||
request_body=request_body,
|
||||
error=error,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
raise
|
||||
|
||||
duration_ms = round((time.perf_counter() - start) * 1000, 2)
|
||||
try:
|
||||
await AuditService.log(
|
||||
action=action,
|
||||
description=description,
|
||||
user_id=user_id,
|
||||
username=username,
|
||||
client_ip=_get_client_ip(request),
|
||||
method=request.method if request else "",
|
||||
path=request.url.path if request else func.__name__,
|
||||
status_code=status_code,
|
||||
duration_ms=duration_ms,
|
||||
success=success,
|
||||
request_params=request_params,
|
||||
request_body=request_body,
|
||||
error=error,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
124
domain/audit/service.py
Normal file
124
domain/audit/service.py
Normal file
@@ -0,0 +1,124 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from models.database import AuditLog
|
||||
|
||||
from .types import AuditAction
|
||||
|
||||
|
||||
class AuditService:
|
||||
@classmethod
|
||||
async def log(
|
||||
cls,
|
||||
*,
|
||||
action: AuditAction | str,
|
||||
description: Optional[str],
|
||||
user_id: Optional[int],
|
||||
username: Optional[str],
|
||||
client_ip: Optional[str],
|
||||
method: str,
|
||||
path: str,
|
||||
status_code: int,
|
||||
duration_ms: Optional[float],
|
||||
success: bool,
|
||||
request_params: Optional[Dict[str, Any]] = None,
|
||||
request_body: Optional[Dict[str, Any]] = None,
|
||||
error: Optional[str] = None,
|
||||
) -> None:
|
||||
await AuditLog.create(
|
||||
action=str(action),
|
||||
description=description,
|
||||
user_id=user_id,
|
||||
username=username,
|
||||
client_ip=client_ip,
|
||||
method=method,
|
||||
path=path,
|
||||
status_code=status_code,
|
||||
duration_ms=duration_ms,
|
||||
success=success,
|
||||
request_params=request_params,
|
||||
request_body=request_body,
|
||||
error=error,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _serialize(cls, log: AuditLog) -> Dict[str, Any]:
|
||||
return {
|
||||
"id": log.id,
|
||||
"created_at": log.created_at.isoformat() if log.created_at else None,
|
||||
"action": log.action,
|
||||
"description": log.description,
|
||||
"user_id": log.user_id,
|
||||
"username": log.username,
|
||||
"client_ip": log.client_ip,
|
||||
"method": log.method,
|
||||
"path": log.path,
|
||||
"status_code": log.status_code,
|
||||
"duration_ms": log.duration_ms,
|
||||
"success": log.success,
|
||||
"request_params": log.request_params,
|
||||
"request_body": log.request_body,
|
||||
"error": log.error,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _apply_filters(
|
||||
cls,
|
||||
*,
|
||||
action: str | None = None,
|
||||
success: bool | None = None,
|
||||
username: str | None = None,
|
||||
path: str | None = None,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
):
|
||||
qs = AuditLog.all()
|
||||
if action:
|
||||
qs = qs.filter(action=action)
|
||||
if success is not None:
|
||||
qs = qs.filter(success=success)
|
||||
if username:
|
||||
qs = qs.filter(username__icontains=username)
|
||||
if path:
|
||||
qs = qs.filter(path__icontains=path)
|
||||
if start_time:
|
||||
qs = qs.filter(created_at__gte=start_time)
|
||||
if end_time:
|
||||
qs = qs.filter(created_at__lte=end_time)
|
||||
return qs
|
||||
|
||||
@classmethod
|
||||
async def list_logs(
|
||||
cls,
|
||||
*,
|
||||
page: int,
|
||||
page_size: int,
|
||||
action: str | None = None,
|
||||
success: bool | None = None,
|
||||
username: str | None = None,
|
||||
path: str | None = None,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
):
|
||||
qs = cls._apply_filters(
|
||||
action=action,
|
||||
success=success,
|
||||
username=username,
|
||||
path=path,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
total = await qs.count()
|
||||
offset = (page - 1) * page_size
|
||||
items = await qs.order_by("-created_at").offset(offset).limit(page_size)
|
||||
return [cls._serialize(log) for log in items], total
|
||||
|
||||
@classmethod
|
||||
async def clear_logs(
|
||||
cls,
|
||||
*,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
) -> int:
|
||||
qs = cls._apply_filters(start_time=start_time, end_time=end_time)
|
||||
deleted_count = await qs.delete()
|
||||
return deleted_count
|
||||
16
domain/audit/types.py
Normal file
16
domain/audit/types.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class AuditAction(StrEnum):
|
||||
LOGIN = "login"
|
||||
LOGOUT = "logout"
|
||||
REGISTER = "register"
|
||||
READ = "read"
|
||||
CREATE = "create"
|
||||
UPDATE = "update"
|
||||
DELETE = "delete"
|
||||
RESET_PASSWORD = "reset_password"
|
||||
SHARE = "share"
|
||||
DOWNLOAD = "download"
|
||||
UPLOAD = "upload"
|
||||
OTHER = "other"
|
||||
49
domain/auth/__init__.py
Normal file
49
domain/auth/__init__.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from .service import (
|
||||
ALGORITHM,
|
||||
AuthService,
|
||||
authenticate_user_db,
|
||||
create_access_token,
|
||||
get_current_active_user,
|
||||
get_current_user,
|
||||
get_password_hash,
|
||||
has_users,
|
||||
register_user,
|
||||
request_password_reset,
|
||||
reset_password_with_token,
|
||||
verify_password,
|
||||
verify_password_reset_token,
|
||||
)
|
||||
from .types import (
|
||||
PasswordResetConfirm,
|
||||
PasswordResetRequest,
|
||||
RegisterRequest,
|
||||
Token,
|
||||
TokenData,
|
||||
UpdateMeRequest,
|
||||
User,
|
||||
UserInDB,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ALGORITHM",
|
||||
"AuthService",
|
||||
"authenticate_user_db",
|
||||
"create_access_token",
|
||||
"get_current_active_user",
|
||||
"get_current_user",
|
||||
"get_password_hash",
|
||||
"has_users",
|
||||
"register_user",
|
||||
"request_password_reset",
|
||||
"reset_password_with_token",
|
||||
"verify_password",
|
||||
"verify_password_reset_token",
|
||||
"PasswordResetConfirm",
|
||||
"PasswordResetRequest",
|
||||
"RegisterRequest",
|
||||
"Token",
|
||||
"TokenData",
|
||||
"UpdateMeRequest",
|
||||
"User",
|
||||
"UserInDB",
|
||||
]
|
||||
90
domain/auth/api.py
Normal file
90
domain/auth/api.py
Normal file
@@ -0,0 +1,90 @@
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
|
||||
from api.response import success
|
||||
from domain.audit import AuditAction, audit
|
||||
from .service import AuthService, get_current_active_user
|
||||
from .types import (
|
||||
PasswordResetConfirm,
|
||||
PasswordResetRequest,
|
||||
RegisterRequest,
|
||||
Token,
|
||||
UpdateMeRequest,
|
||||
User,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/api/auth", tags=["auth"])
|
||||
|
||||
|
||||
@router.post("/register", summary="注册用户(首个用户为管理员)")
|
||||
@audit(
|
||||
action=AuditAction.REGISTER,
|
||||
description="注册用户",
|
||||
body_fields=["username", "email", "full_name"],
|
||||
redact_fields=["password"],
|
||||
)
|
||||
async def register(request: Request, data: RegisterRequest):
|
||||
user = await AuthService.register_user(data)
|
||||
return success({"username": user.username}, msg="注册成功")
|
||||
|
||||
|
||||
@router.post("/login")
|
||||
@audit(action=AuditAction.LOGIN, description="用户登录", body_fields=["username"], redact_fields=["password"])
|
||||
async def login_for_access_token(
|
||||
request: Request,
|
||||
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
||||
) -> Token:
|
||||
return await AuthService.login(form_data)
|
||||
|
||||
|
||||
@router.get("/me", summary="获取当前登录用户信息")
|
||||
@audit(action=AuditAction.READ, description="获取当前用户信息")
|
||||
async def get_me(
|
||||
request: Request, current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
profile = AuthService.get_profile(current_user)
|
||||
return success(profile)
|
||||
|
||||
|
||||
@router.put("/me", summary="更新当前登录用户信息")
|
||||
@audit(
|
||||
action=AuditAction.UPDATE,
|
||||
description="更新当前用户信息",
|
||||
body_fields=["email", "full_name"],
|
||||
redact_fields=["old_password", "new_password"],
|
||||
)
|
||||
async def update_me(
|
||||
request: Request,
|
||||
payload: UpdateMeRequest,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
profile = await AuthService.update_me(payload, current_user)
|
||||
return success(profile)
|
||||
|
||||
|
||||
@router.post("/password-reset/request", summary="请求密码重置邮件")
|
||||
@audit(action=AuditAction.RESET_PASSWORD, description="请求密码重置邮件", body_fields=["email"])
|
||||
async def password_reset_request_endpoint(request: Request, payload: PasswordResetRequest):
|
||||
await AuthService.request_password_reset(payload)
|
||||
return success(msg="如果邮箱存在,将发送重置邮件")
|
||||
|
||||
|
||||
@router.get("/password-reset/verify", summary="校验密码重置令牌")
|
||||
@audit(action=AuditAction.RESET_PASSWORD, description="校验密码重置令牌", redact_fields=["token"])
|
||||
async def password_reset_verify(request: Request, token: str):
|
||||
user = await AuthService.verify_password_reset_token(token)
|
||||
return success({"username": user.username, "email": user.email})
|
||||
|
||||
|
||||
@router.post("/password-reset/confirm", summary="使用令牌重置密码")
|
||||
@audit(
|
||||
action=AuditAction.RESET_PASSWORD,
|
||||
description="重置密码",
|
||||
body_fields=["token"],
|
||||
redact_fields=["token", "password"],
|
||||
)
|
||||
async def password_reset_confirm(request: Request, payload: PasswordResetConfirm):
|
||||
await AuthService.reset_password_with_token(payload)
|
||||
return success(msg="密码已重置")
|
||||
415
domain/auth/service.py
Normal file
415
domain/auth/service.py
Normal file
@@ -0,0 +1,415 @@
|
||||
import asyncio
|
||||
import hashlib
|
||||
import secrets
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Annotated
|
||||
|
||||
import bcrypt
|
||||
import jwt
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
||||
from jwt.exceptions import InvalidTokenError
|
||||
|
||||
from domain.config import ConfigService
|
||||
from models.database import Role, UserAccount, UserRole
|
||||
from .types import (
|
||||
PasswordResetConfirm,
|
||||
PasswordResetRequest,
|
||||
RegisterRequest,
|
||||
Token,
|
||||
TokenData,
|
||||
UpdateMeRequest,
|
||||
User,
|
||||
UserInDB,
|
||||
)
|
||||
|
||||
ALGORITHM = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 * 365
|
||||
PASSWORD_RESET_TOKEN_EXPIRE_MINUTES = 10
|
||||
|
||||
|
||||
def _now() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PasswordResetEntry:
|
||||
user_id: int
|
||||
email: str
|
||||
username: str
|
||||
expires_at: datetime
|
||||
used: bool = False
|
||||
|
||||
|
||||
class PasswordResetStore:
|
||||
_tokens: dict[str, PasswordResetEntry] = {}
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
@classmethod
|
||||
def _cleanup(cls):
|
||||
now = _now()
|
||||
for token, record in list(cls._tokens.items()):
|
||||
if record.used or record.expires_at < now:
|
||||
cls._tokens.pop(token, None)
|
||||
|
||||
@classmethod
|
||||
async def create(cls, user: UserAccount) -> str:
|
||||
async with cls._lock:
|
||||
cls._cleanup()
|
||||
for key, record in list(cls._tokens.items()):
|
||||
if record.user_id == user.id:
|
||||
cls._tokens.pop(key, None)
|
||||
token = secrets.token_urlsafe(32)
|
||||
expires_at = _now() + timedelta(minutes=PASSWORD_RESET_TOKEN_EXPIRE_MINUTES)
|
||||
cls._tokens[token] = PasswordResetEntry(
|
||||
user_id=user.id,
|
||||
email=user.email or "",
|
||||
username=user.username,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
return token
|
||||
|
||||
@classmethod
|
||||
async def get(cls, token: str) -> PasswordResetEntry | None:
|
||||
async with cls._lock:
|
||||
cls._cleanup()
|
||||
record = cls._tokens.get(token)
|
||||
if not record or record.used:
|
||||
return None
|
||||
return record
|
||||
|
||||
@classmethod
|
||||
async def mark_used(cls, token: str) -> None:
|
||||
async with cls._lock:
|
||||
record = cls._tokens.get(token)
|
||||
if record:
|
||||
record.used = True
|
||||
cls._cleanup()
|
||||
|
||||
@classmethod
|
||||
async def invalidate_user(cls, user_id: int, except_token: str | None = None) -> None:
|
||||
async with cls._lock:
|
||||
for key, record in list(cls._tokens.items()):
|
||||
if record.user_id == user_id and key != except_token:
|
||||
cls._tokens.pop(key, None)
|
||||
cls._cleanup()
|
||||
|
||||
|
||||
class AuthService:
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login")
|
||||
algorithm = ALGORITHM
|
||||
access_token_expire_minutes = ACCESS_TOKEN_EXPIRE_MINUTES
|
||||
password_reset_token_expire_minutes = PASSWORD_RESET_TOKEN_EXPIRE_MINUTES
|
||||
|
||||
@staticmethod
|
||||
def _to_bytes(value: str) -> bytes:
|
||||
return value.encode("utf-8")
|
||||
|
||||
@classmethod
|
||||
async def get_secret_key(cls) -> str:
|
||||
return await ConfigService.get_secret_key("SECRET_KEY", None)
|
||||
|
||||
@classmethod
|
||||
def _normalize_email(cls, email: str | None) -> str:
|
||||
return (email or "").strip().lower()
|
||||
|
||||
@classmethod
|
||||
def verify_password(cls, plain_password: str, hashed_password: str) -> bool:
|
||||
try:
|
||||
return bcrypt.checkpw(cls._to_bytes(plain_password), hashed_password.encode("utf-8"))
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_password_hash(cls, password: str) -> str:
|
||||
encoded = cls._to_bytes(password)
|
||||
if len(encoded) > 72:
|
||||
raise HTTPException(status_code=400, detail="密码过长")
|
||||
return bcrypt.hashpw(encoded, bcrypt.gensalt()).decode("utf-8")
|
||||
|
||||
@classmethod
|
||||
async def get_user_db(cls, username_or_email: str) -> UserInDB | None:
|
||||
user = await UserAccount.get_or_none(username=username_or_email)
|
||||
if not user:
|
||||
user = await UserAccount.get_or_none(email=username_or_email)
|
||||
if user:
|
||||
return UserInDB(
|
||||
id=user.id,
|
||||
username=user.username,
|
||||
email=user.email,
|
||||
full_name=user.full_name,
|
||||
disabled=user.disabled,
|
||||
is_admin=user.is_admin,
|
||||
hashed_password=user.hashed_password,
|
||||
)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
async def authenticate_user_db(cls, username_or_email: str, password: str) -> UserInDB | None:
|
||||
user = await cls.get_user_db(username_or_email)
|
||||
if not user:
|
||||
return None
|
||||
if not cls.verify_password(password, user.hashed_password):
|
||||
return None
|
||||
return user
|
||||
|
||||
@classmethod
|
||||
async def has_users(cls) -> bool:
|
||||
user_count = await UserAccount.all().count()
|
||||
return user_count > 0
|
||||
|
||||
@classmethod
|
||||
async def register_user(cls, payload: RegisterRequest):
|
||||
has_users = await cls.has_users()
|
||||
normalized_email = cls._normalize_email(payload.email)
|
||||
if not normalized_email:
|
||||
raise HTTPException(status_code=400, detail="邮箱不能为空")
|
||||
|
||||
if has_users:
|
||||
allow_register = str(await ConfigService.get("AUTH_ALLOW_REGISTER", "false") or "").strip().lower()
|
||||
if allow_register not in ("1", "true", "yes", "on"):
|
||||
raise HTTPException(status_code=403, detail="系统未开放注册")
|
||||
|
||||
default_role_id_raw = str(await ConfigService.get("AUTH_DEFAULT_REGISTER_ROLE_ID", "") or "").strip()
|
||||
if not default_role_id_raw:
|
||||
raise HTTPException(status_code=400, detail="未配置默认注册角色")
|
||||
try:
|
||||
default_role_id = int(default_role_id_raw)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail="默认注册角色配置错误") from exc
|
||||
|
||||
role = await Role.get_or_none(id=default_role_id)
|
||||
if not role:
|
||||
raise HTTPException(status_code=400, detail="默认注册角色不存在")
|
||||
|
||||
exists = await UserAccount.get_or_none(username=payload.username)
|
||||
if exists:
|
||||
raise HTTPException(status_code=400, detail="用户名已存在")
|
||||
|
||||
existing_email = await UserAccount.get_or_none(email=normalized_email)
|
||||
if existing_email:
|
||||
raise HTTPException(status_code=400, detail="邮箱已被使用")
|
||||
|
||||
hashed = cls.get_password_hash(payload.password)
|
||||
|
||||
# 第一个用户自动成为超级管理员(不受开放注册开关影响)
|
||||
if not has_users:
|
||||
user = await UserAccount.create(
|
||||
username=payload.username,
|
||||
email=normalized_email,
|
||||
full_name=payload.full_name,
|
||||
hashed_password=hashed,
|
||||
disabled=False,
|
||||
is_admin=True,
|
||||
)
|
||||
return user
|
||||
|
||||
# 系统已初始化:按默认角色创建普通用户
|
||||
user = await UserAccount.create(
|
||||
username=payload.username,
|
||||
email=normalized_email,
|
||||
full_name=payload.full_name,
|
||||
hashed_password=hashed,
|
||||
disabled=False,
|
||||
is_admin=False,
|
||||
)
|
||||
await UserRole.create(user_id=user.id, role_id=default_role_id)
|
||||
return user
|
||||
|
||||
@classmethod
|
||||
async def create_access_token(cls, data: dict, expires_delta: timedelta | None = None):
|
||||
to_encode = data.copy()
|
||||
if "sub" not in to_encode and "username" in to_encode:
|
||||
to_encode["sub"] = to_encode["username"]
|
||||
expire = _now() + (expires_delta or timedelta(minutes=15))
|
||||
to_encode.update({"exp": expire})
|
||||
secret_key = await cls.get_secret_key()
|
||||
encoded_jwt = jwt.encode(to_encode, secret_key, algorithm=cls.algorithm)
|
||||
return encoded_jwt
|
||||
|
||||
@classmethod
|
||||
async def login(cls, form: OAuth2PasswordRequestForm) -> Token:
|
||||
user = await cls.authenticate_user_db(form.username, form.password)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="用户名或密码错误",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
# 更新最后登录时间
|
||||
db_user = await UserAccount.get_or_none(id=user.id)
|
||||
if db_user:
|
||||
db_user.last_login = _now()
|
||||
await db_user.save(update_fields=["last_login"])
|
||||
|
||||
access_token_expires = timedelta(minutes=cls.access_token_expire_minutes)
|
||||
access_token = await cls.create_access_token(
|
||||
data={"sub": user.username}, expires_delta=access_token_expires
|
||||
)
|
||||
return Token(access_token=access_token, token_type="bearer")
|
||||
|
||||
@classmethod
|
||||
def _build_profile(cls, user: User | UserInDB | UserAccount) -> dict:
|
||||
email = cls._normalize_email(getattr(user, "email", None))
|
||||
md5_hash = hashlib.md5(email.encode("utf-8")).hexdigest()
|
||||
gravatar_url = f"https://cn.cravatar.com/avatar/{md5_hash}?s=64&d=identicon"
|
||||
return {
|
||||
"id": user.id,
|
||||
"username": user.username,
|
||||
"email": getattr(user, "email", None),
|
||||
"full_name": getattr(user, "full_name", None),
|
||||
"gravatar_url": gravatar_url,
|
||||
"is_admin": getattr(user, "is_admin", False),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_profile(cls, user: User | UserInDB | UserAccount) -> dict:
|
||||
return cls._build_profile(user)
|
||||
|
||||
@classmethod
|
||||
async def update_me(cls, payload: UpdateMeRequest, current_user: User) -> dict:
|
||||
db_user = await UserAccount.get_or_none(id=current_user.id)
|
||||
if not db_user:
|
||||
raise HTTPException(status_code=404, detail="用户不存在")
|
||||
|
||||
if payload.email is not None:
|
||||
exists = (
|
||||
await UserAccount.filter(email=payload.email)
|
||||
.exclude(id=db_user.id)
|
||||
.exists()
|
||||
)
|
||||
if exists:
|
||||
raise HTTPException(status_code=400, detail="邮箱已被占用")
|
||||
db_user.email = payload.email
|
||||
|
||||
if payload.full_name is not None:
|
||||
db_user.full_name = payload.full_name
|
||||
|
||||
if payload.new_password:
|
||||
if not payload.old_password:
|
||||
raise HTTPException(status_code=400, detail="请提供原密码")
|
||||
if not cls.verify_password(payload.old_password, db_user.hashed_password):
|
||||
raise HTTPException(status_code=400, detail="原密码错误")
|
||||
db_user.hashed_password = cls.get_password_hash(payload.new_password)
|
||||
|
||||
await db_user.save()
|
||||
return cls._build_profile(db_user)
|
||||
|
||||
@classmethod
|
||||
async def request_password_reset(cls, payload: PasswordResetRequest) -> bool:
|
||||
normalized = cls._normalize_email(payload.email)
|
||||
if not normalized:
|
||||
return False
|
||||
user = await UserAccount.get_or_none(email=normalized)
|
||||
if not user or not user.email:
|
||||
return False
|
||||
|
||||
token = await PasswordResetStore.create(user)
|
||||
try:
|
||||
await cls._send_password_reset_email(user, token)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
await PasswordResetStore.mark_used(token)
|
||||
await PasswordResetStore.invalidate_user(user.id)
|
||||
raise HTTPException(status_code=500, detail="邮件发送失败") from exc
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
async def verify_password_reset_token(cls, token: str) -> UserAccount:
|
||||
record = await PasswordResetStore.get(token)
|
||||
if not record:
|
||||
raise HTTPException(status_code=400, detail="重置链接无效")
|
||||
user = await UserAccount.get_or_none(id=record.user_id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=400, detail="重置链接无效")
|
||||
if record.expires_at < _now():
|
||||
await PasswordResetStore.mark_used(token)
|
||||
raise HTTPException(status_code=400, detail="重置链接已过期")
|
||||
return user
|
||||
|
||||
@classmethod
|
||||
async def reset_password_with_token(cls, payload: PasswordResetConfirm) -> None:
|
||||
record = await PasswordResetStore.get(payload.token)
|
||||
if not record:
|
||||
raise HTTPException(status_code=400, detail="重置链接无效")
|
||||
if record.expires_at < _now():
|
||||
await PasswordResetStore.mark_used(payload.token)
|
||||
raise HTTPException(status_code=400, detail="重置链接已过期")
|
||||
|
||||
user = await UserAccount.get_or_none(id=record.user_id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=400, detail="重置链接无效")
|
||||
user.hashed_password = cls.get_password_hash(payload.password)
|
||||
await user.save(update_fields=["hashed_password"])
|
||||
await PasswordResetStore.mark_used(payload.token)
|
||||
await PasswordResetStore.invalidate_user(user.id)
|
||||
|
||||
@classmethod
|
||||
async def get_current_user(cls, token: str):
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
try:
|
||||
secret_key = await cls.get_secret_key()
|
||||
payload = jwt.decode(token, secret_key, algorithms=[cls.algorithm])
|
||||
username = payload.get("sub")
|
||||
if username is None:
|
||||
raise credentials_exception
|
||||
token_data = TokenData(username=username)
|
||||
except InvalidTokenError:
|
||||
raise credentials_exception
|
||||
user = await cls.get_user_db(token_data.username)
|
||||
if user is None:
|
||||
raise credentials_exception
|
||||
return user
|
||||
|
||||
@classmethod
|
||||
async def get_current_active_user(cls, current_user: User):
|
||||
if current_user.disabled:
|
||||
raise HTTPException(status_code=400, detail="Inactive user")
|
||||
return current_user
|
||||
|
||||
@classmethod
|
||||
async def _send_password_reset_email(cls, user: UserAccount, token: str) -> None:
|
||||
from domain.email import EmailService
|
||||
|
||||
app_domain = await ConfigService.get("APP_DOMAIN", None)
|
||||
base_url = (app_domain or "http://localhost:5173").rstrip("/")
|
||||
reset_link = f"{base_url}/reset-password?token={token}"
|
||||
await EmailService.enqueue_email(
|
||||
recipients=[user.email],
|
||||
subject="Foxel 密码重置",
|
||||
template="password_reset",
|
||||
context={
|
||||
"username": user.username,
|
||||
"reset_link": reset_link,
|
||||
"expire_minutes": cls.password_reset_token_expire_minutes,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def _current_user_dep(token: Annotated[str, Depends(AuthService.oauth2_scheme)]):
|
||||
return await AuthService.get_current_user(token)
|
||||
|
||||
|
||||
async def _current_active_user_dep(
|
||||
current_user: Annotated[User, Depends(_current_user_dep)],
|
||||
):
|
||||
return await AuthService.get_current_active_user(current_user)
|
||||
|
||||
|
||||
# 方便依赖注入与外部使用
|
||||
get_current_user = _current_user_dep
|
||||
get_current_active_user = _current_active_user_dep
|
||||
authenticate_user_db = AuthService.authenticate_user_db
|
||||
create_access_token = AuthService.create_access_token
|
||||
register_user = AuthService.register_user
|
||||
request_password_reset = AuthService.request_password_reset
|
||||
verify_password_reset_token = AuthService.verify_password_reset_token
|
||||
reset_password_with_token = AuthService.reset_password_with_token
|
||||
has_users = AuthService.has_users
|
||||
verify_password = AuthService.verify_password
|
||||
get_password_hash = AuthService.get_password_hash
|
||||
46
domain/auth/types.py
Normal file
46
domain/auth/types.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Token(BaseModel):
|
||||
access_token: str
|
||||
token_type: str
|
||||
|
||||
|
||||
class TokenData(BaseModel):
|
||||
username: str | None = None
|
||||
|
||||
|
||||
class User(BaseModel):
|
||||
id: int
|
||||
username: str
|
||||
email: str | None = None
|
||||
full_name: str | None = None
|
||||
disabled: bool | None = None
|
||||
is_admin: bool = False
|
||||
|
||||
|
||||
class UserInDB(User):
|
||||
hashed_password: str
|
||||
|
||||
|
||||
class RegisterRequest(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
email: str
|
||||
full_name: str | None = None
|
||||
|
||||
|
||||
class UpdateMeRequest(BaseModel):
|
||||
email: str | None = None
|
||||
full_name: str | None = None
|
||||
old_password: str | None = None
|
||||
new_password: str | None = None
|
||||
|
||||
|
||||
class PasswordResetRequest(BaseModel):
|
||||
email: str
|
||||
|
||||
|
||||
class PasswordResetConfirm(BaseModel):
|
||||
token: str
|
||||
password: str
|
||||
7
domain/backup/__init__.py
Normal file
7
domain/backup/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from .service import BackupService
|
||||
from .types import BackupData
|
||||
|
||||
__all__ = [
|
||||
"BackupService",
|
||||
"BackupData",
|
||||
]
|
||||
44
domain/backup/api.py
Normal file
44
domain/backup/api.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import datetime
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, File, Form, Query, Request, UploadFile
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.auth import User, get_current_active_user
|
||||
from domain.permission import require_system_permission
|
||||
from domain.permission.types import SystemPermission
|
||||
from .service import BackupService
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api/backup",
|
||||
tags=["Backup & Restore"],
|
||||
dependencies=[Depends(get_current_active_user)],
|
||||
)
|
||||
|
||||
|
||||
@router.get("/export", summary="导出全站数据")
|
||||
@audit(action=AuditAction.DOWNLOAD, description="导出备份")
|
||||
@require_system_permission(SystemPermission.CONFIG_EDIT)
|
||||
async def export_backup(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
sections: list[str] | None = Query(default=None),
|
||||
):
|
||||
data = await BackupService.export_data(sections=sections)
|
||||
timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
headers = {"Content-Disposition": f"attachment; filename=foxel_backup_{timestamp}.json"}
|
||||
return JSONResponse(content=data.model_dump(), headers=headers)
|
||||
|
||||
|
||||
@router.post("/import", summary="导入数据")
|
||||
@audit(action=AuditAction.UPLOAD, description="导入备份")
|
||||
@require_system_permission(SystemPermission.CONFIG_EDIT)
|
||||
async def import_backup(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
file: UploadFile = File(...),
|
||||
mode: str = Form("replace"),
|
||||
):
|
||||
await BackupService.import_from_bytes(file.filename, await file.read(), mode=mode)
|
||||
return {"message": "数据导入成功。"}
|
||||
339
domain/backup/service.py
Normal file
339
domain/backup/service.py
Normal file
@@ -0,0 +1,339 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import HTTPException
|
||||
from tortoise.transactions import in_transaction
|
||||
|
||||
from domain.config import VERSION
|
||||
from .types import BackupData
|
||||
from models.database import (
|
||||
AIDefaultModel,
|
||||
AIModel,
|
||||
AIProvider,
|
||||
AutomationTask,
|
||||
Configuration,
|
||||
Plugin,
|
||||
ShareLink,
|
||||
StorageAdapter,
|
||||
UserAccount,
|
||||
)
|
||||
|
||||
|
||||
class BackupService:
|
||||
ALL_SECTIONS = (
|
||||
"storage_adapters",
|
||||
"user_accounts",
|
||||
"automation_tasks",
|
||||
"share_links",
|
||||
"configurations",
|
||||
"ai_providers",
|
||||
"ai_models",
|
||||
"ai_default_models",
|
||||
"plugins",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def export_data(cls, sections: list[str] | None = None) -> BackupData:
|
||||
sections = cls._normalize_sections(sections)
|
||||
section_set = set(sections)
|
||||
async with in_transaction():
|
||||
adapters = (
|
||||
await StorageAdapter.all().values()
|
||||
if "storage_adapters" in section_set
|
||||
else []
|
||||
)
|
||||
users = (
|
||||
await UserAccount.all().values()
|
||||
if "user_accounts" in section_set
|
||||
else []
|
||||
)
|
||||
tasks = (
|
||||
await AutomationTask.all().values()
|
||||
if "automation_tasks" in section_set
|
||||
else []
|
||||
)
|
||||
shares = (
|
||||
await ShareLink.all().values()
|
||||
if "share_links" in section_set
|
||||
else []
|
||||
)
|
||||
configs = (
|
||||
await Configuration.all().values()
|
||||
if "configurations" in section_set
|
||||
else []
|
||||
)
|
||||
providers = (
|
||||
await AIProvider.all().values()
|
||||
if "ai_providers" in section_set
|
||||
else []
|
||||
)
|
||||
models = (
|
||||
await AIModel.all().values() if "ai_models" in section_set else []
|
||||
)
|
||||
default_models = (
|
||||
await AIDefaultModel.all().values()
|
||||
if "ai_default_models" in section_set
|
||||
else []
|
||||
)
|
||||
plugins = (
|
||||
await Plugin.all().values() if "plugins" in section_set else []
|
||||
)
|
||||
|
||||
share_links = cls._serialize_datetime_fields(
|
||||
shares, ["created_at", "expires_at"]
|
||||
)
|
||||
ai_providers = cls._serialize_datetime_fields(
|
||||
providers, ["created_at", "updated_at"]
|
||||
)
|
||||
ai_models = cls._serialize_datetime_fields(
|
||||
models, ["created_at", "updated_at"]
|
||||
)
|
||||
ai_default_models = cls._serialize_datetime_fields(
|
||||
default_models, ["created_at", "updated_at"]
|
||||
)
|
||||
plugin_items = cls._serialize_datetime_fields(
|
||||
plugins, ["created_at", "updated_at"]
|
||||
)
|
||||
|
||||
return BackupData(
|
||||
version=VERSION,
|
||||
sections=sections,
|
||||
storage_adapters=list(adapters),
|
||||
user_accounts=list(users),
|
||||
automation_tasks=list(tasks),
|
||||
share_links=share_links,
|
||||
configurations=list(configs),
|
||||
ai_providers=ai_providers,
|
||||
ai_models=ai_models,
|
||||
ai_default_models=ai_default_models,
|
||||
plugins=plugin_items,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def import_from_bytes(
|
||||
cls, filename: str, content: bytes, mode: str = "replace"
|
||||
) -> None:
|
||||
if not filename.endswith(".json"):
|
||||
raise HTTPException(status_code=400, detail="无效的文件类型, 请上传 .json 文件")
|
||||
try:
|
||||
raw_data = json.loads(content)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=400, detail="无法解析JSON文件")
|
||||
await cls.import_data(BackupData(**raw_data), mode=mode)
|
||||
|
||||
@classmethod
|
||||
async def import_data(cls, payload: BackupData, mode: str = "replace") -> None:
|
||||
sections = cls._normalize_sections(payload.sections)
|
||||
if mode not in {"replace", "merge"}:
|
||||
raise HTTPException(status_code=400, detail="无效的导入模式")
|
||||
|
||||
share_links = (
|
||||
cls._parse_datetime_fields(payload.share_links, ["created_at", "expires_at"])
|
||||
if payload.share_links
|
||||
else []
|
||||
)
|
||||
ai_providers = (
|
||||
cls._parse_datetime_fields(payload.ai_providers, ["created_at", "updated_at"])
|
||||
if payload.ai_providers
|
||||
else []
|
||||
)
|
||||
ai_models = (
|
||||
cls._parse_datetime_fields(payload.ai_models, ["created_at", "updated_at"])
|
||||
if payload.ai_models
|
||||
else []
|
||||
)
|
||||
ai_default_models = (
|
||||
cls._parse_datetime_fields(
|
||||
payload.ai_default_models, ["created_at", "updated_at"]
|
||||
)
|
||||
if payload.ai_default_models
|
||||
else []
|
||||
)
|
||||
plugins = (
|
||||
cls._parse_datetime_fields(payload.plugins, ["created_at", "updated_at"])
|
||||
if payload.plugins
|
||||
else []
|
||||
)
|
||||
|
||||
async with in_transaction() as conn:
|
||||
if mode == "replace":
|
||||
if "share_links" in sections:
|
||||
await ShareLink.all().using_db(conn).delete()
|
||||
if "automation_tasks" in sections:
|
||||
await AutomationTask.all().using_db(conn).delete()
|
||||
if "storage_adapters" in sections:
|
||||
await StorageAdapter.all().using_db(conn).delete()
|
||||
if "user_accounts" in sections:
|
||||
await UserAccount.all().using_db(conn).delete()
|
||||
if "configurations" in sections:
|
||||
await Configuration.all().using_db(conn).delete()
|
||||
if "ai_default_models" in sections:
|
||||
await AIDefaultModel.all().using_db(conn).delete()
|
||||
if "ai_models" in sections:
|
||||
await AIModel.all().using_db(conn).delete()
|
||||
if "ai_providers" in sections:
|
||||
await AIProvider.all().using_db(conn).delete()
|
||||
if "plugins" in sections:
|
||||
await Plugin.all().using_db(conn).delete()
|
||||
|
||||
if "configurations" in sections and payload.configurations:
|
||||
if mode == "merge":
|
||||
await cls._merge_records(
|
||||
Configuration, payload.configurations, conn
|
||||
)
|
||||
else:
|
||||
await Configuration.bulk_create(
|
||||
[Configuration(**config) for config in payload.configurations],
|
||||
using_db=conn,
|
||||
)
|
||||
|
||||
if "user_accounts" in sections and payload.user_accounts:
|
||||
if mode == "merge":
|
||||
await cls._merge_records(UserAccount, payload.user_accounts, conn)
|
||||
else:
|
||||
await UserAccount.bulk_create(
|
||||
[UserAccount(**user) for user in payload.user_accounts],
|
||||
using_db=conn,
|
||||
)
|
||||
|
||||
if "storage_adapters" in sections and payload.storage_adapters:
|
||||
if mode == "merge":
|
||||
await cls._merge_records(
|
||||
StorageAdapter, payload.storage_adapters, conn
|
||||
)
|
||||
else:
|
||||
await StorageAdapter.bulk_create(
|
||||
[StorageAdapter(**adapter) for adapter in payload.storage_adapters],
|
||||
using_db=conn,
|
||||
)
|
||||
|
||||
if "automation_tasks" in sections and payload.automation_tasks:
|
||||
if mode == "merge":
|
||||
await cls._merge_records(
|
||||
AutomationTask, payload.automation_tasks, conn
|
||||
)
|
||||
else:
|
||||
await AutomationTask.bulk_create(
|
||||
[AutomationTask(**task) for task in payload.automation_tasks],
|
||||
using_db=conn,
|
||||
)
|
||||
|
||||
if "share_links" in sections and share_links:
|
||||
if mode == "merge":
|
||||
await cls._merge_records(ShareLink, share_links, conn)
|
||||
else:
|
||||
await ShareLink.bulk_create(
|
||||
[ShareLink(**share) for share in share_links],
|
||||
using_db=conn,
|
||||
)
|
||||
|
||||
if "ai_providers" in sections and ai_providers:
|
||||
if mode == "merge":
|
||||
await cls._merge_records(AIProvider, ai_providers, conn)
|
||||
else:
|
||||
await AIProvider.bulk_create(
|
||||
[AIProvider(**item) for item in ai_providers],
|
||||
using_db=conn,
|
||||
)
|
||||
|
||||
if "ai_models" in sections and ai_models:
|
||||
if mode == "merge":
|
||||
await cls._merge_records(AIModel, ai_models, conn)
|
||||
else:
|
||||
await AIModel.bulk_create(
|
||||
[AIModel(**item) for item in ai_models],
|
||||
using_db=conn,
|
||||
)
|
||||
|
||||
if "ai_default_models" in sections and ai_default_models:
|
||||
if mode == "merge":
|
||||
await cls._merge_records(
|
||||
AIDefaultModel, ai_default_models, conn
|
||||
)
|
||||
else:
|
||||
await AIDefaultModel.bulk_create(
|
||||
[AIDefaultModel(**item) for item in ai_default_models],
|
||||
using_db=conn,
|
||||
)
|
||||
|
||||
if "plugins" in sections and plugins:
|
||||
if mode == "merge":
|
||||
await cls._merge_records(Plugin, plugins, conn)
|
||||
else:
|
||||
await Plugin.bulk_create(
|
||||
[Plugin(**item) for item in plugins],
|
||||
using_db=conn,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _normalize_sections(cls, sections: list[str] | None) -> list[str]:
|
||||
if not sections:
|
||||
return list(cls.ALL_SECTIONS)
|
||||
normalized = [item for item in sections if item]
|
||||
invalid = [item for item in normalized if item not in cls.ALL_SECTIONS]
|
||||
if invalid:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"无效的备份分区: {', '.join(invalid)}"
|
||||
)
|
||||
result: list[str] = []
|
||||
seen = set()
|
||||
for item in normalized:
|
||||
if item in seen:
|
||||
continue
|
||||
seen.add(item)
|
||||
result.append(item)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
async def _merge_records(model, records: list[dict], using_db) -> None:
|
||||
for record in records:
|
||||
data = dict(record)
|
||||
record_id = data.pop("id", None)
|
||||
if record_id is None:
|
||||
await model.create(using_db=using_db, **data)
|
||||
continue
|
||||
updated = (
|
||||
await model.filter(id=record_id)
|
||||
.using_db(using_db)
|
||||
.update(**data)
|
||||
)
|
||||
if updated == 0:
|
||||
await model.create(using_db=using_db, id=record_id, **data)
|
||||
|
||||
@staticmethod
|
||||
def _serialize_datetime_fields(
|
||||
records: list[dict], fields: list[str]
|
||||
) -> list[dict]:
|
||||
serialized: list[dict] = []
|
||||
for record in records:
|
||||
item = dict(record)
|
||||
for field in fields:
|
||||
value = item.get(field)
|
||||
if isinstance(value, datetime):
|
||||
item[field] = value.isoformat()
|
||||
serialized.append(item)
|
||||
return serialized
|
||||
|
||||
@staticmethod
|
||||
def _parse_datetime_fields(
|
||||
records: list[dict], fields: list[str]
|
||||
) -> list[dict]:
|
||||
parsed: list[dict] = []
|
||||
for record in records:
|
||||
item = dict(record)
|
||||
for field in fields:
|
||||
value = item.get(field)
|
||||
if isinstance(value, str):
|
||||
item[field] = BackupService._from_iso(value)
|
||||
parsed.append(item)
|
||||
return parsed
|
||||
|
||||
@staticmethod
|
||||
def _from_iso(value: str) -> datetime | None:
|
||||
if not value:
|
||||
return None
|
||||
normalized = value.replace("Z", "+00:00")
|
||||
try:
|
||||
return datetime.fromisoformat(normalized)
|
||||
except ValueError as exc: # noqa: BLE001
|
||||
raise HTTPException(status_code=400, detail="无效的日期格式") from exc
|
||||
17
domain/backup/types.py
Normal file
17
domain/backup/types.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class BackupData(BaseModel):
|
||||
version: str | None = None
|
||||
sections: list[str] = Field(default_factory=list)
|
||||
storage_adapters: list[dict[str, Any]] = Field(default_factory=list)
|
||||
user_accounts: list[dict[str, Any]] = Field(default_factory=list)
|
||||
automation_tasks: list[dict[str, Any]] = Field(default_factory=list)
|
||||
share_links: list[dict[str, Any]] = Field(default_factory=list)
|
||||
configurations: list[dict[str, Any]] = Field(default_factory=list)
|
||||
ai_providers: list[dict[str, Any]] = Field(default_factory=list)
|
||||
ai_models: list[dict[str, Any]] = Field(default_factory=list)
|
||||
ai_default_models: list[dict[str, Any]] = Field(default_factory=list)
|
||||
plugins: list[dict[str, Any]] = Field(default_factory=list)
|
||||
10
domain/config/__init__.py
Normal file
10
domain/config/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from .service import ConfigService, VERSION
|
||||
from .types import ConfigItem, LatestVersionInfo, SystemStatus
|
||||
|
||||
__all__ = [
|
||||
"ConfigService",
|
||||
"VERSION",
|
||||
"ConfigItem",
|
||||
"LatestVersionInfo",
|
||||
"SystemStatus",
|
||||
]
|
||||
83
domain/config/api.py
Normal file
83
domain/config/api.py
Normal file
@@ -0,0 +1,83 @@
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Form, Request
|
||||
|
||||
from api.response import success
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.auth import User, get_current_active_user
|
||||
from domain.permission import require_system_permission
|
||||
from domain.permission.types import SystemPermission
|
||||
from .service import ConfigService
|
||||
from .types import ConfigItem
|
||||
|
||||
router = APIRouter(prefix="/api/config", tags=["config"])
|
||||
|
||||
PUBLIC_CONFIG_KEYS = [
|
||||
"THEME_MODE",
|
||||
"THEME_PRIMARY_COLOR",
|
||||
"THEME_BORDER_RADIUS",
|
||||
"THEME_CUSTOM_TOKENS",
|
||||
"THEME_CUSTOM_CSS",
|
||||
]
|
||||
|
||||
|
||||
@router.get("/")
|
||||
@audit(action=AuditAction.READ, description="获取配置")
|
||||
@require_system_permission(SystemPermission.CONFIG_EDIT)
|
||||
async def get_config(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
key: str,
|
||||
):
|
||||
value = await ConfigService.get(key)
|
||||
return success(ConfigItem(key=key, value=value).model_dump())
|
||||
|
||||
|
||||
@router.post("/")
|
||||
@audit(action=AuditAction.UPDATE, description="设置配置", body_fields=["key", "value"])
|
||||
@require_system_permission(SystemPermission.CONFIG_EDIT)
|
||||
async def set_config(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
key: str = Form(...),
|
||||
value: str = Form(""),
|
||||
):
|
||||
await ConfigService.set(key, value)
|
||||
return success(ConfigItem(key=key, value=value).model_dump())
|
||||
|
||||
|
||||
@router.get("/all")
|
||||
@audit(action=AuditAction.READ, description="获取全部配置")
|
||||
@require_system_permission(SystemPermission.CONFIG_EDIT)
|
||||
async def get_all_config(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
configs = await ConfigService.get_all()
|
||||
return success(configs)
|
||||
|
||||
@router.get("/public")
|
||||
@audit(action=AuditAction.READ, description="获取公开配置")
|
||||
async def get_public_config(
|
||||
request: Request,
|
||||
):
|
||||
data = {}
|
||||
for key in PUBLIC_CONFIG_KEYS:
|
||||
value = await ConfigService.get(key)
|
||||
if value is not None:
|
||||
data[key] = value
|
||||
return success(data)
|
||||
|
||||
|
||||
@router.get("/status")
|
||||
@audit(action=AuditAction.READ, description="获取系统状态")
|
||||
async def get_system_status(request: Request):
|
||||
status_data = await ConfigService.get_system_status()
|
||||
return success(status_data.model_dump())
|
||||
|
||||
|
||||
@router.get("/latest-version")
|
||||
@audit(action=AuditAction.READ, description="获取最新版本")
|
||||
async def get_latest_version(request: Request):
|
||||
info = await ConfigService.get_latest_version()
|
||||
return success(info.model_dump())
|
||||
111
domain/config/service.py
Normal file
111
domain/config/service.py
Normal file
@@ -0,0 +1,111 @@
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import httpx
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from .types import LatestVersionInfo, SystemStatus
|
||||
from models.database import Configuration, UserAccount
|
||||
|
||||
load_dotenv(dotenv_path=".env")
|
||||
|
||||
VERSION = "v1.7.4"
|
||||
|
||||
|
||||
class ConfigService:
|
||||
_cache: Dict[str, Any] = {}
|
||||
_latest_version_cache: Dict[str, Any] = {"timestamp": 0.0, "data": None}
|
||||
|
||||
@classmethod
|
||||
async def get(cls, key: str, default: Optional[Any] = None) -> Any:
|
||||
if key in cls._cache:
|
||||
return cls._cache[key]
|
||||
try:
|
||||
config = await Configuration.get_or_none(key=key)
|
||||
if config:
|
||||
cls._cache[key] = config.value
|
||||
return config.value
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
env_value = os.getenv(key)
|
||||
if env_value is not None:
|
||||
cls._cache[key] = env_value
|
||||
return env_value
|
||||
return default
|
||||
|
||||
@classmethod
|
||||
async def get_secret_key(cls, key: str, default: Optional[Any] = None) -> bytes:
|
||||
value = await cls.get(key, default)
|
||||
if isinstance(value, bytes):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
return value.encode("utf-8")
|
||||
if value is None:
|
||||
raise ValueError(f"Secret key '{key}' not found in config or environment.")
|
||||
return str(value).encode("utf-8")
|
||||
|
||||
@classmethod
|
||||
async def set(cls, key: str, value: Any):
|
||||
obj, _ = await Configuration.get_or_create(key=key, defaults={"value": value})
|
||||
obj.value = value
|
||||
await obj.save()
|
||||
cls._cache[key] = value
|
||||
|
||||
@classmethod
|
||||
async def get_all(cls) -> Dict[str, Any]:
|
||||
try:
|
||||
configs = await Configuration.all()
|
||||
result = {}
|
||||
for config in configs:
|
||||
result[config.key] = config.value
|
||||
cls._cache[config.key] = config.value
|
||||
return result
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def clear_cache(cls):
|
||||
cls._cache.clear()
|
||||
|
||||
@classmethod
|
||||
async def get_system_status(cls) -> SystemStatus:
|
||||
logo = await cls.get("APP_LOGO", "/logo.svg")
|
||||
favicon = await cls.get("APP_FAVICON", logo)
|
||||
user_count = await UserAccount.all().count()
|
||||
return SystemStatus(
|
||||
version=VERSION,
|
||||
title=await cls.get("APP_NAME", "Foxel"),
|
||||
logo=logo,
|
||||
favicon=favicon,
|
||||
is_initialized=user_count > 0,
|
||||
app_domain=await cls.get("APP_DOMAIN"),
|
||||
file_domain=await cls.get("FILE_DOMAIN"),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def get_latest_version(cls) -> LatestVersionInfo:
|
||||
current_time = time.time()
|
||||
cache = cls._latest_version_cache
|
||||
if current_time - cache["timestamp"] < 3600 and cache["data"]:
|
||||
return cache["data"]
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.get(
|
||||
"https://api.github.com/repos/DrizzleTime/Foxel/releases/latest",
|
||||
follow_redirects=True,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
version_info = LatestVersionInfo(
|
||||
latest_version=data.get("tag_name"),
|
||||
body=data.get("body"),
|
||||
)
|
||||
cache["timestamp"] = current_time
|
||||
cache["data"] = version_info
|
||||
return version_info
|
||||
except httpx.RequestError:
|
||||
if cache["data"]:
|
||||
return cache["data"]
|
||||
return LatestVersionInfo()
|
||||
23
domain/config/types.py
Normal file
23
domain/config/types.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ConfigItem(BaseModel):
|
||||
key: str
|
||||
value: Optional[Any] = None
|
||||
|
||||
|
||||
class SystemStatus(BaseModel):
|
||||
version: str
|
||||
title: str
|
||||
logo: str
|
||||
favicon: str
|
||||
is_initialized: bool
|
||||
app_domain: Optional[str] = None
|
||||
file_domain: Optional[str] = None
|
||||
|
||||
|
||||
class LatestVersionInfo(BaseModel):
|
||||
latest_version: Optional[str] = None
|
||||
body: Optional[str] = None
|
||||
20
domain/email/__init__.py
Normal file
20
domain/email/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from .service import EmailService, EmailTemplateRenderer
|
||||
from .types import (
|
||||
EmailConfig,
|
||||
EmailSecurity,
|
||||
EmailSendPayload,
|
||||
EmailTemplatePreviewPayload,
|
||||
EmailTemplateUpdate,
|
||||
EmailTestRequest,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"EmailService",
|
||||
"EmailTemplateRenderer",
|
||||
"EmailConfig",
|
||||
"EmailSecurity",
|
||||
"EmailSendPayload",
|
||||
"EmailTemplatePreviewPayload",
|
||||
"EmailTemplateUpdate",
|
||||
"EmailTestRequest",
|
||||
]
|
||||
91
domain/email/api.py
Normal file
91
domain/email/api.py
Normal file
@@ -0,0 +1,91 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
|
||||
from api.response import success
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.auth import User, get_current_active_user
|
||||
from .service import EmailService, EmailTemplateRenderer
|
||||
from .types import (
|
||||
EmailTemplatePreviewPayload,
|
||||
EmailTemplateUpdate,
|
||||
EmailTestRequest,
|
||||
)
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/email", tags=["email"])
|
||||
|
||||
|
||||
@router.post("/test")
|
||||
@audit(action=AuditAction.CREATE, description="发送测试邮件")
|
||||
async def trigger_test_email(
|
||||
request: Request,
|
||||
payload: EmailTestRequest,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
):
|
||||
try:
|
||||
task = await EmailService.enqueue_email(
|
||||
recipients=[str(payload.to)],
|
||||
subject=payload.subject,
|
||||
template=payload.template,
|
||||
context=payload.context,
|
||||
)
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc))
|
||||
return success({"task_id": task.id})
|
||||
|
||||
|
||||
@router.get("/templates")
|
||||
@audit(action=AuditAction.READ, description="获取邮件模板列表")
|
||||
async def list_email_templates(
|
||||
request: Request,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
):
|
||||
templates = await EmailTemplateRenderer.list_templates()
|
||||
return success({"templates": templates})
|
||||
|
||||
|
||||
@router.get("/templates/{name}")
|
||||
@audit(action=AuditAction.READ, description="查看邮件模板")
|
||||
async def get_email_template(
|
||||
request: Request,
|
||||
name: str,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
):
|
||||
try:
|
||||
content = await EmailTemplateRenderer.load(name)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc))
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="模板不存在")
|
||||
return success({"name": name, "content": content})
|
||||
|
||||
|
||||
@router.post("/templates/{name}")
|
||||
@audit(action=AuditAction.UPDATE, description="更新邮件模板")
|
||||
async def update_email_template(
|
||||
request: Request,
|
||||
name: str,
|
||||
payload: EmailTemplateUpdate,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
):
|
||||
try:
|
||||
await EmailTemplateRenderer.save(name, payload.content)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc))
|
||||
return success({"name": name})
|
||||
|
||||
|
||||
@router.post("/templates/{name}/preview")
|
||||
@audit(action=AuditAction.READ, description="预览邮件模板")
|
||||
async def preview_email_template(
|
||||
request: Request,
|
||||
name: str,
|
||||
payload: EmailTemplatePreviewPayload,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
):
|
||||
try:
|
||||
html = await EmailTemplateRenderer.render(name, payload.context)
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="模板不存在")
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc))
|
||||
return success({"html": html})
|
||||
151
domain/email/service.py
Normal file
151
domain/email/service.py
Normal file
@@ -0,0 +1,151 @@
|
||||
import asyncio
|
||||
import re
|
||||
import smtplib
|
||||
from email.message import EmailMessage
|
||||
from email.utils import formataddr
|
||||
from pathlib import Path
|
||||
from string import Template
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from domain.config import ConfigService
|
||||
from .types import EmailConfig, EmailSecurity, EmailSendPayload
|
||||
|
||||
|
||||
class EmailTemplateRenderer:
|
||||
ROOT = Path("templates/email")
|
||||
|
||||
@classmethod
|
||||
def _resolve_path(cls, template_name: str) -> Path:
|
||||
if not re.fullmatch(r"[A-Za-z0-9_\-]+", template_name):
|
||||
raise ValueError("Invalid template name")
|
||||
return cls.ROOT / f"{template_name}.html"
|
||||
|
||||
@classmethod
|
||||
async def list_templates(cls) -> list[str]:
|
||||
cls.ROOT.mkdir(parents=True, exist_ok=True)
|
||||
return sorted(
|
||||
path.stem for path in cls.ROOT.glob("*.html") if path.is_file()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def load(cls, template_name: str) -> str:
|
||||
path = cls._resolve_path(template_name)
|
||||
if not path.is_file():
|
||||
raise FileNotFoundError(f"Email template '{template_name}' not found")
|
||||
return await asyncio.to_thread(path.read_text, encoding="utf-8")
|
||||
|
||||
@classmethod
|
||||
async def save(cls, template_name: str, content: str) -> None:
|
||||
path = cls._resolve_path(template_name)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
await asyncio.to_thread(path.write_text, content, encoding="utf-8")
|
||||
|
||||
@classmethod
|
||||
async def render(cls, template_name: str, context: Dict[str, Any]) -> str:
|
||||
raw = await cls.load(template_name)
|
||||
context = {k: str(v) for k, v in (context or {}).items()}
|
||||
return Template(raw).safe_substitute(context)
|
||||
|
||||
|
||||
class EmailService:
|
||||
CONFIG_KEY = "EMAIL_CONFIG"
|
||||
|
||||
@classmethod
|
||||
async def _load_config(cls) -> EmailConfig:
|
||||
raw_config = await ConfigService.get(cls.CONFIG_KEY)
|
||||
return EmailConfig.parse_config(raw_config)
|
||||
|
||||
@staticmethod
|
||||
def _html_to_text(html: str) -> str:
|
||||
stripped = re.sub(r"<[^>]+>", " ", html)
|
||||
return " ".join(stripped.split())
|
||||
|
||||
@classmethod
|
||||
async def _deliver(cls, config: EmailConfig, payload: EmailSendPayload, html_body: str):
|
||||
message = EmailMessage()
|
||||
message["Subject"] = payload.subject
|
||||
message["From"] = formataddr(
|
||||
(config.sender_name or str(config.sender_email), str(config.sender_email))
|
||||
)
|
||||
message["To"] = ", ".join([str(addr) for addr in payload.recipients])
|
||||
|
||||
plain_body = cls._html_to_text(html_body)
|
||||
message.set_content(plain_body or html_body)
|
||||
message.add_alternative(html_body, subtype="html")
|
||||
|
||||
await asyncio.to_thread(cls._deliver_sync, config, message)
|
||||
|
||||
@staticmethod
|
||||
def _deliver_sync(config: EmailConfig, message: EmailMessage):
|
||||
if config.security == EmailSecurity.SSL:
|
||||
smtp: smtplib.SMTP = smtplib.SMTP_SSL(
|
||||
config.host, config.port, timeout=config.timeout
|
||||
)
|
||||
else:
|
||||
smtp = smtplib.SMTP(config.host, config.port, timeout=config.timeout)
|
||||
|
||||
try:
|
||||
if config.security == EmailSecurity.STARTTLS:
|
||||
smtp.starttls()
|
||||
if config.username and config.password:
|
||||
smtp.login(config.username, config.password)
|
||||
smtp.send_message(message)
|
||||
finally:
|
||||
try:
|
||||
smtp.quit()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
async def enqueue_email(
|
||||
cls,
|
||||
recipients: List[str],
|
||||
subject: str,
|
||||
template: str,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
from domain.tasks import TaskProgress, task_queue_service
|
||||
|
||||
payload = EmailSendPayload(
|
||||
recipients=recipients,
|
||||
subject=subject,
|
||||
template=template,
|
||||
context=context or {},
|
||||
)
|
||||
|
||||
task = await task_queue_service.add_task(
|
||||
"send_email",
|
||||
payload.model_dump(mode="json"),
|
||||
)
|
||||
|
||||
await task_queue_service.update_progress(
|
||||
task.id,
|
||||
TaskProgress(stage="queued", percent=0.0, detail="Waiting to send"),
|
||||
)
|
||||
return task
|
||||
|
||||
@classmethod
|
||||
async def send_from_task(cls, task_id: str, data: Dict[str, Any]):
|
||||
from domain.tasks import TaskProgress, task_queue_service
|
||||
|
||||
payload = EmailSendPayload(**data)
|
||||
|
||||
await task_queue_service.update_progress(
|
||||
task_id,
|
||||
TaskProgress(stage="preparing", percent=10.0, detail="Rendering template"),
|
||||
)
|
||||
|
||||
config = await cls._load_config()
|
||||
html_body = await EmailTemplateRenderer.render(payload.template, payload.context)
|
||||
|
||||
await task_queue_service.update_progress(
|
||||
task_id,
|
||||
TaskProgress(stage="sending", percent=60.0, detail="Sending message"),
|
||||
)
|
||||
|
||||
await cls._deliver(config, payload, html_body)
|
||||
|
||||
await task_queue_service.update_progress(
|
||||
task_id,
|
||||
TaskProgress(stage="completed", percent=100.0, detail="Email sent"),
|
||||
)
|
||||
63
domain/email/types.py
Normal file
63
domain/email/types.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import json
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, EmailStr, Field, ValidationError
|
||||
|
||||
|
||||
class EmailSecurity(str, Enum):
|
||||
NONE = "none"
|
||||
SSL = "ssl"
|
||||
STARTTLS = "starttls"
|
||||
|
||||
|
||||
class EmailConfig(BaseModel):
|
||||
host: str
|
||||
port: int = Field(..., gt=0)
|
||||
username: Optional[str] = None
|
||||
password: Optional[str] = None
|
||||
sender_email: EmailStr
|
||||
sender_name: Optional[str] = None
|
||||
security: EmailSecurity = EmailSecurity.NONE
|
||||
timeout: float = Field(default=30.0, gt=0.0)
|
||||
|
||||
@classmethod
|
||||
def parse_config(cls, raw_config: Any) -> "EmailConfig":
|
||||
"""接受字符串或 dict 配置并解析为 EmailConfig。"""
|
||||
if raw_config is None:
|
||||
raise ValueError("Email configuration not found")
|
||||
|
||||
if isinstance(raw_config, str):
|
||||
raw_config = raw_config.strip()
|
||||
data: Any = json.loads(raw_config) if raw_config else {}
|
||||
elif isinstance(raw_config, dict):
|
||||
data = raw_config
|
||||
else:
|
||||
raise ValueError("Invalid email configuration format")
|
||||
|
||||
try:
|
||||
return cls(**data)
|
||||
except ValidationError as exc:
|
||||
raise ValueError(f"Invalid email configuration: {exc}") from exc
|
||||
|
||||
|
||||
class EmailSendPayload(BaseModel):
|
||||
recipients: List[EmailStr] = Field(..., min_length=1)
|
||||
subject: str = Field(..., min_length=1)
|
||||
template: str = Field(..., min_length=1)
|
||||
context: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class EmailTestRequest(BaseModel):
|
||||
to: EmailStr
|
||||
subject: str = Field(..., min_length=1)
|
||||
template: str = Field(default="test", min_length=1)
|
||||
context: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class EmailTemplateUpdate(BaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
class EmailTemplatePreviewPayload(BaseModel):
|
||||
context: Dict[str, Any] = Field(default_factory=dict)
|
||||
7
domain/offline_downloads/__init__.py
Normal file
7
domain/offline_downloads/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from .service import OfflineDownloadService
|
||||
from .types import OfflineDownloadCreate
|
||||
|
||||
__all__ = [
|
||||
"OfflineDownloadService",
|
||||
"OfflineDownloadCreate",
|
||||
]
|
||||
44
domain/offline_downloads/api.py
Normal file
44
domain/offline_downloads/api.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
|
||||
from api.response import success
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.auth import User, get_current_active_user
|
||||
from domain.permission import require_path_permission
|
||||
from domain.permission.types import PathAction
|
||||
from .service import OfflineDownloadService
|
||||
from .types import OfflineDownloadCreate
|
||||
|
||||
CurrentUser = Annotated[User, Depends(get_current_active_user)]
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api/offline-downloads",
|
||||
tags=["OfflineDownloads"],
|
||||
)
|
||||
|
||||
|
||||
@router.post("/")
|
||||
@audit(
|
||||
action=AuditAction.CREATE,
|
||||
description="创建离线下载任务",
|
||||
body_fields=["url", "dest_dir", "filename"],
|
||||
)
|
||||
@require_path_permission(PathAction.WRITE, "payload.dest_dir")
|
||||
async def create_offline_download(request: Request, payload: OfflineDownloadCreate, current_user: CurrentUser):
|
||||
data = await OfflineDownloadService.create_download(payload, current_user)
|
||||
return success(data)
|
||||
|
||||
|
||||
@router.get("/")
|
||||
@audit(action=AuditAction.READ, description="获取离线下载列表")
|
||||
async def list_offline_downloads(request: Request, current_user: CurrentUser):
|
||||
data = OfflineDownloadService.list_downloads()
|
||||
return success(data)
|
||||
|
||||
|
||||
@router.get("/{task_id}")
|
||||
@audit(action=AuditAction.READ, description="获取离线下载详情")
|
||||
async def get_offline_download(task_id: str, request: Request, current_user: CurrentUser):
|
||||
data = OfflineDownloadService.get_download(task_id)
|
||||
return success(data)
|
||||
251
domain/offline_downloads/service.py
Normal file
251
domain/offline_downloads/service.py
Normal file
@@ -0,0 +1,251 @@
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Annotated, AsyncIterator
|
||||
|
||||
import aiofiles
|
||||
import aiohttp
|
||||
from fastapi import Depends, HTTPException
|
||||
|
||||
from domain.auth import User, get_current_active_user
|
||||
from domain.tasks import Task, TaskProgress, task_queue_service
|
||||
from domain.virtual_fs import VirtualFSService
|
||||
from .types import OfflineDownloadCreate
|
||||
|
||||
|
||||
class OfflineDownloadService:
|
||||
current_user_dep = Annotated[User, Depends(get_current_active_user)]
|
||||
temp_root = Path("data/tmp/offline_downloads")
|
||||
|
||||
@classmethod
|
||||
async def create_download(cls, payload: OfflineDownloadCreate, current_user: User) -> dict:
|
||||
await cls._ensure_destination(payload.dest_dir)
|
||||
task = await task_queue_service.add_task(
|
||||
"offline_http_download",
|
||||
{
|
||||
"url": str(payload.url),
|
||||
"dest_dir": payload.dest_dir,
|
||||
"filename": payload.filename,
|
||||
},
|
||||
)
|
||||
|
||||
await task_queue_service.update_progress(
|
||||
task.id,
|
||||
TaskProgress(
|
||||
stage="queued",
|
||||
percent=0.0,
|
||||
bytes_total=None,
|
||||
bytes_done=0,
|
||||
detail="Waiting to start",
|
||||
),
|
||||
)
|
||||
|
||||
return {"task_id": task.id}
|
||||
|
||||
@classmethod
|
||||
def list_downloads(cls) -> list[dict]:
|
||||
tasks = [t for t in task_queue_service.get_all_tasks() if t.name == "offline_http_download"]
|
||||
return [t.dict() for t in tasks]
|
||||
|
||||
@classmethod
|
||||
def get_download(cls, task_id: str) -> dict:
|
||||
task = task_queue_service.get_task(task_id)
|
||||
if not task or task.name != "offline_http_download":
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
return task.dict()
|
||||
|
||||
@classmethod
|
||||
async def run_http_download(cls, task: Task):
|
||||
params = task.task_info
|
||||
url = params.get("url")
|
||||
dest_dir = params.get("dest_dir")
|
||||
filename = params.get("filename")
|
||||
|
||||
if not url or not dest_dir or not filename:
|
||||
raise ValueError("Missing required parameters for offline download")
|
||||
|
||||
cls.temp_root.mkdir(parents=True, exist_ok=True)
|
||||
temp_dir = cls.temp_root / task.id
|
||||
temp_dir.mkdir(parents=True, exist_ok=True)
|
||||
temp_file = temp_dir / "payload"
|
||||
|
||||
bytes_total: int | None = None
|
||||
bytes_done = 0
|
||||
last_update = time.monotonic()
|
||||
|
||||
await task_queue_service.update_progress(
|
||||
task.id,
|
||||
TaskProgress(
|
||||
stage="downloading",
|
||||
percent=0.0,
|
||||
bytes_total=None,
|
||||
bytes_done=0,
|
||||
detail="HTTP downloading",
|
||||
),
|
||||
)
|
||||
|
||||
async def report_download(delta: int, total: int | None):
|
||||
nonlocal bytes_done, bytes_total, last_update
|
||||
if total is not None:
|
||||
bytes_total = total
|
||||
bytes_done += delta
|
||||
now = time.monotonic()
|
||||
if delta and now - last_update < 0.5:
|
||||
return
|
||||
last_update = now
|
||||
percent = None
|
||||
total_for_display = bytes_total if bytes_total is not None else None
|
||||
if bytes_total:
|
||||
percent = min(100.0, round(bytes_done / bytes_total * 100, 2))
|
||||
await task_queue_service.update_progress(
|
||||
task.id,
|
||||
TaskProgress(
|
||||
stage="downloading",
|
||||
percent=percent,
|
||||
bytes_total=total_for_display,
|
||||
bytes_done=bytes_done,
|
||||
detail="HTTP downloading",
|
||||
),
|
||||
)
|
||||
|
||||
timeout = aiohttp.ClientTimeout(total=None, connect=30)
|
||||
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
async with session.get(url) as resp:
|
||||
if resp.status != 200:
|
||||
raise ValueError(f"HTTP {resp.status} for {url}")
|
||||
content_length = resp.headers.get("Content-Length")
|
||||
total_size = int(content_length) if content_length else None
|
||||
bytes_done = 0
|
||||
async with aiofiles.open(temp_file, "wb") as f:
|
||||
async for chunk in resp.content.iter_chunked(512 * 1024):
|
||||
if not chunk:
|
||||
continue
|
||||
await f.write(chunk)
|
||||
await report_download(len(chunk), total_size)
|
||||
await report_download(0, total_size)
|
||||
|
||||
file_size = os.path.getsize(temp_file)
|
||||
bytes_done_transfer = 0
|
||||
|
||||
async def report_transfer(delta: int):
|
||||
nonlocal bytes_done_transfer
|
||||
bytes_done_transfer += delta
|
||||
percent = min(100.0, round(bytes_done_transfer / file_size * 100, 2)) if file_size else None
|
||||
await task_queue_service.update_progress(
|
||||
task.id,
|
||||
TaskProgress(
|
||||
stage="transferring",
|
||||
percent=percent,
|
||||
bytes_total=file_size or None,
|
||||
bytes_done=bytes_done_transfer,
|
||||
detail="Saving to storage",
|
||||
),
|
||||
)
|
||||
|
||||
async def chunk_iter() -> AsyncIterator[bytes]:
|
||||
async for chunk in cls._iter_file(temp_file, 512 * 1024, report_transfer):
|
||||
yield chunk
|
||||
|
||||
final_path, resolved_name = await cls._allocate_destination(dest_dir, filename)
|
||||
|
||||
await task_queue_service.update_progress(
|
||||
task.id,
|
||||
TaskProgress(
|
||||
stage="transferring",
|
||||
percent=0.0,
|
||||
bytes_total=file_size or None,
|
||||
bytes_done=0,
|
||||
detail="Saving to storage",
|
||||
),
|
||||
)
|
||||
|
||||
await VirtualFSService.write_file_stream(final_path, chunk_iter())
|
||||
|
||||
await task_queue_service.update_progress(
|
||||
task.id,
|
||||
TaskProgress(
|
||||
stage="completed",
|
||||
percent=100.0,
|
||||
bytes_total=file_size or None,
|
||||
bytes_done=file_size,
|
||||
detail="Completed",
|
||||
),
|
||||
)
|
||||
await task_queue_service.update_meta(task.id, {"final_path": final_path, "filename": resolved_name})
|
||||
|
||||
try:
|
||||
os.remove(temp_file)
|
||||
temp_dir.rmdir()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return final_path
|
||||
|
||||
@classmethod
|
||||
async def _ensure_destination(cls, dest_dir: str) -> None:
|
||||
try:
|
||||
is_dir = await VirtualFSService.path_is_directory(dest_dir)
|
||||
except HTTPException:
|
||||
is_dir = False
|
||||
if not is_dir:
|
||||
raise HTTPException(400, detail="Destination directory not found")
|
||||
|
||||
@staticmethod
|
||||
def _normalize_path(path: str) -> str:
|
||||
if not path:
|
||||
return "/"
|
||||
if not path.startswith("/"):
|
||||
path = "/" + path
|
||||
if len(path) > 1 and path.endswith("/"):
|
||||
path = path.rstrip("/")
|
||||
return path or "/"
|
||||
|
||||
@staticmethod
|
||||
async def _path_exists(full_path: str) -> bool:
|
||||
try:
|
||||
await VirtualFSService.stat_file(full_path)
|
||||
return True
|
||||
except FileNotFoundError:
|
||||
return False
|
||||
except HTTPException as exc: # noqa: PERF203
|
||||
if exc.status_code == 404:
|
||||
return False
|
||||
raise
|
||||
|
||||
@classmethod
|
||||
async def _allocate_destination(cls, dest_dir: str, filename: str) -> tuple[str, str]:
|
||||
dest_dir = cls._normalize_path(dest_dir)
|
||||
stem, suffix = cls._split_filename(filename)
|
||||
candidate = filename
|
||||
base = "" if dest_dir == "/" else dest_dir
|
||||
attempt = 0
|
||||
while await cls._path_exists(f"{base}/{candidate}" if base else f"/{candidate}"):
|
||||
attempt += 1
|
||||
if stem:
|
||||
candidate = f"{stem} ({attempt}){suffix}"
|
||||
else:
|
||||
candidate = f"file ({attempt}){suffix}" if suffix else f"file ({attempt})"
|
||||
full_path = f"{base}/{candidate}" if base else f"/{candidate}"
|
||||
return full_path, candidate
|
||||
|
||||
@staticmethod
|
||||
def _split_filename(filename: str) -> tuple[str, str]:
|
||||
if not filename:
|
||||
return "", ""
|
||||
if filename.startswith(".") and filename.count(".") == 1:
|
||||
return filename, ""
|
||||
if "." not in filename:
|
||||
return filename, ""
|
||||
stem, ext = filename.rsplit(".", 1)
|
||||
return stem, f".{ext}"
|
||||
|
||||
@staticmethod
|
||||
async def _iter_file(path: Path, chunk_size: int, report_cb) -> AsyncIterator[bytes]:
|
||||
async with aiofiles.open(path, "rb") as f:
|
||||
while True:
|
||||
chunk = await f.read(chunk_size)
|
||||
if not chunk:
|
||||
break
|
||||
await report_cb(len(chunk))
|
||||
yield chunk
|
||||
10
domain/permission/__init__.py
Normal file
10
domain/permission/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from .service import PermissionService
|
||||
from .matcher import PathMatcher
|
||||
from .decorator import require_path_permission, require_system_permission
|
||||
|
||||
__all__ = [
|
||||
"PermissionService",
|
||||
"PathMatcher",
|
||||
"require_system_permission",
|
||||
"require_path_permission",
|
||||
]
|
||||
41
domain/permission/api.py
Normal file
41
domain/permission/api.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from typing import Annotated
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from domain.auth.service import get_current_active_user
|
||||
from domain.auth.types import User
|
||||
from .service import PermissionService
|
||||
from .types import (
|
||||
PathPermissionCheck,
|
||||
PathPermissionResult,
|
||||
UserPermissions,
|
||||
PermissionInfo,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/api", tags=["permissions"])
|
||||
|
||||
|
||||
@router.get("/permissions", response_model=list[PermissionInfo])
|
||||
async def get_all_permissions(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
) -> list[PermissionInfo]:
|
||||
"""获取所有权限定义"""
|
||||
return await PermissionService.get_all_permissions()
|
||||
|
||||
|
||||
@router.get("/me/permissions", response_model=UserPermissions)
|
||||
async def get_my_permissions(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
) -> UserPermissions:
|
||||
"""获取当前用户的有效权限"""
|
||||
return await PermissionService.get_user_permissions(current_user.id)
|
||||
|
||||
|
||||
@router.post("/me/check-path", response_model=PathPermissionResult)
|
||||
async def check_path_permission(
|
||||
data: PathPermissionCheck,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
) -> PathPermissionResult:
|
||||
"""检查当前用户对某路径的权限"""
|
||||
return await PermissionService.check_path_permission_detailed(
|
||||
current_user.id, data.path, data.action
|
||||
)
|
||||
103
domain/permission/decorator.py
Normal file
103
domain/permission/decorator.py
Normal file
@@ -0,0 +1,103 @@
|
||||
import inspect
|
||||
from functools import wraps
|
||||
from typing import Any, Iterable, Mapping
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from .service import PermissionService
|
||||
|
||||
|
||||
def _get_user_id(user: Any) -> int | None:
|
||||
if user is None:
|
||||
return None
|
||||
if isinstance(user, Mapping):
|
||||
raw = user.get("id") or user.get("user_id")
|
||||
return int(raw) if isinstance(raw, int) else None
|
||||
value = getattr(user, "id", None) or getattr(user, "user_id", None)
|
||||
return int(value) if isinstance(value, int) else None
|
||||
|
||||
|
||||
def _resolve_expr(bound_args: Mapping[str, Any], expr: str) -> Any:
|
||||
parts = [p for p in (expr or "").split(".") if p]
|
||||
if not parts:
|
||||
return None
|
||||
cur: Any = bound_args.get(parts[0])
|
||||
for part in parts[1:]:
|
||||
if cur is None:
|
||||
return None
|
||||
if isinstance(cur, Mapping):
|
||||
cur = cur.get(part)
|
||||
else:
|
||||
cur = getattr(cur, part, None)
|
||||
return cur
|
||||
|
||||
|
||||
def require_system_permission(permission_code: str, *, user_kw: str = "current_user"):
|
||||
"""
|
||||
在 endpoint 内部执行系统/适配器权限校验。
|
||||
|
||||
设计目标:
|
||||
- 保持和当前“在函数体内手写 require_*”一致的行为:失败会被外层 @audit 捕获记录
|
||||
- 不依赖 FastAPI dependencies(避免权限失败发生在 endpoint 之外)
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
bound = inspect.signature(func).bind_partial(*args, **kwargs)
|
||||
bound.apply_defaults()
|
||||
user_id = _get_user_id(bound.arguments.get(user_kw))
|
||||
if user_id is None:
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
await PermissionService.require_system_permission(user_id, permission_code)
|
||||
|
||||
result = func(*args, **kwargs)
|
||||
if inspect.isawaitable(result):
|
||||
result = await result
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def require_path_permission(action: str, path_expr: str, *, user_kw: str = "current_user"):
|
||||
"""
|
||||
在 endpoint 内部执行路径权限校验。
|
||||
|
||||
path_expr 支持:
|
||||
- "full_path"
|
||||
- "body.src" / "body.dst"
|
||||
- "payload.paths"(list[str] 会逐个检查)
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
bound = inspect.signature(func).bind_partial(*args, **kwargs)
|
||||
bound.apply_defaults()
|
||||
user_id = _get_user_id(bound.arguments.get(user_kw))
|
||||
if user_id is None:
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
|
||||
value = _resolve_expr(bound.arguments, path_expr)
|
||||
paths: Iterable[Any]
|
||||
if isinstance(value, (list, tuple, set)):
|
||||
paths = value
|
||||
else:
|
||||
paths = [value]
|
||||
|
||||
for path in paths:
|
||||
if path is None:
|
||||
raise HTTPException(status_code=400, detail="Missing path")
|
||||
await PermissionService.require_path_permission(user_id, str(path), action)
|
||||
|
||||
result = func(*args, **kwargs)
|
||||
if inspect.isawaitable(result):
|
||||
result = await result
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
158
domain/permission/matcher.py
Normal file
158
domain/permission/matcher.py
Normal file
@@ -0,0 +1,158 @@
|
||||
import re
|
||||
import fnmatch
|
||||
from functools import lru_cache
|
||||
|
||||
|
||||
class PathMatcher:
|
||||
"""路径匹配器,支持精确匹配、通配符匹配和正则匹配"""
|
||||
|
||||
@classmethod
|
||||
def normalize_path(cls, path: str) -> str:
|
||||
"""规范化路径"""
|
||||
if not path:
|
||||
return "/"
|
||||
# 确保以 / 开头
|
||||
if not path.startswith("/"):
|
||||
path = "/" + path
|
||||
# 移除末尾的 /(除了根路径)
|
||||
if path != "/" and path.endswith("/"):
|
||||
path = path.rstrip("/")
|
||||
return path
|
||||
|
||||
@classmethod
|
||||
def get_parent_path(cls, path: str) -> str | None:
|
||||
"""获取父目录路径"""
|
||||
path = cls.normalize_path(path)
|
||||
if path == "/":
|
||||
return None
|
||||
parent = "/".join(path.rsplit("/", 1)[:-1])
|
||||
return parent if parent else "/"
|
||||
|
||||
@classmethod
|
||||
def match_pattern(cls, path: str, pattern: str, is_regex: bool = False) -> bool:
|
||||
"""
|
||||
匹配路径和模式
|
||||
|
||||
Args:
|
||||
path: 要匹配的路径
|
||||
pattern: 匹配模式
|
||||
is_regex: 是否为正则表达式
|
||||
|
||||
Returns:
|
||||
是否匹配
|
||||
"""
|
||||
path = cls.normalize_path(path)
|
||||
pattern = cls.normalize_path(pattern)
|
||||
|
||||
if is_regex:
|
||||
return cls._match_regex(path, pattern)
|
||||
else:
|
||||
return cls._match_glob(path, pattern)
|
||||
|
||||
@classmethod
|
||||
def _match_regex(cls, path: str, pattern: str) -> bool:
|
||||
"""正则表达式匹配"""
|
||||
try:
|
||||
# 限制正则表达式的复杂度,防止 ReDoS 攻击
|
||||
if len(pattern) > 500:
|
||||
return False
|
||||
regex = re.compile(pattern)
|
||||
return bool(regex.match(path))
|
||||
except re.error:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def _match_glob(cls, path: str, pattern: str) -> bool:
|
||||
"""
|
||||
通配符匹配
|
||||
|
||||
支持的语法:
|
||||
- * : 匹配单层目录中的任意字符
|
||||
- ** : 匹配任意层级目录
|
||||
- ? : 匹配单个字符
|
||||
"""
|
||||
# 精确匹配
|
||||
if pattern == path:
|
||||
return True
|
||||
|
||||
# 处理 ** 通配符
|
||||
if "**" in pattern:
|
||||
return cls._match_double_star(path, pattern)
|
||||
|
||||
# 使用 fnmatch 进行标准通配符匹配
|
||||
return fnmatch.fnmatch(path, pattern)
|
||||
|
||||
@classmethod
|
||||
def _match_double_star(cls, path: str, pattern: str) -> bool:
|
||||
"""处理 ** 通配符匹配"""
|
||||
# 将 ** 替换为特殊标记
|
||||
parts = pattern.split("**")
|
||||
|
||||
if len(parts) == 2:
|
||||
prefix, suffix = parts
|
||||
# 移除 prefix 末尾的 / 和 suffix 开头的 /
|
||||
prefix = prefix.rstrip("/") if prefix else ""
|
||||
suffix = suffix.lstrip("/") if suffix else ""
|
||||
|
||||
# 检查前缀匹配
|
||||
if prefix and not path.startswith(prefix):
|
||||
return False
|
||||
|
||||
# 如果没有后缀,只需要前缀匹配
|
||||
if not suffix:
|
||||
return True
|
||||
|
||||
# 检查后缀匹配
|
||||
remaining = path[len(prefix):].lstrip("/") if prefix else path.lstrip("/")
|
||||
|
||||
# 后缀可以出现在任意位置
|
||||
if "*" in suffix or "?" in suffix:
|
||||
# 后缀包含通配符,逐层检查
|
||||
path_parts = remaining.split("/")
|
||||
suffix_parts = suffix.split("/")
|
||||
|
||||
# 简化处理:检查路径的最后几层是否与后缀匹配
|
||||
if len(path_parts) >= len(suffix_parts):
|
||||
tail = "/".join(path_parts[-len(suffix_parts):])
|
||||
return fnmatch.fnmatch(tail, suffix)
|
||||
return False
|
||||
else:
|
||||
# 后缀是精确字符串
|
||||
return remaining.endswith(suffix) or ("/" + suffix) in remaining or remaining == suffix
|
||||
|
||||
# 多个 ** 的情况,使用简化匹配
|
||||
regex_pattern = pattern.replace("**", ".*").replace("*", "[^/]*").replace("?", ".")
|
||||
try:
|
||||
return bool(re.match(f"^{regex_pattern}$", path))
|
||||
except re.error:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_pattern_specificity(cls, pattern: str, is_regex: bool = False) -> int:
|
||||
"""
|
||||
计算模式的具体程度(用于优先级排序)
|
||||
|
||||
返回值越大表示模式越具体
|
||||
"""
|
||||
pattern = cls.normalize_path(pattern)
|
||||
|
||||
if is_regex:
|
||||
# 正则表达式具体程度较低
|
||||
return len(pattern) // 2
|
||||
|
||||
# 精确路径最具体
|
||||
if "*" not in pattern and "?" not in pattern:
|
||||
return len(pattern) * 10
|
||||
|
||||
# 计算非通配符部分的长度
|
||||
specificity = 0
|
||||
parts = pattern.split("/")
|
||||
for part in parts:
|
||||
if part == "**":
|
||||
specificity += 1
|
||||
elif "*" in part or "?" in part:
|
||||
specificity += 5
|
||||
else:
|
||||
specificity += 10
|
||||
|
||||
return specificity
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user