mirror of
https://github.com/DrizzleTime/Foxel.git
synced 2026-05-08 19:02:53 +08:00
Compare commits
203 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d788bde44f | ||
|
|
28ede26801 | ||
|
|
53130383c1 | ||
|
|
036eeb92c2 | ||
|
|
5701a13f4f | ||
|
|
184997deed | ||
|
|
1d5824d498 | ||
|
|
91ff1860b7 | ||
|
|
56f947d0bf | ||
|
|
ad016baaf9 | ||
|
|
ad2e2858da | ||
|
|
a69d6c21a6 | ||
|
|
2a4a3c44b9 | ||
|
|
cdb8543370 | ||
|
|
2dabe9255f | ||
|
|
239216e574 | ||
|
|
09c65bffb7 | ||
|
|
ff1c06ad18 | ||
|
|
d88e95a9af | ||
|
|
ae80a751a8 | ||
|
|
b40e700a64 | ||
|
|
040d8346b3 | ||
|
|
55d062f0a7 | ||
|
|
cfaaff8a8c | ||
|
|
d6d41333fd | ||
|
|
a4efba94d5 | ||
|
|
00e6419b12 | ||
|
|
bbe8465aa0 | ||
|
|
baadaa70a7 | ||
|
|
e7e34cda54 | ||
|
|
adb80d0a6c | ||
|
|
bcd4ae7aef | ||
|
|
1ef80a087c | ||
|
|
f503d521e6 | ||
|
|
7c38c0045b | ||
|
|
b582a89d08 | ||
|
|
4ea0b9884a | ||
|
|
dfeec58ed9 | ||
|
|
e2f0037053 | ||
|
|
e34ee6f70d | ||
|
|
0f856bb5b7 | ||
|
|
3b4b01a18d | ||
|
|
2e1f76d0bc | ||
|
|
18ed7dcee1 | ||
|
|
5c3ab65cee | ||
|
|
1ddd2e464c | ||
|
|
aeb7cf75a1 | ||
|
|
648fd51d26 | ||
|
|
98c7b3af9b | ||
|
|
fc3b6a9d70 | ||
|
|
1c0fc24cfa | ||
|
|
5127d9f0fc | ||
|
|
ba1feb150b | ||
|
|
6a1ff3afa6 | ||
|
|
724f551b00 | ||
|
|
8cf147bf34 | ||
|
|
c2a473fac9 | ||
|
|
aaae37e7cb | ||
|
|
78de3b46be | ||
|
|
388ddfd869 | ||
|
|
18f59f8d33 | ||
|
|
b319b545fc | ||
|
|
0fcb3b8ce0 | ||
|
|
686202a0dd | ||
|
|
1cda987723 | ||
|
|
49a4300fc3 | ||
|
|
d7260e8863 | ||
|
|
62d0316d48 | ||
|
|
fc85f21aaa | ||
|
|
16283dea09 | ||
|
|
055c240079 | ||
|
|
12a3bb8efc | ||
|
|
050577cf62 | ||
|
|
394c2f7229 | ||
|
|
8f515aaaf4 | ||
|
|
cf8d10f71c | ||
|
|
5c4d3a625b | ||
|
|
f0a51c3369 | ||
|
|
3278896d4b | ||
|
|
219f3e81b8 | ||
|
|
8ef0a34642 | ||
|
|
8aaa2900ef | ||
|
|
e3e68f5397 | ||
|
|
78dfbac458 | ||
|
|
583db651a7 | ||
|
|
3a15362422 | ||
|
|
e55a09d84f | ||
|
|
8957174e6f | ||
|
|
abb6b0ce22 | ||
|
|
74df438053 | ||
|
|
f271a8bee5 | ||
|
|
17236e601f | ||
|
|
71e5f84eb7 | ||
|
|
4e724b9c4a | ||
|
|
ba62bd0d4a | ||
|
|
138296e5a6 | ||
|
|
51326dea08 | ||
|
|
ac6d8ff7ad | ||
|
|
029aa2574d | ||
|
|
eeb0e6aa70 | ||
|
|
d1ceb7ddba | ||
|
|
63b54458e9 | ||
|
|
f7e6815265 | ||
|
|
4d6e0b86ad | ||
|
|
77a4749fec | ||
|
|
8eaa025f7e | ||
|
|
11799cd97c | ||
|
|
c14224827d | ||
|
|
130a304f25 | ||
|
|
bc595310a6 | ||
|
|
bf83187d8c | ||
|
|
02cc31d296 | ||
|
|
c66ca181c6 | ||
|
|
5815e6a545 | ||
|
|
7cf335ab19 | ||
|
|
36365d7410 | ||
|
|
90ddeef027 | ||
|
|
8ac3acebb4 | ||
|
|
5625f2d8bf | ||
|
|
7f33eb85ba | ||
|
|
0da64b8d9c | ||
|
|
7caa602d93 | ||
|
|
a4af9475ef | ||
|
|
ee6e570ccb | ||
|
|
ce45fca8bd | ||
|
|
77058f3535 | ||
|
|
738f3c9718 | ||
|
|
f3d9220569 | ||
|
|
da41393db3 | ||
|
|
0399011406 | ||
|
|
00462f2259 | ||
|
|
f0892ebcd6 | ||
|
|
cf5f19043b | ||
|
|
6444ed264c | ||
|
|
bed8c8b19c | ||
|
|
37e13dabe0 | ||
|
|
9d6c63aff4 | ||
|
|
81095f11df | ||
|
|
7d35c10d71 | ||
|
|
17ebb8d4f4 | ||
|
|
330e8fd72b | ||
|
|
11c717e61d | ||
|
|
45d63febb9 | ||
|
|
5a29c579dc | ||
|
|
b530b16c53 | ||
|
|
7da49191aa | ||
|
|
fbeb673126 | ||
|
|
0a06f4d02c | ||
|
|
f02c29492b | ||
|
|
1a79e87887 | ||
|
|
626ff727b3 | ||
|
|
117a94d793 | ||
|
|
c39bea67a4 | ||
|
|
2cbfb29260 | ||
|
|
155f3a144d | ||
|
|
208a52589f | ||
|
|
0732b611a9 | ||
|
|
7b25e6d3b6 | ||
|
|
04441d0bc4 | ||
|
|
917b542dab | ||
|
|
e43b68beda | ||
|
|
801ff26cc7 | ||
|
|
284c2d24a2 | ||
|
|
a34be25ec0 | ||
|
|
db2e02dd32 | ||
|
|
9bb5310df0 | ||
|
|
427a4f023f | ||
|
|
71a2a88c8e | ||
|
|
fb0b7b13d1 | ||
|
|
f484557874 | ||
|
|
2b8cfce8f2 | ||
|
|
db453ef09b | ||
|
|
59c017a05b | ||
|
|
d42c6b5cee | ||
|
|
9e69eb3e20 | ||
|
|
6e7225ac40 | ||
|
|
d41b72d0ce | ||
|
|
f40ff4d751 | ||
|
|
280bedcf1a | ||
|
|
b03f2619ca | ||
|
|
72403d5861 | ||
|
|
dffcdb7a8b | ||
|
|
19c4394f3d | ||
|
|
3fd48da2b4 | ||
|
|
c759b36aba | ||
|
|
99a6acd54a | ||
|
|
20f6b5c210 | ||
|
|
74ffc0bb30 | ||
|
|
57919aa7ae | ||
|
|
5126dae411 | ||
|
|
2a78d809af | ||
|
|
ce74c2712b | ||
|
|
59d6c94a57 | ||
|
|
fd87dc3ce2 | ||
|
|
620ae17732 | ||
|
|
9b0dd13816 | ||
|
|
6a52fa3fd5 | ||
|
|
219999914c | ||
|
|
1a3d9d41ec | ||
|
|
27ad49d8ed | ||
|
|
e230bf6661 | ||
|
|
50fb0b4977 | ||
|
|
b50f19bcb4 |
1
.github/FUNDING.yml
vendored
Normal file
1
.github/FUNDING.yml
vendored
Normal file
@@ -0,0 +1 @@
|
||||
custom: https://foxel.cc/sponsor
|
||||
75
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
Normal file
75
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
Normal file
@@ -0,0 +1,75 @@
|
||||
name: Bug Report / 缺陷报告
|
||||
description: Report reproducible defects with clear context / 请提供可复现的缺陷信息
|
||||
title: "[Bug] "
|
||||
labels:
|
||||
- bug
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
Thanks for helping us improve Foxel! / 感谢你帮助改进 Foxel!
|
||||
Please confirm the checklist below before filing. / 在提交前请确认以下事项。
|
||||
- type: checkboxes
|
||||
id: validations
|
||||
attributes:
|
||||
label: Pre-flight Check / 提交前检查
|
||||
options:
|
||||
- label: I searched existing issues and docs / 我已搜索现有 Issue 与文档
|
||||
required: true
|
||||
- label: This is not a question or feature request / 这不是问题咨询或功能需求
|
||||
required: true
|
||||
- type: textarea
|
||||
id: summary
|
||||
attributes:
|
||||
label: Bug Summary / 缺陷摘要
|
||||
description: Briefly describe what is wrong / 简要说明出现了什么问题
|
||||
placeholder: e.g. Upload fails with 500 error / 例如:上传时报 500 错误
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: steps
|
||||
attributes:
|
||||
label: Steps to Reproduce / 复现步骤
|
||||
description: List numbered steps to trigger the bug / 列出触发问题的步骤
|
||||
placeholder: |
|
||||
1. ...
|
||||
2. ...
|
||||
3. ...
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: expected
|
||||
attributes:
|
||||
label: Expected Behavior / 预期行为
|
||||
description: What should happen instead? / 期望看到什么结果?
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: actual
|
||||
attributes:
|
||||
label: Actual Behavior / 实际行为
|
||||
description: What actually happens? Include messages or screenshots / 实际发生了什么?可附报错或截图
|
||||
validations:
|
||||
required: true
|
||||
- type: input
|
||||
id: version
|
||||
attributes:
|
||||
label: Version / 版本信息
|
||||
description: Git commit, tag, or build number / 提供 Git 提交、标签或构建号
|
||||
validations:
|
||||
required: false
|
||||
- type: textarea
|
||||
id: environment
|
||||
attributes:
|
||||
label: Environment / 运行环境
|
||||
description: OS, browser, API server config, etc. / 操作系统、浏览器、服务端配置等
|
||||
validations:
|
||||
required: false
|
||||
- type: textarea
|
||||
id: logs
|
||||
attributes:
|
||||
label: Logs & Attachments / 日志与附件
|
||||
description: Paste relevant logs, stack traces, screenshots / 粘贴相关日志、堆栈或截图
|
||||
render: shell
|
||||
validations:
|
||||
required: false
|
||||
56
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
Normal file
56
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
Normal file
@@ -0,0 +1,56 @@
|
||||
name: Feature Request / 功能需求
|
||||
description: Suggest enhancements or new capabilities / 提出改进或新增能力
|
||||
title: "[Feature] "
|
||||
labels:
|
||||
- enhancement
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
Tell us about your idea! / 欢迎分享你的想法!
|
||||
Please complete the sections below so we can evaluate it quickly. / 请完整填写以下信息,便于快速评估。
|
||||
- type: checkboxes
|
||||
id: prechecks
|
||||
attributes:
|
||||
label: Pre-flight Check / 提交前检查
|
||||
options:
|
||||
- label: I searched existing issues and roadmap / 我已搜索现有 Issue 与路线图
|
||||
required: true
|
||||
- label: This is not a bug report or question / 这不是缺陷或问题咨询
|
||||
required: true
|
||||
- type: textarea
|
||||
id: summary
|
||||
attributes:
|
||||
label: Feature Summary / 功能概述
|
||||
description: What do you want to build? / 希望新增什么能力?
|
||||
placeholder: e.g. Support sharing download links / 例如:支持分享下载链接
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: motivation
|
||||
attributes:
|
||||
label: Motivation / 背景与价值
|
||||
description: Why is this feature important? Who benefits? / 为什么重要?受益者是谁?
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: scope
|
||||
attributes:
|
||||
label: Proposed Solution / 建议方案
|
||||
description: Outline how the feature might work, including API or UI hints / 描述可能的实现方式,包含 API 或 UI 提示
|
||||
validations:
|
||||
required: false
|
||||
- type: textarea
|
||||
id: alternatives
|
||||
attributes:
|
||||
label: Alternatives / 可选方案
|
||||
description: List any alternatives considered / 如有考虑过其他方案请列出
|
||||
validations:
|
||||
required: false
|
||||
- type: textarea
|
||||
id: extra
|
||||
attributes:
|
||||
label: Additional Context / 补充信息
|
||||
description: Diagrams, sketches, links, constraints, etc. / 可附上草图、链接或约束
|
||||
validations:
|
||||
required: false
|
||||
42
.github/ISSUE_TEMPLATE/question.yml
vendored
Normal file
42
.github/ISSUE_TEMPLATE/question.yml
vendored
Normal file
@@ -0,0 +1,42 @@
|
||||
name: Question / 问题咨询
|
||||
description: Ask about usage, configuration, or clarification / 用于使用、配置或澄清问题
|
||||
title: "[Question] "
|
||||
labels:
|
||||
- question
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
Need help? You're in the right place. / 需要帮助?请按以下提示填写。
|
||||
Check the docs before filing. / 提交前请先查阅文档。
|
||||
- type: checkboxes
|
||||
id: prechecks
|
||||
attributes:
|
||||
label: Pre-flight Check / 提交前检查
|
||||
options:
|
||||
- label: I searched existing issues and discussions / 我已搜索现有 Issue 和讨论
|
||||
required: true
|
||||
- label: I read the relevant documentation / 我已阅读相关文档
|
||||
required: true
|
||||
- type: textarea
|
||||
id: question
|
||||
attributes:
|
||||
label: Question Details / 问题详情
|
||||
description: What do you need help with? Be specific. / 具体说明需要帮助的内容
|
||||
placeholder: Describe the scenario, expectation, and blockers / 说明场景、期望结果与阻碍
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: tried
|
||||
attributes:
|
||||
label: What You Tried / 已尝试方案
|
||||
description: List commands, configs, or steps attempted / 列出尝试过的命令、配置或步骤
|
||||
validations:
|
||||
required: false
|
||||
- type: textarea
|
||||
id: context
|
||||
attributes:
|
||||
label: Additional Context / 补充信息
|
||||
description: Environment details, logs, screenshots / 可补充运行环境、日志或截图
|
||||
validations:
|
||||
required: false
|
||||
16
.github/dependabot.yml
vendored
Normal file
16
.github/dependabot.yml
vendored
Normal file
@@ -0,0 +1,16 @@
|
||||
version: 2
|
||||
updates:
|
||||
- package-ecosystem: "github-actions"
|
||||
directory: "/"
|
||||
schedule:
|
||||
interval: "monthly"
|
||||
|
||||
- package-ecosystem: "bun"
|
||||
directory: "/web"
|
||||
schedule:
|
||||
interval: "monthly"
|
||||
|
||||
- package-ecosystem: "uv"
|
||||
directory: "/"
|
||||
schedule:
|
||||
interval: "monthly"
|
||||
51
.github/workflows/docker-clean.yml
vendored
Normal file
51
.github/workflows/docker-clean.yml
vendored
Normal file
@@ -0,0 +1,51 @@
|
||||
name: Clean dangling Docker images
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
docker-clean:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
|
||||
steps:
|
||||
- name: Delete untagged GHCR versions
|
||||
shell: bash
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
OWNER="${GITHUB_REPOSITORY_OWNER}"
|
||||
PACKAGE="$(echo "${GITHUB_REPOSITORY##*/}" | tr '[:upper:]' '[:lower:]')"
|
||||
|
||||
OWNER_TYPE="$(gh api "/users/${OWNER}" -q '.type')"
|
||||
if [[ "${OWNER_TYPE}" == "Organization" ]]; then
|
||||
SCOPE="orgs/${OWNER}"
|
||||
else
|
||||
SCOPE="users/${OWNER}"
|
||||
fi
|
||||
|
||||
BASE_PATH="/${SCOPE}/packages/container/${PACKAGE}"
|
||||
|
||||
if ! gh api "${BASE_PATH}" >/dev/null 2>&1; then
|
||||
echo "Package ghcr.io/${OWNER}/${PACKAGE} not found or accessible. Nothing to clean."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
mapfile -t VERSION_IDS < <(gh api --paginate "${BASE_PATH}/versions?per_page=100" \
|
||||
-q '.[] | select(.metadata.container.tags | length == 0) | .id')
|
||||
|
||||
if [[ ${#VERSION_IDS[@]} -eq 0 ]]; then
|
||||
echo "No untagged versions to delete."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "Deleting ${#VERSION_IDS[@]} untagged versions from ghcr.io/${OWNER}/${PACKAGE}..."
|
||||
for id in "${VERSION_IDS[@]}"; do
|
||||
gh api -X DELETE "${BASE_PATH}/versions/${id}" >/dev/null
|
||||
echo "Deleted version ${id}"
|
||||
done
|
||||
|
||||
echo "Cleanup complete."
|
||||
12
.github/workflows/docker.yml
vendored
12
.github/workflows/docker.yml
vendored
@@ -2,6 +2,8 @@ name: Build and Push Docker image
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
tags:
|
||||
- 'v*.*.*'
|
||||
workflow_dispatch:
|
||||
@@ -15,7 +17,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
@@ -42,10 +44,10 @@ jobs:
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Build and push Docker image
|
||||
uses: docker/build-push-action@v5
|
||||
- name: Build and push Docker image (multi arch)
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
platforms: linux/amd64
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: ${{ env.DOCKER_TAGS }}
|
||||
tags: ${{ env.DOCKER_TAGS }}
|
||||
|
||||
2
.github/workflows/release-drafter.yml
vendored
2
.github/workflows/release-drafter.yml
vendored
@@ -10,7 +10,7 @@ jobs:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
steps:
|
||||
- uses: release-drafter/release-drafter@v5
|
||||
- uses: release-drafter/release-drafter@v6
|
||||
with:
|
||||
config-name: release-drafter.yml
|
||||
env:
|
||||
|
||||
28
.gitignore
vendored
28
.gitignore
vendored
@@ -6,4 +6,30 @@ __pycache__/
|
||||
.vscode/
|
||||
data/
|
||||
migrate/
|
||||
.env
|
||||
.env
|
||||
AGENTS.md
|
||||
|
||||
# Logs
|
||||
/web/logs
|
||||
*.log
|
||||
npm-debug.log*
|
||||
yarn-debug.log*
|
||||
yarn-error.log*
|
||||
pnpm-debug.log*
|
||||
lerna-debug.log*
|
||||
|
||||
/web/node_modules
|
||||
/web/dist
|
||||
/web/dist-ssr
|
||||
/web/*.local
|
||||
|
||||
# Editor directories and files
|
||||
.vscode/*
|
||||
!.vscode/extensions.json
|
||||
.idea
|
||||
.DS_Store
|
||||
*.suo
|
||||
*.ntvs*
|
||||
*.njsproj
|
||||
*.sln
|
||||
*.sw?
|
||||
1
.python-version
Normal file
1
.python-version
Normal file
@@ -0,0 +1 @@
|
||||
3.14
|
||||
191
CONTRIBUTING.md
191
CONTRIBUTING.md
@@ -1,149 +1,162 @@
|
||||
<div align="right">
|
||||
<b>English</b> | <a href="./CONTRIBUTING_zh.md">简体中文</a>
|
||||
</div>
|
||||
|
||||
# Contributing to Foxel
|
||||
|
||||
🎉 首先,非常感谢您愿意花时间为 Foxel 做出贡献!
|
||||
We appreciate every minute you spend helping Foxel improve. This guide explains the contribution workflow so you can get started quickly.
|
||||
|
||||
我们热烈欢迎各种形式的贡献。无论是报告 Bug、提出新功能建议、完善文档,还是直接提交代码,都将对项目产生积极的影响。
|
||||
## Table of Contents
|
||||
|
||||
本指南将帮助您顺利地参与到项目中来。
|
||||
|
||||
## 目录
|
||||
|
||||
- [如何贡献](#如何贡献)
|
||||
- [🐛 报告 Bug](#-报告-bug)
|
||||
- [✨ 提交功能建议](#-提交功能建议)
|
||||
- [🛠️ 贡献代码](#️-贡献代码)
|
||||
- [开发环境搭建](#开发环境搭建)
|
||||
- [依赖准备](#依赖准备)
|
||||
- [后端 (FastAPI)](#后端-fastapi)
|
||||
- [前端 (React + Vite)](#前端-react--vite)
|
||||
- [代码贡献指南](#代码贡献指南)
|
||||
- [贡献存储适配器 (Adapter)](#贡献存储适配器-adapter)
|
||||
- [贡献前端应用 (App)](#贡献前端应用-app)
|
||||
- [提交规范](#提交规范)
|
||||
- [Git 分支管理](#git-分支管理)
|
||||
- [Commit Message 格式](#commit-message-格式)
|
||||
- [Pull Request 流程](#pull-request-流程)
|
||||
- [How to Contribute](#how-to-contribute)
|
||||
- [🐛 Report Bugs](#-report-bugs)
|
||||
- [✨ Suggest Features](#-suggest-features)
|
||||
- [🛠️ Contribute Code](#️-contribute-code)
|
||||
- [Development Environment](#development-environment)
|
||||
- [Prerequisites](#prerequisites)
|
||||
- [Backend (FastAPI)](#backend-fastapi)
|
||||
- [Frontend (React + Vite)](#frontend-react--vite)
|
||||
- [Contribution Guidelines](#contribution-guidelines)
|
||||
- [Storage Adapters](#storage-adapters)
|
||||
- [Frontend Apps](#frontend-apps)
|
||||
- [Submission Rules](#submission-rules)
|
||||
- [Git Branching](#git-branching)
|
||||
- [Commit Message Format](#commit-message-format)
|
||||
- [Pull Request Flow](#pull-request-flow)
|
||||
|
||||
---
|
||||
|
||||
## 如何贡献
|
||||
## How to Contribute
|
||||
|
||||
### 🐛 报告 Bug
|
||||
### 🐛 Report Bugs
|
||||
|
||||
如果您在使用的过程中发现了 Bug,请通过 [GitHub Issues](https://github.com/DrizzleTime/Foxel/issues) 来报告。请在报告中提供以下信息:
|
||||
If you discover a bug, open a ticket via [GitHub Issues](https://github.com/DrizzleTime/Foxel/issues) and include:
|
||||
|
||||
- **清晰的标题**:简明扼要地描述问题。
|
||||
- **复现步骤**:详细说明如何一步步重现该 Bug。
|
||||
- **期望行为** vs **实际行为**:描述您预期的结果和实际发生的情况。
|
||||
- **环境信息**:例如操作系统、浏览器版本、Foxel 版本等。
|
||||
- **A clear title** that summarises the problem.
|
||||
- **Reproduction steps** with enough detail to trigger the bug.
|
||||
- **Expected vs actual behaviour** to highlight the gap.
|
||||
- **Environment details** such as operating system, browser version, and the Foxel build you used.
|
||||
|
||||
### ✨ 提交功能建议
|
||||
### ✨ Suggest Features
|
||||
|
||||
我们欢迎任何关于新功能或改进的建议。请通过 [GitHub Issues](https://github.com/DrizzleTime/Foxel/issues) 创建一个 "Feature Request",并详细阐述您的想法:
|
||||
To propose a new capability or an improvement, create an Issue and choose the "Feature Request" template. Document:
|
||||
|
||||
- **问题描述**:说明该功能要解决什么问题。
|
||||
- **方案设想**:描述您希望该功能如何工作。
|
||||
- **相关信息**:提供任何有助于理解您想法的截图、链接或参考。
|
||||
- **Problem statement** – what pain point will the feature solve?
|
||||
- **Proposed solution** – how you expect it to work.
|
||||
- **Supporting material** – screenshots, references, or related links if helpful.
|
||||
|
||||
### 🛠️ 贡献代码
|
||||
### 🛠️ Contribute Code
|
||||
|
||||
如果您希望直接贡献代码,请参考下面的开发和提交流程。
|
||||
Follow the development setup below before opening a pull request. Keep changes focused and small so they are easier to review.
|
||||
|
||||
## 开发环境搭建
|
||||
## Development Environment
|
||||
|
||||
### 依赖准备
|
||||
### Prerequisites
|
||||
|
||||
- **Git**: 用于版本控制。
|
||||
- **Python**: >= 3.13
|
||||
- **Bun**: 用于前端包管理和脚本运行。
|
||||
Install the following tooling first:
|
||||
|
||||
### 后端 (FastAPI)
|
||||
- **Git** for version control
|
||||
- **Python** 3.13 or newer
|
||||
- **Bun** for frontend package management and scripts
|
||||
|
||||
后端 API 服务基于 Python 和 FastAPI 构建。
|
||||
### Backend (FastAPI)
|
||||
|
||||
1. **克隆仓库**
|
||||
1. **Clone the repository**
|
||||
|
||||
```bash
|
||||
git clone https://github.com/DrizzleTime/foxel.git
|
||||
cd Foxel
|
||||
```
|
||||
|
||||
2. **创建并激活 Python 虚拟环境**
|
||||
2. **Create and activate a virtual environment**
|
||||
|
||||
`uv` is recommended for performance and reproducibility:
|
||||
|
||||
```bash
|
||||
python3 -m venv .venv
|
||||
uv venv
|
||||
source .venv/bin/activate
|
||||
# On Windows: .venv\Scripts\activate
|
||||
```
|
||||
|
||||
3. **安装依赖**
|
||||
3. **Install dependencies**
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
uv sync
|
||||
```
|
||||
|
||||
4. **启动开发服务器**
|
||||
4. **Prepare local resources**
|
||||
|
||||
- Create the data directory:
|
||||
|
||||
```bash
|
||||
mkdir -p data/db
|
||||
```
|
||||
|
||||
Ensure the application user can read and write to `data/db`.
|
||||
|
||||
- Create an `.env` file in the project root and provide the required secrets. Replace the sample values with your own random strings:
|
||||
|
||||
```dotenv
|
||||
SECRET_KEY=EnsRhL9NFPxgFVc+7t96/y70DIOR+9SpntcIqQa90TU=
|
||||
TEMP_LINK_SECRET_KEY=EnsRhL9NFPxgFVc+7t96/y70DIOR+9SpntcIqQa90TU=
|
||||
```
|
||||
|
||||
5. **Start the development server**
|
||||
|
||||
```bash
|
||||
uvicorn main:app --reload --host 0.0.0.0 --port 8000
|
||||
```
|
||||
|
||||
API 服务将在 `http://localhost:8000` 上运行,您可以通过 `http://localhost:8000/docs` 访问自动生成的 API 文档。
|
||||
The API is available at `http://localhost:8000`, and the interactive docs live at `http://localhost:8000/docs`.
|
||||
|
||||
### 前端 (React + Vite)
|
||||
### Frontend (React + Vite)
|
||||
|
||||
前端应用使用 React, Vite, 和 TypeScript 构建。
|
||||
|
||||
1. **进入前端目录**
|
||||
1. **Enter the frontend directory**
|
||||
|
||||
```bash
|
||||
cd web
|
||||
```
|
||||
|
||||
2. **安装依赖**
|
||||
2. **Install dependencies**
|
||||
|
||||
```bash
|
||||
bun install
|
||||
```
|
||||
|
||||
3. **启动开发服务器**
|
||||
3. **Run the dev server**
|
||||
|
||||
```bash
|
||||
bun run dev
|
||||
```
|
||||
|
||||
前端开发服务器将在 `http://localhost:5173` 运行。它已经配置了代理,会自动将 `/api` 请求转发到后端服务。
|
||||
The Vite dev server runs at `http://localhost:5173` and proxies `/api` requests to the backend.
|
||||
|
||||
## 代码贡献指南
|
||||
## Contribution Guidelines
|
||||
|
||||
### 贡献存储适配器 (Adapter)
|
||||
### Storage Adapters
|
||||
|
||||
存储适配器是 Foxel 的核心扩展点,用于接入不同的存储后端 (如 S3, FTP, Alist 等)。
|
||||
Storage adapters integrate new storage providers (for example S3, FTP, or Alist).
|
||||
|
||||
1. **创建适配器文件**: 在 [`services/adapters/`](services/adapters/) 目录下,创建一个新文件,例如 `my_new_adapter.py`。
|
||||
2. **实现适配器类**:
|
||||
- 创建一个类,继承自 [`services.adapters.base.BaseAdapter`](services/adapters/base.py)。
|
||||
- 实现 `BaseAdapter` 中定义的所有抽象方法,如 `list_dir`, `get_meta`, `upload`, `download` 等。请仔细阅读基类中的文档注释以理解每个方法的作用和参数。
|
||||
1. Create a new module under [`domain/adapters/providers/`](domain/adapters/providers/) (for example `my_new_adapter.py`).
|
||||
2. Implement a class that inherits from [`domain.adapters.providers.base.BaseAdapter`](domain/adapters/providers/base.py) and provide concrete implementations for the abstract methods such as `list_dir`, `get_meta`, `upload`, and `download`.
|
||||
|
||||
### 贡献前端应用 (App)
|
||||
### Frontend Apps
|
||||
|
||||
前端应用允许用户在浏览器中直接预览或编辑特定类型的文件。
|
||||
Frontend apps enable in-browser previews or editors for specific file types.
|
||||
|
||||
1. **创建应用组件**: 在 [`web/src/apps/`](web/src/apps/) 目录下,为您的应用创建一个新的文件夹,并在其中创建 React 组件。
|
||||
2. **定义应用类型**: 您的应用需要实现 [`web/src/apps/types.ts`](web/src/apps/types.ts) 中定义的 `FoxelApp` 接口。
|
||||
3. **注册应用**: 在 [`web/src/apps/registry.ts`](web/src/apps/registry.ts) 中,导入您的应用组件,并将其添加到 `APP_REGISTRY`。在注册时,您需要指定该应用可以处理的文件类型(通过 MIME Type 或文件扩展名)。
|
||||
1. Add a new folder in [`web/src/apps/`](web/src/apps/) for your app and expose a React component.
|
||||
2. Implement the `FoxelApp` interface defined in [`web/src/apps/types.ts`](web/src/apps/types.ts).
|
||||
3. Register the app in [`web/src/apps/registry.ts`](web/src/apps/registry.ts) and declare the MIME types or extensions it supports.
|
||||
|
||||
## 提交规范
|
||||
## Submission Rules
|
||||
|
||||
### Git 分支管理
|
||||
### Git Branching
|
||||
|
||||
- 从最新的 `main` 分支创建您的特性分支。
|
||||
Start your work from the latest `main` branch and push feature changes on a dedicated branch.
|
||||
|
||||
### Commit Message 格式
|
||||
### Commit Message Format
|
||||
|
||||
我们遵循 [Conventional Commits](https://www.conventionalcommits.org/) 规范。这有助于自动化生成更新日志和版本管理。
|
||||
|
||||
Commit Message 格式如下:
|
||||
We follow the [Conventional Commits](https://www.conventionalcommits.org/) specification to drive release tooling.
|
||||
|
||||
```
|
||||
<type>(<scope>): <subject>
|
||||
@@ -153,27 +166,27 @@ Commit Message 格式如下:
|
||||
<footer>
|
||||
```
|
||||
|
||||
- **type**: `feat`, `fix`, `docs`, `style`, `refactor`, `test`, `chore` 等。
|
||||
- **scope**: (可选) 本次提交影响的范围,例如 `adapter`, `ui`, `api`。
|
||||
- **subject**: 简明扼要的描述。
|
||||
- **type**: e.g. `feat`, `fix`, `docs`, `style`, `refactor`, `test`, `chore`.
|
||||
- **scope** (optional): the area impacted by the change, such as `adapter`, `ui`, or `api`.
|
||||
- **subject**: a concise summary written in the imperative mood.
|
||||
|
||||
**示例:**
|
||||
**Examples:**
|
||||
|
||||
```
|
||||
feat(adapter): Add support for Alist storage
|
||||
feat(adapter): add support for Alist storage
|
||||
```
|
||||
|
||||
```
|
||||
fix(ui): Correct display issue in file list view
|
||||
fix(ui): correct display issue in file list view
|
||||
```
|
||||
|
||||
### Pull Request 流程
|
||||
### Pull Request Flow
|
||||
|
||||
1. Fork 仓库并克隆到本地。
|
||||
2. 创建并切换到您的特性分支。
|
||||
3. 完成代码编写和测试。
|
||||
4. 将您的分支推送到您的 Fork 仓库。
|
||||
5. 在 Foxel 主仓库创建一个 Pull Request,目标分支为 `main`。
|
||||
6. 在 PR 描述中清晰地说明您的更改内容、目的和任何相关的 Issue 编号。
|
||||
1. Fork the repository and clone it locally.
|
||||
2. Create and switch to your feature branch.
|
||||
3. Implement the change and run relevant checks.
|
||||
4. Push the branch to your fork.
|
||||
5. Open a pull request against `main` in the Foxel repository.
|
||||
6. Explain the change set, its motivation, and reference related Issues in the PR description.
|
||||
|
||||
项目维护者会尽快审查您的 PR。感谢您的耐心和贡献!
|
||||
Maintainers will review your pull request as soon as possible.
|
||||
|
||||
202
CONTRIBUTING_zh.md
Normal file
202
CONTRIBUTING_zh.md
Normal file
@@ -0,0 +1,202 @@
|
||||
<div align="right">
|
||||
<a href="./CONTRIBUTING.md">English</a> | <b>简体中文</b>
|
||||
</div>
|
||||
|
||||
# Contributing to Foxel
|
||||
|
||||
🎉 首先,非常感谢您愿意花时间为 Foxel 做出贡献!
|
||||
|
||||
我们热烈欢迎各种形式的贡献。无论是报告 Bug、提出新功能建议、完善文档,还是直接提交代码,都将对项目产生积极的影响。
|
||||
|
||||
本指南将帮助您顺利地参与到项目中来。
|
||||
|
||||
## 目录
|
||||
|
||||
- [如何贡献](#如何贡献)
|
||||
- [🐛 报告 Bug](#-报告-bug)
|
||||
- [✨ 提交功能建议](#-提交功能建议)
|
||||
- [🛠️ 贡献代码](#️-贡献代码)
|
||||
- [开发环境搭建](#开发环境搭建)
|
||||
- [依赖准备](#依赖准备)
|
||||
- [后端 (FastAPI)](#后端-fastapi)
|
||||
- [前端 (React + Vite)](#前端-react--vite)
|
||||
- [代码贡献指南](#代码贡献指南)
|
||||
- [贡献存储适配器 (Adapter)](#贡献存储适配器-adapter)
|
||||
- [贡献前端应用 (App)](#贡献前端应用-app)
|
||||
- [提交规范](#提交规范)
|
||||
- [Git 分支管理](#git-分支管理)
|
||||
- [Commit Message 格式](#commit-message-格式)
|
||||
- [Pull Request 流程](#pull-request-流程)
|
||||
|
||||
---
|
||||
|
||||
## 如何贡献
|
||||
|
||||
### 🐛 报告 Bug
|
||||
|
||||
如果您在使用的过程中发现了 Bug,请通过 [GitHub Issues](https://github.com/DrizzleTime/Foxel/issues) 来报告。请在报告中提供以下信息:
|
||||
|
||||
- **清晰的标题**:简明扼要地描述问题。
|
||||
- **复现步骤**:详细说明如何一步步重现该 Bug。
|
||||
- **期望行为** vs **实际行为**:描述您预期的结果和实际发生的情况。
|
||||
- **环境信息**:例如操作系统、浏览器版本、Foxel 版本等。
|
||||
|
||||
### ✨ 提交功能建议
|
||||
|
||||
我们欢迎任何关于新功能或改进的建议。请通过 [GitHub Issues](https://github.com/DrizzleTime/Foxel/issues) 创建一个 "Feature Request",并详细阐述您的想法:
|
||||
|
||||
- **问题描述**:说明该功能要解决什么问题。
|
||||
- **方案设想**:描述您希望该功能如何工作。
|
||||
- **相关信息**:提供任何有助于理解您想法的截图、链接或参考。
|
||||
|
||||
### 🛠️ 贡献代码
|
||||
|
||||
如果您希望直接贡献代码,请参考下面的开发和提交流程。
|
||||
|
||||
## 开发环境搭建
|
||||
|
||||
### 依赖准备
|
||||
|
||||
- **Git**: 用于版本控制。
|
||||
- **Python**: >= 3.13
|
||||
- **Bun**: 用于前端包管理和脚本运行。
|
||||
|
||||
### 后端 (FastAPI)
|
||||
|
||||
后端 API 服务基于 Python 和 FastAPI 构建。
|
||||
|
||||
1. **克隆仓库**
|
||||
|
||||
```bash
|
||||
git clone https://github.com/DrizzleTime/foxel.git
|
||||
cd Foxel
|
||||
```
|
||||
|
||||
2. **创建并激活 Python 虚拟环境**
|
||||
|
||||
我们推荐使用 `uv` 来管理虚拟环境,以获得最佳性能。
|
||||
|
||||
```bash
|
||||
uv venv
|
||||
source .venv/bin/activate
|
||||
# On Windows: .venv\Scripts\activate
|
||||
```
|
||||
|
||||
3. **安装依赖**
|
||||
|
||||
```bash
|
||||
uv sync
|
||||
```
|
||||
|
||||
4. **初始化环境**
|
||||
|
||||
在启动服务前,请进行以下准备:
|
||||
|
||||
- **创建数据目录**:
|
||||
在项目根目录执行 `mkdir -p data/db`。这将创建用于存放数据库等文件的目录。
|
||||
> [!IMPORTANT]
|
||||
> 请确保应用拥有对 `data/db` 目录的读写权限。
|
||||
|
||||
- **创建 `.env` 配置文件**:
|
||||
在项目根目录创建名为 `.env` 的文件,并填入以下内容。这些密钥用于保障应用安全,您可以按需修改。
|
||||
|
||||
```dotenv
|
||||
SECRET_KEY=EnsRhL9NFPxgFVc+7t96/y70DIOR+9SpntcIqQa90TU=
|
||||
TEMP_LINK_SECRET_KEY=EnsRhL9NFPxgFVc+7t96/y70DIOR+9SpntcIqQa90TU=
|
||||
```
|
||||
|
||||
5. **启动开发服务器**
|
||||
|
||||
```bash
|
||||
uvicorn main:app --reload --host 0.0.0.0 --port 8000
|
||||
```
|
||||
|
||||
API 服务将在 `http://localhost:8000` 上运行,您可以通过 `http://localhost:8000/docs` 访问自动生成的 API 文档。
|
||||
|
||||
### 前端 (React + Vite)
|
||||
|
||||
前端应用使用 React, Vite, 和 TypeScript 构建。
|
||||
|
||||
1. **进入前端目录**
|
||||
|
||||
```bash
|
||||
cd web
|
||||
```
|
||||
|
||||
2. **安装依赖**
|
||||
|
||||
```bash
|
||||
bun install
|
||||
```
|
||||
|
||||
3. **启动开发服务器**
|
||||
|
||||
```bash
|
||||
bun run dev
|
||||
```
|
||||
|
||||
前端开发服务器将在 `http://localhost:5173` 运行。它已经配置了代理,会自动将 `/api` 请求转发到后端服务。
|
||||
|
||||
## 代码贡献指南
|
||||
|
||||
### 贡献存储适配器 (Adapter)
|
||||
|
||||
存储适配器是 Foxel 的核心扩展点,用于接入不同的存储后端 (如 S3, FTP, Alist 等)。
|
||||
|
||||
1. **创建适配器文件**: 在 [`domain/adapters/providers/`](domain/adapters/providers/) 目录下,创建一个新文件,例如 `my_new_adapter.py`。
|
||||
2. **实现适配器类**:
|
||||
- 创建一个类,继承自 [`domain.adapters.providers.base.BaseAdapter`](domain/adapters/providers/base.py)。
|
||||
- 实现 `BaseAdapter` 中定义的所有抽象方法,如 `list_dir`, `get_meta`, `upload`, `download` 等。请仔细阅读基类中的文档注释以理解每个方法的作用和参数。
|
||||
|
||||
### 贡献前端应用 (App)
|
||||
|
||||
前端应用允许用户在浏览器中直接预览或编辑特定类型的文件。
|
||||
|
||||
1. **创建应用组件**: 在 [`web/src/apps/`](web/src/apps/) 目录下,为您的应用创建一个新的文件夹,并在其中创建 React 组件。
|
||||
2. **定义应用类型**: 您的应用需要实现 [`web/src/apps/types.ts`](web/src/apps/types.ts) 中定义的 `FoxelApp` 接口。
|
||||
3. **注册应用**: 在 [`web/src/apps/registry.ts`](web/src/apps/registry.ts) 中,导入您的应用组件,并将其添加到 `APP_REGISTRY`。在注册时,您需要指定该应用可以处理的文件类型(通过 MIME Type 或文件扩展名)。
|
||||
|
||||
## 提交规范
|
||||
|
||||
### Git 分支管理
|
||||
|
||||
- 从最新的 `main` 分支创建您的特性分支。
|
||||
|
||||
### Commit Message 格式
|
||||
|
||||
我们遵循 [Conventional Commits](https://www.conventionalcommits.org/) 规范。这有助于自动化生成更新日志和版本管理。
|
||||
|
||||
Commit Message 格式如下:
|
||||
|
||||
```
|
||||
<type>(<scope>): <subject>
|
||||
<BLANK LINE>
|
||||
<body>
|
||||
<BLANK LINE>
|
||||
<footer>
|
||||
```
|
||||
|
||||
- **type**: `feat`, `fix`, `docs`, `style`, `refactor`, `test`, `chore` 等。
|
||||
- **scope**: (可选) 本次提交影响的范围,例如 `adapter`, `ui`, `api`。
|
||||
- **subject**: 简明扼要的描述。
|
||||
|
||||
**示例:**
|
||||
|
||||
```
|
||||
feat(adapter): Add support for Alist storage
|
||||
```
|
||||
|
||||
```
|
||||
fix(ui): Correct display issue in file list view
|
||||
```
|
||||
|
||||
### Pull Request 流程
|
||||
|
||||
1. Fork 仓库并克隆到本地。
|
||||
2. 创建并切换到您的特性分支。
|
||||
3. 完成代码编写和测试。
|
||||
4. 将您的分支推送到您的 Fork 仓库。
|
||||
5. 在 Foxel 主仓库创建一个 Pull Request,目标分支为 `main`。
|
||||
6. 在 PR 描述中清晰地说明您的更改内容、目的和任何相关的 Issue 编号。
|
||||
|
||||
项目维护者会尽快审查您的 PR。感谢您的耐心和贡献!
|
||||
27
Dockerfile
27
Dockerfile
@@ -9,26 +9,37 @@ COPY web/ ./
|
||||
|
||||
RUN bun run build
|
||||
|
||||
FROM python:3.13-slim
|
||||
FROM python:3.14-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
RUN apt-get update && apt-get install -y nginx git && rm -rf /var/lib/apt/lists/*
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends ffmpeg curl ca-certificates \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt && pip install gunicorn
|
||||
RUN pip install uv
|
||||
COPY pyproject.toml uv.lock ./
|
||||
RUN uv pip install --system . gunicorn \
|
||||
&& rm -rf /root/.cache
|
||||
|
||||
RUN git clone https://github.com/DrizzleTime/FoxelUpgrade /app/migrate
|
||||
RUN curl -L https://github.com/DrizzleTime/FoxelUpgrade/archive/refs/heads/main.tar.gz -o /tmp/migrate.tgz \
|
||||
&& mkdir -p /app/migrate \
|
||||
&& tar -xzf /tmp/migrate.tgz --strip-components=1 -C /app/migrate \
|
||||
&& rm -rf /tmp/migrate.tgz
|
||||
|
||||
COPY --from=frontend-builder /app/web/dist /app/web/dist
|
||||
|
||||
COPY . .
|
||||
|
||||
COPY nginx.conf /etc/nginx/nginx.conf
|
||||
|
||||
RUN mkdir -p data/db data/mount && \
|
||||
chmod 777 data/db data/mount && \
|
||||
chmod +x setup/foxel_cli.py && \
|
||||
ln -sf /app/setup/foxel_cli.py /usr/local/bin/foxel && \
|
||||
rm -rf /var/log/apt /var/cache/apt/archives
|
||||
|
||||
EXPOSE 80
|
||||
|
||||
COPY entrypoint.sh /entrypoint.sh
|
||||
RUN chmod +x /entrypoint.sh
|
||||
|
||||
CMD ["/entrypoint.sh"]
|
||||
CMD ["/entrypoint.sh"]
|
||||
|
||||
64
README.md
64
README.md
@@ -1,8 +1,12 @@
|
||||
<div align="right">
|
||||
<b>English</b> | <a href="./README_zh.md">简体中文</a>
|
||||
</div>
|
||||
|
||||
<div align="center">
|
||||
|
||||
# Foxel
|
||||
|
||||
**一个面向个人和团队的、高度可扩展的私有云盘解决方案,支持 AI 语义搜索。**
|
||||
**A highly extensible private cloud storage solution for individuals and teams, featuring AI-powered semantic search.**
|
||||
|
||||

|
||||

|
||||
@@ -11,32 +15,32 @@
|
||||
|
||||
---
|
||||
<blockquote>
|
||||
<em><strong>数据之洋浩瀚无涯,当以洞察之目引航,然其脉络深隐,非表象所能尽窥。</strong></em><br>
|
||||
<em><strong>The ocean of data is boundless, let the eye of insight guide the voyage, yet its intricate connections lie deep, not fully discernible from the surface.</strong></em>
|
||||
</blockquote>
|
||||
<img src="https://foxel.cc/image/ad-min.png" alt="UI Screenshot">
|
||||
</div>
|
||||
|
||||
## 👀 在线体验
|
||||
## 👀 Online Demo
|
||||
|
||||
> [https://demo.foxel.cc](https://demo.foxel.cc)
|
||||
>
|
||||
> 账号/密码:`admin` / `admin`
|
||||
> Account/Password: `admin` / `admin`
|
||||
|
||||
## ✨ 核心功能
|
||||
## ✨ Core Features
|
||||
|
||||
- **统一文件管理**:集中管理分布于不同存储后端的文件。
|
||||
- **插件化存储后端**:采用可扩展的适配器模式,方便集成多种存储类型。
|
||||
- **语义搜索**:支持自然语言描述搜索图片、文档等非结构化数据内容。
|
||||
- **内置文件预览**:可直接预览图片、视频、PDF、Office 文档及文本、代码文件,无需下载。
|
||||
- **权限与分享**:支持公开或私密分享链接,便于文件共享。
|
||||
- **任务处理中心**:支持异步任务处理,如文件索引和数据备份,不影响主应用运行。
|
||||
- **Unified File Management**: Centralize management of files distributed across different storage backends.
|
||||
- **Pluggable Storage Backends**: Utilizes an extensible adapter pattern to easily integrate various storage types.
|
||||
- **Semantic Search**: Supports natural language search for content within unstructured data like images and documents.
|
||||
- **Built-in File Preview**: Preview images, videos, PDFs, Office documents, text, and code files directly without downloading.
|
||||
- **Permissions and Sharing**: Supports public or private sharing links for easy file distribution.
|
||||
- **Task Processing Center**: Supports asynchronous task processing, such as file indexing and data backups, without impacting the main application.
|
||||
|
||||
## 🚀 快速开始
|
||||
## 🚀 Quick Start
|
||||
|
||||
使用 Docker Compose 是启动 Foxel 最推荐的方式。
|
||||
Using Docker Compose is the most recommended way to start Foxel.
|
||||
|
||||
1. **创建数据目录**:
|
||||
新建 `data` 文件夹用于持久化数据:
|
||||
1. **Create Data Directories**:
|
||||
Create a `data` folder for persistent data:
|
||||
|
||||
```bash
|
||||
mkdir -p data/db
|
||||
@@ -44,40 +48,40 @@ mkdir -p data/mount
|
||||
chmod 777 data/db data/mount
|
||||
```
|
||||
|
||||
2. **下载 Docker Compose 文件**:
|
||||
2. **Download Docker Compose File**:
|
||||
|
||||
```bash
|
||||
curl -L -O https://github.com/DrizzleTime/Foxel/raw/main/compose.yaml
|
||||
```
|
||||
|
||||
下载完成后,**强烈建议**修改 `compose.yaml` 文件中的环境变量以确保安全:
|
||||
After downloading, it is **strongly recommended** to modify the environment variables in the `compose.yaml` file to ensure security:
|
||||
|
||||
- 修改 `SECRET_KEY` 和 `TEMP_LINK_SECRET_KEY`:将默认的密钥替换为随机生成的强密钥
|
||||
- Modify `SECRET_KEY` and `TEMP_LINK_SECRET_KEY`: Replace the default keys with randomly generated strong keys.
|
||||
|
||||
3. **启动服务**:
|
||||
3. **Start the Services**:
|
||||
|
||||
```bash
|
||||
docker-compose up -d
|
||||
```
|
||||
|
||||
4. **访问应用**:
|
||||
4. **Access the Application**:
|
||||
|
||||
服务启动后,在浏览器中打开页面。
|
||||
Once the services are running, open the page in your browser.
|
||||
|
||||
> 首次启动,请根据引导页面完成管理员账号的初始化设置。
|
||||
> On the first launch, please follow the setup guide to initialize the administrator account.
|
||||
|
||||
## 🤝 如何贡献
|
||||
## 🤝 How to Contribute
|
||||
|
||||
我们非常欢迎来自社区的贡献!无论是提交 Bug、建议新功能还是直接贡献代码。
|
||||
We welcome contributions from the community! Whether it's submitting bugs, suggesting new features, or contributing code directly.
|
||||
|
||||
在开始之前,请先阅读我们的 [`CONTRIBUTING.md`](CONTRIBUTING.md) 文件,它会指导你如何设置开发环境以及提交流程。
|
||||
Before you start, please read our [`CONTRIBUTING.md`](CONTRIBUTING.md) file, which explains the development environment and submission process. A Simplified Chinese translation is available in [`CONTRIBUTING_zh.md`](CONTRIBUTING_zh.md).
|
||||
|
||||
## 🌐 社区
|
||||
## 🌐 Community
|
||||
|
||||
加入我们的交流社区:[Telegram 群组](https://t.me/+thDsBfyqJxZkNTU1),与开发者和用户一起讨论!
|
||||
Join our community on [Telegram](https://t.me/+thDsBfyqJxZkNTU1) to discuss with developers and other users!
|
||||
|
||||
你也可以加入我们的微信群,获取更多实时交流与支持。请扫描下方二维码加入:
|
||||
You can also join our WeChat group for more real-time communication and support. Please scan the QR code below to join:
|
||||
|
||||
<img src="https://foxel.cc/image/wechat.png" alt="微信群二维码" width="180">
|
||||
<img src="https://foxel.cc/image/wechat.png" alt="WeChat Group QR Code" width="180">
|
||||
|
||||
> 如果二维码失效,请添加微信号 **drizzle2001**,我们会邀请你加入群聊。
|
||||
> If the QR code is invalid, please add WeChat ID **drizzle2001**, and we will invite you to the group.
|
||||
|
||||
88
README_zh.md
Normal file
88
README_zh.md
Normal file
@@ -0,0 +1,88 @@
|
||||
<div align="right">
|
||||
<a href="./README.md">English</a> | <b>简体中文</b>
|
||||
</div>
|
||||
|
||||
<div align="center">
|
||||
|
||||
# Foxel
|
||||
|
||||
**一个面向个人和团队的、高度可扩展的私有云盘解决方案,支持 AI 语义搜索。**
|
||||
|
||||

|
||||

|
||||

|
||||

|
||||
|
||||
---
|
||||
<blockquote>
|
||||
<em><strong>数据之洋浩瀚无涯,当以洞察之目引航,然其脉络深隐,非表象所能尽窥。</strong></em><br>
|
||||
<em><strong>The ocean of data is boundless, let the eye of insight guide the voyage, yet its intricate connections lie deep, not fully discernible from the surface.</strong></em>
|
||||
</blockquote>
|
||||
<img src="https://foxel.cc/image/ad-min.png" alt="UI Screenshot">
|
||||
</div>
|
||||
|
||||
## 👀 在线体验
|
||||
|
||||
> [https://demo.foxel.cc](https://demo.foxel.cc)
|
||||
>
|
||||
> 账号/密码:`admin` / `admin`
|
||||
|
||||
## ✨ 核心功能
|
||||
|
||||
- **统一文件管理**:集中管理分布于不同存储后端的文件。
|
||||
- **插件化存储后端**:采用可扩展的适配器模式,方便集成多种存储类型。
|
||||
- **语义搜索**:支持自然语言描述搜索图片、文档等非结构化数据内容。
|
||||
- **内置文件预览**:可直接预览图片、视频、PDF、Office 文档及文本、代码文件,无需下载。
|
||||
- **权限与分享**:支持公开或私密分享链接,便于文件共享。
|
||||
- **任务处理中心**:支持异步任务处理,如文件索引和数据备份,不影响主应用运行。
|
||||
|
||||
## 🚀 快速开始
|
||||
|
||||
使用 Docker Compose 是启动 Foxel 最推荐的方式。
|
||||
|
||||
1. **创建数据目录**:
|
||||
新建 `data` 文件夹用于持久化数据:
|
||||
|
||||
```bash
|
||||
mkdir -p data/db
|
||||
mkdir -p data/mount
|
||||
chmod 777 data/db data/mount
|
||||
```
|
||||
|
||||
2. **下载 Docker Compose 文件**:
|
||||
|
||||
```bash
|
||||
curl -L -O https://github.com/DrizzleTime/Foxel/raw/main/compose.yaml
|
||||
```
|
||||
|
||||
下载完成后,**强烈建议**修改 `compose.yaml` 文件中的环境变量以确保安全:
|
||||
|
||||
- 修改 `SECRET_KEY` 和 `TEMP_LINK_SECRET_KEY`:将默认的密钥替换为随机生成的强密钥
|
||||
|
||||
3. **启动服务**:
|
||||
|
||||
```bash
|
||||
docker-compose up -d
|
||||
```
|
||||
|
||||
4. **访问应用**:
|
||||
|
||||
服务启动后,在浏览器中打开页面。
|
||||
|
||||
> 首次启动,请根据引导页面完成管理员账号的初始化设置。
|
||||
|
||||
## 🤝 如何贡献
|
||||
|
||||
我们非常欢迎来自社区的贡献!无论是提交 Bug、建议新功能还是直接贡献代码。
|
||||
|
||||
在开始之前,请先阅读我们的 [`CONTRIBUTING_zh.md`](CONTRIBUTING_zh.md) 文件,它会指导你如何设置开发环境以及提交流程。
|
||||
|
||||
## 🌐 社区
|
||||
|
||||
加入我们的交流社区:[Telegram 群组](https://t.me/+thDsBfyqJxZkNTU1),与开发者和用户一起讨论!
|
||||
|
||||
你也可以加入我们的微信群,获取更多实时交流与支持。请扫描下方二维码加入:
|
||||
|
||||
<img src="https://foxel.cc/image/wechat.png" alt="微信群二维码" width="180">
|
||||
|
||||
> 如果二维码失效,请添加微信号 **drizzle2001**,我们会邀请你加入群聊。
|
||||
@@ -1,17 +1,38 @@
|
||||
from fastapi import FastAPI
|
||||
|
||||
from .routes import adapters, virtual_fs, auth, config, processors, tasks, logs, share, backup, search
|
||||
from domain.adapters import api as adapters
|
||||
from domain.auth import api as auth
|
||||
from domain.backup import api as backup
|
||||
from domain.config import api as config
|
||||
from domain.email import api as email
|
||||
from domain.offline_downloads import api as offline_downloads
|
||||
from domain.plugins import api as plugins
|
||||
from domain.processors import api as processors
|
||||
from domain.share import api as share
|
||||
from domain.tasks import api as tasks
|
||||
from domain.ai import api as ai
|
||||
from domain.virtual_fs import api as virtual_fs
|
||||
from domain.virtual_fs.mapping import s3_api, webdav_api
|
||||
from domain.virtual_fs.search import search_api
|
||||
from domain.audit import router as audit
|
||||
|
||||
|
||||
def include_routers(app: FastAPI):
|
||||
app.include_router(adapters.router)
|
||||
app.include_router(virtual_fs.router)
|
||||
app.include_router(search.router)
|
||||
app.include_router(search_api.router)
|
||||
app.include_router(auth.router)
|
||||
app.include_router(config.router)
|
||||
app.include_router(processors.router)
|
||||
app.include_router(tasks.router)
|
||||
app.include_router(logs.router)
|
||||
app.include_router(share.router)
|
||||
app.include_router(share.public_router)
|
||||
app.include_router(backup.router)
|
||||
app.include_router(backup.router)
|
||||
app.include_router(ai.router_vector_db)
|
||||
app.include_router(ai.router_ai)
|
||||
app.include_router(plugins.router)
|
||||
app.include_router(webdav_api.router)
|
||||
app.include_router(s3_api.router)
|
||||
app.include_router(offline_downloads.router)
|
||||
app.include_router(email.router)
|
||||
app.include_router(audit)
|
||||
|
||||
@@ -1,149 +0,0 @@
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from tortoise.transactions import in_transaction
|
||||
from typing import Annotated
|
||||
|
||||
from models import StorageAdapter
|
||||
from schemas import AdapterCreate, AdapterOut
|
||||
from services.auth import get_current_active_user, User
|
||||
from services.adapters.registry import runtime_registry, get_config_schemas
|
||||
from api.response import success
|
||||
from services.logging import LogService
|
||||
|
||||
router = APIRouter(prefix="/api/adapters", tags=["adapters"])
|
||||
|
||||
|
||||
def validate_and_normalize_config(adapter_type: str, cfg):
|
||||
schemas = get_config_schemas()
|
||||
if not isinstance(cfg, dict):
|
||||
raise HTTPException(400, detail="config 必须是对象")
|
||||
schema = schemas.get(adapter_type)
|
||||
if not schema:
|
||||
raise HTTPException(400, detail=f"不支持的适配器类型: {adapter_type}")
|
||||
out = {}
|
||||
missing = []
|
||||
for f in schema:
|
||||
k = f["key"]
|
||||
if k in cfg and cfg[k] not in (None, ""):
|
||||
out[k] = cfg[k]
|
||||
elif "default" in f:
|
||||
out[k] = f["default"]
|
||||
elif f.get("required"):
|
||||
missing.append(k)
|
||||
if missing:
|
||||
raise HTTPException(400, detail="缺少必填配置字段: " + ", ".join(missing))
|
||||
return out
|
||||
|
||||
|
||||
@router.post("")
|
||||
async def create_adapter(
|
||||
data: AdapterCreate,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
norm_path = AdapterCreate.normalize_mount_path(data.path)
|
||||
exists = await StorageAdapter.get_or_none(path=norm_path)
|
||||
if exists:
|
||||
raise HTTPException(400, detail="Mount path already exists")
|
||||
|
||||
adapter_fields = {
|
||||
"name": data.name,
|
||||
"type": data.type,
|
||||
"config": validate_and_normalize_config(data.type, data.config or {}),
|
||||
"enabled": data.enabled,
|
||||
"path": norm_path,
|
||||
"sub_path": data.sub_path,
|
||||
}
|
||||
|
||||
rec = await StorageAdapter.create(**adapter_fields)
|
||||
await runtime_registry.refresh()
|
||||
await LogService.action(
|
||||
"route:adapters",
|
||||
f"Created adapter {rec.name}",
|
||||
details=adapter_fields,
|
||||
user_id=current_user.id if hasattr(current_user, "id") else None,
|
||||
)
|
||||
return success(rec)
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_adapters(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
adapters = await StorageAdapter.all()
|
||||
out = [AdapterOut.model_validate(a) for a in adapters]
|
||||
return success(out)
|
||||
|
||||
|
||||
@router.get("/available")
|
||||
async def available_adapter_types(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
data = []
|
||||
for t, fields in get_config_schemas().items():
|
||||
data.append({
|
||||
"type": t,
|
||||
"name": "本地文件系统" if t == "local" else ("WebDAV" if t == "webdav" else t),
|
||||
"config_schema": fields,
|
||||
})
|
||||
return success(data)
|
||||
|
||||
|
||||
@router.get("/{adapter_id}")
|
||||
async def get_adapter(
|
||||
adapter_id: int,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
rec = await StorageAdapter.get_or_none(id=adapter_id)
|
||||
if not rec:
|
||||
raise HTTPException(404, detail="Not found")
|
||||
return success(AdapterOut.model_validate(rec))
|
||||
|
||||
|
||||
@router.put("/{adapter_id}")
|
||||
async def update_adapter(
|
||||
adapter_id: int,
|
||||
data: AdapterCreate,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
rec = await StorageAdapter.get_or_none(id=adapter_id)
|
||||
if not rec:
|
||||
raise HTTPException(404, detail="Not found")
|
||||
|
||||
norm_path = AdapterCreate.normalize_mount_path(data.path)
|
||||
existing = await StorageAdapter.get_or_none(path=norm_path)
|
||||
if existing and existing.id != adapter_id:
|
||||
raise HTTPException(400, detail="Mount path already exists")
|
||||
|
||||
rec.name = data.name
|
||||
rec.type = data.type
|
||||
rec.config = validate_and_normalize_config(data.type, data.config or {})
|
||||
rec.enabled = data.enabled
|
||||
rec.path = norm_path
|
||||
rec.sub_path = data.sub_path
|
||||
await rec.save()
|
||||
|
||||
await runtime_registry.refresh()
|
||||
await LogService.action(
|
||||
"route:adapters",
|
||||
f"Updated adapter {rec.name}",
|
||||
details=data.model_dump(),
|
||||
user_id=current_user.id if hasattr(current_user, "id") else None,
|
||||
)
|
||||
return success(rec)
|
||||
|
||||
|
||||
@router.delete("/{adapter_id}")
|
||||
async def delete_adapter(
|
||||
adapter_id: int,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
deleted = await StorageAdapter.filter(id=adapter_id).delete()
|
||||
if not deleted:
|
||||
raise HTTPException(404, detail="Not found")
|
||||
await runtime_registry.refresh()
|
||||
await LogService.action(
|
||||
"route:adapters",
|
||||
f"Deleted adapter {adapter_id}",
|
||||
details={"adapter_id": adapter_id},
|
||||
user_id=current_user.id if hasattr(current_user, "id") else None,
|
||||
)
|
||||
return success({"deleted": True})
|
||||
@@ -1,53 +0,0 @@
|
||||
from typing import Annotated
|
||||
from fastapi import APIRouter, HTTPException, Depends, Form
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from services.auth import (
|
||||
authenticate_user_db,
|
||||
create_access_token,
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES,
|
||||
register_user,
|
||||
Token,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
from datetime import timedelta
|
||||
from api.response import success
|
||||
|
||||
router = APIRouter(prefix="/api/auth", tags=["auth"])
|
||||
|
||||
|
||||
class RegisterRequest(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
email: str | None = None
|
||||
full_name: str | None = None
|
||||
|
||||
@router.post("/register", summary="注册第一个管理员用户")
|
||||
async def register(data: RegisterRequest):
|
||||
"""
|
||||
仅当系统中没有用户时,才允许注册。
|
||||
"""
|
||||
user = await register_user(
|
||||
username=data.username,
|
||||
password=data.password,
|
||||
email=data.email,
|
||||
full_name=data.full_name,
|
||||
)
|
||||
return success({"username": user.username}, msg="初始用户注册成功")
|
||||
|
||||
|
||||
@router.post("/login")
|
||||
async def login_for_access_token(
|
||||
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
||||
) -> Token:
|
||||
user = await authenticate_user_db(form_data.username, form_data.password)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="用户名或密码错误",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
access_token = await create_access_token(
|
||||
data={"sub": user.username}, expires_delta=access_token_expires
|
||||
)
|
||||
return Token(access_token=access_token, token_type="bearer")
|
||||
@@ -1,50 +0,0 @@
|
||||
from fastapi import APIRouter, Depends, UploadFile, File, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from services.auth import get_current_active_user
|
||||
from services.backup import BackupService
|
||||
from models.database import UserAccount
|
||||
import json
|
||||
import datetime
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api/backup",
|
||||
tags=["Backup & Restore"],
|
||||
dependencies=[Depends(get_current_active_user)],
|
||||
)
|
||||
|
||||
@router.get("/export", summary="导出全站数据")
|
||||
async def export_backup():
|
||||
"""
|
||||
生成并下载一个包含所有关键数据的JSON文件。
|
||||
"""
|
||||
try:
|
||||
data = await BackupService.export_data()
|
||||
timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
headers = {
|
||||
"Content-Disposition": f"attachment; filename=foxel_backup_{timestamp}.json"
|
||||
}
|
||||
return JSONResponse(content=data, headers=headers)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.post("/import", summary="导入数据")
|
||||
async def import_backup(file: UploadFile = File(...)):
|
||||
"""
|
||||
从上传的JSON文件恢复数据。
|
||||
**警告**: 这将会覆盖所有现有数据!
|
||||
"""
|
||||
|
||||
if not file.filename.endswith(".json"):
|
||||
raise HTTPException(status_code=400, detail="无效的文件类型, 请上传 .json 文件")
|
||||
|
||||
try:
|
||||
contents = await file.read()
|
||||
data = json.loads(contents)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=400, detail="无法解析JSON文件")
|
||||
|
||||
try:
|
||||
await BackupService.import_data(data)
|
||||
return {"message": "数据导入成功。"}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"导入失败: {e}")
|
||||
@@ -1,78 +0,0 @@
|
||||
import httpx
|
||||
import time
|
||||
from fastapi import APIRouter, Depends, Form
|
||||
from typing import Annotated
|
||||
from services.config import ConfigCenter, VERSION
|
||||
from services.auth import get_current_active_user, User, has_users
|
||||
from api.response import success
|
||||
router = APIRouter(prefix="/api/config", tags=["config"])
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def get_config(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
key: str
|
||||
):
|
||||
value = await ConfigCenter.get(key)
|
||||
return success({"key": key, "value": value})
|
||||
|
||||
|
||||
@router.post("/")
|
||||
async def set_config(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
key: str = Form(...),
|
||||
value: str = Form(...)
|
||||
):
|
||||
await ConfigCenter.set(key, value)
|
||||
return success({"key": key, "value": value})
|
||||
|
||||
|
||||
@router.get("/all")
|
||||
async def get_all_config(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
configs = await ConfigCenter.get_all()
|
||||
return success(configs)
|
||||
|
||||
|
||||
@router.get("/status")
|
||||
async def get_system_status():
|
||||
system_info = {
|
||||
"version": VERSION,
|
||||
"title": await ConfigCenter.get("APP_NAME", "Foxel"),
|
||||
"logo": await ConfigCenter.get("APP_LOGO", "/logo.svg"),
|
||||
"is_initialized": await has_users()
|
||||
}
|
||||
return success(system_info)
|
||||
|
||||
|
||||
latest_version_cache = {
|
||||
"timestamp": 0,
|
||||
"data": None
|
||||
}
|
||||
|
||||
|
||||
@router.get("/latest-version")
|
||||
async def get_latest_version():
|
||||
current_time = time.time()
|
||||
if current_time - latest_version_cache["timestamp"] < 3600 and latest_version_cache["data"]:
|
||||
return success(latest_version_cache["data"])
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.get(
|
||||
"https://api.github.com/repos/DrizzleTime/Foxel/releases/latest",
|
||||
follow_redirects=True,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
version_info = {
|
||||
"latest_version": data.get("tag_name"),
|
||||
"body": data.get("body")
|
||||
}
|
||||
latest_version_cache["timestamp"] = current_time
|
||||
latest_version_cache["data"] = version_info
|
||||
return success(version_info)
|
||||
except httpx.RequestError as e:
|
||||
if latest_version_cache["data"]:
|
||||
return success(latest_version_cache["data"])
|
||||
return success({"latest_version": None, "body": None})
|
||||
@@ -1,48 +0,0 @@
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, Query
|
||||
from models.database import Log
|
||||
from api.response import page, success
|
||||
from tortoise.expressions import Q
|
||||
from datetime import datetime
|
||||
|
||||
router = APIRouter(prefix="/api/logs", tags=["Logs"])
|
||||
|
||||
@router.get("")
|
||||
async def get_logs(
|
||||
page_num: int = Query(1, alias="page"),
|
||||
page_size: int = Query(20, alias="page_size"),
|
||||
level: Optional[str] = Query(None),
|
||||
source: Optional[str] = Query(None),
|
||||
start_time: Optional[datetime] = Query(None),
|
||||
end_time: Optional[datetime] = Query(None),
|
||||
):
|
||||
"""获取日志列表,支持分页和筛选"""
|
||||
query = Log.all()
|
||||
if level:
|
||||
query = query.filter(level=level)
|
||||
if source:
|
||||
query = query.filter(source__icontains=source)
|
||||
if start_time:
|
||||
query = query.filter(timestamp__gte=start_time)
|
||||
if end_time:
|
||||
query = query.filter(timestamp__lte=end_time)
|
||||
|
||||
total = await query.count()
|
||||
logs = await query.order_by("-timestamp").offset((page_num - 1) * page_size).limit(page_size)
|
||||
|
||||
return success(page([log for log in logs], total, page_num, page_size))
|
||||
|
||||
@router.delete("")
|
||||
async def clear_logs(
|
||||
start_time: Optional[datetime] = Query(None),
|
||||
end_time: Optional[datetime] = Query(None),
|
||||
):
|
||||
"""清理指定时间范围内的日志"""
|
||||
query = Log.all()
|
||||
if start_time:
|
||||
query = query.filter(timestamp__gte=start_time)
|
||||
if end_time:
|
||||
query = query.filter(timestamp__lte=end_time)
|
||||
|
||||
deleted_count = await query.delete()
|
||||
return success({"deleted_count": deleted_count})
|
||||
@@ -1,44 +0,0 @@
|
||||
from fastapi import APIRouter, Depends, Body
|
||||
from typing import Annotated
|
||||
from services.processors.registry import get_config_schemas
|
||||
from services.virtual_fs import process_file
|
||||
from services.auth import get_current_active_user, User
|
||||
from api.response import success
|
||||
from pydantic import BaseModel
|
||||
|
||||
router = APIRouter(prefix="/api/processors", tags=["processors"])
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_processors(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
schemas = get_config_schemas()
|
||||
out = []
|
||||
for t, meta in schemas.items():
|
||||
out.append({
|
||||
"type": meta["type"],
|
||||
"name": meta["name"],
|
||||
"supported_exts": meta.get("supported_exts", []),
|
||||
"config_schema": meta["config_schema"],
|
||||
"produces_file": meta.get("produces_file", False),
|
||||
})
|
||||
return success(out)
|
||||
|
||||
|
||||
class ProcessRequest(BaseModel):
|
||||
path: str
|
||||
processor_type: str
|
||||
config: dict
|
||||
save_to: str | None = None
|
||||
overwrite: bool = False
|
||||
|
||||
|
||||
@router.post("/process")
|
||||
async def process_file_with_processor(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
req: ProcessRequest = Body(...)
|
||||
):
|
||||
save_to = req.path if req.overwrite else req.save_to
|
||||
result = await process_file(req.path, req.processor_type, req.config, save_to)
|
||||
return success(result)
|
||||
@@ -1,41 +0,0 @@
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from schemas.fs import SearchResultItem
|
||||
from services.auth import get_current_active_user, User
|
||||
from services.ai import get_text_embedding
|
||||
from services.vector_db import VectorDBService
|
||||
|
||||
router = APIRouter(prefix="/api/search", tags=["search"])
|
||||
|
||||
async def search_files_by_vector(q: str, top_k: int):
|
||||
embedding = await get_text_embedding(q)
|
||||
vector_db = VectorDBService()
|
||||
results = vector_db.search_vectors("vector_collection", embedding, top_k)
|
||||
items = [
|
||||
SearchResultItem(id=res["id"], path=res["entity"]["path"], score=res["distance"])
|
||||
for res in results[0]
|
||||
]
|
||||
return {"items": items, "query": q}
|
||||
|
||||
async def search_files_by_name(q: str, top_k: int):
|
||||
vector_db = VectorDBService()
|
||||
results = vector_db.search_by_path("vector_collection", q, top_k)
|
||||
items = [
|
||||
SearchResultItem(id=idx, path=res["entity"]["path"], score=res["distance"])
|
||||
for idx, res in enumerate(results[0])
|
||||
]
|
||||
return {"items": items, "query": q}
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def search_files(
|
||||
q: str = Query(..., description="搜索查询"),
|
||||
top_k: int = Query(10, description="返回结果数量"),
|
||||
mode: str = Query("vector", description="搜索模式: 'vector' 或 'filename'"),
|
||||
user: User = Depends(get_current_active_user),
|
||||
):
|
||||
if mode == "vector":
|
||||
return await search_files_by_vector(q, top_k)
|
||||
elif mode == "filename":
|
||||
return await search_files_by_name(q, top_k)
|
||||
else:
|
||||
return {"items": [], "query": q, "error": "Invalid search mode"}
|
||||
@@ -1,205 +0,0 @@
|
||||
from typing import List, Optional
|
||||
from urllib.parse import quote
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from pydantic import BaseModel
|
||||
|
||||
from api.response import success
|
||||
from services.auth import User, get_current_active_user
|
||||
from services.share import share_service
|
||||
from services.virtual_fs import stream_file, stat_file
|
||||
from models.database import ShareLink, UserAccount
|
||||
|
||||
public_router = APIRouter(prefix="/api/s", tags=["Share - Public"])
|
||||
router = APIRouter(prefix="/api/shares", tags=["Share - Management"])
|
||||
|
||||
class ShareCreate(BaseModel):
|
||||
name: str
|
||||
paths: List[str]
|
||||
expires_in_days: Optional[int] = 7
|
||||
access_type: str = "public"
|
||||
password: Optional[str] = None
|
||||
|
||||
|
||||
class ShareInfo(BaseModel):
|
||||
id: int
|
||||
token: str
|
||||
name: str
|
||||
paths: List[str]
|
||||
created_at: str
|
||||
expires_at: Optional[str] = None
|
||||
access_type: str
|
||||
|
||||
@classmethod
|
||||
def from_orm(cls, obj: ShareLink):
|
||||
return cls(
|
||||
id=obj.id,
|
||||
token=obj.token,
|
||||
name=obj.name,
|
||||
paths=obj.paths,
|
||||
created_at=obj.created_at.isoformat(),
|
||||
expires_at=obj.expires_at.isoformat() if obj.expires_at else None,
|
||||
access_type=obj.access_type,
|
||||
)
|
||||
|
||||
|
||||
class ShareInfoWithPassword(ShareInfo):
|
||||
password: Optional[str] = None
|
||||
|
||||
|
||||
# --- Management Routes ---
|
||||
|
||||
@router.post("", response_model=ShareInfoWithPassword)
|
||||
async def create_share(
|
||||
payload: ShareCreate,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
):
|
||||
"""
|
||||
创建一个新的分享链接。
|
||||
"""
|
||||
user_account = await UserAccount.get(id=current_user.id)
|
||||
share = await share_service.create_share_link(
|
||||
user=user_account,
|
||||
name=payload.name,
|
||||
paths=payload.paths,
|
||||
expires_in_days=payload.expires_in_days,
|
||||
access_type=payload.access_type,
|
||||
password=payload.password,
|
||||
)
|
||||
share_info_base = ShareInfo.from_orm(share)
|
||||
response_data = share_info_base.model_dump()
|
||||
if payload.access_type == "password" and payload.password:
|
||||
response_data['password'] = payload.password
|
||||
|
||||
return response_data
|
||||
|
||||
|
||||
@router.get("", response_model=List[ShareInfo])
|
||||
async def get_my_shares(current_user: User = Depends(get_current_active_user)):
|
||||
"""
|
||||
获取当前用户的所有分享链接。
|
||||
"""
|
||||
user_account = await UserAccount.get(id=current_user.id)
|
||||
shares = await share_service.get_user_shares(user=user_account)
|
||||
return [ShareInfo.from_orm(s) for s in shares]
|
||||
|
||||
|
||||
@router.delete("/{share_id}")
|
||||
async def delete_share(
|
||||
share_id: int,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
):
|
||||
"""
|
||||
删除一个分享链接。
|
||||
"""
|
||||
await share_service.delete_share_link(user=current_user, share_id=share_id)
|
||||
return success(msg="分享已取消")
|
||||
|
||||
|
||||
# --- Public Routes ---
|
||||
|
||||
class SharePassword(BaseModel):
|
||||
password: str
|
||||
|
||||
@public_router.post("/{token}/verify")
|
||||
async def verify_password(token: str, payload: SharePassword):
|
||||
"""
|
||||
验证分享链接的密码。
|
||||
"""
|
||||
share = await share_service.get_share_by_token(token)
|
||||
if share.access_type != "password":
|
||||
raise HTTPException(status_code=400, detail="此分享不需要密码")
|
||||
|
||||
if not share_service._verify_password(payload.password, share.hashed_password):
|
||||
raise HTTPException(status_code=403, detail="密码错误")
|
||||
|
||||
# 在这里可以考虑返回一个有时效性的token用于后续访问,但为了简单起见,
|
||||
# 我们让前端在每次请求时都带上密码或一个会话标识。
|
||||
# 简单起见,我们只返回成功状态。
|
||||
return success(msg="验证成功")
|
||||
|
||||
|
||||
@public_router.get("/{token}/ls")
|
||||
async def list_share_content(token: str, path: str = "/", password: Optional[str] = None):
|
||||
"""
|
||||
列出分享链接中的文件和目录。
|
||||
"""
|
||||
share = await share_service.get_share_by_token(token)
|
||||
|
||||
if share.access_type == "password":
|
||||
if not password:
|
||||
raise HTTPException(status_code=401, detail="需要密码")
|
||||
if not share_service._verify_password(password, share.hashed_password):
|
||||
raise HTTPException(status_code=403, detail="密码错误")
|
||||
|
||||
content = await share_service.get_shared_item_details(share, path)
|
||||
return success({
|
||||
"path": path,
|
||||
"entries": content.get("items", []),
|
||||
"pagination": {
|
||||
"total": content.get("total", 0),
|
||||
"page": content.get("page", 1),
|
||||
"page_size": content.get("page_size", 1),
|
||||
"pages": content.get("pages", 1),
|
||||
}
|
||||
})
|
||||
|
||||
@public_router.get("/{token}")
|
||||
async def get_share_info(token: str):
|
||||
"""
|
||||
获取分享链接的元数据信息。
|
||||
"""
|
||||
share = await share_service.get_share_by_token(token)
|
||||
return success(ShareInfo.from_orm(share))
|
||||
|
||||
|
||||
|
||||
@public_router.get("/{token}/download")
|
||||
async def download_shared_file(token: str, path: str, request: Request, password: Optional[str] = None):
|
||||
"""
|
||||
下载分享链接中的单个文件。
|
||||
"""
|
||||
if not path or path == "/" or ".." in path.split('/'):
|
||||
raise HTTPException(status_code=400, detail="无效的文件路径")
|
||||
|
||||
share = await share_service.get_share_by_token(token)
|
||||
if share.access_type == "password":
|
||||
if not password:
|
||||
raise HTTPException(status_code=401, detail="需要密码")
|
||||
if not share_service._verify_password(password, share.hashed_password):
|
||||
raise HTTPException(status_code=403, detail="密码错误")
|
||||
base_shared_path = share.paths[0]
|
||||
|
||||
# 判断分享的是文件还是目录
|
||||
is_dir = False
|
||||
try:
|
||||
stat = await stat_file(base_shared_path)
|
||||
if stat and stat.get("is_dir"):
|
||||
is_dir = True
|
||||
except HTTPException as e:
|
||||
if "Path is a directory" in str(e.detail) or "Not a file" in str(e.detail):
|
||||
is_dir = True
|
||||
else:
|
||||
# The shared path itself doesn't exist, which is an issue.
|
||||
raise HTTPException(status_code=404, detail="分享的源文件不存在")
|
||||
|
||||
if is_dir:
|
||||
# 目录分享:拼接路径
|
||||
full_virtual_path = f"{base_shared_path.rstrip('/')}/{path.lstrip('/')}"
|
||||
if not full_virtual_path.startswith(base_shared_path):
|
||||
raise HTTPException(status_code=403, detail="无权访问此路径")
|
||||
else:
|
||||
# 文件分享:路径应为分享的根路径
|
||||
shared_filename = base_shared_path.split('/')[-1]
|
||||
request_filename = path.lstrip('/')
|
||||
if shared_filename != request_filename:
|
||||
raise HTTPException(status_code=403, detail="无权访问此路径")
|
||||
full_virtual_path = base_shared_path
|
||||
|
||||
range_header = request.headers.get("Range")
|
||||
response = await stream_file(full_virtual_path, range_header)
|
||||
|
||||
# 设置 Content-Disposition 头来强制下载
|
||||
filename = full_virtual_path.split('/')[-1]
|
||||
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{quote(filename)}"
|
||||
|
||||
return response
|
||||
@@ -1,84 +0,0 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from typing import Annotated
|
||||
|
||||
from models.database import AutomationTask
|
||||
from schemas.tasks import AutomationTaskCreate, AutomationTaskUpdate
|
||||
from api.response import success
|
||||
from services.auth import get_current_active_user, User
|
||||
from services.logging import LogService
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api/tasks",
|
||||
tags=["Tasks"],
|
||||
dependencies=[Depends(get_current_active_user)],
|
||||
responses={404: {"description": "Not found"}},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/")
|
||||
async def create_task(
|
||||
task_in: AutomationTaskCreate,
|
||||
user: User = Depends(get_current_active_user)
|
||||
):
|
||||
task = await AutomationTask.create(**task_in.model_dump())
|
||||
await LogService.action(
|
||||
"route:tasks",
|
||||
f"Created task {task.name}",
|
||||
details=task_in.model_dump(),
|
||||
user_id=user.id if hasattr(user, "id") else None,
|
||||
)
|
||||
return success(task)
|
||||
|
||||
|
||||
@router.get("/{task_id}")
|
||||
async def get_task(task_id: int):
|
||||
task = await AutomationTask.get_or_none(id=task_id)
|
||||
if not task:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Task {task_id} not found")
|
||||
return success(task)
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def list_tasks():
|
||||
tasks = await AutomationTask.all()
|
||||
return success(tasks)
|
||||
|
||||
|
||||
@router.put("/{task_id}")
|
||||
async def update_task(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
task_id: int, task_in: AutomationTaskUpdate):
|
||||
task = await AutomationTask.get_or_none(id=task_id)
|
||||
if not task:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Task {task_id} not found")
|
||||
update_data = task_in.model_dump(exclude_unset=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(task, key, value)
|
||||
await task.save()
|
||||
await LogService.action(
|
||||
"route:tasks",
|
||||
f"Updated task {task.name}",
|
||||
details=task_in.model_dump(),
|
||||
user_id=current_user.id,
|
||||
)
|
||||
return success(task)
|
||||
|
||||
|
||||
@router.delete("/{task_id}")
|
||||
async def delete_task(
|
||||
task_id: int,
|
||||
user: User = Depends(get_current_active_user)
|
||||
):
|
||||
deleted_count = await AutomationTask.filter(id=task_id).delete()
|
||||
if not deleted_count:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Task {task_id} not found")
|
||||
await LogService.action(
|
||||
"route:tasks",
|
||||
f"Deleted task {task_id}",
|
||||
details={"task_id": task_id},
|
||||
user_id=user.id if hasattr(user, "id") else None,
|
||||
)
|
||||
return success(msg="Task deleted")
|
||||
@@ -1,334 +0,0 @@
|
||||
from fastapi import APIRouter, UploadFile, File, HTTPException, Response, Query, Request, Depends
|
||||
import mimetypes
|
||||
import re
|
||||
from typing import Annotated
|
||||
|
||||
from services.auth import get_current_active_user, User
|
||||
from services.virtual_fs import (
|
||||
list_virtual_dir,
|
||||
read_file,
|
||||
write_file,
|
||||
make_dir,
|
||||
delete_path,
|
||||
move_path,
|
||||
resolve_adapter_and_rel,
|
||||
stream_file,
|
||||
generate_temp_link_token,
|
||||
verify_temp_link_token,
|
||||
)
|
||||
from services.thumbnail import is_image_filename, get_or_create_thumb, is_raw_filename
|
||||
from schemas import MkdirRequest, MoveRequest
|
||||
from api.response import success
|
||||
|
||||
router = APIRouter(prefix='/api/fs', tags=["virtual-fs"])
|
||||
|
||||
|
||||
@router.get("/file/{full_path:path}")
|
||||
async def get_file(
|
||||
full_path: str,
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
full_path = '/' + full_path if not full_path.startswith('/') else full_path
|
||||
|
||||
if is_raw_filename(full_path):
|
||||
import rawpy
|
||||
from PIL import Image
|
||||
import io
|
||||
try:
|
||||
raw_data = await read_file(full_path)
|
||||
with rawpy.imread(io.BytesIO(raw_data)) as raw:
|
||||
rgb = raw.postprocess(use_camera_wb=True, output_bps=8)
|
||||
im = Image.fromarray(rgb)
|
||||
buf = io.BytesIO()
|
||||
im.save(buf, 'JPEG', quality=90)
|
||||
content = buf.getvalue()
|
||||
return Response(content=content, media_type='image/jpeg')
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(404, detail="File not found")
|
||||
except Exception as e:
|
||||
raise HTTPException(500, detail=f"RAW file processing failed: {e}")
|
||||
|
||||
try:
|
||||
content = await read_file(full_path)
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(404, detail="File not found")
|
||||
|
||||
if not isinstance(content, (bytes, bytearray)):
|
||||
return Response(content=content, media_type="application/octet-stream")
|
||||
|
||||
content_length = len(content)
|
||||
content_type = mimetypes.guess_type(
|
||||
full_path)[0] or "application/octet-stream"
|
||||
|
||||
range_header = request.headers.get('Range')
|
||||
if range_header:
|
||||
range_match = re.match(r'bytes=(\d+)-(\d*)', range_header)
|
||||
if range_match:
|
||||
start = int(range_match.group(1))
|
||||
end = int(range_match.group(2)) if range_match.group(
|
||||
2) else content_length - 1
|
||||
|
||||
start = max(0, min(start, content_length - 1))
|
||||
end = max(start, min(end, content_length - 1))
|
||||
|
||||
chunk = content[start:end + 1]
|
||||
chunk_size = len(chunk)
|
||||
|
||||
headers = {
|
||||
'Content-Range': f'bytes {start}-{end}/{content_length}',
|
||||
'Accept-Ranges': 'bytes',
|
||||
'Content-Length': str(chunk_size),
|
||||
'Content-Type': content_type,
|
||||
}
|
||||
|
||||
return Response(
|
||||
content=chunk,
|
||||
status_code=206,
|
||||
headers=headers
|
||||
)
|
||||
|
||||
headers = {
|
||||
'Accept-Ranges': 'bytes',
|
||||
'Content-Length': str(content_length),
|
||||
'Content-Type': content_type,
|
||||
}
|
||||
|
||||
if content_type.startswith('video/'):
|
||||
headers['Cache-Control'] = 'public, max-age=3600'
|
||||
|
||||
return Response(content=content, headers=headers)
|
||||
|
||||
|
||||
@router.get("/thumb/{full_path:path}")
|
||||
async def get_thumb(
|
||||
full_path: str,
|
||||
w: int = Query(256, ge=8, le=1024),
|
||||
h: int = Query(256, ge=8, le=1024),
|
||||
fit: str = Query("cover"),
|
||||
):
|
||||
full_path = '/' + full_path if not full_path.startswith('/') else full_path
|
||||
if fit not in ("cover", "contain"):
|
||||
raise HTTPException(400, detail="fit must be cover|contain")
|
||||
adapter, mount, root, rel = await resolve_adapter_and_rel(full_path)
|
||||
if not rel or rel.endswith('/'):
|
||||
raise HTTPException(400, detail="Not a file")
|
||||
if not is_image_filename(rel):
|
||||
raise HTTPException(404, detail="Not an image")
|
||||
# type: ignore
|
||||
data, mime, key = await get_or_create_thumb(adapter, mount.id, root, rel, w, h, fit)
|
||||
headers = {
|
||||
'Cache-Control': 'public, max-age=3600',
|
||||
'ETag': key,
|
||||
}
|
||||
return Response(content=data, media_type=mime, headers=headers)
|
||||
|
||||
|
||||
@router.get("/stream/{full_path:path}")
|
||||
async def stream_endpoint(
|
||||
full_path: str,
|
||||
request: Request,
|
||||
):
|
||||
"""支持 Range 的视频/大文件流式读取,优先使用底层适配器 Range 能力。"""
|
||||
full_path = '/' + full_path if not full_path.startswith('/') else full_path
|
||||
range_header = request.headers.get('Range')
|
||||
try:
|
||||
return await stream_file(full_path, range_header)
|
||||
except HTTPException:
|
||||
raise
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(404, detail="File not found")
|
||||
except Exception as e:
|
||||
raise HTTPException(500, detail=f"Stream error: {e}")
|
||||
|
||||
|
||||
@router.get("/temp-link/{full_path:path}")
|
||||
async def get_temp_link(
|
||||
full_path: str,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
expires_in: int = Query(3600, description="有效时间(秒), 0或负数表示永久")
|
||||
):
|
||||
"""获取文件的临时公开访问令牌"""
|
||||
full_path = '/' + full_path if not full_path.startswith('/') else full_path
|
||||
token = await generate_temp_link_token(full_path, expires_in=expires_in)
|
||||
return success({"token": token, "path": full_path})
|
||||
|
||||
|
||||
@router.get("/public/{token}")
|
||||
async def access_public_file(
|
||||
token: str,
|
||||
request: Request,
|
||||
):
|
||||
"""通过令牌公开访问文件,支持 Range 请求"""
|
||||
try:
|
||||
path = await verify_temp_link_token(token)
|
||||
except HTTPException as e:
|
||||
raise e
|
||||
|
||||
range_header = request.headers.get('Range')
|
||||
try:
|
||||
return await stream_file(path, range_header)
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(404, detail="File not found via token")
|
||||
except Exception as e:
|
||||
raise HTTPException(500, detail=f"File access error: {e}")
|
||||
|
||||
|
||||
@router.get("/stat/{full_path:path}")
|
||||
async def get_file_stat(
|
||||
full_path: str,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
full_path = '/' + full_path if not full_path.startswith('/') else full_path
|
||||
from services.virtual_fs import stat_file
|
||||
stat = await stat_file(full_path)
|
||||
return success(stat)
|
||||
|
||||
|
||||
@router.post("/file/{full_path:path}")
|
||||
async def put_file(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
full_path: str,
|
||||
file: UploadFile = File(...)
|
||||
):
|
||||
full_path = '/' + full_path if not full_path.startswith('/') else full_path
|
||||
data = await file.read()
|
||||
await write_file(full_path, data)
|
||||
return success({"written": True, "path": full_path, "size": len(data)})
|
||||
|
||||
|
||||
@router.post("/mkdir")
|
||||
async def api_mkdir(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
body: MkdirRequest
|
||||
):
|
||||
path = body.path if body.path.startswith('/') else '/' + body.path
|
||||
if not path or path == '/':
|
||||
raise HTTPException(400, detail="Invalid path")
|
||||
await make_dir(path)
|
||||
return success({"created": True, "path": path})
|
||||
|
||||
|
||||
@router.post("/move")
|
||||
async def api_move(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
body: MoveRequest
|
||||
):
|
||||
src = body.src if body.src.startswith('/') else '/' + body.src
|
||||
dst = body.dst if body.dst.startswith('/') else '/' + body.dst
|
||||
await move_path(src, dst)
|
||||
return success({"moved": True, "src": src, "dst": dst})
|
||||
|
||||
|
||||
@router.post("/rename")
|
||||
async def api_rename(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
body: MoveRequest,
|
||||
overwrite: bool = Query(False, description="是否允许覆盖已存在目标"),
|
||||
debug: bool = Query(False, description="返回调试信息")
|
||||
):
|
||||
src = body.src if body.src.startswith('/') else '/' + body.src
|
||||
dst = body.dst if body.dst.startswith('/') else '/' + body.dst
|
||||
from services.virtual_fs import rename_path
|
||||
debug_info = await rename_path(src, dst, overwrite=overwrite, return_debug=debug)
|
||||
return success({
|
||||
"renamed": True,
|
||||
"src": src,
|
||||
"dst": dst,
|
||||
"overwrite": overwrite,
|
||||
**({"debug": debug_info} if debug else {})
|
||||
})
|
||||
|
||||
|
||||
@router.post("/copy")
|
||||
async def api_copy(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
body: MoveRequest,
|
||||
overwrite: bool = Query(False, description="是否覆盖已存在目标"),
|
||||
debug: bool = Query(False, description="返回调试信息")
|
||||
):
|
||||
from services.virtual_fs import copy_path
|
||||
src = body.src if body.src.startswith('/') else '/' + body.src
|
||||
dst = body.dst if body.dst.startswith('/') else '/' + body.dst
|
||||
debug_info = await copy_path(src, dst, overwrite=overwrite, return_debug=debug)
|
||||
return success({
|
||||
"copied": True,
|
||||
"src": src,
|
||||
"dst": dst,
|
||||
"overwrite": overwrite,
|
||||
**({"debug": debug_info} if debug else {})
|
||||
})
|
||||
|
||||
|
||||
@router.post("/upload/{full_path:path}")
|
||||
async def upload_stream(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
full_path: str,
|
||||
file: UploadFile = File(...),
|
||||
overwrite: bool = Query(True, description="是否覆盖已存在文件"),
|
||||
chunk_size: int = Query(1024 * 1024, ge=8 * 1024,
|
||||
le=8 * 1024 * 1024, description="单次读取块大小")
|
||||
):
|
||||
full_path = '/' + full_path if not full_path.startswith('/') else full_path
|
||||
if full_path.endswith('/'):
|
||||
raise HTTPException(400, detail="Path must be a file")
|
||||
from services.virtual_fs import write_file_stream, resolve_adapter_and_rel
|
||||
adapter, _m, root, rel = await resolve_adapter_and_rel(full_path)
|
||||
exists_func = getattr(adapter, "exists", None)
|
||||
if not overwrite and callable(exists_func):
|
||||
try:
|
||||
if await exists_func(root, rel):
|
||||
raise HTTPException(409, detail="Destination exists")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def gen():
|
||||
while True:
|
||||
chunk = await file.read(chunk_size)
|
||||
if not chunk:
|
||||
break
|
||||
yield chunk
|
||||
size = await write_file_stream(full_path, gen(), overwrite=overwrite)
|
||||
return success({"uploaded": True, "path": full_path, "size": size, "overwrite": overwrite})
|
||||
|
||||
|
||||
@router.get("/{full_path:path}")
|
||||
async def browse_fs(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
full_path: str,
|
||||
page_num: int = Query(1, alias="page", ge=1, description="页码"),
|
||||
page_size: int = Query(50, ge=1, le=500, description="每页条数")
|
||||
):
|
||||
full_path = '/' + full_path if not full_path.startswith('/') else full_path
|
||||
result = await list_virtual_dir(full_path, page_num, page_size)
|
||||
return success({
|
||||
"path": full_path,
|
||||
"entries": result["items"],
|
||||
"pagination": {
|
||||
"total": result["total"],
|
||||
"page": result["page"],
|
||||
"page_size": result["page_size"],
|
||||
"pages": result["pages"]
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
@router.delete("/{full_path:path}")
|
||||
async def api_delete(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
full_path: str
|
||||
):
|
||||
full_path = '/' + full_path if not full_path.startswith('/') else full_path
|
||||
await delete_path(full_path)
|
||||
return success({"deleted": True, "path": full_path})
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def root_listing(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
page_num: int = Query(1, alias="page", ge=1, description="页码"),
|
||||
page_size: int = Query(50, ge=1, le=500, description="每页条数")
|
||||
):
|
||||
return await browse_fs("", page_num, page_size)
|
||||
@@ -1,7 +1,7 @@
|
||||
services:
|
||||
foxel:
|
||||
image: ghcr.io/drizzletime/foxel:latest
|
||||
#image: ghcr.nju.edu.cn/drizzletime/foxel:latest #国内用户可以用此镜像命令
|
||||
#image: ghcr.nju.edu.cn/drizzletime/foxel:latest # 国内用户可以用此镜像命令
|
||||
container_name: foxel
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from tortoise import Tortoise
|
||||
|
||||
from services.adapters.registry import runtime_registry
|
||||
from domain.adapters.registry import runtime_registry
|
||||
|
||||
TORTOISE_ORM = {
|
||||
"connections": {"default": "sqlite://data/db/db.sqlite3"},
|
||||
@@ -12,7 +12,6 @@ TORTOISE_ORM = {
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
async def init_db():
|
||||
await Tortoise.init(config=TORTOISE_ORM)
|
||||
await Tortoise.generate_schemas()
|
||||
|
||||
1
domain/adapters/__init__.py
Normal file
1
domain/adapters/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
85
domain/adapters/api.py
Normal file
85
domain/adapters/api.py
Normal file
@@ -0,0 +1,85 @@
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
|
||||
from api.response import success
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.adapters.service import AdapterService
|
||||
from domain.adapters.types import AdapterCreate
|
||||
from domain.auth.service import get_current_active_user
|
||||
from domain.auth.types import User
|
||||
|
||||
router = APIRouter(prefix="/api/adapters", tags=["adapters"])
|
||||
|
||||
|
||||
@router.post("")
|
||||
@audit(
|
||||
action=AuditAction.CREATE,
|
||||
description="创建存储适配器",
|
||||
body_fields=["name", "type", "path", "sub_path", "enabled"],
|
||||
)
|
||||
async def create_adapter(
|
||||
request: Request,
|
||||
data: AdapterCreate,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
adapter = await AdapterService.create_adapter(data, current_user)
|
||||
return success(adapter)
|
||||
|
||||
|
||||
@router.get("")
|
||||
@audit(action=AuditAction.READ, description="获取适配器列表")
|
||||
async def list_adapters(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
adapters = await AdapterService.list_adapters()
|
||||
return success(adapters)
|
||||
|
||||
|
||||
@router.get("/available")
|
||||
@audit(action=AuditAction.READ, description="获取可用适配器类型")
|
||||
async def available_adapter_types(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
data = await AdapterService.available_adapter_types()
|
||||
return success(data)
|
||||
|
||||
|
||||
@router.get("/{adapter_id}")
|
||||
@audit(action=AuditAction.READ, description="获取适配器详情")
|
||||
async def get_adapter(
|
||||
request: Request,
|
||||
adapter_id: int,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
adapter = await AdapterService.get_adapter(adapter_id)
|
||||
return success(adapter)
|
||||
|
||||
|
||||
@router.put("/{adapter_id}")
|
||||
@audit(
|
||||
action=AuditAction.UPDATE,
|
||||
description="更新存储适配器",
|
||||
body_fields=["name", "type", "path", "sub_path", "enabled"],
|
||||
)
|
||||
async def update_adapter(
|
||||
request: Request,
|
||||
adapter_id: int,
|
||||
data: AdapterCreate,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
adapter = await AdapterService.update_adapter(adapter_id, data, current_user)
|
||||
return success(adapter)
|
||||
|
||||
|
||||
@router.delete("/{adapter_id}")
|
||||
@audit(action=AuditAction.DELETE, description="删除存储适配器")
|
||||
async def delete_adapter(
|
||||
request: Request,
|
||||
adapter_id: int,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
result = await AdapterService.delete_adapter(adapter_id, current_user)
|
||||
return success(result)
|
||||
3
domain/adapters/providers/__init__.py
Normal file
3
domain/adapters/providers/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .base import BaseAdapter
|
||||
|
||||
__all__ = ["BaseAdapter"]
|
||||
487
domain/adapters/providers/alist.py
Normal file
487
domain/adapters/providers/alist.py
Normal file
@@ -0,0 +1,487 @@
|
||||
import asyncio
|
||||
import mimetypes
|
||||
import re
|
||||
import tempfile
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, AsyncIterator, Dict, List, Tuple
|
||||
from urllib.parse import quote, urljoin
|
||||
|
||||
import httpx
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import Response, StreamingResponse
|
||||
|
||||
from models import StorageAdapter
|
||||
|
||||
|
||||
def _normalize_fs_path(path: str) -> str:
|
||||
path = (path or "").replace("\\", "/").strip()
|
||||
if not path or path == "/":
|
||||
return "/"
|
||||
if not path.startswith("/"):
|
||||
path = "/" + path
|
||||
path = re.sub(r"/{2,}", "/", path)
|
||||
if path != "/" and path.endswith("/"):
|
||||
path = path.rstrip("/")
|
||||
return path or "/"
|
||||
|
||||
|
||||
def _join_fs_path(base: str, rel: str) -> str:
|
||||
base = _normalize_fs_path(base)
|
||||
rel = (rel or "").replace("\\", "/").lstrip("/")
|
||||
if not rel:
|
||||
return base
|
||||
if base == "/":
|
||||
return "/" + rel
|
||||
return f"{base}/{rel}"
|
||||
|
||||
|
||||
def _split_parent_and_name(path: str) -> Tuple[str, str]:
|
||||
path = _normalize_fs_path(path)
|
||||
if path == "/":
|
||||
return "/", ""
|
||||
parent, _, name = path.rpartition("/")
|
||||
if not parent:
|
||||
parent = "/"
|
||||
return parent, name
|
||||
|
||||
|
||||
def _parse_iso_to_epoch(value: str | None) -> int:
|
||||
if not value:
|
||||
return 0
|
||||
text = str(value).strip()
|
||||
if not text:
|
||||
return 0
|
||||
try:
|
||||
if text.endswith("Z"):
|
||||
text = text[:-1] + "+00:00"
|
||||
m = re.match(r"^(.*?)(\.\d+)([+-]\d\d:\d\d)?$", text)
|
||||
if m:
|
||||
head, frac, tz = m.group(1), m.group(2), m.group(3) or ""
|
||||
digits = frac[1:]
|
||||
if len(digits) > 6:
|
||||
frac = "." + digits[:6]
|
||||
text = head + frac + tz
|
||||
dt = datetime.fromisoformat(text)
|
||||
if dt.tzinfo is None:
|
||||
dt = dt.replace(tzinfo=timezone.utc)
|
||||
return int(dt.timestamp())
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
|
||||
class AListApiAdapterBase:
|
||||
def __init__(self, record: StorageAdapter, *, product_name: str):
|
||||
self.record = record
|
||||
self.product_name = product_name
|
||||
|
||||
cfg = record.config or {}
|
||||
self.base_url: str = str(cfg.get("base_url", "")).rstrip("/")
|
||||
if not self.base_url.startswith("http"):
|
||||
raise ValueError(f"{product_name} requires base_url http/https")
|
||||
self.username: str = str(cfg.get("username") or "")
|
||||
self.password: str = str(cfg.get("password") or "")
|
||||
if not self.username or not self.password:
|
||||
raise ValueError(f"{product_name} requires username and password")
|
||||
|
||||
self.timeout: float = float(cfg.get("timeout", 30))
|
||||
self.root_path: str = _normalize_fs_path(str(cfg.get("root") or "/"))
|
||||
self.enable_redirect_307: bool = bool(cfg.get("enable_direct_download_307"))
|
||||
|
||||
self._token: str | None = None
|
||||
self._login_lock = asyncio.Lock()
|
||||
|
||||
def get_effective_root(self, sub_path: str | None) -> str:
|
||||
base = _normalize_fs_path(self.root_path)
|
||||
if sub_path:
|
||||
return _join_fs_path(base, sub_path)
|
||||
return base
|
||||
|
||||
async def _ensure_token(self) -> str:
|
||||
if self._token:
|
||||
return self._token
|
||||
async with self._login_lock:
|
||||
if self._token:
|
||||
return self._token
|
||||
self._token = await self._login()
|
||||
return self._token
|
||||
|
||||
async def _login(self) -> str:
|
||||
url = self.base_url + "/api/auth/login"
|
||||
body = {"username": self.username, "password": self.password}
|
||||
async with httpx.AsyncClient(timeout=self.timeout, follow_redirects=True) as client:
|
||||
resp = await client.post(url, json=body)
|
||||
resp.raise_for_status()
|
||||
payload = resp.json()
|
||||
if not isinstance(payload, dict):
|
||||
raise HTTPException(502, detail=f"{self.product_name} login: invalid response")
|
||||
code = payload.get("code")
|
||||
if code not in (0, 200):
|
||||
raise HTTPException(502, detail=f"{self.product_name} login failed: {payload.get('message')}")
|
||||
data = payload.get("data") or {}
|
||||
token = (data.get("token") if isinstance(data, dict) else None) or ""
|
||||
token = str(token).strip()
|
||||
if not token:
|
||||
raise HTTPException(502, detail=f"{self.product_name} login: missing token")
|
||||
return token
|
||||
|
||||
async def _api_json(
|
||||
self,
|
||||
method: str,
|
||||
endpoint: str,
|
||||
*,
|
||||
json: Dict[str, Any] | None = None,
|
||||
headers: Dict[str, str] | None = None,
|
||||
retry: bool = True,
|
||||
files: Any = None,
|
||||
) -> Any:
|
||||
token = await self._ensure_token()
|
||||
url = self.base_url + endpoint
|
||||
req_headers: Dict[str, str] = {"Authorization": token}
|
||||
if headers:
|
||||
req_headers.update(headers)
|
||||
async with httpx.AsyncClient(timeout=self.timeout, follow_redirects=True) as client:
|
||||
resp = await client.request(method, url, json=json, headers=req_headers, files=files)
|
||||
if resp.status_code == 401 and retry:
|
||||
self._token = None
|
||||
return await self._api_json(method, endpoint, json=json, headers=headers, retry=False, files=files)
|
||||
resp.raise_for_status()
|
||||
payload = resp.json()
|
||||
if not isinstance(payload, dict):
|
||||
raise HTTPException(502, detail=f"{self.product_name} api: invalid response")
|
||||
|
||||
code = payload.get("code")
|
||||
if code in (0, 200):
|
||||
return payload.get("data")
|
||||
if code in (401, 403) and retry:
|
||||
self._token = None
|
||||
return await self._api_json(method, endpoint, json=json, headers=headers, retry=False, files=files)
|
||||
if code == 404:
|
||||
raise FileNotFoundError(json.get("path") if json else "")
|
||||
msg = payload.get("message") or payload.get("msg") or ""
|
||||
raise HTTPException(502, detail=f"{self.product_name} api error code={code} msg={msg}")
|
||||
|
||||
def _abs_url(self, url: str) -> str:
|
||||
u = (url or "").strip()
|
||||
if not u:
|
||||
return ""
|
||||
if u.startswith("http://") or u.startswith("https://"):
|
||||
return u
|
||||
return urljoin(self.base_url.rstrip("/") + "/", u.lstrip("/"))
|
||||
|
||||
async def _fs_list(self, path: str) -> Dict[str, Any]:
|
||||
body = {"path": path, "password": "", "page": 1, "per_page": 0, "refresh": False}
|
||||
data = await self._api_json("POST", "/api/fs/list", json=body)
|
||||
return data or {}
|
||||
|
||||
async def _fs_get(self, path: str) -> Dict[str, Any]:
|
||||
body = {"path": path, "password": "", "page": 1, "per_page": 0, "refresh": False}
|
||||
data = await self._api_json("POST", "/api/fs/get", json=body)
|
||||
return data or {}
|
||||
|
||||
async def list_dir(
|
||||
self,
|
||||
root: str,
|
||||
rel: str,
|
||||
page_num: int = 1,
|
||||
page_size: int = 50,
|
||||
sort_by: str = "name",
|
||||
sort_order: str = "asc",
|
||||
) -> Tuple[List[Dict], int]:
|
||||
path = _join_fs_path(root, rel)
|
||||
data = await self._fs_list(path)
|
||||
content = data.get("content") or []
|
||||
if not isinstance(content, list):
|
||||
raise HTTPException(502, detail=f"{self.product_name} list_dir: invalid content")
|
||||
|
||||
entries: List[Dict] = []
|
||||
for it in content:
|
||||
if not isinstance(it, dict):
|
||||
continue
|
||||
name = str(it.get("name") or "")
|
||||
if not name:
|
||||
continue
|
||||
is_dir = bool(it.get("is_dir"))
|
||||
size = int(it.get("size") or 0) if not is_dir else 0
|
||||
mtime = _parse_iso_to_epoch(it.get("modified"))
|
||||
entries.append(
|
||||
{
|
||||
"name": name,
|
||||
"is_dir": is_dir,
|
||||
"size": size,
|
||||
"mtime": mtime,
|
||||
"type": "dir" if is_dir else "file",
|
||||
}
|
||||
)
|
||||
|
||||
reverse = sort_order.lower() == "desc"
|
||||
|
||||
def get_sort_key(item: Dict) -> Tuple:
|
||||
key = (not item.get("is_dir"),)
|
||||
f = sort_by.lower()
|
||||
if f == "name":
|
||||
key += (str(item.get("name", "")).lower(),)
|
||||
elif f == "size":
|
||||
key += (int(item.get("size", 0)),)
|
||||
elif f == "mtime":
|
||||
key += (int(item.get("mtime", 0)),)
|
||||
else:
|
||||
key += (str(item.get("name", "")).lower(),)
|
||||
return key
|
||||
|
||||
entries.sort(key=get_sort_key, reverse=reverse)
|
||||
total = len(entries)
|
||||
start = (page_num - 1) * page_size
|
||||
end = start + page_size
|
||||
return entries[start:end], total
|
||||
|
||||
async def stat_file(self, root: str, rel: str):
|
||||
path = _join_fs_path(root, rel)
|
||||
data = await self._fs_get(path)
|
||||
if not data:
|
||||
raise FileNotFoundError(rel)
|
||||
is_dir = bool(data.get("is_dir"))
|
||||
name = str(data.get("name") or (rel.rstrip("/").split("/")[-1] if rel else ""))
|
||||
size = int(data.get("size") or 0) if not is_dir else 0
|
||||
mtime = _parse_iso_to_epoch(data.get("modified"))
|
||||
info = {
|
||||
"name": name,
|
||||
"is_dir": is_dir,
|
||||
"size": size,
|
||||
"mtime": mtime,
|
||||
"type": "dir" if is_dir else "file",
|
||||
"path": path,
|
||||
}
|
||||
return info
|
||||
|
||||
async def stat_path(self, root: str, rel: str):
|
||||
try:
|
||||
info = await self.stat_file(root, rel)
|
||||
return {"exists": True, "is_dir": bool(info.get("is_dir")), "path": info.get("path")}
|
||||
except FileNotFoundError:
|
||||
return {"exists": False, "is_dir": None, "path": _join_fs_path(root, rel)}
|
||||
|
||||
async def exists(self, root: str, rel: str) -> bool:
|
||||
try:
|
||||
await self.stat_file(root, rel)
|
||||
return True
|
||||
except FileNotFoundError:
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def get_direct_download_response(self, root: str, rel: str):
|
||||
if not self.enable_redirect_307:
|
||||
return None
|
||||
data = await self._fs_get(_join_fs_path(root, rel))
|
||||
if not data:
|
||||
raise FileNotFoundError(rel)
|
||||
if bool(data.get("is_dir")):
|
||||
raise IsADirectoryError(rel)
|
||||
raw_url = self._abs_url(str(data.get("raw_url") or ""))
|
||||
if not raw_url:
|
||||
return None
|
||||
return Response(status_code=307, headers={"Location": raw_url})
|
||||
|
||||
async def _get_raw_url_and_meta(self, root: str, rel: str) -> Tuple[str, int, str]:
|
||||
data = await self._fs_get(_join_fs_path(root, rel))
|
||||
if not data:
|
||||
raise FileNotFoundError(rel)
|
||||
if bool(data.get("is_dir")):
|
||||
raise IsADirectoryError(rel)
|
||||
raw_url = self._abs_url(str(data.get("raw_url") or ""))
|
||||
if not raw_url:
|
||||
raise HTTPException(502, detail=f"{self.product_name} missing raw_url")
|
||||
size = int(data.get("size") or 0)
|
||||
name = str(data.get("name") or "")
|
||||
return raw_url, size, name
|
||||
|
||||
async def read_file(self, root: str, rel: str) -> bytes:
|
||||
raw_url, _, _ = await self._get_raw_url_and_meta(root, rel)
|
||||
async with httpx.AsyncClient(timeout=self.timeout, follow_redirects=True) as client:
|
||||
resp = await client.get(raw_url)
|
||||
resp.raise_for_status()
|
||||
return resp.content
|
||||
|
||||
async def stream_file(self, root: str, rel: str, range_header: str | None):
|
||||
raw_url, file_size, name = await self._get_raw_url_and_meta(root, rel)
|
||||
mime, _ = mimetypes.guess_type(name or rel)
|
||||
content_type = mime or "application/octet-stream"
|
||||
|
||||
start = 0
|
||||
end = max(file_size - 1, 0)
|
||||
status = 200
|
||||
headers = {
|
||||
"Accept-Ranges": "bytes",
|
||||
"Content-Type": content_type,
|
||||
}
|
||||
if file_size >= 0:
|
||||
headers["Content-Length"] = str(file_size)
|
||||
|
||||
if range_header and range_header.startswith("bytes="):
|
||||
try:
|
||||
part = range_header.removeprefix("bytes=")
|
||||
s, e = part.split("-", 1)
|
||||
if s.strip():
|
||||
start = int(s)
|
||||
if e.strip():
|
||||
end = int(e)
|
||||
if file_size and start >= file_size:
|
||||
raise HTTPException(416, detail="Requested Range Not Satisfiable")
|
||||
if file_size and end >= file_size:
|
||||
end = file_size - 1
|
||||
status = 206
|
||||
except ValueError:
|
||||
raise HTTPException(400, detail="Invalid Range header")
|
||||
headers["Content-Range"] = f"bytes {start}-{end}/{file_size}"
|
||||
headers["Content-Length"] = str(end - start + 1)
|
||||
|
||||
async def agen():
|
||||
async with httpx.AsyncClient(timeout=self.timeout, follow_redirects=True) as client:
|
||||
req_headers = {"Range": f"bytes={start}-{end}"} if status == 206 else {}
|
||||
async with client.stream("GET", raw_url, headers=req_headers) as resp:
|
||||
resp.raise_for_status()
|
||||
async for chunk in resp.aiter_bytes():
|
||||
if chunk:
|
||||
yield chunk
|
||||
|
||||
return StreamingResponse(agen(), status_code=status, headers=headers, media_type=content_type)
|
||||
|
||||
async def _upload_file(self, full_path: str, file_path: Path) -> Any:
|
||||
token = await self._ensure_token()
|
||||
headers = {
|
||||
"Authorization": token,
|
||||
"File-Path": quote(full_path, safe="/"),
|
||||
}
|
||||
with file_path.open("rb") as f:
|
||||
files = {"file": (file_path.name, f, "application/octet-stream")}
|
||||
async with httpx.AsyncClient(timeout=self.timeout, follow_redirects=True) as client:
|
||||
resp = await client.put(self.base_url + "/api/fs/form", headers=headers, files=files)
|
||||
resp.raise_for_status()
|
||||
payload = resp.json()
|
||||
if not isinstance(payload, dict):
|
||||
raise HTTPException(502, detail=f"{self.product_name} upload: invalid response")
|
||||
code = payload.get("code")
|
||||
if code not in (0, 200):
|
||||
msg = payload.get("message") or payload.get("msg") or ""
|
||||
raise HTTPException(502, detail=f"{self.product_name} upload failed: {msg}")
|
||||
return payload.get("data")
|
||||
|
||||
async def write_file(self, root: str, rel: str, data: bytes):
|
||||
full_path = _join_fs_path(root, rel)
|
||||
suffix = Path(rel).suffix
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tf:
|
||||
tf.write(data)
|
||||
tmp_path = Path(tf.name)
|
||||
try:
|
||||
await self._upload_file(full_path, tmp_path)
|
||||
finally:
|
||||
try:
|
||||
tmp_path.unlink(missing_ok=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def write_file_stream(self, root: str, rel: str, data_iter: AsyncIterator[bytes]):
|
||||
full_path = _join_fs_path(root, rel)
|
||||
suffix = Path(rel).suffix
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tf:
|
||||
tmp_path = Path(tf.name)
|
||||
size = 0
|
||||
try:
|
||||
with tmp_path.open("wb") as f:
|
||||
async for chunk in data_iter:
|
||||
if not chunk:
|
||||
continue
|
||||
f.write(chunk)
|
||||
size += len(chunk)
|
||||
await self._upload_file(full_path, tmp_path)
|
||||
return size
|
||||
finally:
|
||||
try:
|
||||
tmp_path.unlink(missing_ok=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def mkdir(self, root: str, rel: str):
|
||||
path = _join_fs_path(root, rel)
|
||||
await self._api_json("POST", "/api/fs/mkdir", json={"path": path})
|
||||
|
||||
async def delete(self, root: str, rel: str):
|
||||
path = _join_fs_path(root, rel)
|
||||
parent, name = _split_parent_and_name(path)
|
||||
if not name:
|
||||
return
|
||||
await self._api_json("POST", "/api/fs/remove", json={"dir": parent, "names": [name]})
|
||||
|
||||
async def move(self, root: str, src_rel: str, dst_rel: str):
|
||||
src_path = _join_fs_path(root, src_rel)
|
||||
dst_path = _join_fs_path(root, dst_rel)
|
||||
src_dir, src_name = _split_parent_and_name(src_path)
|
||||
dst_dir, dst_name = _split_parent_and_name(dst_path)
|
||||
if not src_name or not dst_name:
|
||||
raise HTTPException(400, detail="Invalid move path")
|
||||
|
||||
if src_dir == dst_dir:
|
||||
if src_name == dst_name:
|
||||
return
|
||||
await self._api_json("POST", "/api/fs/rename", json={"path": src_path, "name": dst_name})
|
||||
return
|
||||
|
||||
await self._api_json("POST", "/api/fs/move", json={"src_dir": src_dir, "dst_dir": dst_dir, "names": [src_name]})
|
||||
if src_name != dst_name:
|
||||
await self._api_json("POST", "/api/fs/rename", json={"path": _join_fs_path(dst_dir, src_name), "name": dst_name})
|
||||
|
||||
async def rename(self, root: str, src_rel: str, dst_rel: str):
|
||||
await self.move(root, src_rel, dst_rel)
|
||||
|
||||
async def copy(self, root: str, src_rel: str, dst_rel: str, overwrite: bool = False):
|
||||
src_path = _join_fs_path(root, src_rel)
|
||||
dst_path = _join_fs_path(root, dst_rel)
|
||||
src_dir, src_name = _split_parent_and_name(src_path)
|
||||
dst_dir, dst_name = _split_parent_and_name(dst_path)
|
||||
if not src_name or not dst_name:
|
||||
raise HTTPException(400, detail="Invalid copy path")
|
||||
|
||||
src_info = await self._fs_get(src_path)
|
||||
if not src_info:
|
||||
raise FileNotFoundError(src_rel)
|
||||
|
||||
if src_name != dst_name and not bool(src_info.get("is_dir")):
|
||||
raw_url, _, _ = await self._get_raw_url_and_meta(root, src_rel)
|
||||
async with httpx.AsyncClient(timeout=self.timeout, follow_redirects=True) as client:
|
||||
async with client.stream("GET", raw_url) as resp:
|
||||
resp.raise_for_status()
|
||||
|
||||
async def gen():
|
||||
async for chunk in resp.aiter_bytes():
|
||||
if chunk:
|
||||
yield chunk
|
||||
|
||||
await self.write_file_stream(root, dst_rel, gen())
|
||||
return
|
||||
|
||||
await self._api_json("POST", "/api/fs/copy", json={"src_dir": src_dir, "dst_dir": dst_dir, "names": [src_name]})
|
||||
if src_name != dst_name:
|
||||
await self._api_json("POST", "/api/fs/rename", json={"path": _join_fs_path(dst_dir, src_name), "name": dst_name})
|
||||
|
||||
|
||||
class AListAdapter(AListApiAdapterBase):
|
||||
def __init__(self, record: StorageAdapter):
|
||||
super().__init__(record, product_name="AList")
|
||||
|
||||
|
||||
class OpenListAdapter(AListApiAdapterBase):
|
||||
def __init__(self, record: StorageAdapter):
|
||||
super().__init__(record, product_name="OpenList")
|
||||
|
||||
|
||||
ADAPTER_TYPES = {"alist": AListAdapter, "openlist": OpenListAdapter}
|
||||
|
||||
CONFIG_SCHEMA = [
|
||||
{"key": "base_url", "label": "基础地址", "type": "string", "required": True, "placeholder": "http://127.0.0.1:5244"},
|
||||
{"key": "username", "label": "用户名", "type": "string", "required": True},
|
||||
{"key": "password", "label": "密码", "type": "password", "required": True},
|
||||
{"key": "root", "label": "根目录", "type": "string", "required": False, "default": "/"},
|
||||
{"key": "timeout", "label": "超时(秒)", "type": "number", "required": False, "default": 30},
|
||||
{"key": "enable_direct_download_307", "label": "启用 307 直链下载", "type": "boolean", "default": False},
|
||||
]
|
||||
@@ -1,4 +1,3 @@
|
||||
from __future__ import annotations
|
||||
from typing import List, Dict, Protocol, runtime_checkable, Tuple, AsyncIterator
|
||||
from models import StorageAdapter
|
||||
|
||||
@@ -10,7 +9,7 @@ from models import StorageAdapter
|
||||
@runtime_checkable
|
||||
class BaseAdapter(Protocol):
|
||||
record: StorageAdapter
|
||||
async def list_dir(self, root: str, rel: str, page_num: int = 1, page_size: int = 50) -> Tuple[List[Dict], int]: ...
|
||||
async def list_dir(self, root: str, rel: str, page_num: int = 1, page_size: int = 50, sort_by: str = "name", sort_order: str = "asc") -> Tuple[List[Dict], int]: ...
|
||||
async def read_file(self, root: str, rel: str) -> bytes: ...
|
||||
async def write_file(self, root: str, rel: str, data: bytes): ...
|
||||
async def write_file_stream(self, root: str, rel: str, data_iter: AsyncIterator[bytes]): ...
|
||||
471
domain/adapters/providers/dropbox.py
Normal file
471
domain/adapters/providers/dropbox.py
Normal file
@@ -0,0 +1,471 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import mimetypes
|
||||
import re
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import AsyncIterator, Dict, List, Tuple
|
||||
|
||||
import httpx
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import Response, StreamingResponse
|
||||
|
||||
from models import StorageAdapter
|
||||
|
||||
DROPBOX_OAUTH_URL = "https://api.dropboxapi.com/oauth2/token"
|
||||
DROPBOX_API_URL = "https://api.dropboxapi.com/2"
|
||||
DROPBOX_CONTENT_URL = "https://content.dropboxapi.com/2"
|
||||
|
||||
|
||||
def _normalize_dbx_path(path: str | None) -> str:
|
||||
path = (path or "").replace("\\", "/").strip()
|
||||
if not path or path == "/":
|
||||
return ""
|
||||
if not path.startswith("/"):
|
||||
path = "/" + path
|
||||
path = re.sub(r"/{2,}", "/", path)
|
||||
if path.endswith("/"):
|
||||
path = path.rstrip("/")
|
||||
return path
|
||||
|
||||
|
||||
def _join_dbx_path(base: str, rel: str) -> str:
|
||||
base = _normalize_dbx_path(base)
|
||||
rel = (rel or "").replace("\\", "/").strip("/")
|
||||
if not rel:
|
||||
return base
|
||||
if not base:
|
||||
return "/" + rel
|
||||
return f"{base}/{rel}"
|
||||
|
||||
|
||||
def _parse_iso_to_epoch(value: str | None) -> int:
|
||||
if not value:
|
||||
return 0
|
||||
text = str(value).strip()
|
||||
if not text:
|
||||
return 0
|
||||
try:
|
||||
if text.endswith("Z"):
|
||||
text = text[:-1] + "+00:00"
|
||||
dt = datetime.fromisoformat(text)
|
||||
if dt.tzinfo is None:
|
||||
dt = dt.replace(tzinfo=timezone.utc)
|
||||
return int(dt.timestamp())
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
|
||||
class DropboxAdapter:
|
||||
def __init__(self, record: StorageAdapter):
|
||||
self.record = record
|
||||
cfg = record.config or {}
|
||||
|
||||
self.app_key: str = str(cfg.get("app_key") or "").strip()
|
||||
self.app_secret: str = str(cfg.get("app_secret") or "").strip()
|
||||
self.refresh_token: str = str(cfg.get("refresh_token") or "").strip()
|
||||
self.root_path: str = _normalize_dbx_path(str(cfg.get("root") or "/"))
|
||||
self.enable_redirect_307: bool = bool(cfg.get("enable_direct_download_307"))
|
||||
self.timeout: float = float(cfg.get("timeout", 60))
|
||||
|
||||
if not (self.app_key and self.app_secret and self.refresh_token):
|
||||
raise ValueError("Dropbox 适配器需要 app_key, app_secret, refresh_token")
|
||||
|
||||
self._access_token: str | None = None
|
||||
self._token_expiry: datetime | None = None
|
||||
self._token_lock = asyncio.Lock()
|
||||
|
||||
def get_effective_root(self, sub_path: str | None) -> str:
|
||||
base = _normalize_dbx_path(self.root_path)
|
||||
if sub_path:
|
||||
return _join_dbx_path(base, sub_path)
|
||||
return base
|
||||
|
||||
async def _get_access_token(self) -> str:
|
||||
if self._access_token and self._token_expiry and datetime.now(timezone.utc) < self._token_expiry:
|
||||
return self._access_token
|
||||
|
||||
async with self._token_lock:
|
||||
if self._access_token and self._token_expiry and datetime.now(timezone.utc) < self._token_expiry:
|
||||
return self._access_token
|
||||
|
||||
basic = base64.b64encode(f"{self.app_key}:{self.app_secret}".encode("utf-8")).decode("ascii")
|
||||
headers = {"Authorization": f"Basic {basic}"}
|
||||
data = {"grant_type": "refresh_token", "refresh_token": self.refresh_token}
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
resp = await client.post(DROPBOX_OAUTH_URL, data=data, headers=headers)
|
||||
resp.raise_for_status()
|
||||
|
||||
payload = resp.json()
|
||||
token = str(payload.get("access_token") or "").strip()
|
||||
if not token:
|
||||
raise HTTPException(502, detail="Dropbox oauth: missing access_token")
|
||||
expires_in = int(payload.get("expires_in") or 3600)
|
||||
self._access_token = token
|
||||
self._token_expiry = datetime.now(timezone.utc) + timedelta(seconds=max(60, expires_in - 300))
|
||||
return token
|
||||
|
||||
async def _api_json(self, endpoint: str, body: Dict) -> httpx.Response:
|
||||
token = await self._get_access_token()
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
return await client.post(f"{DROPBOX_API_URL}{endpoint}", json=body, headers=headers)
|
||||
|
||||
async def _content_request(
|
||||
self,
|
||||
endpoint: str,
|
||||
api_arg: Dict,
|
||||
*,
|
||||
content: bytes | None = None,
|
||||
data_iter: AsyncIterator[bytes] | None = None,
|
||||
extra_headers: Dict[str, str] | None = None,
|
||||
) -> httpx.Response:
|
||||
token = await self._get_access_token()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Dropbox-API-Arg": json.dumps(api_arg, separators=(",", ":"), ensure_ascii=False),
|
||||
}
|
||||
if extra_headers:
|
||||
headers.update(extra_headers)
|
||||
|
||||
if data_iter is None:
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
return await client.post(f"{DROPBOX_CONTENT_URL}{endpoint}", headers=headers, content=content or b"")
|
||||
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
return await client.post(f"{DROPBOX_CONTENT_URL}{endpoint}", headers=headers, content=data_iter)
|
||||
|
||||
@staticmethod
|
||||
def _raise_dbx_error(resp: httpx.Response, *, rel: str):
|
||||
try:
|
||||
payload = resp.json()
|
||||
except Exception:
|
||||
payload = None
|
||||
summary = ""
|
||||
if isinstance(payload, dict):
|
||||
summary = str(payload.get("error_summary") or "")
|
||||
if "not_found" in summary:
|
||||
raise FileNotFoundError(rel)
|
||||
if "conflict" in summary or "already_exists" in summary:
|
||||
raise FileExistsError(rel)
|
||||
if "is_folder" in summary:
|
||||
raise IsADirectoryError(rel)
|
||||
if "not_folder" in summary:
|
||||
raise NotADirectoryError(rel)
|
||||
raise HTTPException(502, detail=f"Dropbox API error: {summary or resp.text}")
|
||||
|
||||
def _format_entry(self, entry: Dict) -> Dict:
|
||||
tag = entry.get(".tag")
|
||||
is_dir = tag == "folder"
|
||||
mtime = _parse_iso_to_epoch(entry.get("server_modified") if not is_dir else None)
|
||||
return {
|
||||
"name": entry.get("name") or "",
|
||||
"is_dir": is_dir,
|
||||
"size": 0 if is_dir else int(entry.get("size") or 0),
|
||||
"mtime": mtime,
|
||||
"type": "dir" if is_dir else "file",
|
||||
}
|
||||
|
||||
async def list_dir(
|
||||
self,
|
||||
root: str,
|
||||
rel: str,
|
||||
page_num: int = 1,
|
||||
page_size: int = 50,
|
||||
sort_by: str = "name",
|
||||
sort_order: str = "asc",
|
||||
) -> Tuple[List[Dict], int]:
|
||||
path = _join_dbx_path(root, rel)
|
||||
body = {"path": path, "recursive": False, "include_deleted": False, "limit": 2000}
|
||||
resp = await self._api_json("/files/list_folder", body)
|
||||
if resp.status_code == 409:
|
||||
try:
|
||||
payload = resp.json()
|
||||
except Exception:
|
||||
payload = None
|
||||
summary = str((payload or {}).get("error_summary") or "")
|
||||
if "not_found" in summary:
|
||||
return [], 0
|
||||
self._raise_dbx_error(resp, rel=rel)
|
||||
resp.raise_for_status()
|
||||
payload = resp.json()
|
||||
|
||||
all_entries: List[Dict] = []
|
||||
all_entries.extend(payload.get("entries") or [])
|
||||
cursor = payload.get("cursor")
|
||||
has_more = bool(payload.get("has_more"))
|
||||
while has_more and cursor:
|
||||
resp2 = await self._api_json("/files/list_folder/continue", {"cursor": cursor})
|
||||
resp2.raise_for_status()
|
||||
p2 = resp2.json()
|
||||
all_entries.extend(p2.get("entries") or [])
|
||||
cursor = p2.get("cursor")
|
||||
has_more = bool(p2.get("has_more"))
|
||||
|
||||
items = [self._format_entry(e) for e in all_entries if isinstance(e, dict)]
|
||||
|
||||
reverse = sort_order.lower() == "desc"
|
||||
|
||||
def get_sort_key(item):
|
||||
key = (not item["is_dir"],)
|
||||
f = sort_by.lower()
|
||||
if f == "name":
|
||||
key += (item["name"].lower(),)
|
||||
elif f == "size":
|
||||
key += (item["size"],)
|
||||
elif f == "mtime":
|
||||
key += (item["mtime"],)
|
||||
else:
|
||||
key += (item["name"].lower(),)
|
||||
return key
|
||||
|
||||
items.sort(key=get_sort_key, reverse=reverse)
|
||||
|
||||
total = len(items)
|
||||
start = (page_num - 1) * page_size
|
||||
end = start + page_size
|
||||
return items[start:end], total
|
||||
|
||||
async def stat_file(self, root: str, rel: str):
|
||||
path = _join_dbx_path(root, rel)
|
||||
resp = await self._api_json("/files/get_metadata", {"path": path, "include_deleted": False})
|
||||
if resp.status_code == 409:
|
||||
self._raise_dbx_error(resp, rel=rel)
|
||||
resp.raise_for_status()
|
||||
meta = resp.json()
|
||||
if not isinstance(meta, dict):
|
||||
raise HTTPException(502, detail="Dropbox metadata: invalid response")
|
||||
return self._format_entry(meta)
|
||||
|
||||
async def exists(self, root: str, rel: str) -> bool:
|
||||
try:
|
||||
await self.stat_file(root, rel)
|
||||
return True
|
||||
except FileNotFoundError:
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def read_file(self, root: str, rel: str) -> bytes:
|
||||
path = _join_dbx_path(root, rel)
|
||||
resp = await self._content_request("/files/download", {"path": path})
|
||||
if resp.status_code == 409:
|
||||
self._raise_dbx_error(resp, rel=rel)
|
||||
resp.raise_for_status()
|
||||
return resp.content
|
||||
|
||||
async def write_file(self, root: str, rel: str, data: bytes):
|
||||
path = _join_dbx_path(root, rel)
|
||||
arg = {
|
||||
"path": path,
|
||||
"mode": "overwrite",
|
||||
"autorename": False,
|
||||
"mute": False,
|
||||
"strict_conflict": False,
|
||||
}
|
||||
resp = await self._content_request(
|
||||
"/files/upload",
|
||||
arg,
|
||||
content=data,
|
||||
extra_headers={"Content-Type": "application/octet-stream"},
|
||||
)
|
||||
if resp.status_code == 409:
|
||||
self._raise_dbx_error(resp, rel=rel)
|
||||
resp.raise_for_status()
|
||||
return True
|
||||
|
||||
async def write_file_stream(self, root: str, rel: str, data_iter: AsyncIterator[bytes]):
|
||||
path = _join_dbx_path(root, rel)
|
||||
|
||||
size = 0
|
||||
session_id: str | None = None
|
||||
offset = 0
|
||||
|
||||
async for chunk in data_iter:
|
||||
if not chunk:
|
||||
continue
|
||||
if session_id is None:
|
||||
resp = await self._content_request(
|
||||
"/files/upload_session_start",
|
||||
{"close": False},
|
||||
content=chunk,
|
||||
extra_headers={"Content-Type": "application/octet-stream"},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
payload = resp.json()
|
||||
session_id = str(payload.get("session_id") or "")
|
||||
if not session_id:
|
||||
raise HTTPException(502, detail="Dropbox upload_session_start: missing session_id")
|
||||
offset += len(chunk)
|
||||
size += len(chunk)
|
||||
continue
|
||||
|
||||
arg = {"cursor": {"session_id": session_id, "offset": offset}, "close": False}
|
||||
resp = await self._content_request(
|
||||
"/files/upload_session_append_v2",
|
||||
arg,
|
||||
content=chunk,
|
||||
extra_headers={"Content-Type": "application/octet-stream"},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
offset += len(chunk)
|
||||
size += len(chunk)
|
||||
|
||||
if session_id is None:
|
||||
await self.write_file(root, rel, b"")
|
||||
return 0
|
||||
|
||||
finish_arg = {
|
||||
"cursor": {"session_id": session_id, "offset": offset},
|
||||
"commit": {
|
||||
"path": path,
|
||||
"mode": "overwrite",
|
||||
"autorename": False,
|
||||
"mute": False,
|
||||
"strict_conflict": False,
|
||||
},
|
||||
}
|
||||
resp = await self._content_request(
|
||||
"/files/upload_session_finish",
|
||||
finish_arg,
|
||||
content=b"",
|
||||
extra_headers={"Content-Type": "application/octet-stream"},
|
||||
)
|
||||
if resp.status_code == 409:
|
||||
self._raise_dbx_error(resp, rel=rel)
|
||||
resp.raise_for_status()
|
||||
return size
|
||||
|
||||
async def mkdir(self, root: str, rel: str):
|
||||
path = _join_dbx_path(root, rel)
|
||||
resp = await self._api_json("/files/create_folder_v2", {"path": path, "autorename": False})
|
||||
if resp.status_code == 409:
|
||||
self._raise_dbx_error(resp, rel=rel)
|
||||
resp.raise_for_status()
|
||||
return True
|
||||
|
||||
async def delete(self, root: str, rel: str):
|
||||
path = _join_dbx_path(root, rel)
|
||||
resp = await self._api_json("/files/delete_v2", {"path": path})
|
||||
if resp.status_code == 409:
|
||||
try:
|
||||
payload = resp.json()
|
||||
except Exception:
|
||||
payload = None
|
||||
summary = str((payload or {}).get("error_summary") or "")
|
||||
if "not_found" in summary:
|
||||
return
|
||||
self._raise_dbx_error(resp, rel=rel)
|
||||
resp.raise_for_status()
|
||||
return True
|
||||
|
||||
async def move(self, root: str, src_rel: str, dst_rel: str):
|
||||
src = _join_dbx_path(root, src_rel)
|
||||
dst = _join_dbx_path(root, dst_rel)
|
||||
resp = await self._api_json(
|
||||
"/files/move_v2",
|
||||
{"from_path": src, "to_path": dst, "autorename": False, "allow_shared_folder": True},
|
||||
)
|
||||
if resp.status_code == 409:
|
||||
self._raise_dbx_error(resp, rel=src_rel)
|
||||
resp.raise_for_status()
|
||||
return True
|
||||
|
||||
async def rename(self, root: str, src_rel: str, dst_rel: str):
|
||||
return await self.move(root, src_rel, dst_rel)
|
||||
|
||||
async def copy(self, root: str, src_rel: str, dst_rel: str, overwrite: bool = False):
|
||||
src = _join_dbx_path(root, src_rel)
|
||||
dst = _join_dbx_path(root, dst_rel)
|
||||
resp = await self._api_json(
|
||||
"/files/copy_v2",
|
||||
{"from_path": src, "to_path": dst, "autorename": False, "allow_shared_folder": True},
|
||||
)
|
||||
if resp.status_code == 409:
|
||||
self._raise_dbx_error(resp, rel=dst_rel if overwrite else dst_rel)
|
||||
resp.raise_for_status()
|
||||
return True
|
||||
|
||||
async def get_direct_download_response(self, root: str, rel: str):
|
||||
if not self.enable_redirect_307:
|
||||
return None
|
||||
|
||||
path = _join_dbx_path(root, rel)
|
||||
resp = await self._api_json("/files/get_temporary_link", {"path": path})
|
||||
if resp.status_code == 409:
|
||||
self._raise_dbx_error(resp, rel=rel)
|
||||
resp.raise_for_status()
|
||||
payload = resp.json()
|
||||
link = (payload.get("link") if isinstance(payload, dict) else None) or ""
|
||||
link = str(link).strip()
|
||||
if not link:
|
||||
return None
|
||||
return Response(status_code=307, headers={"Location": link})
|
||||
|
||||
async def stream_file(self, root: str, rel: str, range_header: str | None):
|
||||
path = _join_dbx_path(root, rel)
|
||||
token = await self._get_access_token()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Dropbox-API-Arg": json.dumps({"path": path}, separators=(",", ":"), ensure_ascii=False),
|
||||
}
|
||||
if range_header:
|
||||
headers["Range"] = range_header
|
||||
|
||||
client = httpx.AsyncClient(timeout=None)
|
||||
stream_cm = client.stream("POST", f"{DROPBOX_CONTENT_URL}/files/download", headers=headers)
|
||||
try:
|
||||
resp = await stream_cm.__aenter__()
|
||||
except Exception:
|
||||
await client.aclose()
|
||||
raise
|
||||
|
||||
if resp.status_code == 409:
|
||||
try:
|
||||
content = await resp.aread()
|
||||
_ = content
|
||||
finally:
|
||||
await stream_cm.__aexit__(None, None, None)
|
||||
await client.aclose()
|
||||
self._raise_dbx_error(resp, rel=rel)
|
||||
|
||||
if resp.status_code >= 400:
|
||||
try:
|
||||
await resp.aread()
|
||||
finally:
|
||||
await stream_cm.__aexit__(None, None, None)
|
||||
await client.aclose()
|
||||
resp.raise_for_status()
|
||||
|
||||
content_type = resp.headers.get("Content-Type") or (mimetypes.guess_type(rel)[0] or "application/octet-stream")
|
||||
out_headers = {}
|
||||
for key in ("Accept-Ranges", "Content-Range", "Content-Length"):
|
||||
value = resp.headers.get(key)
|
||||
if value:
|
||||
out_headers[key] = value
|
||||
|
||||
async def iterator():
|
||||
try:
|
||||
async for chunk in resp.aiter_bytes():
|
||||
if chunk:
|
||||
yield chunk
|
||||
finally:
|
||||
await stream_cm.__aexit__(None, None, None)
|
||||
await client.aclose()
|
||||
|
||||
return StreamingResponse(iterator(), status_code=resp.status_code, headers=out_headers, media_type=content_type)
|
||||
|
||||
|
||||
ADAPTER_TYPE = "dropbox"
|
||||
CONFIG_SCHEMA = [
|
||||
{"key": "app_key", "label": "App Key", "type": "string", "required": True},
|
||||
{"key": "app_secret", "label": "App Secret", "type": "password", "required": True},
|
||||
{"key": "refresh_token", "label": "Refresh Token", "type": "password", "required": True},
|
||||
{"key": "root", "label": "Root Path", "type": "string", "required": False, "default": "/", "placeholder": "/ or /Apps/Foxel"},
|
||||
{"key": "timeout", "label": "超时(秒)", "type": "number", "required": False, "default": 60},
|
||||
{"key": "enable_direct_download_307", "label": "Enable 307 redirect download", "type": "boolean", "default": False},
|
||||
]
|
||||
|
||||
|
||||
def ADAPTER_FACTORY(rec): return DropboxAdapter(rec)
|
||||
|
||||
612
domain/adapters/providers/ftp.py
Normal file
612
domain/adapters/providers/ftp.py
Normal file
@@ -0,0 +1,612 @@
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Dict, Tuple, AsyncIterator, Optional
|
||||
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from ftplib import FTP, error_perm
|
||||
import mimetypes
|
||||
|
||||
from models import StorageAdapter
|
||||
|
||||
|
||||
def _join_remote(root: str, rel: str) -> str:
|
||||
root = (root or "/").rstrip("/") or "/"
|
||||
rel = (rel or "").lstrip("/")
|
||||
if not rel:
|
||||
return root
|
||||
return f"{root}/{rel}"
|
||||
|
||||
|
||||
def _parse_mlst_line(line: str) -> Dict[str, str]:
|
||||
out: Dict[str, str] = {}
|
||||
try:
|
||||
facts, _, name = line.partition(" ")
|
||||
for part in facts.split(";"):
|
||||
if not part or "=" not in part:
|
||||
continue
|
||||
k, v = part.split("=", 1)
|
||||
out[k.strip().lower()] = v.strip()
|
||||
if name:
|
||||
out["name"] = name.strip()
|
||||
except Exception:
|
||||
pass
|
||||
return out
|
||||
|
||||
|
||||
def _parse_modify_to_epoch(mod: str) -> int:
|
||||
# Formats we may see: YYYYMMDDHHMMSS or YYYYMMDDHHMMSS(.sss)
|
||||
try:
|
||||
mod = mod.strip()
|
||||
mod = mod.split(".")[0]
|
||||
if len(mod) >= 14:
|
||||
y = int(mod[0:4])
|
||||
m = int(mod[4:6])
|
||||
d = int(mod[6:8])
|
||||
hh = int(mod[8:10])
|
||||
mm = int(mod[10:12])
|
||||
ss = int(mod[12:14])
|
||||
import datetime as _dt
|
||||
return int(_dt.datetime(y, m, d, hh, mm, ss, tzinfo=_dt.timezone.utc).timestamp())
|
||||
except Exception:
|
||||
return 0
|
||||
return 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class _Range:
|
||||
start: int
|
||||
end: Optional[int] # inclusive
|
||||
|
||||
|
||||
class FTPAdapter:
|
||||
def __init__(self, record: StorageAdapter):
|
||||
self.record = record
|
||||
cfg = record.config
|
||||
self.host: str = cfg.get("host")
|
||||
self.port: int = int(cfg.get("port", 21))
|
||||
self.username: Optional[str] = cfg.get("username")
|
||||
self.password: Optional[str] = cfg.get("password")
|
||||
self.passive: bool = bool(cfg.get("passive", True))
|
||||
self.timeout: int = int(cfg.get("timeout", 15))
|
||||
self.root_path: str = cfg.get("root", "/") or "/"
|
||||
|
||||
if not self.host:
|
||||
raise ValueError("FTP adapter requires 'host'")
|
||||
|
||||
def get_effective_root(self, sub_path: str | None) -> str:
|
||||
base = self.root_path.rstrip("/") or "/"
|
||||
if sub_path:
|
||||
return _join_remote(base, sub_path)
|
||||
return base
|
||||
|
||||
def _connect(self) -> FTP:
|
||||
ftp = FTP()
|
||||
ftp.connect(self.host, self.port, timeout=self.timeout)
|
||||
if self.username:
|
||||
ftp.login(self.username, self.password or "")
|
||||
else:
|
||||
ftp.login()
|
||||
ftp.set_pasv(self.passive)
|
||||
return ftp
|
||||
|
||||
async def list_dir(self, root: str, rel: str, page_num: int = 1, page_size: int = 50, sort_by: str = "name", sort_order: str = "asc") -> Tuple[List[Dict], int]:
|
||||
path = _join_remote(root, rel.strip('/'))
|
||||
|
||||
def _do_list() -> List[Dict]:
|
||||
ftp = self._connect()
|
||||
try:
|
||||
ftp.cwd(path)
|
||||
except error_perm as e:
|
||||
# path may be file
|
||||
ftp.quit()
|
||||
raise NotADirectoryError(rel) from e
|
||||
|
||||
entries: List[Dict] = []
|
||||
# Try MLSD first
|
||||
try:
|
||||
for name, facts in ftp.mlsd():
|
||||
if name in (".", ".."):
|
||||
continue
|
||||
is_dir = (facts.get("type") == "dir")
|
||||
size = int(facts.get("size") or 0)
|
||||
mtime = _parse_modify_to_epoch(facts.get("modify") or "")
|
||||
entries.append({
|
||||
"name": name,
|
||||
"is_dir": is_dir,
|
||||
"size": 0 if is_dir else size,
|
||||
"mtime": mtime,
|
||||
"type": "dir" if is_dir else "file",
|
||||
})
|
||||
ftp.quit()
|
||||
return entries
|
||||
except Exception:
|
||||
# Fallback to NLST + probing
|
||||
pass
|
||||
|
||||
names = []
|
||||
try:
|
||||
names = ftp.nlst()
|
||||
except Exception:
|
||||
ftp.quit()
|
||||
return []
|
||||
|
||||
for name in names:
|
||||
if name in (".", ".."):
|
||||
continue
|
||||
is_dir = False
|
||||
size = 0
|
||||
mtime = 0
|
||||
try:
|
||||
# If we can CWD, it's a directory
|
||||
ftp.cwd(_join_remote(path, name))
|
||||
ftp.cwd(path)
|
||||
is_dir = True
|
||||
except Exception:
|
||||
is_dir = False
|
||||
try:
|
||||
size = ftp.size(_join_remote(path, name)) or 0
|
||||
except Exception:
|
||||
size = 0
|
||||
try:
|
||||
mdtm = ftp.sendcmd("MDTM " + _join_remote(path, name))
|
||||
# Example: '213 20241012XXXXXX'
|
||||
if mdtm.startswith("213 "):
|
||||
mtime = _parse_modify_to_epoch(mdtm.split(" ", 1)[1])
|
||||
except Exception:
|
||||
pass
|
||||
entries.append({
|
||||
"name": name,
|
||||
"is_dir": is_dir,
|
||||
"size": 0 if is_dir else int(size or 0),
|
||||
"mtime": int(mtime or 0),
|
||||
"type": "dir" if is_dir else "file",
|
||||
})
|
||||
ftp.quit()
|
||||
return entries
|
||||
|
||||
entries = await asyncio.to_thread(_do_list)
|
||||
|
||||
reverse = sort_order.lower() == "desc"
|
||||
|
||||
def get_sort_key(item):
|
||||
key = (not item["is_dir"],)
|
||||
f = sort_by.lower()
|
||||
if f == "name":
|
||||
key += (item["name"].lower(),)
|
||||
elif f == "size":
|
||||
key += (item.get("size", 0),)
|
||||
elif f == "mtime":
|
||||
key += (item.get("mtime", 0),)
|
||||
else:
|
||||
key += (item["name"].lower(),)
|
||||
return key
|
||||
|
||||
entries.sort(key=get_sort_key, reverse=reverse)
|
||||
total = len(entries)
|
||||
start = (page_num - 1) * page_size
|
||||
end = start + page_size
|
||||
return entries[start:end], total
|
||||
|
||||
async def read_file(self, root: str, rel: str) -> bytes:
|
||||
path = _join_remote(root, rel)
|
||||
|
||||
def _do_read() -> bytes:
|
||||
ftp = self._connect()
|
||||
try:
|
||||
chunks: List[bytes] = []
|
||||
ftp.retrbinary("RETR " + path, lambda b: chunks.append(b))
|
||||
return b"".join(chunks)
|
||||
except error_perm as e:
|
||||
if str(e).startswith("550"):
|
||||
raise FileNotFoundError(rel)
|
||||
raise
|
||||
finally:
|
||||
try:
|
||||
ftp.quit()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return await asyncio.to_thread(_do_read)
|
||||
|
||||
async def write_file(self, root: str, rel: str, data: bytes):
|
||||
path = _join_remote(root, rel)
|
||||
|
||||
def _ensure_dirs(ftp: FTP, dir_path: str):
|
||||
parts = [p for p in dir_path.strip("/").split("/") if p]
|
||||
cur = "/"
|
||||
for p in parts:
|
||||
cur = _join_remote(cur, p)
|
||||
try:
|
||||
ftp.mkd(cur)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _do_write():
|
||||
ftp = self._connect()
|
||||
try:
|
||||
parent = "/" if "/" not in path.strip("/") else path.rsplit("/", 1)[0]
|
||||
_ensure_dirs(ftp, parent)
|
||||
from io import BytesIO
|
||||
bio = BytesIO(data)
|
||||
ftp.storbinary("STOR " + path, bio)
|
||||
finally:
|
||||
try:
|
||||
ftp.quit()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await asyncio.to_thread(_do_write)
|
||||
|
||||
async def write_file_stream(self, root: str, rel: str, data_iter: AsyncIterator[bytes]):
|
||||
# KISS: 聚合后一次性写入
|
||||
buf = bytearray()
|
||||
async for chunk in data_iter:
|
||||
if chunk:
|
||||
buf.extend(chunk)
|
||||
await self.write_file(root, rel, bytes(buf))
|
||||
return len(buf)
|
||||
|
||||
async def mkdir(self, root: str, rel: str):
|
||||
path = _join_remote(root, rel)
|
||||
|
||||
def _do_mkdir():
|
||||
ftp = self._connect()
|
||||
try:
|
||||
parts = [p for p in path.strip("/").split("/") if p]
|
||||
cur = "/"
|
||||
for p in parts:
|
||||
cur = _join_remote(cur, p)
|
||||
try:
|
||||
ftp.mkd(cur)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
try:
|
||||
ftp.quit()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await asyncio.to_thread(_do_mkdir)
|
||||
|
||||
async def delete(self, root: str, rel: str):
|
||||
path = _join_remote(root, rel)
|
||||
|
||||
def _do_delete():
|
||||
ftp = self._connect()
|
||||
try:
|
||||
# Try file delete
|
||||
try:
|
||||
ftp.delete(path)
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Recursively delete dir
|
||||
def _rm_tree(dir_path: str):
|
||||
try:
|
||||
ftp.cwd(dir_path)
|
||||
except Exception:
|
||||
return
|
||||
items = []
|
||||
try:
|
||||
for name, facts in ftp.mlsd():
|
||||
if name in (".", ".."):
|
||||
continue
|
||||
items.append((name, facts.get("type") == "dir"))
|
||||
except Exception:
|
||||
try:
|
||||
names = ftp.nlst()
|
||||
except Exception:
|
||||
names = []
|
||||
for n in names:
|
||||
if n in (".", ".."):
|
||||
continue
|
||||
# Best-effort dir check
|
||||
try:
|
||||
ftp.cwd(_join_remote(dir_path, n))
|
||||
ftp.cwd(dir_path)
|
||||
items.append((n, True))
|
||||
except Exception:
|
||||
items.append((n, False))
|
||||
for n, is_dir in items:
|
||||
child = _join_remote(dir_path, n)
|
||||
if is_dir:
|
||||
_rm_tree(child)
|
||||
else:
|
||||
try:
|
||||
ftp.delete(child)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
ftp.rmd(dir_path)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
_rm_tree(path)
|
||||
finally:
|
||||
try:
|
||||
ftp.quit()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await asyncio.to_thread(_do_delete)
|
||||
|
||||
async def move(self, root: str, src_rel: str, dst_rel: str):
|
||||
src = _join_remote(root, src_rel)
|
||||
dst = _join_remote(root, dst_rel)
|
||||
|
||||
def _do_move():
|
||||
ftp = self._connect()
|
||||
try:
|
||||
# Ensure dst parent exists
|
||||
parent = "/" if "/" not in dst.strip("/") else dst.rsplit("/", 1)[0]
|
||||
parts = [p for p in parent.strip("/").split("/") if p]
|
||||
cur = "/"
|
||||
for p in parts:
|
||||
cur = _join_remote(cur, p)
|
||||
try:
|
||||
ftp.mkd(cur)
|
||||
except Exception:
|
||||
pass
|
||||
ftp.rename(src, dst)
|
||||
finally:
|
||||
try:
|
||||
ftp.quit()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await asyncio.to_thread(_do_move)
|
||||
|
||||
async def rename(self, root: str, src_rel: str, dst_rel: str):
|
||||
await self.move(root, src_rel, dst_rel)
|
||||
|
||||
async def copy(self, root: str, src_rel: str, dst_rel: str, overwrite: bool = False):
|
||||
src = _join_remote(root, src_rel)
|
||||
dst = _join_remote(root, dst_rel)
|
||||
|
||||
# naive implementation: download then upload; recursively for dirs
|
||||
async def _is_dir(path: str) -> bool:
|
||||
def _probe() -> bool:
|
||||
ftp = self._connect()
|
||||
try:
|
||||
try:
|
||||
ftp.cwd(path)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
finally:
|
||||
try:
|
||||
ftp.quit()
|
||||
except Exception:
|
||||
pass
|
||||
return await asyncio.to_thread(_probe)
|
||||
|
||||
if await _is_dir(src):
|
||||
# list children, create dst dir, copy recursively
|
||||
await self.mkdir(root, dst_rel)
|
||||
|
||||
children, _ = await self.list_dir(root, src_rel, page_num=1, page_size=10_000)
|
||||
for ent in children:
|
||||
child_src = f"{src_rel.rstrip('/')}/{ent['name']}"
|
||||
child_dst = f"{dst_rel.rstrip('/')}/{ent['name']}"
|
||||
await self.copy(root, child_src, child_dst, overwrite)
|
||||
return
|
||||
|
||||
# file
|
||||
data = await self.read_file(root, src_rel)
|
||||
if not overwrite:
|
||||
# best-effort existence check
|
||||
try:
|
||||
await self.stat_file(root, dst_rel)
|
||||
raise FileExistsError(dst_rel)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
await self.write_file(root, dst_rel, data)
|
||||
|
||||
async def stat_file(self, root: str, rel: str):
|
||||
path = _join_remote(root, rel)
|
||||
|
||||
def _do_stat():
|
||||
ftp = self._connect()
|
||||
try:
|
||||
# Try MLST
|
||||
try:
|
||||
resp: List[str] = []
|
||||
ftp.retrlines("MLST " + path, resp.append)
|
||||
# The last line usually contains facts
|
||||
facts = {}
|
||||
if resp:
|
||||
facts = _parse_mlst_line(resp[-1])
|
||||
name = rel.split("/")[-1]
|
||||
t = facts.get("type") or "file"
|
||||
is_dir = t == "dir"
|
||||
size = int(facts.get("size") or 0)
|
||||
mtime = _parse_modify_to_epoch(facts.get("modify") or "")
|
||||
return {
|
||||
"name": name,
|
||||
"is_dir": is_dir,
|
||||
"size": 0 if is_dir else size,
|
||||
"mtime": mtime,
|
||||
"type": "dir" if is_dir else "file",
|
||||
"path": path,
|
||||
}
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Probe directory
|
||||
try:
|
||||
ftp.cwd(path)
|
||||
return {
|
||||
"name": rel.split("/")[-1],
|
||||
"is_dir": True,
|
||||
"size": 0,
|
||||
"mtime": 0,
|
||||
"type": "dir",
|
||||
"path": path,
|
||||
}
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Treat as file
|
||||
try:
|
||||
size = ftp.size(path) or 0
|
||||
except Exception:
|
||||
size = 0
|
||||
try:
|
||||
mdtm = ftp.sendcmd("MDTM " + path)
|
||||
mtime = _parse_modify_to_epoch(mdtm.split(" ", 1)[1]) if mdtm.startswith("213 ") else 0
|
||||
except Exception:
|
||||
mtime = 0
|
||||
return {
|
||||
"name": rel.split("/")[-1],
|
||||
"is_dir": False,
|
||||
"size": int(size or 0),
|
||||
"mtime": int(mtime or 0),
|
||||
"type": "file",
|
||||
"path": path,
|
||||
}
|
||||
except error_perm as e:
|
||||
if str(e).startswith("550"):
|
||||
raise FileNotFoundError(rel)
|
||||
raise
|
||||
finally:
|
||||
try:
|
||||
ftp.quit()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return await asyncio.to_thread(_do_stat)
|
||||
|
||||
async def stream_file(self, root: str, rel: str, range_header: str | None):
|
||||
path = _join_remote(root, rel)
|
||||
# Get size (best-effort)
|
||||
def _get_size() -> Optional[int]:
|
||||
ftp = self._connect()
|
||||
try:
|
||||
try:
|
||||
return int(ftp.size(path) or 0)
|
||||
except Exception:
|
||||
return None
|
||||
finally:
|
||||
try:
|
||||
ftp.quit()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
total_size = await asyncio.to_thread(_get_size)
|
||||
mime, _ = mimetypes.guess_type(rel)
|
||||
content_type = mime or "application/octet-stream"
|
||||
|
||||
rng: Optional[_Range] = None
|
||||
status = 200
|
||||
headers = {"Accept-Ranges": "bytes", "Content-Type": content_type}
|
||||
if range_header and range_header.startswith("bytes=") and total_size is not None:
|
||||
try:
|
||||
s, e = (range_header.removeprefix("bytes=").split("-", 1))
|
||||
start = int(s) if s.strip() else 0
|
||||
end = int(e) if e.strip() else (total_size - 1)
|
||||
if start >= total_size:
|
||||
raise HTTPException(416, detail="Requested Range Not Satisfiable")
|
||||
if end >= total_size:
|
||||
end = total_size - 1
|
||||
rng = _Range(start, end)
|
||||
status = 206
|
||||
headers["Content-Range"] = f"bytes {start}-{end}/{total_size}"
|
||||
headers["Content-Length"] = str(end - start + 1)
|
||||
except ValueError:
|
||||
raise HTTPException(400, detail="Invalid Range header")
|
||||
elif total_size is not None:
|
||||
headers["Content-Length"] = str(total_size)
|
||||
|
||||
queue: asyncio.Queue[Optional[bytes]] = asyncio.Queue(maxsize=8)
|
||||
|
||||
class _Stop(Exception):
|
||||
pass
|
||||
|
||||
def _worker():
|
||||
ftp = self._connect()
|
||||
remaining = None
|
||||
if rng is not None:
|
||||
remaining = (rng.end - rng.start + 1) if rng.end is not None else None
|
||||
|
||||
def _cb(chunk: bytes):
|
||||
nonlocal remaining
|
||||
if not chunk:
|
||||
return
|
||||
try:
|
||||
if remaining is not None:
|
||||
if len(chunk) > remaining:
|
||||
part = chunk[:remaining]
|
||||
queue.put_nowait(part)
|
||||
remaining = 0
|
||||
raise _Stop()
|
||||
else:
|
||||
queue.put_nowait(chunk)
|
||||
remaining -= len(chunk)
|
||||
if remaining <= 0:
|
||||
raise _Stop()
|
||||
else:
|
||||
queue.put_nowait(chunk)
|
||||
except _Stop:
|
||||
raise
|
||||
except Exception:
|
||||
# queue full or event loop closed
|
||||
raise _Stop()
|
||||
|
||||
try:
|
||||
if rng is not None:
|
||||
ftp.retrbinary("RETR " + path, _cb, rest=rng.start)
|
||||
else:
|
||||
ftp.retrbinary("RETR " + path, _cb)
|
||||
queue.put_nowait(None)
|
||||
except _Stop:
|
||||
try:
|
||||
queue.put_nowait(None)
|
||||
except Exception:
|
||||
pass
|
||||
except error_perm as e:
|
||||
try:
|
||||
queue.put_nowait(None)
|
||||
except Exception:
|
||||
pass
|
||||
if str(e).startswith("550"):
|
||||
pass
|
||||
finally:
|
||||
try:
|
||||
ftp.quit()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def agen():
|
||||
worker_fut = asyncio.to_thread(_worker)
|
||||
try:
|
||||
while True:
|
||||
chunk = await queue.get()
|
||||
if chunk is None:
|
||||
break
|
||||
yield chunk
|
||||
finally:
|
||||
try:
|
||||
await worker_fut
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return StreamingResponse(agen(), status_code=status, headers=headers, media_type=content_type)
|
||||
|
||||
|
||||
ADAPTER_TYPE = "ftp"
|
||||
|
||||
CONFIG_SCHEMA = [
|
||||
{"key": "host", "label": "主机", "type": "string", "required": True, "placeholder": "ftp.example.com"},
|
||||
{"key": "port", "label": "端口", "type": "number", "required": False, "default": 21},
|
||||
{"key": "username", "label": "用户名", "type": "string", "required": False},
|
||||
{"key": "password", "label": "密码", "type": "password", "required": False},
|
||||
{"key": "passive", "label": "被动模式", "type": "boolean", "required": False, "default": True},
|
||||
{"key": "timeout", "label": "超时(秒)", "type": "number", "required": False, "default": 15},
|
||||
{"key": "root", "label": "根路径", "type": "string", "required": False, "default": "/"},
|
||||
]
|
||||
|
||||
|
||||
def ADAPTER_FACTORY(rec: StorageAdapter):
|
||||
return FTPAdapter(rec)
|
||||
559
domain/adapters/providers/googledrive.py
Normal file
559
domain/adapters/providers/googledrive.py
Normal file
@@ -0,0 +1,559 @@
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import List, Dict, Tuple, AsyncIterator
|
||||
import httpx
|
||||
from fastapi.responses import StreamingResponse, Response
|
||||
from fastapi import HTTPException
|
||||
from models import StorageAdapter
|
||||
|
||||
GOOGLE_OAUTH_URL = "https://oauth2.googleapis.com/token"
|
||||
GOOGLE_DRIVE_API_URL = "https://www.googleapis.com/drive/v3"
|
||||
|
||||
|
||||
class GoogleDriveAdapter:
|
||||
"""Google Drive 存储适配器"""
|
||||
|
||||
def __init__(self, record: StorageAdapter):
|
||||
self.record = record
|
||||
cfg = record.config
|
||||
self.client_id = cfg.get("client_id")
|
||||
self.client_secret = cfg.get("client_secret")
|
||||
self.refresh_token = cfg.get("refresh_token")
|
||||
self.root_folder_id = cfg.get("root_folder_id", "root")
|
||||
self.enable_redirect_307 = bool(cfg.get("enable_direct_download_307"))
|
||||
|
||||
if not all([self.client_id, self.client_secret, self.refresh_token]):
|
||||
raise ValueError(
|
||||
"Google Drive 适配器需要 client_id, client_secret, 和 refresh_token")
|
||||
|
||||
self._access_token: str | None = None
|
||||
self._token_expiry: datetime | None = None
|
||||
|
||||
def get_effective_root(self, sub_path: str | None) -> str:
|
||||
"""
|
||||
获取有效根路径。
|
||||
:param sub_path: 子路径。
|
||||
:return: 完整的有效路径。
|
||||
"""
|
||||
if sub_path:
|
||||
return f"{sub_path.strip('/')}".strip()
|
||||
return ""
|
||||
|
||||
async def _get_access_token(self) -> str:
|
||||
"""
|
||||
获取或刷新 access token。
|
||||
:return: access token。
|
||||
"""
|
||||
if self._access_token and self._token_expiry and datetime.now(timezone.utc) < self._token_expiry:
|
||||
return self._access_token
|
||||
|
||||
data = {
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
"refresh_token": self.refresh_token,
|
||||
"grant_type": "refresh_token",
|
||||
}
|
||||
async with httpx.AsyncClient(timeout=20.0) as client:
|
||||
resp = await client.post(GOOGLE_OAUTH_URL, data=data)
|
||||
resp.raise_for_status()
|
||||
token_data = resp.json()
|
||||
self._access_token = token_data["access_token"]
|
||||
self._token_expiry = datetime.now(
|
||||
timezone.utc) + timedelta(seconds=token_data["expires_in"] - 300)
|
||||
return self._access_token
|
||||
|
||||
async def _request(self, method: str, endpoint: str, **kwargs):
|
||||
"""
|
||||
向 Google Drive API 发送请求。
|
||||
:param method: HTTP 方法。
|
||||
:param endpoint: API 端点。
|
||||
:param kwargs: 其他请求参数。
|
||||
:return: 响应对象。
|
||||
"""
|
||||
token = await self._get_access_token()
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
if "headers" in kwargs:
|
||||
headers.update(kwargs.pop("headers"))
|
||||
|
||||
url = f"{GOOGLE_DRIVE_API_URL}{endpoint}"
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
resp = await client.request(method, url, headers=headers, **kwargs)
|
||||
if resp.status_code == 401:
|
||||
self._access_token = None
|
||||
token = await self._get_access_token()
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
resp = await client.request(method, url, headers=headers, **kwargs)
|
||||
return resp
|
||||
|
||||
async def _get_folder_id_by_path(self, path: str) -> str:
|
||||
"""
|
||||
通过路径获取文件夹 ID。
|
||||
:param path: 路径。
|
||||
:return: 文件夹 ID。
|
||||
"""
|
||||
if not path or path == "/":
|
||||
return self.root_folder_id
|
||||
|
||||
parts = [p for p in path.strip("/").split("/") if p]
|
||||
current_id = self.root_folder_id
|
||||
|
||||
for part in parts:
|
||||
query = f"name='{part}' and '{current_id}' in parents and mimeType='application/vnd.google-apps.folder' and trashed=false"
|
||||
params = {"q": query, "fields": "files(id, name)"}
|
||||
resp = await self._request("GET", "/files", params=params)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
files = data.get("files", [])
|
||||
if not files:
|
||||
raise FileNotFoundError(f"文件夹不存在: {part}")
|
||||
current_id = files[0]["id"]
|
||||
|
||||
return current_id
|
||||
|
||||
async def _get_file_id_by_path(self, path: str) -> str | None:
|
||||
"""
|
||||
通过路径获取文件 ID。
|
||||
:param path: 文件路径。
|
||||
:return: 文件 ID 或 None。
|
||||
"""
|
||||
if not path or path == "/":
|
||||
return self.root_folder_id
|
||||
|
||||
parts = [p for p in path.strip("/").split("/") if p]
|
||||
parent_id = self.root_folder_id
|
||||
|
||||
for i, part in enumerate(parts):
|
||||
is_last = i == len(parts) - 1
|
||||
mime_filter = "" if is_last else "and mimeType='application/vnd.google-apps.folder'"
|
||||
query = f"name='{part}' and '{parent_id}' in parents {mime_filter} and trashed=false"
|
||||
params = {"q": query, "fields": "files(id, name)"}
|
||||
resp = await self._request("GET", "/files", params=params)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
files = data.get("files", [])
|
||||
if not files:
|
||||
return None
|
||||
parent_id = files[0]["id"]
|
||||
|
||||
return parent_id
|
||||
|
||||
def _format_item(self, item: Dict) -> Dict:
|
||||
"""
|
||||
将 Google Drive API 返回的 item 格式化为统一的格式。
|
||||
:param item: Google Drive API 返回的 item 字典。
|
||||
:return: 格式化后的字典。
|
||||
"""
|
||||
is_dir = item["mimeType"] == "application/vnd.google-apps.folder"
|
||||
mtime_str = item.get("modifiedTime", item.get("createdTime", ""))
|
||||
try:
|
||||
mtime = int(datetime.fromisoformat(mtime_str.replace("Z", "+00:00")).timestamp())
|
||||
except:
|
||||
mtime = 0
|
||||
|
||||
return {
|
||||
"name": item["name"],
|
||||
"is_dir": is_dir,
|
||||
"size": 0 if is_dir else int(item.get("size", 0)),
|
||||
"mtime": mtime,
|
||||
"type": "dir" if is_dir else "file",
|
||||
}
|
||||
|
||||
async def list_dir(self, root: str, rel: str, page_num: int = 1, page_size: int = 50, sort_by: str = "name", sort_order: str = "asc") -> Tuple[List[Dict], int]:
|
||||
"""
|
||||
列出目录内容。
|
||||
:param root: 根路径。
|
||||
:param rel: 相对路径。
|
||||
:param page_num: 页码。
|
||||
:param page_size: 每页大小。
|
||||
:param sort_by: 排序字段
|
||||
:param sort_order: 排序顺序
|
||||
:return: 文件/目录列表和总数。
|
||||
"""
|
||||
try:
|
||||
folder_id = await self._get_folder_id_by_path(rel)
|
||||
except FileNotFoundError:
|
||||
return [], 0
|
||||
|
||||
query = f"'{folder_id}' in parents and trashed=false"
|
||||
params = {
|
||||
"q": query,
|
||||
"fields": "files(id, name, mimeType, size, modifiedTime, createdTime)",
|
||||
"pageSize": 1000,
|
||||
}
|
||||
|
||||
all_items = []
|
||||
page_token = None
|
||||
|
||||
while True:
|
||||
if page_token:
|
||||
params["pageToken"] = page_token
|
||||
|
||||
resp = await self._request("GET", "/files", params=params)
|
||||
if resp.status_code == 404:
|
||||
return [], 0
|
||||
resp.raise_for_status()
|
||||
|
||||
data = resp.json()
|
||||
all_items.extend(data.get("files", []))
|
||||
page_token = data.get("nextPageToken")
|
||||
|
||||
if not page_token:
|
||||
break
|
||||
|
||||
formatted_items = [self._format_item(item) for item in all_items]
|
||||
|
||||
# 排序
|
||||
reverse = sort_order.lower() == "desc"
|
||||
def get_sort_key(item):
|
||||
key = (not item["is_dir"],)
|
||||
sort_field = sort_by.lower()
|
||||
if sort_field == "name":
|
||||
key += (item["name"].lower(),)
|
||||
elif sort_field == "size":
|
||||
key += (item["size"],)
|
||||
elif sort_field == "mtime":
|
||||
key += (item["mtime"],)
|
||||
else:
|
||||
key += (item["name"].lower(),)
|
||||
return key
|
||||
formatted_items.sort(key=get_sort_key, reverse=reverse)
|
||||
|
||||
total_count = len(formatted_items)
|
||||
start_idx = (page_num - 1) * page_size
|
||||
end_idx = start_idx + page_size
|
||||
|
||||
return formatted_items[start_idx:end_idx], total_count
|
||||
|
||||
async def read_file(self, root: str, rel: str) -> bytes:
|
||||
"""
|
||||
读取文件内容。
|
||||
:param root: 根路径。
|
||||
:param rel: 相对路径。
|
||||
:return: 文件内容的字节流。
|
||||
"""
|
||||
file_id = await self._get_file_id_by_path(rel)
|
||||
if not file_id:
|
||||
raise FileNotFoundError(rel)
|
||||
|
||||
resp = await self._request("GET", f"/files/{file_id}", params={"alt": "media"})
|
||||
if resp.status_code == 404:
|
||||
raise FileNotFoundError(rel)
|
||||
resp.raise_for_status()
|
||||
return resp.content
|
||||
|
||||
async def write_file(self, root: str, rel: str, data: bytes):
|
||||
"""
|
||||
写入文件。
|
||||
:param root: 根路径。
|
||||
:param rel: 相对路径。
|
||||
:param data: 文件内容的字节流。
|
||||
"""
|
||||
parent_path = "/".join(rel.strip("/").split("/")[:-1])
|
||||
file_name = rel.strip("/").split("/")[-1]
|
||||
parent_id = await self._get_folder_id_by_path(parent_path)
|
||||
|
||||
# 检查文件是否已存在
|
||||
existing_id = await self._get_file_id_by_path(rel)
|
||||
|
||||
if existing_id:
|
||||
# 更新现有文件
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
token = await self._get_access_token()
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
url = f"https://www.googleapis.com/upload/drive/v3/files/{existing_id}?uploadType=media"
|
||||
resp = await client.patch(url, headers=headers, content=data)
|
||||
resp.raise_for_status()
|
||||
else:
|
||||
# 创建新文件
|
||||
metadata = {
|
||||
"name": file_name,
|
||||
"parents": [parent_id]
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
token = await self._get_access_token()
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
|
||||
# 使用 multipart 上传
|
||||
import json
|
||||
boundary = "===============boundary==============="
|
||||
headers["Content-Type"] = f"multipart/related; boundary={boundary}"
|
||||
|
||||
body = (
|
||||
f"--{boundary}\r\n"
|
||||
f"Content-Type: application/json; charset=UTF-8\r\n\r\n"
|
||||
f"{json.dumps(metadata)}\r\n"
|
||||
f"--{boundary}\r\n"
|
||||
f"Content-Type: application/octet-stream\r\n\r\n"
|
||||
).encode() + data + f"\r\n--{boundary}--".encode()
|
||||
|
||||
url = "https://www.googleapis.com/upload/drive/v3/files?uploadType=multipart"
|
||||
resp = await client.post(url, headers=headers, content=body)
|
||||
resp.raise_for_status()
|
||||
|
||||
async def write_file_stream(self, root: str, rel: str, data_iter: AsyncIterator[bytes]):
|
||||
"""
|
||||
以流式方式写入文件。
|
||||
:param root: 根路径。
|
||||
:param rel: 相对路径。
|
||||
:param data_iter: 文件内容的异步迭代器。
|
||||
:return: 文件大小。
|
||||
"""
|
||||
# 先收集所有数据
|
||||
chunks = []
|
||||
total_size = 0
|
||||
async for chunk in data_iter:
|
||||
chunks.append(chunk)
|
||||
total_size += len(chunk)
|
||||
|
||||
data = b"".join(chunks)
|
||||
await self.write_file(root, rel, data)
|
||||
return total_size
|
||||
|
||||
async def mkdir(self, root: str, rel: str):
|
||||
"""
|
||||
创建目录。
|
||||
:param root: 根路径。
|
||||
:param rel: 相对路径。
|
||||
"""
|
||||
parent_path = "/".join(rel.strip("/").split("/")[:-1])
|
||||
folder_name = rel.strip("/").split("/")[-1]
|
||||
parent_id = await self._get_folder_id_by_path(parent_path)
|
||||
|
||||
metadata = {
|
||||
"name": folder_name,
|
||||
"mimeType": "application/vnd.google-apps.folder",
|
||||
"parents": [parent_id]
|
||||
}
|
||||
|
||||
resp = await self._request("POST", "/files", json=metadata)
|
||||
resp.raise_for_status()
|
||||
|
||||
async def delete(self, root: str, rel: str):
|
||||
"""
|
||||
删除文件或目录。
|
||||
:param root: 根路径。
|
||||
:param rel: 相对路径。
|
||||
"""
|
||||
file_id = await self._get_file_id_by_path(rel)
|
||||
if not file_id:
|
||||
return
|
||||
|
||||
resp = await self._request("DELETE", f"/files/{file_id}")
|
||||
if resp.status_code not in (204, 404):
|
||||
resp.raise_for_status()
|
||||
|
||||
async def move(self, root: str, src_rel: str, dst_rel: str):
|
||||
"""
|
||||
移动或重命名文件/目录。
|
||||
:param root: 根路径。
|
||||
:param src_rel: 源相对路径。
|
||||
:param dst_rel: 目标相对路径。
|
||||
"""
|
||||
file_id = await self._get_file_id_by_path(src_rel)
|
||||
if not file_id:
|
||||
raise FileNotFoundError(src_rel)
|
||||
|
||||
# 获取当前父文件夹
|
||||
resp = await self._request("GET", f"/files/{file_id}", params={"fields": "parents"})
|
||||
resp.raise_for_status()
|
||||
current_parents = resp.json().get("parents", [])
|
||||
|
||||
# 获取目标父文件夹和新名称
|
||||
dst_parent_path = "/".join(dst_rel.strip("/").split("/")[:-1])
|
||||
dst_name = dst_rel.strip("/").split("/")[-1]
|
||||
dst_parent_id = await self._get_folder_id_by_path(dst_parent_path)
|
||||
|
||||
# 更新文件
|
||||
params = {
|
||||
"addParents": dst_parent_id,
|
||||
"removeParents": ",".join(current_parents) if current_parents else None,
|
||||
}
|
||||
metadata = {"name": dst_name}
|
||||
|
||||
resp = await self._request("PATCH", f"/files/{file_id}", params=params, json=metadata)
|
||||
resp.raise_for_status()
|
||||
|
||||
async def rename(self, root: str, src_rel: str, dst_rel: str):
|
||||
"""
|
||||
重命名文件或目录。
|
||||
"""
|
||||
await self.move(root, src_rel, dst_rel)
|
||||
|
||||
async def copy(self, root: str, src_rel: str, dst_rel: str, overwrite: bool = False):
|
||||
"""
|
||||
复制文件或目录。
|
||||
:param root: 根路径。
|
||||
:param src_rel: 源相对路径。
|
||||
:param dst_rel: 目标相对路径。
|
||||
:param overwrite: 是否覆盖。
|
||||
"""
|
||||
file_id = await self._get_file_id_by_path(src_rel)
|
||||
if not file_id:
|
||||
raise FileNotFoundError(src_rel)
|
||||
|
||||
dst_parent_path = "/".join(dst_rel.strip("/").split("/")[:-1])
|
||||
dst_name = dst_rel.strip("/").split("/")[-1]
|
||||
dst_parent_id = await self._get_folder_id_by_path(dst_parent_path)
|
||||
|
||||
metadata = {
|
||||
"name": dst_name,
|
||||
"parents": [dst_parent_id]
|
||||
}
|
||||
|
||||
resp = await self._request("POST", f"/files/{file_id}/copy", json=metadata)
|
||||
resp.raise_for_status()
|
||||
|
||||
async def stream_file(self, root: str, rel: str, range_header: str | None):
|
||||
"""
|
||||
流式传输文件(支持范围请求)。
|
||||
:param root: 根路径。
|
||||
:param rel: 相对路径。
|
||||
:param range_header: HTTP Range 头。
|
||||
:return: FastAPI StreamingResponse 对象。
|
||||
"""
|
||||
file_id = await self._get_file_id_by_path(rel)
|
||||
if not file_id:
|
||||
raise FileNotFoundError(rel)
|
||||
|
||||
# 获取文件元数据
|
||||
resp = await self._request("GET", f"/files/{file_id}", params={"fields": "name, size, mimeType"})
|
||||
if resp.status_code == 404:
|
||||
raise FileNotFoundError(rel)
|
||||
resp.raise_for_status()
|
||||
item_data = resp.json()
|
||||
|
||||
file_size = int(item_data.get("size", 0))
|
||||
content_type = item_data.get("mimeType", "application/octet-stream")
|
||||
|
||||
start = 0
|
||||
end = file_size - 1
|
||||
status = 200
|
||||
headers = {
|
||||
"Accept-Ranges": "bytes",
|
||||
"Content-Type": content_type,
|
||||
"Content-Disposition": f"inline; filename=\"{item_data.get('name')}\""
|
||||
}
|
||||
|
||||
if range_header and range_header.startswith("bytes="):
|
||||
try:
|
||||
part = range_header.removeprefix("bytes=")
|
||||
s, e = part.split("-", 1)
|
||||
if s.strip():
|
||||
start = int(s)
|
||||
if e.strip():
|
||||
end = int(e)
|
||||
if start >= file_size:
|
||||
raise HTTPException(416, "Requested Range Not Satisfiable")
|
||||
if end >= file_size:
|
||||
end = file_size - 1
|
||||
status = 206
|
||||
except ValueError:
|
||||
raise HTTPException(400, "Invalid Range header")
|
||||
|
||||
headers["Content-Range"] = f"bytes {start}-{end}/{file_size}"
|
||||
headers["Content-Length"] = str(end - start + 1)
|
||||
else:
|
||||
headers["Content-Length"] = str(file_size)
|
||||
|
||||
async def file_iterator():
|
||||
nonlocal start, end
|
||||
token = await self._get_access_token()
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
req_headers = {
|
||||
'Authorization': f'Bearer {token}',
|
||||
'Range': f'bytes={start}-{end}'
|
||||
}
|
||||
url = f"{GOOGLE_DRIVE_API_URL}/files/{file_id}?alt=media"
|
||||
async with client.stream("GET", url, headers=req_headers) as stream_resp:
|
||||
stream_resp.raise_for_status()
|
||||
async for chunk in stream_resp.aiter_bytes():
|
||||
yield chunk
|
||||
|
||||
return StreamingResponse(file_iterator(), status_code=status, headers=headers, media_type=content_type)
|
||||
|
||||
async def stat_file(self, root: str, rel: str):
|
||||
"""
|
||||
获取文件或目录的元数据。
|
||||
:param root: 根路径。
|
||||
:param rel: 相对路径。
|
||||
:return: 格式化后的文件/目录信息。
|
||||
"""
|
||||
file_id = await self._get_file_id_by_path(rel)
|
||||
if not file_id:
|
||||
raise FileNotFoundError(rel)
|
||||
|
||||
resp = await self._request("GET", f"/files/{file_id}", params={"fields": "id, name, mimeType, size, modifiedTime, createdTime"})
|
||||
if resp.status_code == 404:
|
||||
raise FileNotFoundError(rel)
|
||||
resp.raise_for_status()
|
||||
return self._format_item(resp.json())
|
||||
|
||||
async def get_direct_download_response(self, root: str, rel: str):
|
||||
"""
|
||||
获取直接下载响应 (307 重定向)。
|
||||
:param root: 根路径。
|
||||
:param rel: 相对路径。
|
||||
:return: 307 重定向响应或 None。
|
||||
"""
|
||||
if not self.enable_redirect_307:
|
||||
return None
|
||||
|
||||
file_id = await self._get_file_id_by_path(rel)
|
||||
if not file_id:
|
||||
raise FileNotFoundError(rel)
|
||||
|
||||
# 获取文件的下载链接
|
||||
resp = await self._request("GET", f"/files/{file_id}", params={"fields": "webContentLink"})
|
||||
if resp.status_code == 404:
|
||||
raise FileNotFoundError(rel)
|
||||
resp.raise_for_status()
|
||||
|
||||
item_data = resp.json()
|
||||
download_url = item_data.get("webContentLink")
|
||||
if not download_url:
|
||||
return None
|
||||
|
||||
return Response(status_code=307, headers={"Location": download_url})
|
||||
|
||||
async def get_thumbnail(self, root: str, rel: str, size: str = "medium"):
|
||||
"""
|
||||
获取文件的缩略图。
|
||||
:param root: 根路径。
|
||||
:param rel: 相对路径。
|
||||
:param size: 缩略图大小 (暂未使用,Google Drive 自动决定)。
|
||||
:return: 缩略图内容的字节流,或在不支持时返回 None。
|
||||
"""
|
||||
file_id = await self._get_file_id_by_path(rel)
|
||||
if not file_id:
|
||||
return None
|
||||
|
||||
try:
|
||||
resp = await self._request("GET", f"/files/{file_id}", params={"fields": "thumbnailLink"})
|
||||
if resp.status_code == 200:
|
||||
item_data = resp.json()
|
||||
thumbnail_link = item_data.get("thumbnailLink")
|
||||
if thumbnail_link:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
thumb_resp = await client.get(thumbnail_link)
|
||||
thumb_resp.raise_for_status()
|
||||
return thumb_resp.content
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
ADAPTER_TYPE = "googledrive"
|
||||
|
||||
CONFIG_SCHEMA = [
|
||||
{"key": "client_id", "label": "Client ID", "type": "string", "required": True},
|
||||
{"key": "client_secret", "label": "Client Secret",
|
||||
"type": "password", "required": True},
|
||||
{"key": "refresh_token", "label": "Refresh Token", "type": "password",
|
||||
"required": True, "help_text": "可以通过 Google OAuth 2.0 Playground 获取"},
|
||||
{"key": "root_folder_id", "label": "根文件夹 ID (Root Folder ID)", "type": "string",
|
||||
"required": False, "placeholder": "默认为根目录 (root)", "default": "root"},
|
||||
{"key": "enable_direct_download_307", "label": "Enable 307 redirect download", "type": "boolean", "default": False},
|
||||
]
|
||||
|
||||
|
||||
def ADAPTER_FACTORY(rec): return GoogleDriveAdapter(rec)
|
||||
@@ -1,4 +1,3 @@
|
||||
from __future__ import annotations
|
||||
import os
|
||||
import shutil
|
||||
import stat
|
||||
@@ -10,7 +9,6 @@ import mimetypes
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import StreamingResponse, Response
|
||||
from models import StorageAdapter
|
||||
from services.logging import LogService
|
||||
|
||||
|
||||
def _safe_join(root: str, rel: str) -> Path:
|
||||
@@ -46,25 +44,18 @@ class LocalAdapter:
|
||||
return str(Path(root) / sub_path)
|
||||
return root
|
||||
|
||||
async def list_dir(self, root: str, rel: str, page_num: int = 1, page_size: int = 50) -> Tuple[List[Dict], int]:
|
||||
async def list_dir(self, root: str, rel: str, page_num: int = 1, page_size: int = 50, sort_by: str = "name", sort_order: str = "asc") -> Tuple[List[Dict], int]:
|
||||
rel = rel.strip('/')
|
||||
base = _safe_join(root, rel) if rel else Path(root)
|
||||
if not base.exists():
|
||||
return [], 0
|
||||
if not base.is_dir():
|
||||
raise NotADirectoryError(rel)
|
||||
|
||||
# 获取所有文件名并排序
|
||||
all_names = await asyncio.to_thread(lambda: sorted(os.listdir(base), key=str.lower))
|
||||
total_count = len(all_names)
|
||||
|
||||
# 计算分页范围
|
||||
start_idx = (page_num - 1) * page_size
|
||||
end_idx = start_idx + page_size
|
||||
page_names = all_names[start_idx:end_idx]
|
||||
|
||||
all_names = await asyncio.to_thread(os.listdir, base)
|
||||
|
||||
entries = []
|
||||
for name in page_names:
|
||||
for name in all_names:
|
||||
fp = base / name
|
||||
try:
|
||||
st = await asyncio.to_thread(fp.stat)
|
||||
@@ -79,10 +70,35 @@ class LocalAdapter:
|
||||
"mode": stat.S_IMODE(st.st_mode),
|
||||
"type": "dir" if is_dir else "file",
|
||||
})
|
||||
|
||||
# 排序
|
||||
reverse = sort_order.lower() == "desc"
|
||||
|
||||
# 按目录优先排序
|
||||
entries.sort(key=lambda x: (not x["is_dir"], x["name"].lower()))
|
||||
return entries, total_count
|
||||
def get_sort_key(item):
|
||||
# 基础排序键,目录优先
|
||||
key = (not item["is_dir"],)
|
||||
sort_field = sort_by.lower()
|
||||
|
||||
if sort_field == "name":
|
||||
key += (item["name"].lower(),)
|
||||
elif sort_field == "size":
|
||||
key += (item["size"],)
|
||||
elif sort_field == "mtime":
|
||||
key += (item["mtime"],)
|
||||
else: # 默认按名称
|
||||
key += (item["name"].lower(),)
|
||||
return key
|
||||
|
||||
entries.sort(key=get_sort_key, reverse=reverse)
|
||||
|
||||
total_count = len(entries)
|
||||
|
||||
# 分页
|
||||
start_idx = (page_num - 1) * page_size
|
||||
end_idx = start_idx + page_size
|
||||
page_entries = entries[start_idx:end_idx]
|
||||
|
||||
return page_entries, total_count
|
||||
|
||||
async def read_file(self, root: str, rel: str) -> bytes:
|
||||
fp = _safe_join(root, rel)
|
||||
@@ -97,11 +113,6 @@ class LocalAdapter:
|
||||
await asyncio.to_thread(fp.write_bytes, data)
|
||||
if not pre_exists:
|
||||
await asyncio.to_thread(_apply_mode, fp, DEFAULT_FILE_MODE)
|
||||
await LogService.info(
|
||||
"adapter:local",
|
||||
f"Wrote file to {rel}",
|
||||
details={"adapter_id": self.record.id, "path": str(fp), "size": len(data)},
|
||||
)
|
||||
|
||||
async def write_file_stream(self, root: str, rel: str, data_iter: AsyncIterator[bytes]):
|
||||
fp = _safe_join(root, rel)
|
||||
@@ -122,21 +133,11 @@ class LocalAdapter:
|
||||
await asyncio.to_thread(f.close)
|
||||
if not pre_exists:
|
||||
await asyncio.to_thread(_apply_mode, fp, DEFAULT_FILE_MODE)
|
||||
await LogService.info(
|
||||
"adapter:local",
|
||||
f"Wrote file stream to {rel}",
|
||||
details={"adapter_id": self.record.id, "path": str(fp), "size": size},
|
||||
)
|
||||
return size
|
||||
|
||||
async def mkdir(self, root: str, rel: str):
|
||||
fp = _safe_join(root, rel)
|
||||
await asyncio.to_thread(os.makedirs, fp, mode=DEFAULT_DIR_MODE, exist_ok=True)
|
||||
await LogService.info(
|
||||
"adapter:local",
|
||||
f"Created directory {rel}",
|
||||
details={"adapter_id": self.record.id, "path": str(fp)},
|
||||
)
|
||||
|
||||
async def delete(self, root: str, rel: str):
|
||||
fp = _safe_join(root, rel)
|
||||
@@ -146,11 +147,6 @@ class LocalAdapter:
|
||||
await asyncio.to_thread(shutil.rmtree, fp)
|
||||
else:
|
||||
await asyncio.to_thread(fp.unlink)
|
||||
await LogService.info(
|
||||
"adapter:local",
|
||||
f"Deleted {rel}",
|
||||
details={"adapter_id": self.record.id, "path": str(fp)},
|
||||
)
|
||||
|
||||
async def stat_path(self, root: str, rel: str):
|
||||
"""新增: 返回路径状态调试信息"""
|
||||
@@ -185,15 +181,6 @@ class LocalAdapter:
|
||||
except OSError:
|
||||
shutil.move(str(src), str(dst))
|
||||
await asyncio.to_thread(_do_move)
|
||||
await LogService.info(
|
||||
"adapter:local",
|
||||
f"Moved {src_rel} to {dst_rel}",
|
||||
details={
|
||||
"adapter_id": self.record.id,
|
||||
"src": str(src),
|
||||
"dst": str(dst),
|
||||
},
|
||||
)
|
||||
|
||||
async def rename(self, root: str, src_rel: str, dst_rel: str):
|
||||
src = _safe_join(root, src_rel)
|
||||
@@ -209,15 +196,6 @@ class LocalAdapter:
|
||||
except OSError:
|
||||
os.replace(src, dst)
|
||||
await asyncio.to_thread(_do_rename)
|
||||
await LogService.info(
|
||||
"adapter:local",
|
||||
f"Renamed {src_rel} to {dst_rel}",
|
||||
details={
|
||||
"adapter_id": self.record.id,
|
||||
"src": str(src),
|
||||
"dst": str(dst),
|
||||
},
|
||||
)
|
||||
|
||||
async def copy(self, root: str, src_rel: str, dst_rel: str, overwrite: bool = False):
|
||||
src = _safe_join(root, src_rel)
|
||||
@@ -240,15 +218,6 @@ class LocalAdapter:
|
||||
else:
|
||||
shutil.copy2(src, dst)
|
||||
await asyncio.to_thread(_do)
|
||||
await LogService.info(
|
||||
"adapter:local",
|
||||
f"Copied {src_rel} to {dst_rel}",
|
||||
details={
|
||||
"adapter_id": self.record.id,
|
||||
"src": str(src),
|
||||
"dst": str(dst),
|
||||
},
|
||||
)
|
||||
|
||||
async def stream_file(self, root: str, rel: str, range_header: str | None):
|
||||
fp = _safe_join(root, rel)
|
||||
@@ -1,8 +1,7 @@
|
||||
from __future__ import annotations
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import List, Dict, Tuple, AsyncIterator
|
||||
import httpx
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.responses import StreamingResponse, Response
|
||||
from fastapi import HTTPException
|
||||
from models import StorageAdapter
|
||||
|
||||
@@ -20,6 +19,7 @@ class OneDriveAdapter:
|
||||
self.client_secret = cfg.get("client_secret")
|
||||
self.refresh_token = cfg.get("refresh_token")
|
||||
self.root = cfg.get("root", "/").strip("/")
|
||||
self.enable_redirect_307 = bool(cfg.get("enable_direct_download_307"))
|
||||
|
||||
if not all([self.client_id, self.client_secret, self.refresh_token]):
|
||||
raise ValueError(
|
||||
@@ -63,7 +63,7 @@ class OneDriveAdapter:
|
||||
"refresh_token": self.refresh_token,
|
||||
"grant_type": "refresh_token",
|
||||
}
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with httpx.AsyncClient(timeout=20.0) as client:
|
||||
resp = await client.post(MS_OAUTH_URL, data=data)
|
||||
resp.raise_for_status()
|
||||
token_data = resp.json()
|
||||
@@ -90,11 +90,10 @@ class OneDriveAdapter:
|
||||
headers.update(kwargs.pop("headers"))
|
||||
|
||||
url = full_url if full_url else f"{MS_GRAPH_URL}/me/drive/root{api_path_segment}"
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
resp = await client.request(method, url, headers=headers, **kwargs)
|
||||
# 如果 token 过期 (401),刷新并重试一次
|
||||
if resp.status_code == 401:
|
||||
self._access_token = None # 强制刷新
|
||||
self._access_token = None
|
||||
token = await self._get_access_token()
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
resp = await client.request(method, url, headers=headers, **kwargs)
|
||||
@@ -115,25 +114,23 @@ class OneDriveAdapter:
|
||||
"type": "dir" if is_dir else "file",
|
||||
}
|
||||
|
||||
async def list_dir(self, root: str, rel: str, page_num: int = 1, page_size: int = 50) -> Tuple[List[Dict], int]:
|
||||
async def list_dir(self, root: str, rel: str, page_num: int = 1, page_size: int = 50, sort_by: str = "name", sort_order: str = "asc") -> Tuple[List[Dict], int]:
|
||||
"""
|
||||
列出目录内容。
|
||||
由于 Graph API 不支持基于偏移($skip)的分页,此方法将获取所有项目,
|
||||
:param root: 根路径 (在此适配器中未使用,通过配置的 root 确定)。
|
||||
:param rel: 相对路径。
|
||||
:param page_num: 页码。
|
||||
:param page_size: 每页大小。
|
||||
:param sort_by: 排序字段
|
||||
:param sort_order: 排序顺序
|
||||
:return: 文件/目录列表和总数。
|
||||
"""
|
||||
api_path = self._get_api_path(rel)
|
||||
children_path = f"{api_path}:/children" if api_path else "/children"
|
||||
|
||||
# Graph API 的分页是基于 @odata.nextLink token 的。
|
||||
# 为了支持自定义排序(文件夹在前),我们必须获取所有项目,
|
||||
# 然后在内存中进行排序和分页。此版本通过处理分页链接来稳健地获取所有项目。
|
||||
all_items = []
|
||||
|
||||
# 初始请求
|
||||
resp = await self._request("GET", api_path_segment=children_path, params={"$top": 200})
|
||||
params = {"$top": 999}
|
||||
resp = await self._request("GET", api_path_segment=children_path, params=params)
|
||||
|
||||
while True:
|
||||
if resp.status_code == 404 and not all_items:
|
||||
@@ -151,13 +148,25 @@ class OneDriveAdapter:
|
||||
if not next_link:
|
||||
break
|
||||
|
||||
# 后续分页请求
|
||||
resp = await self._request("GET", full_url=next_link)
|
||||
|
||||
formatted_items = [self._format_item(item) for item in all_items]
|
||||
# 排序:文件夹在前,然后按名称排序
|
||||
formatted_items.sort(key=lambda x: (
|
||||
not x["is_dir"], x["name"].lower()))
|
||||
|
||||
# 排序
|
||||
reverse = sort_order.lower() == "desc"
|
||||
def get_sort_key(item):
|
||||
key = (not item["is_dir"],)
|
||||
sort_field = sort_by.lower()
|
||||
if sort_field == "name":
|
||||
key += (item["name"].lower(),)
|
||||
elif sort_field == "size":
|
||||
key += (item["size"],)
|
||||
elif sort_field == "mtime":
|
||||
key += (item["mtime"],)
|
||||
else:
|
||||
key += (item["name"].lower(),)
|
||||
return key
|
||||
formatted_items.sort(key=get_sort_key, reverse=reverse)
|
||||
|
||||
total_count = len(formatted_items)
|
||||
start_idx = (page_num - 1) * page_size
|
||||
@@ -362,7 +371,7 @@ class OneDriveAdapter:
|
||||
|
||||
async def file_iterator():
|
||||
nonlocal start, end
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
req_headers = {'Range': f'bytes={start}-{end}'}
|
||||
async with client.stream("GET", download_url, headers=req_headers) as stream_resp:
|
||||
stream_resp.raise_for_status()
|
||||
@@ -371,6 +380,26 @@ class OneDriveAdapter:
|
||||
|
||||
return StreamingResponse(file_iterator(), status_code=status, headers=headers, media_type=content_type)
|
||||
|
||||
async def get_direct_download_response(self, root: str, rel: str):
|
||||
if not self.enable_redirect_307:
|
||||
return None
|
||||
|
||||
api_path = self._get_api_path(rel)
|
||||
if not api_path:
|
||||
raise IsADirectoryError("不能对目录进行直链重定向")
|
||||
|
||||
resp = await self._request("GET", api_path_segment=api_path)
|
||||
if resp.status_code == 404:
|
||||
raise FileNotFoundError(rel)
|
||||
resp.raise_for_status()
|
||||
|
||||
item_data = resp.json()
|
||||
download_url = item_data.get("@microsoft.graph.downloadUrl")
|
||||
if not download_url:
|
||||
return None
|
||||
|
||||
return Response(status_code=307, headers={"Location": download_url})
|
||||
|
||||
async def get_thumbnail(self, root: str, rel: str, size: str = "medium"):
|
||||
"""
|
||||
获取文件的缩略图。
|
||||
@@ -389,7 +418,7 @@ class OneDriveAdapter:
|
||||
resp = await self._request("GET", api_path_segment=thumb_path)
|
||||
if resp.status_code == 200:
|
||||
thumb_data = resp.json()
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
thumb_resp = await client.get(thumb_data['url'])
|
||||
thumb_resp.raise_for_status()
|
||||
return thumb_resp.content
|
||||
@@ -415,16 +444,17 @@ class OneDriveAdapter:
|
||||
return self._format_item(resp.json())
|
||||
|
||||
|
||||
ADAPTER_TYPE = "OneDrive"
|
||||
ADAPTER_TYPE = "onedrive"
|
||||
|
||||
CONFIG_SCHEMA = [
|
||||
{"key": "client_id", "label": "Client ID", "type": "string", "required": True},
|
||||
{"key": "client_secret", "label": "Client Secret",
|
||||
"type": "password", "required": True},
|
||||
{"key": "refresh_token", "label": "Refresh Token", "type": "password",
|
||||
"required": True, "help_text": "可以通过运行 'python -m services.adapters.onedrive' 获取"},
|
||||
"required": True, "help_text": "可以通过运行 'python -m domain.adapters.providers.onedrive' 获取"},
|
||||
{"key": "root", "label": "根目录 (Root Path)", "type": "string",
|
||||
"required": False, "placeholder": "默认为根目录 /"},
|
||||
{"key": "enable_direct_download_307", "label": "Enable 307 redirect download", "type": "boolean", "default": False},
|
||||
]
|
||||
|
||||
|
||||
758
domain/adapters/providers/quark.py
Normal file
758
domain/adapters/providers/quark.py
Normal file
@@ -0,0 +1,758 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import mimetypes
|
||||
import os
|
||||
import time
|
||||
from typing import Dict, List, Tuple, Optional, AsyncIterator, Any
|
||||
|
||||
import httpx
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from models import StorageAdapter
|
||||
from .base import BaseAdapter
|
||||
|
||||
|
||||
# Quark 普通(UC)接口
|
||||
API_BASE = "https://drive.quark.cn/1/clouddrive"
|
||||
REFERER = "https://pan.quark.cn"
|
||||
PR = "ucpro"
|
||||
|
||||
|
||||
class QuarkAdapter:
|
||||
"""夸克网盘(Cookie 模式)
|
||||
|
||||
- 使用浏览器导出的 Cookie 进行鉴权
|
||||
- 通过 Quark/UC 的 clouddrive 接口实现:列目录、读写、分片上传、基础操作
|
||||
- 根 FID 固定为 "0";路径解析通过名称遍历
|
||||
"""
|
||||
|
||||
def __init__(self, record: StorageAdapter):
|
||||
self.record = record
|
||||
cfg = record.config or {}
|
||||
self.cookie: str = cfg.get("cookie") or cfg.get("Cookie")
|
||||
self.root_fid: str = cfg.get("root_fid", "0")
|
||||
def _as_bool(value: Any) -> bool:
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
return value.strip().lower() in {"1", "true", "yes", "on"}
|
||||
return bool(value)
|
||||
|
||||
self.use_transcoding_address: bool = _as_bool(cfg.get("use_transcoding_address", False))
|
||||
self.only_list_video_file: bool = _as_bool(cfg.get("only_list_video_file", False))
|
||||
|
||||
if not self.cookie:
|
||||
raise ValueError("Quark 适配器需要 cookie 配置")
|
||||
|
||||
# 运行期缓存
|
||||
self._dir_fid_cache: Dict[str, str] = {f"{self.root_fid}:": self.root_fid}
|
||||
self._children_cache: Dict[str, List[Dict[str, Any]]] = {}
|
||||
|
||||
# UA 与超时
|
||||
self._ua = (
|
||||
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 "
|
||||
"(KHTML, like Gecko) quark-cloud-drive/2.5.20 Chrome/100.0.4896.160 "
|
||||
"Electron/18.3.5.4-b478491100 Safari/537.36 Channel/pckk_other_ch"
|
||||
)
|
||||
self._timeout = 30.0
|
||||
|
||||
# -----------------
|
||||
# 工具与通用请求
|
||||
# -----------------
|
||||
def get_effective_root(self, sub_path: str | None) -> str:
|
||||
return self.root_fid
|
||||
|
||||
async def _request(
|
||||
self,
|
||||
method: str,
|
||||
pathname: str,
|
||||
*,
|
||||
json: Any | None = None,
|
||||
params: Dict[str, str] | None = None,
|
||||
) -> Any:
|
||||
headers = {
|
||||
"Cookie": self._safe_cookie(self.cookie),
|
||||
"Accept": "application/json, text/plain, */*",
|
||||
"Referer": REFERER,
|
||||
"User-Agent": self._ua,
|
||||
}
|
||||
query = {"pr": PR, "fr": "pc"}
|
||||
if params:
|
||||
query.update(params)
|
||||
url = f"{API_BASE}{pathname}"
|
||||
|
||||
async with httpx.AsyncClient(timeout=self._timeout) as client:
|
||||
resp = await client.request(method, url, headers=headers, params=query, json=json)
|
||||
# 更新运行期 cookie(若返回 __puus/__pus)
|
||||
try:
|
||||
for key in ("__puus", "__pus"):
|
||||
v = resp.cookies.get(key)
|
||||
if v:
|
||||
# 简单替换/追加到 self.cookie
|
||||
self._set_cookie_kv(key, v)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 解析业务状态
|
||||
data = None
|
||||
try:
|
||||
data = resp.json()
|
||||
except Exception:
|
||||
resp.raise_for_status()
|
||||
return resp
|
||||
status = data.get("status")
|
||||
code = data.get("code")
|
||||
msg = data.get("message") or ""
|
||||
if (status is not None and status >= 400) or (code is not None and code != 0):
|
||||
raise HTTPException(502, detail=f"Quark error status={status} code={code} msg={msg}")
|
||||
return data
|
||||
|
||||
def _set_cookie_kv(self, key: str, value: str):
|
||||
# 将指定键值写入 self.cookie(粗略字符串处理)
|
||||
parts = [p.strip() for p in (self.cookie or "").replace("\r", "").replace("\n", "").split(";") if p.strip()]
|
||||
found = False
|
||||
for i, p in enumerate(parts):
|
||||
if p.startswith(key + "="):
|
||||
parts[i] = f"{key}={value}"
|
||||
found = True
|
||||
break
|
||||
if not found:
|
||||
parts.append(f"{key}={value}")
|
||||
self.cookie = "; ".join(parts)
|
||||
|
||||
def _sanitize_cookie(self, cookie: str) -> str:
|
||||
if not cookie:
|
||||
return ""
|
||||
# 去除换行与前后空白
|
||||
cookie = cookie.replace("\r", "").replace("\n", "").strip()
|
||||
# 统一分号分隔并去除多余空格/空段
|
||||
parts = [p.strip() for p in cookie.split(";") if p.strip()]
|
||||
return "; ".join(parts)
|
||||
|
||||
def _safe_cookie(self, cookie: str) -> str:
|
||||
s = self._sanitize_cookie(cookie)
|
||||
# 仅保留可见 ASCII (0x20-0x7E)
|
||||
s = "".join(ch for ch in s if 32 <= ord(ch) <= 126)
|
||||
return s
|
||||
|
||||
# -----------------
|
||||
# 列表与路径解析
|
||||
# -----------------
|
||||
def _map_file_item(self, it: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# Quark/UC 列表项:file=true 表示文件;false 表示目录
|
||||
is_dir = not bool(it.get("file", False))
|
||||
updated_at_ms = int(it.get("updated_at", 0) or 0)
|
||||
name = it.get("file_name") or it.get("filename") or it.get("name")
|
||||
return {
|
||||
"fid": it.get("fid"),
|
||||
"name": name,
|
||||
"is_dir": is_dir,
|
||||
"size": 0 if is_dir else int(it.get("size", 0) or 0),
|
||||
"mtime": updated_at_ms // 1000 if updated_at_ms else 0,
|
||||
"type": "dir" if is_dir else "file",
|
||||
}
|
||||
|
||||
async def _list_children(self, parent_fid: str) -> List[Dict[str, Any]]:
|
||||
if parent_fid in self._children_cache:
|
||||
return self._children_cache[parent_fid]
|
||||
|
||||
files: List[Dict[str, Any]] = []
|
||||
page = 1
|
||||
size = 100
|
||||
total = None
|
||||
while True:
|
||||
qp = {"pdir_fid": parent_fid, "_size": str(size), "_page": str(page), "_fetch_total": "1"}
|
||||
data = await self._request("GET", "/file/sort", params=qp)
|
||||
d = (data or {}).get("data", {})
|
||||
meta = (data or {}).get("metadata", {})
|
||||
page_files = d.get("list", [])
|
||||
files.extend(page_files)
|
||||
if total is None:
|
||||
total = meta.get("_total") or meta.get("total") or 0
|
||||
if page * size >= int(total):
|
||||
break
|
||||
page += 1
|
||||
|
||||
mapped = [self._map_file_item(x) for x in files if (not self.only_list_video_file) or (not x.get("file")) or (x.get("category") == 1)]
|
||||
self._children_cache[parent_fid] = mapped
|
||||
return mapped
|
||||
|
||||
def _dir_cache_key(self, base_fid: str, rel: str) -> str:
|
||||
return f"{base_fid}:{rel.strip('/')}"
|
||||
|
||||
async def _resolve_dir_fid_from(self, base_fid: str, rel: str) -> str:
|
||||
key = rel.strip("/")
|
||||
cache_key = self._dir_cache_key(base_fid, key)
|
||||
if cache_key in self._dir_fid_cache:
|
||||
return self._dir_fid_cache[cache_key]
|
||||
if key == "":
|
||||
self._dir_fid_cache[cache_key] = base_fid
|
||||
return base_fid
|
||||
|
||||
parent_fid = base_fid
|
||||
path_so_far = []
|
||||
for seg in key.split("/"):
|
||||
if seg == "":
|
||||
continue
|
||||
path_so_far.append(seg)
|
||||
cache_key = self._dir_cache_key(base_fid, "/".join(path_so_far))
|
||||
cached = self._dir_fid_cache.get(cache_key)
|
||||
if cached:
|
||||
parent_fid = cached
|
||||
continue
|
||||
children = await self._list_children(parent_fid)
|
||||
found = next((c for c in children if c["is_dir"] and c["name"] == seg), None)
|
||||
if not found:
|
||||
raise FileNotFoundError(f"Directory not found: {seg}")
|
||||
parent_fid = found["fid"]
|
||||
self._dir_fid_cache[cache_key] = parent_fid
|
||||
|
||||
return parent_fid
|
||||
|
||||
async def _find_child(self, parent_fid: str, name: str) -> Optional[Dict[str, Any]]:
|
||||
children = await self._list_children(parent_fid)
|
||||
for it in children:
|
||||
if it["name"] == name:
|
||||
return it
|
||||
return None
|
||||
|
||||
def _invalidate_children_cache(self, parent_fid: str):
|
||||
if parent_fid in self._children_cache:
|
||||
try:
|
||||
del self._children_cache[parent_fid]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# -----------------
|
||||
# 目录与文件列表
|
||||
# -----------------
|
||||
async def list_dir(
|
||||
self,
|
||||
root: str,
|
||||
rel: str,
|
||||
page_num: int = 1,
|
||||
page_size: int = 50,
|
||||
sort_by: str = "name",
|
||||
sort_order: str = "asc",
|
||||
) -> Tuple[List[Dict], int]:
|
||||
base_fid = root or self.root_fid
|
||||
fid = await self._resolve_dir_fid_from(base_fid, rel)
|
||||
items = await self._list_children(fid)
|
||||
|
||||
# 排序,目录优先
|
||||
reverse = sort_order.lower() == "desc"
|
||||
|
||||
def get_sort_key(item):
|
||||
key = (not item["is_dir"],)
|
||||
sf = sort_by.lower()
|
||||
if sf == "name":
|
||||
key += (item["name"].lower(),)
|
||||
elif sf == "size":
|
||||
key += (item["size"],)
|
||||
elif sf == "mtime":
|
||||
key += (item["mtime"],)
|
||||
else:
|
||||
key += (item["name"].lower(),)
|
||||
return key
|
||||
|
||||
items.sort(key=get_sort_key, reverse=reverse)
|
||||
total = len(items)
|
||||
start = (page_num - 1) * page_size
|
||||
end = start + page_size
|
||||
return items[start:end], total
|
||||
|
||||
# -----------------
|
||||
# 下载与流式下载
|
||||
# -----------------
|
||||
async def _get_download_url(self, fid: str) -> str:
|
||||
data = await self._request("POST", "/file/download", json={"fids": [fid]})
|
||||
arr = (data or {}).get("data", [])
|
||||
if not arr:
|
||||
raise HTTPException(502, detail="No download data returned by Quark")
|
||||
url = arr[0].get("download_url") or arr[0].get("DownloadUrl")
|
||||
if not url:
|
||||
raise HTTPException(502, detail="No download_url returned by Quark")
|
||||
return url
|
||||
|
||||
async def _get_transcoding_url(self, fid: str) -> Optional[str]:
|
||||
try:
|
||||
payload = {"fid": fid, "resolutions": "low,normal,high,super,2k,4k", "supports": "fmp4_av,m3u8,dolby_vision"}
|
||||
data = await self._request("POST", "/file/v2/play/project", json=payload)
|
||||
lst = (data or {}).get("data", {}).get("video_list", [])
|
||||
for item in lst:
|
||||
vi = item.get("video_info") or {}
|
||||
url = vi.get("url")
|
||||
if url:
|
||||
return url
|
||||
except Exception:
|
||||
return None
|
||||
return None
|
||||
|
||||
async def get_video_transcoding_url(self, fid: str) -> Optional[str]:
|
||||
if not self.use_transcoding_address:
|
||||
return None
|
||||
return await self._get_transcoding_url(fid)
|
||||
|
||||
def _is_video_name(self, name: str) -> bool:
|
||||
mime, _ = mimetypes.guess_type(name)
|
||||
return bool(mime and mime.startswith("video/"))
|
||||
|
||||
def _download_headers(self) -> Dict[str, str]:
|
||||
return {"Cookie": self._safe_cookie(self.cookie), "User-Agent": self._ua, "Referer": REFERER}
|
||||
|
||||
async def read_file(self, root: str, rel: str) -> bytes:
|
||||
if not rel or rel.endswith("/"):
|
||||
raise IsADirectoryError("Path is a directory")
|
||||
parent = rel.rsplit("/", 1)[0] if "/" in rel else ""
|
||||
name = rel.rsplit("/", 1)[-1]
|
||||
base_fid = root or self.root_fid
|
||||
parent_fid = await self._resolve_dir_fid_from(base_fid, parent)
|
||||
it = await self._find_child(parent_fid, name)
|
||||
if not it or it["is_dir"]:
|
||||
raise FileNotFoundError(rel)
|
||||
url = await self._get_download_url(it["fid"])
|
||||
headers = self._download_headers()
|
||||
async with httpx.AsyncClient(timeout=None, follow_redirects=True) as client:
|
||||
resp = await client.get(url, headers=headers)
|
||||
if resp.status_code == 404:
|
||||
raise FileNotFoundError(rel)
|
||||
resp.raise_for_status()
|
||||
return resp.content
|
||||
|
||||
async def read_file_range(self, root: str, rel: str, start: int, end: Optional[int] = None) -> bytes:
|
||||
if not rel or rel.endswith("/"):
|
||||
raise IsADirectoryError("Path is a directory")
|
||||
parent = rel.rsplit("/", 1)[0] if "/" in rel else ""
|
||||
name = rel.rsplit("/", 1)[-1]
|
||||
base_fid = root or self.root_fid
|
||||
parent_fid = await self._resolve_dir_fid_from(base_fid, parent)
|
||||
it = await self._find_child(parent_fid, name)
|
||||
if not it or it["is_dir"]:
|
||||
raise FileNotFoundError(rel)
|
||||
|
||||
url = await self._get_download_url(it["fid"])
|
||||
headers = dict(self._download_headers())
|
||||
headers["Range"] = f"bytes={start}-" if end is None else f"bytes={start}-{end}"
|
||||
async with httpx.AsyncClient(timeout=self._timeout, follow_redirects=True) as client:
|
||||
resp = await client.get(url, headers=headers)
|
||||
if resp.status_code == 404:
|
||||
raise FileNotFoundError(rel)
|
||||
if resp.status_code == 416:
|
||||
raise HTTPException(416, detail="Requested Range Not Satisfiable")
|
||||
resp.raise_for_status()
|
||||
return resp.content
|
||||
|
||||
async def stream_file(self, root: str, rel: str, range_header: str | None):
|
||||
if not rel or rel.endswith("/"):
|
||||
raise IsADirectoryError("Path is a directory")
|
||||
parent = rel.rsplit("/", 1)[0] if "/" in rel else ""
|
||||
name = rel.rsplit("/", 1)[-1]
|
||||
base_fid = root or self.root_fid
|
||||
parent_fid = await self._resolve_dir_fid_from(base_fid, parent)
|
||||
it = await self._find_child(parent_fid, name)
|
||||
if not it or it["is_dir"]:
|
||||
raise FileNotFoundError(rel)
|
||||
url = await self._get_download_url(it["fid"])
|
||||
if self.use_transcoding_address and self._is_video_name(name):
|
||||
tr = await self._get_transcoding_url(it["fid"])
|
||||
if tr:
|
||||
url = tr
|
||||
dl_headers = self._download_headers()
|
||||
|
||||
# 预获取大小/是否支持范围
|
||||
total_size: Optional[int] = None
|
||||
async with httpx.AsyncClient(timeout=self._timeout, follow_redirects=True) as client:
|
||||
try:
|
||||
head_resp = await client.head(url, headers=dl_headers)
|
||||
if head_resp.status_code == 200:
|
||||
cl = head_resp.headers.get("Content-Length")
|
||||
if cl and cl.isdigit():
|
||||
total_size = int(cl)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
mime, _ = mimetypes.guess_type(rel)
|
||||
content_type = mime or "application/octet-stream"
|
||||
|
||||
# 解析 Range
|
||||
start = 0
|
||||
end: Optional[int] = None
|
||||
status_code = 200
|
||||
if range_header and range_header.startswith("bytes="):
|
||||
status_code = 206
|
||||
part = range_header.split("=", 1)[1]
|
||||
s, e = part.split("-", 1)
|
||||
if s.strip():
|
||||
start = int(s)
|
||||
if e.strip():
|
||||
end = int(e)
|
||||
|
||||
if total_size is not None and end is None and status_code == 206:
|
||||
end = total_size - 1
|
||||
if end is not None and total_size is not None and end >= total_size:
|
||||
end = total_size - 1
|
||||
if total_size is not None and start >= total_size:
|
||||
raise HTTPException(416, detail="Requested Range Not Satisfiable")
|
||||
|
||||
resp_headers: Dict[str, str] = {"Accept-Ranges": "bytes", "Content-Type": content_type}
|
||||
if status_code == 206 and total_size is not None and end is not None:
|
||||
resp_headers["Content-Range"] = f"bytes {start}-{end}/{total_size}"
|
||||
resp_headers["Content-Length"] = str(end - start + 1)
|
||||
elif total_size is not None:
|
||||
resp_headers["Content-Length"] = str(total_size)
|
||||
|
||||
async def iterator():
|
||||
headers = dict(dl_headers)
|
||||
if status_code == 206 and end is not None:
|
||||
headers["Range"] = f"bytes={start}-{end}"
|
||||
async with httpx.AsyncClient(timeout=None, follow_redirects=True) as client:
|
||||
async with client.stream("GET", url, headers=headers) as resp:
|
||||
if resp.status_code in (404, 416):
|
||||
await resp.aclose()
|
||||
raise HTTPException(resp.status_code, detail="Upstream not available")
|
||||
async for chunk in resp.aiter_bytes():
|
||||
if chunk:
|
||||
yield chunk
|
||||
|
||||
return StreamingResponse(iterator(), status_code=status_code, headers=resp_headers, media_type=content_type)
|
||||
|
||||
# -----------------
|
||||
# 上传(大文件分片)
|
||||
# -----------------
|
||||
@staticmethod
|
||||
def _md5_hex(b: bytes) -> str:
|
||||
return hashlib.md5(b).hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def _sha1_hex(b: bytes) -> str:
|
||||
return hashlib.sha1(b).hexdigest()
|
||||
|
||||
def _guess_mime(self, name: str) -> str:
|
||||
mime, _ = mimetypes.guess_type(name)
|
||||
return mime or "application/octet-stream"
|
||||
|
||||
async def _upload_pre(self, filename: str, size: int, parent_fid: str) -> Dict[str, Any]:
|
||||
now_ms = int(time.time() * 1000)
|
||||
body = {
|
||||
"ccp_hash_update": True,
|
||||
"dir_name": "",
|
||||
"file_name": filename,
|
||||
"format_type": self._guess_mime(filename),
|
||||
"l_created_at": now_ms,
|
||||
"l_updated_at": now_ms,
|
||||
"pdir_fid": parent_fid,
|
||||
"size": size,
|
||||
}
|
||||
data = await self._request("POST", "/file/upload/pre", json=body)
|
||||
return data
|
||||
|
||||
async def write_file(self, root: str, rel: str, data: bytes):
|
||||
async def gen():
|
||||
yield data
|
||||
return await self.write_file_stream(root, rel, gen())
|
||||
|
||||
async def write_file_stream(self, root: str, rel: str, data_iter: AsyncIterator[bytes]):
|
||||
if not rel or rel.endswith("/"):
|
||||
raise HTTPException(400, detail="Invalid file path")
|
||||
|
||||
parent = rel.rsplit("/", 1)[0] if "/" in rel else ""
|
||||
name = rel.rsplit("/", 1)[-1]
|
||||
base_fid = root or self.root_fid
|
||||
parent_fid = await self._resolve_dir_fid_from(base_fid, parent)
|
||||
|
||||
# 将数据落盘到临时文件,同时计算 MD5/SHA1
|
||||
import tempfile
|
||||
|
||||
md5 = hashlib.md5()
|
||||
sha1 = hashlib.sha1()
|
||||
total = 0
|
||||
with tempfile.NamedTemporaryFile(delete=False) as tf:
|
||||
tmp_path = tf.name
|
||||
try:
|
||||
async for chunk in data_iter:
|
||||
if not chunk:
|
||||
continue
|
||||
total += len(chunk)
|
||||
md5.update(chunk)
|
||||
sha1.update(chunk)
|
||||
tf.write(chunk)
|
||||
finally:
|
||||
tf.flush()
|
||||
|
||||
md5_hex = md5.hexdigest()
|
||||
sha1_hex = sha1.hexdigest()
|
||||
|
||||
# 预上传,拿到上传信息
|
||||
pre_resp = await self._upload_pre(name, total, parent_fid)
|
||||
pre_data = pre_resp.get("data", {})
|
||||
|
||||
# hash 秒传
|
||||
hash_body = {"md5": md5_hex, "sha1": sha1_hex, "task_id": pre_data.get("task_id")}
|
||||
hash_resp = await self._request("POST", "/file/update/hash", json=hash_body)
|
||||
if (hash_resp.get("data") or {}).get("finish") is True:
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except Exception:
|
||||
pass
|
||||
# 刷新父目录缓存
|
||||
self._invalidate_children_cache(parent_fid)
|
||||
return total
|
||||
|
||||
# 分片上传
|
||||
part_size = int((pre_resp.get("metadata") or {}).get("part_size") or 0)
|
||||
if part_size <= 0:
|
||||
raise HTTPException(502, detail="Invalid part_size from Quark")
|
||||
|
||||
bucket = pre_data.get("bucket")
|
||||
obj_key = pre_data.get("obj_key")
|
||||
upload_id = pre_data.get("upload_id")
|
||||
upload_url = pre_data.get("upload_url")
|
||||
if not (bucket and obj_key and upload_id and upload_url):
|
||||
raise HTTPException(502, detail="Upload pre missing fields")
|
||||
|
||||
# 计算 host 与基础 URL
|
||||
try:
|
||||
upload_host = upload_url.split("://", 1)[1]
|
||||
except Exception:
|
||||
upload_host = upload_url
|
||||
base_url = f"https://{bucket}.{upload_host}/{obj_key}"
|
||||
|
||||
# 分片循环
|
||||
etags: List[str] = []
|
||||
oss_ua = "aliyun-sdk-js/6.6.1 Chrome 98.0.4758.80 on Windows 10 64-bit"
|
||||
async with httpx.AsyncClient(timeout=None, follow_redirects=True) as client:
|
||||
with open(tmp_path, "rb") as rf:
|
||||
part_number = 1
|
||||
left = total
|
||||
while left > 0:
|
||||
sz = min(part_size, left)
|
||||
data_bytes = rf.read(sz)
|
||||
if len(data_bytes) != sz:
|
||||
raise IOError("Failed to read part bytes")
|
||||
now_str = time.strftime("%a, %d %b %Y %H:%M:%S GMT", time.gmtime())
|
||||
# 申请签名
|
||||
auth_meta = (
|
||||
"PUT\n\n"
|
||||
f"{self._guess_mime(name)}\n"
|
||||
f"{now_str}\n"
|
||||
f"x-oss-date:{now_str}\n"
|
||||
f"x-oss-user-agent:{oss_ua}\n"
|
||||
f"/{bucket}/{obj_key}?partNumber={part_number}&uploadId={upload_id}"
|
||||
)
|
||||
auth_req_body = {"auth_info": pre_data.get("auth_info"), "auth_meta": auth_meta, "task_id": pre_data.get("task_id")}
|
||||
auth_resp = await self._request("POST", "/file/upload/auth", json=auth_req_body)
|
||||
auth_key = (auth_resp.get("data") or {}).get("auth_key")
|
||||
if not auth_key:
|
||||
raise HTTPException(502, detail="upload/auth missing auth_key")
|
||||
|
||||
put_headers = {
|
||||
"Authorization": auth_key,
|
||||
"Content-Type": self._guess_mime(name),
|
||||
"Referer": REFERER + "/",
|
||||
"x-oss-date": now_str,
|
||||
"x-oss-user-agent": oss_ua,
|
||||
}
|
||||
put_url = f"{base_url}?partNumber={part_number}&uploadId={upload_id}"
|
||||
put_resp = await client.put(put_url, headers=put_headers, content=data_bytes)
|
||||
if put_resp.status_code != 200:
|
||||
raise HTTPException(502, detail=f"Upload part failed status={put_resp.status_code} text={put_resp.text}")
|
||||
etag = put_resp.headers.get("Etag", "")
|
||||
etags.append(etag)
|
||||
left -= sz
|
||||
part_number += 1
|
||||
|
||||
# 组合 commit xml
|
||||
parts_xml = [f"<Part>\n<PartNumber>{i+1}</PartNumber>\n<ETag>{etags[i]}</ETag>\n</Part>\n" for i in range(len(etags))]
|
||||
body_xml = "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<CompleteMultipartUpload>\n" + "".join(parts_xml) + "</CompleteMultipartUpload>"
|
||||
content_md5 = base64.b64encode(hashlib.md5(body_xml.encode("utf-8")).digest()).decode("ascii")
|
||||
callback = pre_data.get("callback") or {}
|
||||
try:
|
||||
import json as _json
|
||||
callback_b64 = base64.b64encode(_json.dumps(callback).encode("utf-8")).decode("ascii")
|
||||
except Exception:
|
||||
callback_b64 = ""
|
||||
|
||||
now_str = time.strftime("%a, %d %b %Y %H:%M:%S GMT", time.gmtime())
|
||||
auth_meta_commit = (
|
||||
"POST\n"
|
||||
f"{content_md5}\n"
|
||||
"application/xml\n"
|
||||
f"{now_str}\n"
|
||||
f"x-oss-callback:{callback_b64}\n"
|
||||
f"x-oss-date:{now_str}\n"
|
||||
f"x-oss-user-agent:{oss_ua}\n"
|
||||
f"/{bucket}/{obj_key}?uploadId={upload_id}"
|
||||
)
|
||||
auth_commit_resp = await self._request("POST", "/file/upload/auth", json={"auth_info": pre_data.get("auth_info"), "auth_meta": auth_meta_commit, "task_id": pre_data.get("task_id")})
|
||||
auth_key_commit = (auth_commit_resp.get("data") or {}).get("auth_key")
|
||||
if not auth_key_commit:
|
||||
raise HTTPException(502, detail="upload/auth(commit) missing auth_key")
|
||||
|
||||
async with httpx.AsyncClient(timeout=None, follow_redirects=True) as client:
|
||||
commit_headers = {
|
||||
"Authorization": auth_key_commit,
|
||||
"Content-MD5": content_md5,
|
||||
"Content-Type": "application/xml",
|
||||
"Referer": REFERER + "/",
|
||||
"x-oss-callback": callback_b64,
|
||||
"x-oss-date": now_str,
|
||||
"x-oss-user-agent": oss_ua,
|
||||
}
|
||||
commit_url = f"{base_url}?uploadId={upload_id}"
|
||||
r = await client.post(commit_url, headers=commit_headers, content=body_xml.encode("utf-8"))
|
||||
if r.status_code != 200:
|
||||
raise HTTPException(502, detail=f"Upload commit failed status={r.status_code} text={r.text}")
|
||||
|
||||
# finish
|
||||
await self._request("POST", "/file/upload/finish", json={"obj_key": obj_key, "task_id": pre_data.get("task_id")})
|
||||
# 端合并存在轻微延迟,等待再刷新缓存
|
||||
try:
|
||||
await asyncio.sleep(1.0)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except Exception:
|
||||
pass
|
||||
# 失效父目录缓存,确保后续列表可见
|
||||
self._invalidate_children_cache(parent_fid)
|
||||
return total
|
||||
|
||||
# -----------------
|
||||
# 基本文件操作
|
||||
# -----------------
|
||||
async def mkdir(self, root: str, rel: str):
|
||||
if not rel or rel == "/":
|
||||
raise HTTPException(400, detail="Cannot create root")
|
||||
parent = rel.rstrip("/")
|
||||
parent_rel, name = (parent.rsplit("/", 1) if "/" in parent else ("", parent))
|
||||
if not name:
|
||||
raise HTTPException(400, detail="Invalid directory name")
|
||||
pdir = await self._resolve_dir_fid_from(root or self.root_fid, parent_rel)
|
||||
await self._request("POST", "/file", json={"dir_init_lock": False, "dir_path": "", "file_name": name, "pdir_fid": pdir})
|
||||
self._invalidate_children_cache(pdir)
|
||||
|
||||
async def delete(self, root: str, rel: str):
|
||||
# 解析对象 fid + 父目录,用于失效缓存
|
||||
base_fid = root or self.root_fid
|
||||
if rel == "" or rel.endswith("/"):
|
||||
parent_rel = rel.rstrip("/")
|
||||
target_fid = await self._resolve_dir_fid_from(base_fid, parent_rel)
|
||||
parent_of_target = await self._resolve_dir_fid_from(base_fid, (parent_rel.rsplit("/", 1)[0] if "/" in parent_rel else ""))
|
||||
else:
|
||||
parent_rel, name = (rel.rsplit("/", 1) if "/" in rel else ("", rel))
|
||||
parent_of_target = await self._resolve_dir_fid_from(base_fid, parent_rel)
|
||||
it = await self._find_child(parent_of_target, name)
|
||||
if not it:
|
||||
return
|
||||
target_fid = it["fid"]
|
||||
await self._request("POST", "/file/delete", json={"action_type": 1, "exclude_fids": [], "filelist": [target_fid]})
|
||||
self._invalidate_children_cache(parent_of_target)
|
||||
|
||||
async def move(self, root: str, src_rel: str, dst_rel: str):
|
||||
# 支持跨目录与重命名:先移动到父目录,后重命名(若需要)
|
||||
src_parent_rel, src_name = (src_rel.rsplit("/", 1) if "/" in src_rel else ("", src_rel))
|
||||
dst_parent_rel, dst_name = (dst_rel.rsplit("/", 1) if "/" in dst_rel else ("", dst_rel))
|
||||
|
||||
base_fid = root or self.root_fid
|
||||
src_parent_fid = await self._resolve_dir_fid_from(base_fid, src_parent_rel)
|
||||
obj = await self._find_child(src_parent_fid, src_name)
|
||||
if not obj:
|
||||
raise FileNotFoundError(src_rel)
|
||||
dst_parent_fid = await self._resolve_dir_fid_from(base_fid, dst_parent_rel)
|
||||
|
||||
if src_parent_fid != dst_parent_fid:
|
||||
await self._request("POST", "/file/move", json={"action_type": 1, "exclude_fids": [], "filelist": [obj["fid"]], "to_pdir_fid": dst_parent_fid})
|
||||
self._invalidate_children_cache(src_parent_fid)
|
||||
self._invalidate_children_cache(dst_parent_fid)
|
||||
|
||||
if obj["name"] != dst_name:
|
||||
await self._request("POST", "/file/rename", json={"fid": obj["fid"], "file_name": dst_name})
|
||||
self._invalidate_children_cache(dst_parent_fid)
|
||||
|
||||
async def rename(self, root: str, src_rel: str, dst_rel: str):
|
||||
src_parent_rel, src_name = (src_rel.rsplit("/", 1) if "/" in src_rel else ("", src_rel))
|
||||
base_fid = root or self.root_fid
|
||||
src_parent_fid = await self._resolve_dir_fid_from(base_fid, src_parent_rel)
|
||||
obj = await self._find_child(src_parent_fid, src_name)
|
||||
if not obj:
|
||||
raise FileNotFoundError(src_rel)
|
||||
dst_name = dst_rel.rsplit("/", 1)[-1]
|
||||
await self._request("POST", "/file/rename", json={"fid": obj["fid"], "file_name": dst_name})
|
||||
self._invalidate_children_cache(src_parent_fid)
|
||||
|
||||
async def copy(self, root: str, src_rel: str, dst_rel: str, overwrite: bool = False):
|
||||
raise NotImplementedError("QuarkOpen does not support copy via open API")
|
||||
|
||||
# -----------------
|
||||
# STAT / EXISTS / 辅助
|
||||
# -----------------
|
||||
async def stat_file(self, root: str, rel: str):
|
||||
# 通过父目录列表获取元数据
|
||||
base_fid = root or self.root_fid
|
||||
if rel == "" or rel.endswith("/"):
|
||||
# 目录
|
||||
fid = await self._resolve_dir_fid_from(base_fid, rel.rstrip("/"))
|
||||
return {"name": rel.rstrip("/").split("/")[-1] if rel else "", "is_dir": True, "size": 0, "mtime": 0, "type": "dir", "fid": fid}
|
||||
parent_rel, name = (rel.rsplit("/", 1) if "/" in rel else ("", rel))
|
||||
parent_fid = await self._resolve_dir_fid_from(base_fid, parent_rel)
|
||||
it = await self._find_child(parent_fid, name)
|
||||
if not it:
|
||||
raise FileNotFoundError(rel)
|
||||
return it
|
||||
|
||||
async def exists(self, root: str, rel: str) -> bool:
|
||||
try:
|
||||
base_fid = root or self.root_fid
|
||||
if rel == "" or rel.endswith("/"):
|
||||
await self._resolve_dir_fid_from(base_fid, rel.rstrip("/"))
|
||||
return True
|
||||
parent_rel, name = (rel.rsplit("/", 1) if "/" in rel else ("", rel))
|
||||
parent_fid = await self._resolve_dir_fid_from(base_fid, parent_rel)
|
||||
it = await self._find_child(parent_fid, name)
|
||||
return it is not None
|
||||
except FileNotFoundError:
|
||||
return False
|
||||
|
||||
async def stat_path(self, root: str, rel: str):
|
||||
# 用于 move/copy 前的预检查调试
|
||||
try:
|
||||
base_fid = root or self.root_fid
|
||||
if rel == "" or rel.endswith("/"):
|
||||
fid = await self._resolve_dir_fid_from(base_fid, rel.rstrip("/"))
|
||||
return {"exists": True, "is_dir": True, "path": rel, "fid": fid}
|
||||
parent_rel, name = (rel.rsplit("/", 1) if "/" in rel else ("", rel))
|
||||
parent_fid = await self._resolve_dir_fid_from(base_fid, parent_rel)
|
||||
it = await self._find_child(parent_fid, name)
|
||||
if it:
|
||||
return {"exists": True, "is_dir": it["is_dir"], "path": rel, "fid": it["fid"]}
|
||||
return {"exists": False, "is_dir": None, "path": rel}
|
||||
except FileNotFoundError:
|
||||
return {"exists": False, "is_dir": None, "path": rel}
|
||||
|
||||
async def _resolve_target_fid(self, rel: str, *, base_fid: Optional[str] = None) -> str:
|
||||
base = base_fid or self.root_fid
|
||||
if rel == "" or rel.endswith("/"):
|
||||
return await self._resolve_dir_fid_from(base, rel.rstrip("/"))
|
||||
parent_rel, name = (rel.rsplit("/", 1) if "/" in rel else ("", rel))
|
||||
parent_fid = await self._resolve_dir_fid_from(base, parent_rel)
|
||||
it = await self._find_child(parent_fid, name)
|
||||
if not it:
|
||||
raise FileNotFoundError(rel)
|
||||
return it["fid"]
|
||||
|
||||
|
||||
ADAPTER_TYPE = "quark"
|
||||
|
||||
CONFIG_SCHEMA = [
|
||||
{"key": "cookie", "label": "Cookie", "type": "password", "required": True, "placeholder": "从 pan.quark.cn 复制"},
|
||||
{"key": "root_fid", "label": "根 FID", "type": "string", "required": False, "default": "0"},
|
||||
{"key": "use_transcoding_address", "label": "视频转码直链", "type": "boolean", "required": False, "default": False},
|
||||
{"key": "only_list_video_file", "label": "仅列出视频文件", "type": "boolean", "required": False, "default": False},
|
||||
]
|
||||
|
||||
def ADAPTER_FACTORY(rec: StorageAdapter) -> BaseAdapter:
|
||||
return QuarkAdapter(rec)
|
||||
@@ -1,4 +1,3 @@
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import mimetypes
|
||||
from datetime import datetime
|
||||
@@ -10,7 +9,6 @@ from botocore.exceptions import ClientError
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from models import StorageAdapter
|
||||
from services.logging import LogService
|
||||
|
||||
|
||||
class S3Adapter:
|
||||
@@ -52,7 +50,7 @@ class S3Adapter:
|
||||
def _get_client(self):
|
||||
return self.session.client("s3", endpoint_url=self.endpoint_url)
|
||||
|
||||
async def list_dir(self, root: str, rel: str, page_num: int = 1, page_size: int = 50) -> Tuple[List[Dict], int]:
|
||||
async def list_dir(self, root: str, rel: str, page_num: int = 1, page_size: int = 50, sort_by: str = "name", sort_order: str = "asc") -> Tuple[List[Dict], int]:
|
||||
prefix = self._get_s3_key(rel)
|
||||
if prefix and not prefix.endswith("/"):
|
||||
prefix += "/"
|
||||
@@ -91,7 +89,21 @@ class S3Adapter:
|
||||
})
|
||||
|
||||
# 在内存中排序和分页
|
||||
all_items.sort(key=lambda x: (not x["is_dir"], x["name"].lower()))
|
||||
reverse = sort_order.lower() == "desc"
|
||||
def get_sort_key(item):
|
||||
key = (not item["is_dir"],)
|
||||
sort_field = sort_by.lower()
|
||||
if sort_field == "name":
|
||||
key += (item["name"].lower(),)
|
||||
elif sort_field == "size":
|
||||
key += (item["size"],)
|
||||
elif sort_field == "mtime":
|
||||
key += (item["mtime"],)
|
||||
else:
|
||||
key += (item["name"].lower(),)
|
||||
return key
|
||||
all_items.sort(key=get_sort_key, reverse=reverse)
|
||||
|
||||
total_count = len(all_items)
|
||||
start_idx = (page_num - 1) * page_size
|
||||
end_idx = start_idx + page_size
|
||||
@@ -113,11 +125,6 @@ class S3Adapter:
|
||||
key = self._get_s3_key(rel)
|
||||
async with self._get_client() as s3:
|
||||
await s3.put_object(Bucket=self.bucket_name, Key=key, Body=data)
|
||||
await LogService.info(
|
||||
"adapter:s3", f"Wrote file to {rel}",
|
||||
details={"adapter_id": self.record.id,
|
||||
"bucket": self.bucket_name, "key": key, "size": len(data)}
|
||||
)
|
||||
|
||||
async def write_file_stream(self, root: str, rel: str, data_iter: AsyncIterator[bytes]):
|
||||
key = self._get_s3_key(rel)
|
||||
@@ -179,10 +186,6 @@ class S3Adapter:
|
||||
)
|
||||
raise IOError(f"S3 stream upload failed: {e}") from e
|
||||
|
||||
await LogService.info(
|
||||
"adapter:s3", f"Wrote file stream to {rel}",
|
||||
details={"adapter_id": self.record.id, "bucket": self.bucket_name, "key": key, "size": total_size}
|
||||
)
|
||||
return total_size
|
||||
|
||||
async def mkdir(self, root: str, rel: str):
|
||||
@@ -191,11 +194,6 @@ class S3Adapter:
|
||||
key += "/"
|
||||
async with self._get_client() as s3:
|
||||
await s3.put_object(Bucket=self.bucket_name, Key=key, Body=b"")
|
||||
await LogService.info(
|
||||
"adapter:s3", f"Created directory {rel}",
|
||||
details={"adapter_id": self.record.id,
|
||||
"bucket": self.bucket_name, "key": key}
|
||||
)
|
||||
|
||||
async def delete(self, root: str, rel: str):
|
||||
key = self._get_s3_key(rel)
|
||||
@@ -223,20 +221,9 @@ class S3Adapter:
|
||||
else:
|
||||
await s3.delete_object(Bucket=self.bucket_name, Key=key)
|
||||
|
||||
await LogService.info(
|
||||
"adapter:s3", f"Deleted {rel}",
|
||||
details={"adapter_id": self.record.id,
|
||||
"bucket": self.bucket_name, "key": key}
|
||||
)
|
||||
|
||||
async def move(self, root: str, src_rel: str, dst_rel: str):
|
||||
await self.copy(root, src_rel, dst_rel, overwrite=True)
|
||||
await self.delete(root, src_rel)
|
||||
await LogService.info(
|
||||
"adapter:s3", f"Moved {src_rel} to {dst_rel}",
|
||||
details={"adapter_id": self.record.id, "bucket": self.bucket_name,
|
||||
"src_key": self._get_s3_key(src_rel), "dst_key": self._get_s3_key(dst_rel)}
|
||||
)
|
||||
|
||||
async def rename(self, root: str, src_rel: str, dst_rel: str):
|
||||
await self.move(root, src_rel, dst_rel)
|
||||
@@ -256,11 +243,6 @@ class S3Adapter:
|
||||
|
||||
copy_source = {"Bucket": self.bucket_name, "Key": src_key}
|
||||
await s3.copy_object(CopySource=copy_source, Bucket=self.bucket_name, Key=dst_key)
|
||||
await LogService.info(
|
||||
"adapter:s3", f"Copied {src_rel} to {dst_rel}",
|
||||
details={"adapter_id": self.record.id, "bucket": self.bucket_name,
|
||||
"src_key": src_key, "dst_key": dst_key}
|
||||
)
|
||||
|
||||
async def stat_file(self, root: str, rel: str):
|
||||
key = self._get_s3_key(rel)
|
||||
@@ -339,13 +321,12 @@ class S3Adapter:
|
||||
while chunk := await body.read(65536):
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
LogService.error(
|
||||
"adapter:s3", f"Error streaming file {key}: {e}")
|
||||
raise
|
||||
|
||||
return StreamingResponse(iterator(), status_code=status, headers=headers, media_type=content_type)
|
||||
|
||||
|
||||
ADAPTER_TYPE = "S3"
|
||||
ADAPTER_TYPE = "s3"
|
||||
|
||||
CONFIG_SCHEMA = [
|
||||
{"key": "bucket_name", "label": "Bucket 名称",
|
||||
438
domain/adapters/providers/sftp.py
Normal file
438
domain/adapters/providers/sftp.py
Normal file
@@ -0,0 +1,438 @@
|
||||
import asyncio
|
||||
import mimetypes
|
||||
import stat as statmod
|
||||
from typing import List, Dict, Tuple, AsyncIterator, Optional
|
||||
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
import paramiko
|
||||
|
||||
from models import StorageAdapter
|
||||
|
||||
|
||||
def _join_remote(root: str, rel: str) -> str:
|
||||
root = (root or "/").rstrip("/") or "/"
|
||||
rel = (rel or "").lstrip("/")
|
||||
if not rel:
|
||||
return root
|
||||
return f"{root}/{rel}"
|
||||
|
||||
|
||||
class SFTPAdapter:
|
||||
def __init__(self, record: StorageAdapter):
|
||||
self.record = record
|
||||
cfg = record.config
|
||||
self.host: str = cfg.get("host")
|
||||
self.port: int = int(cfg.get("port", 22))
|
||||
self.username: str | None = cfg.get("username")
|
||||
self.password: str | None = cfg.get("password")
|
||||
self.timeout: int = int(cfg.get("timeout", 15))
|
||||
self.root_path: str = cfg.get("root") # 必填
|
||||
self.allow_unknown_host: bool = bool(cfg.get("allow_unknown_host", True))
|
||||
|
||||
if not self.host:
|
||||
raise ValueError("SFTP adapter requires 'host'")
|
||||
if not self.username or not self.password:
|
||||
raise ValueError("SFTP adapter requires 'username' and 'password'")
|
||||
if not self.root_path:
|
||||
raise ValueError("SFTP adapter requires 'root'")
|
||||
|
||||
def get_effective_root(self, sub_path: str | None) -> str:
|
||||
base = self.root_path.rstrip("/") or "/"
|
||||
if sub_path:
|
||||
return _join_remote(base, sub_path)
|
||||
return base
|
||||
|
||||
def _connect(self) -> paramiko.SFTPClient:
|
||||
ssh = paramiko.SSHClient()
|
||||
if self.allow_unknown_host:
|
||||
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
||||
ssh.connect(
|
||||
hostname=self.host,
|
||||
port=self.port,
|
||||
username=self.username,
|
||||
password=self.password,
|
||||
timeout=self.timeout,
|
||||
allow_agent=False,
|
||||
look_for_keys=False,
|
||||
)
|
||||
return ssh.open_sftp()
|
||||
|
||||
async def list_dir(self, root: str, rel: str, page_num: int = 1, page_size: int = 50, sort_by: str = "name", sort_order: str = "asc") -> Tuple[List[Dict], int]:
|
||||
path = _join_remote(root, rel)
|
||||
|
||||
def _do_list() -> List[Dict]:
|
||||
sftp = self._connect()
|
||||
try:
|
||||
attrs = sftp.listdir_attr(path)
|
||||
entries: List[Dict] = []
|
||||
for a in attrs:
|
||||
name = a.filename
|
||||
is_dir = statmod.S_ISDIR(a.st_mode)
|
||||
entries.append({
|
||||
"name": name,
|
||||
"is_dir": is_dir,
|
||||
"size": 0 if is_dir else int(a.st_size or 0),
|
||||
"mtime": int(a.st_mtime or 0),
|
||||
"type": "dir" if is_dir else "file",
|
||||
})
|
||||
return entries
|
||||
finally:
|
||||
try:
|
||||
sftp.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
entries = await asyncio.to_thread(_do_list)
|
||||
|
||||
reverse = sort_order.lower() == "desc"
|
||||
|
||||
def get_sort_key(item):
|
||||
key = (not item["is_dir"],)
|
||||
f = sort_by.lower()
|
||||
if f == "name":
|
||||
key += (item["name"].lower(),)
|
||||
elif f == "size":
|
||||
key += (item.get("size", 0),)
|
||||
elif f == "mtime":
|
||||
key += (item.get("mtime", 0),)
|
||||
else:
|
||||
key += (item["name"].lower(),)
|
||||
return key
|
||||
|
||||
entries.sort(key=get_sort_key, reverse=reverse)
|
||||
total = len(entries)
|
||||
start = (page_num - 1) * page_size
|
||||
end = start + page_size
|
||||
return entries[start:end], total
|
||||
|
||||
async def read_file(self, root: str, rel: str) -> bytes:
|
||||
path = _join_remote(root, rel)
|
||||
|
||||
def _do_read() -> bytes:
|
||||
sftp = self._connect()
|
||||
try:
|
||||
with sftp.open(path, "rb") as f:
|
||||
return f.read()
|
||||
except FileNotFoundError:
|
||||
raise
|
||||
except IOError as e:
|
||||
if getattr(e, "errno", None) == 2:
|
||||
raise FileNotFoundError(rel)
|
||||
raise
|
||||
finally:
|
||||
try:
|
||||
sftp.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return await asyncio.to_thread(_do_read)
|
||||
|
||||
async def write_file(self, root: str, rel: str, data: bytes):
|
||||
path = _join_remote(root, rel)
|
||||
|
||||
def _ensure_dirs(sftp: paramiko.SFTPClient, dir_path: str):
|
||||
parts = [p for p in dir_path.strip("/").split("/") if p]
|
||||
cur = "/"
|
||||
for p in parts:
|
||||
cur = _join_remote(cur, p)
|
||||
try:
|
||||
sftp.mkdir(cur)
|
||||
except IOError:
|
||||
# likely exists
|
||||
pass
|
||||
|
||||
def _do_write():
|
||||
sftp = self._connect()
|
||||
try:
|
||||
parent = "/" if "/" not in path.strip("/") else path.rsplit("/", 1)[0]
|
||||
_ensure_dirs(sftp, parent)
|
||||
with sftp.open(path, "wb") as f:
|
||||
f.write(data)
|
||||
finally:
|
||||
try:
|
||||
sftp.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await asyncio.to_thread(_do_write)
|
||||
|
||||
async def write_file_stream(self, root: str, rel: str, data_iter: AsyncIterator[bytes]):
|
||||
buf = bytearray()
|
||||
async for chunk in data_iter:
|
||||
if chunk:
|
||||
buf.extend(chunk)
|
||||
await self.write_file(root, rel, bytes(buf))
|
||||
return len(buf)
|
||||
|
||||
async def mkdir(self, root: str, rel: str):
|
||||
path = _join_remote(root, rel)
|
||||
|
||||
def _do_mkdir():
|
||||
sftp = self._connect()
|
||||
try:
|
||||
parts = [p for p in path.strip("/").split("/") if p]
|
||||
cur = "/"
|
||||
for p in parts:
|
||||
cur = _join_remote(cur, p)
|
||||
try:
|
||||
sftp.mkdir(cur)
|
||||
except IOError:
|
||||
pass
|
||||
finally:
|
||||
try:
|
||||
sftp.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await asyncio.to_thread(_do_mkdir)
|
||||
|
||||
async def delete(self, root: str, rel: str):
|
||||
path = _join_remote(root, rel)
|
||||
|
||||
def _do_delete():
|
||||
sftp = self._connect()
|
||||
try:
|
||||
# Try file remove first
|
||||
try:
|
||||
sftp.remove(path)
|
||||
return
|
||||
except IOError:
|
||||
pass
|
||||
|
||||
def _rm_tree(dp: str):
|
||||
try:
|
||||
for a in sftp.listdir_attr(dp):
|
||||
child = _join_remote(dp, a.filename)
|
||||
if statmod.S_ISDIR(a.st_mode):
|
||||
_rm_tree(child)
|
||||
else:
|
||||
try:
|
||||
sftp.remove(child)
|
||||
except Exception:
|
||||
pass
|
||||
sftp.rmdir(dp)
|
||||
except IOError:
|
||||
pass
|
||||
|
||||
_rm_tree(path)
|
||||
finally:
|
||||
try:
|
||||
sftp.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await asyncio.to_thread(_do_delete)
|
||||
|
||||
async def move(self, root: str, src_rel: str, dst_rel: str):
|
||||
src = _join_remote(root, src_rel)
|
||||
dst = _join_remote(root, dst_rel)
|
||||
|
||||
def _do_move():
|
||||
sftp = self._connect()
|
||||
try:
|
||||
# ensure dst parent exists
|
||||
parent = "/" if "/" not in dst.strip("/") else dst.rsplit("/", 1)[0]
|
||||
parts = [p for p in parent.strip("/").split("/") if p]
|
||||
cur = "/"
|
||||
for p in parts:
|
||||
cur = _join_remote(cur, p)
|
||||
try:
|
||||
sftp.mkdir(cur)
|
||||
except IOError:
|
||||
pass
|
||||
sftp.rename(src, dst)
|
||||
finally:
|
||||
try:
|
||||
sftp.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await asyncio.to_thread(_do_move)
|
||||
|
||||
async def rename(self, root: str, src_rel: str, dst_rel: str):
|
||||
await self.move(root, src_rel, dst_rel)
|
||||
|
||||
async def copy(self, root: str, src_rel: str, dst_rel: str, overwrite: bool = False):
|
||||
src = _join_remote(root, src_rel)
|
||||
dst = _join_remote(root, dst_rel)
|
||||
|
||||
def _is_dir() -> bool:
|
||||
sftp = self._connect()
|
||||
try:
|
||||
st = sftp.stat(src)
|
||||
return statmod.S_ISDIR(st.st_mode)
|
||||
finally:
|
||||
try:
|
||||
sftp.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if await asyncio.to_thread(_is_dir):
|
||||
await self.mkdir(root, dst_rel)
|
||||
|
||||
children, _ = await self.list_dir(root, src_rel, page_num=1, page_size=10_000)
|
||||
for ent in children:
|
||||
child_src = f"{src_rel.rstrip('/')}/{ent['name']}"
|
||||
child_dst = f"{dst_rel.rstrip('/')}/{ent['name']}"
|
||||
await self.copy(root, child_src, child_dst, overwrite)
|
||||
return
|
||||
|
||||
# file copy
|
||||
data = await self.read_file(root, src_rel)
|
||||
if not overwrite:
|
||||
try:
|
||||
await self.stat_file(root, dst_rel)
|
||||
raise FileExistsError(dst_rel)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
await self.write_file(root, dst_rel, data)
|
||||
|
||||
async def stat_file(self, root: str, rel: str):
|
||||
path = _join_remote(root, rel)
|
||||
|
||||
def _do_stat():
|
||||
sftp = self._connect()
|
||||
try:
|
||||
st = sftp.stat(path)
|
||||
is_dir = statmod.S_ISDIR(st.st_mode)
|
||||
info = {
|
||||
"name": rel.split("/")[-1],
|
||||
"is_dir": is_dir,
|
||||
"size": 0 if is_dir else int(st.st_size or 0),
|
||||
"mtime": int(st.st_mtime or 0),
|
||||
"type": "dir" if is_dir else "file",
|
||||
"path": path,
|
||||
}
|
||||
return info
|
||||
except FileNotFoundError:
|
||||
raise
|
||||
except IOError as e:
|
||||
if getattr(e, "errno", None) == 2:
|
||||
raise FileNotFoundError(rel)
|
||||
raise
|
||||
finally:
|
||||
try:
|
||||
sftp.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return await asyncio.to_thread(_do_stat)
|
||||
|
||||
async def exists(self, root: str, rel: str) -> bool:
|
||||
try:
|
||||
await self.stat_file(root, rel)
|
||||
return True
|
||||
except FileNotFoundError:
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def stream_file(self, root: str, rel: str, range_header: str | None):
|
||||
path = _join_remote(root, rel)
|
||||
|
||||
def _get_stat():
|
||||
sftp = self._connect()
|
||||
try:
|
||||
st = sftp.stat(path)
|
||||
return int(st.st_size or 0)
|
||||
finally:
|
||||
try:
|
||||
sftp.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
file_size = await asyncio.to_thread(_get_stat)
|
||||
if file_size is None:
|
||||
raise HTTPException(404, detail="File not found")
|
||||
|
||||
mime, _ = mimetypes.guess_type(rel)
|
||||
content_type = mime or "application/octet-stream"
|
||||
|
||||
start = 0
|
||||
end = file_size - 1
|
||||
status = 200
|
||||
headers = {
|
||||
"Accept-Ranges": "bytes",
|
||||
"Content-Type": content_type,
|
||||
"Content-Length": str(file_size),
|
||||
}
|
||||
|
||||
if range_header and range_header.startswith("bytes="):
|
||||
try:
|
||||
s, e = (range_header.removeprefix("bytes=").split("-", 1))
|
||||
if s.strip():
|
||||
start = int(s)
|
||||
if e.strip():
|
||||
end = int(e)
|
||||
if start >= file_size:
|
||||
raise HTTPException(416, detail="Requested Range Not Satisfiable")
|
||||
if end >= file_size:
|
||||
end = file_size - 1
|
||||
status = 206
|
||||
headers["Content-Length"] = str(end - start + 1)
|
||||
headers["Content-Range"] = f"bytes {start}-{end}/{file_size}"
|
||||
except ValueError:
|
||||
raise HTTPException(400, detail="Invalid Range header")
|
||||
|
||||
queue: asyncio.Queue[Optional[bytes]] = asyncio.Queue(maxsize=8)
|
||||
|
||||
def _worker():
|
||||
sftp = self._connect()
|
||||
try:
|
||||
with sftp.open(path, "rb") as f:
|
||||
f.seek(start)
|
||||
remaining = end - start + 1
|
||||
chunk_size = 64 * 1024
|
||||
while remaining > 0:
|
||||
to_read = chunk_size if remaining > chunk_size else remaining
|
||||
data = f.read(to_read)
|
||||
if not data:
|
||||
break
|
||||
try:
|
||||
queue.put_nowait(data)
|
||||
except Exception:
|
||||
break
|
||||
remaining -= len(data)
|
||||
try:
|
||||
queue.put_nowait(None)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
try:
|
||||
sftp.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def agen():
|
||||
worker_fut = asyncio.to_thread(_worker)
|
||||
try:
|
||||
while True:
|
||||
chunk = await queue.get()
|
||||
if chunk is None:
|
||||
break
|
||||
yield chunk
|
||||
finally:
|
||||
try:
|
||||
await worker_fut
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return StreamingResponse(agen(), status_code=status, headers=headers, media_type=content_type)
|
||||
|
||||
|
||||
ADAPTER_TYPE = "sftp"
|
||||
|
||||
CONFIG_SCHEMA = [
|
||||
{"key": "host", "label": "主机", "type": "string", "required": True, "placeholder": "sftp.example.com"},
|
||||
{"key": "port", "label": "端口", "type": "number", "required": False, "default": 22},
|
||||
{"key": "username", "label": "用户名", "type": "string", "required": True},
|
||||
{"key": "password", "label": "密码", "type": "password", "required": True},
|
||||
{"key": "root", "label": "根路径", "type": "string", "required": True, "placeholder": "/data"},
|
||||
{"key": "timeout", "label": "超时(秒)", "type": "number", "required": False, "default": 15},
|
||||
{"key": "allow_unknown_host", "label": "允许未知主机指纹", "type": "boolean", "required": False, "default": True},
|
||||
]
|
||||
|
||||
|
||||
def ADAPTER_FACTORY(rec: StorageAdapter):
|
||||
return SFTPAdapter(rec)
|
||||
@@ -1,12 +1,13 @@
|
||||
from __future__ import annotations
|
||||
from typing import List, Dict, Tuple, AsyncIterator
|
||||
import io
|
||||
import os
|
||||
from models import StorageAdapter
|
||||
from telethon import TelegramClient
|
||||
from telethon.sessions import StringSession
|
||||
import socks
|
||||
|
||||
# 适配器类型标识
|
||||
ADAPTER_TYPE = "Telegram"
|
||||
ADAPTER_TYPE = "telegram"
|
||||
|
||||
# 适配器配置项定义
|
||||
CONFIG_SCHEMA = [
|
||||
@@ -20,7 +21,7 @@ CONFIG_SCHEMA = [
|
||||
]
|
||||
|
||||
class TelegramAdapter:
|
||||
"""Telegram 存储适配器 (只读, 使用用户 Session)"""
|
||||
"""Telegram 存储适配器 (使用用户 Session)"""
|
||||
|
||||
def __init__(self, record: StorageAdapter):
|
||||
self.record = record
|
||||
@@ -60,7 +61,7 @@ class TelegramAdapter:
|
||||
def get_effective_root(self, sub_path: str | None) -> str:
|
||||
return ""
|
||||
|
||||
async def list_dir(self, root: str, rel: str, page_num: int = 1, page_size: int = 50) -> Tuple[List[Dict], int]:
|
||||
async def list_dir(self, root: str, rel: str, page_num: int = 1, page_size: int = 50, sort_by: str = "name", sort_order: str = "asc") -> Tuple[List[Dict], int]:
|
||||
if rel:
|
||||
return [], 0
|
||||
|
||||
@@ -68,37 +69,72 @@ class TelegramAdapter:
|
||||
entries = []
|
||||
try:
|
||||
await client.connect()
|
||||
messages = await client.get_messages(self.chat_id, limit=50)
|
||||
messages = await client.get_messages(self.chat_id, limit=200)
|
||||
for message in messages:
|
||||
if message and (message.document or message.video):
|
||||
media = message.document or message.video
|
||||
filename = None
|
||||
if hasattr(media, 'attributes'):
|
||||
for attr in media.attributes:
|
||||
if hasattr(attr, 'file_name') and attr.file_name:
|
||||
filename = attr.file_name
|
||||
break
|
||||
|
||||
if not filename:
|
||||
if message.text and '.' in message.text:
|
||||
if len(message.text) < 256 and '\n' not in message.text:
|
||||
filename = message.text
|
||||
if not message:
|
||||
continue
|
||||
|
||||
if not filename:
|
||||
filename = "Unknown"
|
||||
|
||||
entries.append({
|
||||
"name": f"{message.id}_{filename}",
|
||||
"is_dir": False,
|
||||
"size": media.size,
|
||||
"mtime": int(message.date.timestamp()),
|
||||
"type": "file",
|
||||
})
|
||||
media = message.document or message.video or message.photo
|
||||
if not media:
|
||||
continue
|
||||
|
||||
file_meta = message.file
|
||||
if not file_meta:
|
||||
continue
|
||||
|
||||
filename = file_meta.name
|
||||
if not filename:
|
||||
if message.text and '.' in message.text and len(message.text) < 256 and '\n' not in message.text:
|
||||
filename = message.text
|
||||
else:
|
||||
filename = f"unknown_{message.id}"
|
||||
|
||||
size = file_meta.size
|
||||
if size is None:
|
||||
# 兼容缺失 size 的情况
|
||||
if hasattr(media, "size") and media.size is not None:
|
||||
size = media.size
|
||||
elif message.photo and getattr(message.photo, "sizes", None):
|
||||
photo_size = message.photo.sizes[-1]
|
||||
size = getattr(photo_size, "size", 0) or 0
|
||||
else:
|
||||
size = 0
|
||||
|
||||
entries.append({
|
||||
"name": f"{message.id}_{filename}",
|
||||
"is_dir": False,
|
||||
"size": size,
|
||||
"mtime": int(message.date.timestamp()),
|
||||
"type": "file",
|
||||
})
|
||||
finally:
|
||||
if client.is_connected():
|
||||
await client.disconnect()
|
||||
|
||||
return entries, len(entries)
|
||||
# 排序
|
||||
reverse = sort_order.lower() == "desc"
|
||||
def get_sort_key(item):
|
||||
key = (not item["is_dir"],)
|
||||
sort_field = sort_by.lower()
|
||||
if sort_field == "name":
|
||||
key += (item["name"].lower(),)
|
||||
elif sort_field == "size":
|
||||
key += (item["size"],)
|
||||
elif sort_field == "mtime":
|
||||
key += (item["mtime"],)
|
||||
else:
|
||||
key += (item["name"].lower(),)
|
||||
return key
|
||||
entries.sort(key=get_sort_key, reverse=reverse)
|
||||
|
||||
total_count = len(entries)
|
||||
|
||||
# 分页
|
||||
start_idx = (page_num - 1) * page_size
|
||||
end_idx = start_idx + page_size
|
||||
page_entries = entries[start_idx:end_idx]
|
||||
|
||||
return page_entries, total_count
|
||||
|
||||
async def read_file(self, root: str, rel: str) -> bytes:
|
||||
try:
|
||||
@@ -111,7 +147,7 @@ class TelegramAdapter:
|
||||
try:
|
||||
await client.connect()
|
||||
message = await client.get_messages(self.chat_id, ids=message_id)
|
||||
if not message or not (message.document or message.video):
|
||||
if not message or not (message.document or message.video or message.photo):
|
||||
raise FileNotFoundError(f"在频道 {self.chat_id} 中未找到消息ID为 {message_id} 的文件")
|
||||
|
||||
file_bytes = await client.download_media(message, file=bytes)
|
||||
@@ -121,25 +157,73 @@ class TelegramAdapter:
|
||||
await client.disconnect()
|
||||
|
||||
async def write_file(self, root: str, rel: str, data: bytes):
|
||||
raise NotImplementedError("Telegram 适配器是只读的,不支持写入文件。")
|
||||
"""将字节数据作为文件上传"""
|
||||
client = self._get_client()
|
||||
file_like = io.BytesIO(data)
|
||||
file_like.name = os.path.basename(rel) or "file"
|
||||
|
||||
try:
|
||||
await client.connect()
|
||||
await client.send_file(self.chat_id, file_like, caption=file_like.name)
|
||||
finally:
|
||||
if client.is_connected():
|
||||
await client.disconnect()
|
||||
|
||||
async def write_file_stream(self, root: str, rel: str, data_iter: AsyncIterator[bytes]):
|
||||
raise NotImplementedError("Telegram 适配器是只读的,不支持流式写入文件。")
|
||||
"""以流式方式上传文件"""
|
||||
client = self._get_client()
|
||||
filename = os.path.basename(rel) or "file"
|
||||
import tempfile
|
||||
temp_dir = tempfile.gettempdir()
|
||||
temp_path = os.path.join(temp_dir, filename)
|
||||
|
||||
total_size = 0
|
||||
try:
|
||||
with open(temp_path, "wb") as f:
|
||||
async for chunk in data_iter:
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
total_size += len(chunk)
|
||||
|
||||
await client.connect()
|
||||
await client.send_file(self.chat_id, temp_path, caption=filename)
|
||||
|
||||
finally:
|
||||
if os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
if client.is_connected():
|
||||
await client.disconnect()
|
||||
return total_size
|
||||
|
||||
async def mkdir(self, root: str, rel: str):
|
||||
raise NotImplementedError("Telegram 适配器是只读的,不支持创建目录。")
|
||||
raise NotImplementedError("Telegram 适配器不支持创建目录。")
|
||||
|
||||
async def delete(self, root: str, rel: str):
|
||||
raise NotImplementedError("Telegram 适配器是只读的,不支持删除。")
|
||||
"""删除一个文件 (即一条消息)"""
|
||||
try:
|
||||
message_id_str, _ = rel.split('_', 1)
|
||||
message_id = int(message_id_str)
|
||||
except (ValueError, IndexError):
|
||||
raise FileNotFoundError(f"无效的文件路径格式,无法解析消息ID: {rel}")
|
||||
|
||||
client = self._get_client()
|
||||
try:
|
||||
await client.connect()
|
||||
result = await client.delete_messages(self.chat_id, [message_id])
|
||||
if not result or not result[0].pts:
|
||||
raise FileNotFoundError(f"在 {self.chat_id} 中删除消息 {message_id} 失败,可能消息不存在或无权限")
|
||||
finally:
|
||||
if client.is_connected():
|
||||
await client.disconnect()
|
||||
|
||||
async def move(self, root: str, src_rel: str, dst_rel: str):
|
||||
raise NotImplementedError("Telegram 适配器是只读的,不支持移动。")
|
||||
raise NotImplementedError("Telegram 适配器不支持移动。")
|
||||
|
||||
async def rename(self, root: str, src_rel: str, dst_rel: str):
|
||||
raise NotImplementedError("Telegram 适配器是只读的,不支持重命名。")
|
||||
raise NotImplementedError("Telegram 适配器不支持重命名。")
|
||||
|
||||
async def copy(self, root: str, src_rel: str, dst_rel: str, overwrite: bool = False):
|
||||
raise NotImplementedError("Telegram 适配器是只读的,不支持复制。")
|
||||
raise NotImplementedError("Telegram 适配器不支持复制。")
|
||||
|
||||
async def stream_file(self, root: str, rel: str, range_header: str | None):
|
||||
from fastapi.responses import StreamingResponse
|
||||
@@ -156,19 +240,39 @@ class TelegramAdapter:
|
||||
try:
|
||||
await client.connect()
|
||||
message = await client.get_messages(self.chat_id, ids=message_id)
|
||||
if not message or not (message.document or message.video):
|
||||
media = message.document or message.video or message.photo
|
||||
if not message or not media:
|
||||
raise FileNotFoundError(f"在频道 {self.chat_id} 中未找到消息ID为 {message_id} 的文件")
|
||||
|
||||
media = message.document or message.video
|
||||
file_size = media.size
|
||||
|
||||
|
||||
file_meta = message.file
|
||||
file_size = file_meta.size if file_meta and file_meta.size is not None else None
|
||||
if file_size is None:
|
||||
if hasattr(media, "size") and media.size is not None:
|
||||
file_size = media.size
|
||||
elif message.photo and getattr(message.photo, "sizes", None):
|
||||
photo_size = message.photo.sizes[-1]
|
||||
file_size = getattr(photo_size, "size", 0) or 0
|
||||
else:
|
||||
file_size = 0
|
||||
|
||||
mime_type = None
|
||||
if file_meta and getattr(file_meta, "mime_type", None):
|
||||
mime_type = file_meta.mime_type
|
||||
if not mime_type:
|
||||
if hasattr(media, "mime_type") and media.mime_type:
|
||||
mime_type = media.mime_type
|
||||
elif message.photo:
|
||||
mime_type = "image/jpeg"
|
||||
else:
|
||||
mime_type = "application/octet-stream"
|
||||
|
||||
start = 0
|
||||
end = file_size - 1
|
||||
status = 200
|
||||
|
||||
headers = {
|
||||
"Accept-Ranges": "bytes",
|
||||
"Content-Type": media.mime_type or "application/octet-stream",
|
||||
"Content-Type": mime_type,
|
||||
"Content-Length": str(file_size),
|
||||
}
|
||||
|
||||
@@ -225,14 +329,25 @@ class TelegramAdapter:
|
||||
try:
|
||||
await client.connect()
|
||||
message = await client.get_messages(self.chat_id, ids=message_id)
|
||||
if not message or not (message.document or message.video):
|
||||
media = message.document or message.video or message.photo
|
||||
if not message or not media:
|
||||
raise FileNotFoundError(f"在频道 {self.chat_id} 中未找到消息ID为 {message_id} 的文件")
|
||||
|
||||
media = message.document or message.video
|
||||
|
||||
file_meta = message.file
|
||||
size = file_meta.size if file_meta and file_meta.size is not None else None
|
||||
if size is None:
|
||||
if hasattr(media, "size") and media.size is not None:
|
||||
size = media.size
|
||||
elif message.photo and getattr(message.photo, "sizes", None):
|
||||
photo_size = message.photo.sizes[-1]
|
||||
size = getattr(photo_size, "size", 0) or 0
|
||||
else:
|
||||
size = 0
|
||||
|
||||
return {
|
||||
"name": rel,
|
||||
"is_dir": False,
|
||||
"size": media.size,
|
||||
"size": size,
|
||||
"mtime": int(message.date.timestamp()),
|
||||
"type": "file",
|
||||
}
|
||||
@@ -241,4 +356,4 @@ class TelegramAdapter:
|
||||
await client.disconnect()
|
||||
|
||||
def ADAPTER_FACTORY(rec: StorageAdapter) -> TelegramAdapter:
|
||||
return TelegramAdapter(rec)
|
||||
return TelegramAdapter(rec)
|
||||
@@ -1,4 +1,3 @@
|
||||
from __future__ import annotations
|
||||
from typing import List, Dict, Optional, Tuple, AsyncIterator
|
||||
import httpx
|
||||
from urllib.parse import urljoin, quote
|
||||
@@ -9,7 +8,6 @@ import mimetypes
|
||||
import logging
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import StreamingResponse, Response
|
||||
from services.logging import LogService
|
||||
|
||||
NS = {"d": "DAV:"}
|
||||
|
||||
@@ -39,7 +37,7 @@ class WebDAVAdapter:
|
||||
rel = rel.strip('/')
|
||||
return self.base_url if not rel else urljoin(self.base_url, quote(rel) + ('/' if rel.endswith('/') else ''))
|
||||
|
||||
async def list_dir(self, root: str, rel: str, page_num: int = 1, page_size: int = 50) -> Tuple[List[Dict], int]:
|
||||
async def list_dir(self, root: str, rel: str, page_num: int = 1, page_size: int = 50, sort_by: str = "name", sort_order: str = "asc") -> Tuple[List[Dict], int]:
|
||||
raw_url = self._build_url(rel)
|
||||
url = raw_url if raw_url.endswith('/') else raw_url + '/'
|
||||
depth = "1"
|
||||
@@ -92,16 +90,39 @@ class WebDAVAdapter:
|
||||
"d:collection", NS) is not None if rt_el is not None else href_path.endswith('/')
|
||||
size = int(
|
||||
size_el.text) if size_el is not None and size_el.text and size_el.text.isdigit() else 0
|
||||
|
||||
from email.utils import parsedate_to_datetime
|
||||
mtime = 0
|
||||
if lm_el is not None and lm_el.text:
|
||||
try:
|
||||
mtime = int(parsedate_to_datetime(lm_el.text).timestamp())
|
||||
except Exception:
|
||||
mtime = 0
|
||||
|
||||
all_entries.append({
|
||||
"name": name,
|
||||
"is_dir": is_dir,
|
||||
"size": 0 if is_dir else size,
|
||||
"mtime": 0,
|
||||
"mtime": mtime,
|
||||
"type": "dir" if is_dir else "file",
|
||||
})
|
||||
|
||||
# 排序所有条目
|
||||
all_entries.sort(key=lambda x: (not x["is_dir"], x["name"].lower()))
|
||||
reverse = sort_order.lower() == "desc"
|
||||
def get_sort_key(item):
|
||||
key = (not item["is_dir"],)
|
||||
sort_field = sort_by.lower()
|
||||
if sort_field == "name":
|
||||
key += (item["name"].lower(),)
|
||||
elif sort_field == "size":
|
||||
key += (item["size"],)
|
||||
elif sort_field == "mtime":
|
||||
key += (item["mtime"],)
|
||||
else:
|
||||
key += (item["name"].lower(),)
|
||||
return key
|
||||
all_entries.sort(key=get_sort_key, reverse=reverse)
|
||||
|
||||
total_count = len(all_entries)
|
||||
|
||||
# 应用分页
|
||||
@@ -125,15 +146,6 @@ class WebDAVAdapter:
|
||||
async with self._client() as client:
|
||||
resp = await client.put(url, content=data)
|
||||
resp.raise_for_status()
|
||||
await LogService.info(
|
||||
"adapter:webdav",
|
||||
f"Wrote file to {rel}",
|
||||
details={
|
||||
"adapter_id": self.record.id,
|
||||
"url": url,
|
||||
"size": len(data),
|
||||
},
|
||||
)
|
||||
|
||||
async def mkdir(self, root: str, rel: str):
|
||||
url = self._build_url(rel.rstrip('/') + '/')
|
||||
@@ -141,11 +153,6 @@ class WebDAVAdapter:
|
||||
resp = await client.request("MKCOL", url)
|
||||
if resp.status_code not in (201, 405):
|
||||
resp.raise_for_status()
|
||||
await LogService.info(
|
||||
"adapter:webdav",
|
||||
f"Created directory {rel}",
|
||||
details={"adapter_id": self.record.id, "url": url},
|
||||
)
|
||||
|
||||
async def delete(self, root: str, rel: str):
|
||||
url = self._build_url(rel)
|
||||
@@ -153,11 +160,6 @@ class WebDAVAdapter:
|
||||
resp = await client.delete(url)
|
||||
if resp.status_code not in (204, 200, 404):
|
||||
resp.raise_for_status()
|
||||
await LogService.info(
|
||||
"adapter:webdav",
|
||||
f"Deleted {rel}",
|
||||
details={"adapter_id": self.record.id, "url": url},
|
||||
)
|
||||
|
||||
async def move(self, root: str, src_rel: str, dst_rel: str):
|
||||
src_url = self._build_url(src_rel)
|
||||
@@ -165,15 +167,6 @@ class WebDAVAdapter:
|
||||
async with self._client() as client:
|
||||
resp = await client.request("MOVE", src_url, headers={"Destination": dst_url})
|
||||
resp.raise_for_status()
|
||||
await LogService.info(
|
||||
"adapter:webdav",
|
||||
f"Moved {src_rel} to {dst_rel}",
|
||||
details={
|
||||
"adapter_id": self.record.id,
|
||||
"src_url": src_url,
|
||||
"dst_url": dst_url,
|
||||
},
|
||||
)
|
||||
|
||||
async def rename(self, root: str, src_rel: str, dst_rel: str):
|
||||
src_url = self._build_url(src_rel)
|
||||
@@ -181,15 +174,6 @@ class WebDAVAdapter:
|
||||
async with self._client() as client:
|
||||
resp = await client.request("MOVE", src_url, headers={"Destination": dst_url})
|
||||
resp.raise_for_status()
|
||||
await LogService.info(
|
||||
"adapter:webdav",
|
||||
f"Renamed {src_rel} to {dst_rel}",
|
||||
details={
|
||||
"adapter_id": self.record.id,
|
||||
"src_url": src_url,
|
||||
"dst_url": dst_url,
|
||||
},
|
||||
)
|
||||
|
||||
async def get_file_size(self, root: str, rel: str) -> int:
|
||||
"""获取文件大小"""
|
||||
@@ -432,8 +416,16 @@ class WebDAVAdapter:
|
||||
info["type"] = "dir" if is_dir else "file"
|
||||
if size_el is not None and size_el.text and size_el.text.isdigit():
|
||||
info["size"] = int(size_el.text)
|
||||
elif info["size"] is None:
|
||||
info["size"] = 0
|
||||
if lm_el is not None and lm_el.text:
|
||||
info["mtime"] = lm_el.text
|
||||
from email.utils import parsedate_to_datetime
|
||||
try:
|
||||
info["mtime"] = int(parsedate_to_datetime(lm_el.text).timestamp())
|
||||
except Exception:
|
||||
info["mtime"] = 0
|
||||
elif info["mtime"] is None:
|
||||
info["mtime"] = 0
|
||||
# exif信息
|
||||
exif = None
|
||||
if not info["is_dir"]:
|
||||
@@ -487,15 +479,6 @@ class WebDAVAdapter:
|
||||
if resp.status_code == 404:
|
||||
raise FileNotFoundError(src_rel)
|
||||
resp.raise_for_status()
|
||||
await LogService.info(
|
||||
"adapter:webdav",
|
||||
f"Copied {src_rel} to {dst_rel}",
|
||||
details={
|
||||
"adapter_id": self.record.id,
|
||||
"src_url": src_url,
|
||||
"dst_url": dst_url,
|
||||
},
|
||||
)
|
||||
|
||||
ADAPTER_TYPE = "webdav"
|
||||
CONFIG_SCHEMA = [
|
||||
157
domain/adapters/registry.py
Normal file
157
domain/adapters/registry.py
Normal file
@@ -0,0 +1,157 @@
|
||||
import inspect
|
||||
import pkgutil
|
||||
from importlib import import_module
|
||||
from typing import Callable, Dict
|
||||
|
||||
from models import StorageAdapter
|
||||
from domain.adapters.providers.base import BaseAdapter
|
||||
|
||||
AdapterFactory = Callable[[StorageAdapter], BaseAdapter]
|
||||
|
||||
TYPE_MAP: Dict[str, AdapterFactory] = {}
|
||||
CONFIG_SCHEMAS: Dict[str, list] = {}
|
||||
|
||||
|
||||
def normalize_adapter_type(value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
normalized = str(value).strip().lower()
|
||||
return normalized or None
|
||||
|
||||
|
||||
def discover_adapters():
|
||||
"""扫描 domain.adapters.providers 包, 自动注册适配器类型、工厂与配置 schema。"""
|
||||
from domain.adapters import providers as adapters_pkg
|
||||
|
||||
TYPE_MAP.clear()
|
||||
CONFIG_SCHEMAS.clear()
|
||||
for modinfo in pkgutil.iter_modules(adapters_pkg.__path__):
|
||||
if modinfo.name.startswith("_"):
|
||||
continue
|
||||
full_name = f"{adapters_pkg.__name__}.{modinfo.name}"
|
||||
try:
|
||||
module = import_module(full_name)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
adapter_types = getattr(module, "ADAPTER_TYPES", None)
|
||||
if isinstance(adapter_types, dict):
|
||||
default_schema = getattr(module, "CONFIG_SCHEMA", None)
|
||||
schema_map = getattr(module, "CONFIG_SCHEMA_MAP", None)
|
||||
if not isinstance(schema_map, dict):
|
||||
schema_map = None
|
||||
|
||||
for adapter_type, factory in adapter_types.items():
|
||||
normalized_type = normalize_adapter_type(adapter_type)
|
||||
if not normalized_type:
|
||||
continue
|
||||
if not callable(factory):
|
||||
continue
|
||||
TYPE_MAP[normalized_type] = factory
|
||||
|
||||
schema = schema_map.get(normalized_type) if schema_map else default_schema
|
||||
if isinstance(schema, list):
|
||||
CONFIG_SCHEMAS[normalized_type] = schema
|
||||
continue
|
||||
|
||||
adapter_type = normalize_adapter_type(getattr(module, "ADAPTER_TYPE", None))
|
||||
schema = getattr(module, "CONFIG_SCHEMA", None)
|
||||
factory = getattr(module, "ADAPTER_FACTORY", None)
|
||||
|
||||
if not adapter_type:
|
||||
continue
|
||||
|
||||
if factory is None:
|
||||
for attr in module.__dict__.values():
|
||||
if inspect.isclass(attr) and attr.__name__.endswith("Adapter"):
|
||||
def _mk(cls=attr):
|
||||
return lambda rec: cls(rec)
|
||||
factory = _mk()
|
||||
break
|
||||
if not callable(factory):
|
||||
continue
|
||||
|
||||
TYPE_MAP[adapter_type] = factory
|
||||
if isinstance(schema, list):
|
||||
CONFIG_SCHEMAS[adapter_type] = schema
|
||||
|
||||
|
||||
def get_config_schemas() -> Dict[str, list]:
|
||||
return CONFIG_SCHEMAS
|
||||
|
||||
|
||||
def get_config_schema(adapter_type: str):
|
||||
return CONFIG_SCHEMAS.get(adapter_type)
|
||||
|
||||
|
||||
class RuntimeRegistry:
|
||||
def __init__(self):
|
||||
self._instances: Dict[int, BaseAdapter] = {}
|
||||
|
||||
async def refresh(self):
|
||||
discover_adapters()
|
||||
self._instances.clear()
|
||||
adapters = await StorageAdapter.filter(enabled=True)
|
||||
for rec in adapters:
|
||||
normalized_type = normalize_adapter_type(rec.type)
|
||||
if not normalized_type:
|
||||
continue
|
||||
if normalized_type != rec.type:
|
||||
rec.type = normalized_type
|
||||
try:
|
||||
await rec.save(update_fields=["type"])
|
||||
except Exception:
|
||||
continue
|
||||
factory = TYPE_MAP.get(normalized_type)
|
||||
if not factory:
|
||||
continue
|
||||
try:
|
||||
self._instances[rec.id] = factory(rec)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
def get(self, adapter_id: int) -> BaseAdapter | None:
|
||||
return self._instances.get(adapter_id)
|
||||
|
||||
def snapshot(self) -> Dict[int, BaseAdapter]:
|
||||
return dict(self._instances)
|
||||
|
||||
def remove(self, adapter_id: int):
|
||||
"""从缓存中移除一个适配器实例"""
|
||||
if adapter_id in self._instances:
|
||||
del self._instances[adapter_id]
|
||||
|
||||
async def upsert(self, rec: StorageAdapter):
|
||||
"""新增或更新一个适配器实例"""
|
||||
if not rec.enabled:
|
||||
self.remove(rec.id)
|
||||
return
|
||||
|
||||
normalized_type = normalize_adapter_type(rec.type)
|
||||
if not normalized_type:
|
||||
self.remove(rec.id)
|
||||
return
|
||||
if normalized_type != rec.type:
|
||||
rec.type = normalized_type
|
||||
try:
|
||||
await rec.save(update_fields=["type"])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
factory = TYPE_MAP.get(normalized_type)
|
||||
if not factory:
|
||||
discover_adapters()
|
||||
factory = TYPE_MAP.get(normalized_type)
|
||||
if not factory:
|
||||
return
|
||||
|
||||
try:
|
||||
instance = factory(rec)
|
||||
self._instances[rec.id] = instance
|
||||
except Exception:
|
||||
self.remove(rec.id)
|
||||
pass
|
||||
|
||||
|
||||
runtime_registry = RuntimeRegistry()
|
||||
discover_adapters()
|
||||
111
domain/adapters/service.py
Normal file
111
domain/adapters/service.py
Normal file
@@ -0,0 +1,111 @@
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from domain.adapters.registry import (
|
||||
get_config_schemas,
|
||||
normalize_adapter_type,
|
||||
runtime_registry,
|
||||
)
|
||||
from domain.adapters.types import AdapterCreate, AdapterOut
|
||||
from domain.auth.types import User
|
||||
from models import StorageAdapter
|
||||
|
||||
|
||||
class AdapterService:
|
||||
@classmethod
|
||||
def _validate_and_normalize_config(cls, adapter_type: str, cfg):
|
||||
schemas = get_config_schemas()
|
||||
adapter_type = normalize_adapter_type(adapter_type)
|
||||
if not adapter_type:
|
||||
raise HTTPException(400, detail="不支持的适配器类型")
|
||||
if not isinstance(cfg, dict):
|
||||
raise HTTPException(400, detail="config 必须是对象")
|
||||
schema = schemas.get(adapter_type)
|
||||
if not schema:
|
||||
raise HTTPException(400, detail=f"不支持的适配器类型: {adapter_type}")
|
||||
out = {}
|
||||
missing = []
|
||||
for f in schema:
|
||||
k = f["key"]
|
||||
if k in cfg and cfg[k] not in (None, ""):
|
||||
out[k] = cfg[k]
|
||||
elif "default" in f:
|
||||
out[k] = f["default"]
|
||||
elif f.get("required"):
|
||||
missing.append(k)
|
||||
if missing:
|
||||
raise HTTPException(400, detail="缺少必填配置字段: " + ", ".join(missing))
|
||||
return out
|
||||
|
||||
@classmethod
|
||||
async def create_adapter(cls, data: AdapterCreate, current_user: Optional[User]):
|
||||
norm_path = AdapterCreate.normalize_mount_path(data.path)
|
||||
exists = await StorageAdapter.get_or_none(path=norm_path)
|
||||
if exists:
|
||||
raise HTTPException(400, detail="Mount path already exists")
|
||||
|
||||
adapter_fields = {
|
||||
"name": data.name,
|
||||
"type": data.type,
|
||||
"config": cls._validate_and_normalize_config(data.type, data.config or {}),
|
||||
"enabled": data.enabled,
|
||||
"path": norm_path,
|
||||
"sub_path": data.sub_path,
|
||||
}
|
||||
|
||||
rec = await StorageAdapter.create(**adapter_fields)
|
||||
await runtime_registry.upsert(rec)
|
||||
return AdapterOut.model_validate(rec)
|
||||
|
||||
@classmethod
|
||||
async def list_adapters(cls):
|
||||
adapters = await StorageAdapter.all()
|
||||
return [AdapterOut.model_validate(a) for a in adapters]
|
||||
|
||||
@classmethod
|
||||
async def available_adapter_types(cls):
|
||||
data = []
|
||||
for adapter_type, fields in get_config_schemas().items():
|
||||
data.append({
|
||||
"type": adapter_type,
|
||||
"config_schema": fields,
|
||||
})
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
async def get_adapter(cls, adapter_id: int):
|
||||
rec = await StorageAdapter.get_or_none(id=adapter_id)
|
||||
if not rec:
|
||||
raise HTTPException(404, detail="Not found")
|
||||
return AdapterOut.model_validate(rec)
|
||||
|
||||
@classmethod
|
||||
async def update_adapter(cls, adapter_id: int, data: AdapterCreate, current_user: Optional[User]):
|
||||
rec = await StorageAdapter.get_or_none(id=adapter_id)
|
||||
if not rec:
|
||||
raise HTTPException(404, detail="Not found")
|
||||
|
||||
norm_path = AdapterCreate.normalize_mount_path(data.path)
|
||||
existing = await StorageAdapter.get_or_none(path=norm_path)
|
||||
if existing and existing.id != adapter_id:
|
||||
raise HTTPException(400, detail="Mount path already exists")
|
||||
|
||||
rec.name = data.name
|
||||
rec.type = data.type
|
||||
rec.config = cls._validate_and_normalize_config(data.type, data.config or {})
|
||||
rec.enabled = data.enabled
|
||||
rec.path = norm_path
|
||||
rec.sub_path = data.sub_path
|
||||
await rec.save()
|
||||
|
||||
await runtime_registry.upsert(rec)
|
||||
return AdapterOut.model_validate(rec)
|
||||
|
||||
@classmethod
|
||||
async def delete_adapter(cls, adapter_id: int, current_user: Optional[User]):
|
||||
deleted = await StorageAdapter.filter(id=adapter_id).delete()
|
||||
if not deleted:
|
||||
raise HTTPException(404, detail="Not found")
|
||||
runtime_registry.remove(adapter_id)
|
||||
return {"deleted": True}
|
||||
@@ -1,15 +1,29 @@
|
||||
import re
|
||||
from typing import Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
class AdapterBase(BaseModel):
|
||||
name: str
|
||||
type: str = Field(pattern=r"^[a-zA-Z0-9_]+$")
|
||||
type: str = Field(pattern=r"^[a-z0-9_]+$")
|
||||
config: Dict = Field(default_factory=dict)
|
||||
enabled: bool = True
|
||||
path: str = None
|
||||
sub_path: Optional[str] = None
|
||||
|
||||
@field_validator("type", mode="before")
|
||||
@classmethod
|
||||
def _normalize_type(cls, v: str):
|
||||
if not isinstance(v, str):
|
||||
raise ValueError("type required")
|
||||
normalized = v.strip().lower()
|
||||
if not normalized:
|
||||
raise ValueError("type required")
|
||||
if not re.fullmatch(r"[a-z0-9_]+", normalized):
|
||||
raise ValueError("type must be lowercase alphanumeric or underscore")
|
||||
return normalized
|
||||
|
||||
|
||||
class AdapterCreate(AdapterBase):
|
||||
@staticmethod
|
||||
34
domain/ai/__init__.py
Normal file
34
domain/ai/__init__.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from .api import router_ai, router_vector_db
|
||||
from .service import (
|
||||
AIProviderService,
|
||||
VectorDBConfigManager,
|
||||
VectorDBService,
|
||||
DEFAULT_VECTOR_DIMENSION,
|
||||
ABILITIES,
|
||||
normalize_capabilities,
|
||||
)
|
||||
from .types import (
|
||||
AIDefaultsUpdate,
|
||||
AIModelCreate,
|
||||
AIModelUpdate,
|
||||
AIProviderCreate,
|
||||
AIProviderUpdate,
|
||||
VectorDBConfigPayload,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"router_ai",
|
||||
"router_vector_db",
|
||||
"AIProviderService",
|
||||
"VectorDBService",
|
||||
"VectorDBConfigManager",
|
||||
"DEFAULT_VECTOR_DIMENSION",
|
||||
"ABILITIES",
|
||||
"normalize_capabilities",
|
||||
"AIDefaultsUpdate",
|
||||
"AIModelCreate",
|
||||
"AIModelUpdate",
|
||||
"AIProviderCreate",
|
||||
"AIProviderUpdate",
|
||||
"VectorDBConfigPayload",
|
||||
]
|
||||
305
domain/ai/api.py
Normal file
305
domain/ai/api.py
Normal file
@@ -0,0 +1,305 @@
|
||||
from typing import Annotated, Dict, Optional
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Depends, HTTPException, Path, Request
|
||||
|
||||
from api.response import success
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.ai.service import AIProviderService, VectorDBConfigManager, VectorDBService
|
||||
from domain.ai.types import (
|
||||
AIDefaultsUpdate,
|
||||
AIModelCreate,
|
||||
AIModelUpdate,
|
||||
AIProviderCreate,
|
||||
AIProviderUpdate,
|
||||
VectorDBConfigPayload,
|
||||
)
|
||||
from domain.ai.vector_providers import get_provider_class, get_provider_entry, list_providers
|
||||
from domain.auth.service import get_current_active_user
|
||||
from domain.auth.types import User
|
||||
|
||||
router_ai = APIRouter(prefix="/api/ai", tags=["ai"])
|
||||
router_vector_db = APIRouter(prefix="/api/vector-db", tags=["vector-db"])
|
||||
|
||||
|
||||
@audit(action=AuditAction.READ, description="获取 AI 提供商列表")
|
||||
@router_ai.get("/providers")
|
||||
async def list_providers_endpoint(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
providers = await AIProviderService.list_providers()
|
||||
return success({"providers": providers})
|
||||
|
||||
|
||||
@audit(
|
||||
action=AuditAction.CREATE,
|
||||
description="创建 AI 提供商",
|
||||
body_fields=["name", "identifier", "provider_type", "api_format", "base_url", "logo_url"],
|
||||
redact_fields=["api_key"],
|
||||
)
|
||||
@router_ai.post("/providers")
|
||||
async def create_provider(
|
||||
request: Request,
|
||||
payload: AIProviderCreate,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
provider = await AIProviderService.create_provider(payload.dict())
|
||||
return success(provider)
|
||||
|
||||
|
||||
@audit(action=AuditAction.READ, description="获取 AI 提供商详情")
|
||||
@router_ai.get("/providers/{provider_id}")
|
||||
async def get_provider(
|
||||
request: Request,
|
||||
provider_id: Annotated[int, Path(..., gt=0)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
provider = await AIProviderService.get_provider(provider_id, with_models=True)
|
||||
return success(provider)
|
||||
|
||||
|
||||
@audit(
|
||||
action=AuditAction.UPDATE,
|
||||
description="更新 AI 提供商",
|
||||
body_fields=["name", "provider_type", "api_format", "base_url", "logo_url", "api_key"],
|
||||
redact_fields=["api_key"],
|
||||
)
|
||||
@router_ai.put("/providers/{provider_id}")
|
||||
async def update_provider(
|
||||
request: Request,
|
||||
provider_id: Annotated[int, Path(..., gt=0)],
|
||||
payload: AIProviderUpdate,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
data = {k: v for k, v in payload.dict().items() if v is not None}
|
||||
if not data:
|
||||
raise HTTPException(status_code=400, detail="No fields to update")
|
||||
provider = await AIProviderService.update_provider(provider_id, data)
|
||||
return success(provider)
|
||||
|
||||
|
||||
@audit(action=AuditAction.DELETE, description="删除 AI 提供商")
|
||||
@router_ai.delete("/providers/{provider_id}")
|
||||
async def delete_provider(
|
||||
request: Request,
|
||||
provider_id: Annotated[int, Path(..., gt=0)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
await AIProviderService.delete_provider(provider_id)
|
||||
return success({"id": provider_id})
|
||||
|
||||
|
||||
@audit(action=AuditAction.UPDATE, description="同步模型列表")
|
||||
@router_ai.post("/providers/{provider_id}/sync-models")
|
||||
async def sync_models(
|
||||
request: Request,
|
||||
provider_id: Annotated[int, Path(..., gt=0)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
try:
|
||||
result = await AIProviderService.sync_models(provider_id)
|
||||
except (httpx.RequestError, httpx.HTTPStatusError) as exc:
|
||||
raise HTTPException(status_code=502, detail=f"Failed to synchronize models: {exc}") from exc
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
return success(result)
|
||||
|
||||
|
||||
@audit(action=AuditAction.READ, description="获取远程模型列表")
|
||||
@router_ai.get("/providers/{provider_id}/remote-models")
|
||||
async def fetch_remote_models(
|
||||
request: Request,
|
||||
provider_id: Annotated[int, Path(..., gt=0)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
try:
|
||||
models = await AIProviderService.fetch_remote_models(provider_id)
|
||||
except (httpx.RequestError, httpx.HTTPStatusError) as exc:
|
||||
raise HTTPException(status_code=502, detail=f"Failed to pull models: {exc}") from exc
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
return success({"models": models})
|
||||
|
||||
|
||||
@audit(action=AuditAction.READ, description="获取模型列表")
|
||||
@router_ai.get("/providers/{provider_id}/models")
|
||||
async def list_models(
|
||||
request: Request,
|
||||
provider_id: Annotated[int, Path(..., gt=0)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
models = await AIProviderService.list_models(provider_id)
|
||||
return success({"models": models})
|
||||
|
||||
|
||||
@audit(
|
||||
action=AuditAction.CREATE,
|
||||
description="创建模型",
|
||||
body_fields=["name", "display_name", "capabilities", "context_window", "embedding_dimensions"],
|
||||
)
|
||||
@router_ai.post("/providers/{provider_id}/models")
|
||||
async def create_model(
|
||||
request: Request,
|
||||
provider_id: Annotated[int, Path(..., gt=0)],
|
||||
payload: AIModelCreate,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
model = await AIProviderService.create_model(provider_id, payload.dict())
|
||||
return success(model)
|
||||
|
||||
|
||||
@audit(
|
||||
action=AuditAction.UPDATE,
|
||||
description="更新模型",
|
||||
body_fields=["display_name", "description", "capabilities", "context_window", "embedding_dimensions"],
|
||||
)
|
||||
@router_ai.put("/models/{model_id}")
|
||||
async def update_model(
|
||||
request: Request,
|
||||
model_id: Annotated[int, Path(..., gt=0)],
|
||||
payload: AIModelUpdate,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
data = {k: v for k, v in payload.dict().items() if v is not None}
|
||||
if not data:
|
||||
raise HTTPException(status_code=400, detail="No fields to update")
|
||||
model = await AIProviderService.update_model(model_id, data)
|
||||
return success(model)
|
||||
|
||||
|
||||
@audit(action=AuditAction.DELETE, description="删除模型")
|
||||
@router_ai.delete("/models/{model_id}")
|
||||
async def delete_model(
|
||||
request: Request,
|
||||
model_id: Annotated[int, Path(..., gt=0)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
await AIProviderService.delete_model(model_id)
|
||||
return success({"id": model_id})
|
||||
|
||||
|
||||
def _get_embedding_dimension(entry: Optional[Dict]) -> Optional[int]:
|
||||
if not entry:
|
||||
return None
|
||||
value = entry.get("embedding_dimensions")
|
||||
return int(value) if value is not None else None
|
||||
|
||||
|
||||
@audit(action=AuditAction.READ, description="获取默认模型")
|
||||
@router_ai.get("/defaults")
|
||||
async def get_defaults(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
defaults = await AIProviderService.get_default_models()
|
||||
return success(defaults)
|
||||
|
||||
|
||||
@audit(
|
||||
action=AuditAction.UPDATE,
|
||||
description="更新默认模型",
|
||||
body_fields=["chat", "vision", "embedding", "rerank", "voice", "tools"],
|
||||
)
|
||||
@router_ai.put("/defaults")
|
||||
async def update_defaults(
|
||||
request: Request,
|
||||
payload: AIDefaultsUpdate,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
previous = await AIProviderService.get_default_models()
|
||||
try:
|
||||
updated = await AIProviderService.set_default_models(payload.as_mapping())
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
prev_dim = _get_embedding_dimension(previous.get("embedding"))
|
||||
next_dim = _get_embedding_dimension(updated.get("embedding"))
|
||||
|
||||
if prev_dim and next_dim and prev_dim != next_dim:
|
||||
try:
|
||||
await VectorDBService().clear_all_data()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
raise HTTPException(status_code=500, detail=f"Failed to clear vector database: {exc}") from exc
|
||||
|
||||
return success(updated)
|
||||
|
||||
|
||||
@audit(action=AuditAction.UPDATE, description="清空向量数据库")
|
||||
@router_vector_db.post("/clear-all", summary="清空向量数据库")
|
||||
async def clear_vector_db(request: Request, user: User = Depends(get_current_active_user)):
|
||||
try:
|
||||
service = VectorDBService()
|
||||
await service.clear_all_data()
|
||||
return success(msg="向量数据库已清空")
|
||||
except Exception as e: # noqa: BLE001
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@audit(action=AuditAction.READ, description="获取向量数据库统计")
|
||||
@router_vector_db.get("/stats", summary="获取向量数据库统计")
|
||||
async def get_vector_db_stats(request: Request, user: User = Depends(get_current_active_user)):
|
||||
try:
|
||||
service = VectorDBService()
|
||||
data = await service.get_all_stats()
|
||||
return success(data=data)
|
||||
except Exception as e: # noqa: BLE001
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@audit(action=AuditAction.READ, description="获取向量数据库提供者列表")
|
||||
@router_vector_db.get("/providers", summary="列出可用向量数据库提供者")
|
||||
async def list_vector_providers(request: Request, user: User = Depends(get_current_active_user)):
|
||||
return success(list_providers())
|
||||
|
||||
|
||||
@audit(action=AuditAction.READ, description="获取向量数据库配置")
|
||||
@router_vector_db.get("/config", summary="获取当前向量数据库配置")
|
||||
async def get_vector_db_config(request: Request, user: User = Depends(get_current_active_user)):
|
||||
service = VectorDBService()
|
||||
data = await service.current_provider()
|
||||
return success(data)
|
||||
|
||||
|
||||
@audit(action=AuditAction.UPDATE, description="更新向量数据库配置", body_fields=["type"])
|
||||
@router_vector_db.post("/config", summary="更新向量数据库配置")
|
||||
async def update_vector_db_config(
|
||||
request: Request, payload: VectorDBConfigPayload, user: User = Depends(get_current_active_user)
|
||||
):
|
||||
entry = get_provider_entry(payload.type)
|
||||
if not entry:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"未知的向量数据库类型: {payload.type}")
|
||||
if not entry.get("enabled", True):
|
||||
raise HTTPException(status_code=400, detail="该向量数据库类型暂不可用")
|
||||
|
||||
provider_cls = get_provider_class(payload.type)
|
||||
if not provider_cls:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"未找到类型 {payload.type} 对应的实现")
|
||||
|
||||
test_provider = provider_cls(payload.config)
|
||||
try:
|
||||
await test_provider.initialize()
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc))
|
||||
finally:
|
||||
client = getattr(test_provider, "client", None)
|
||||
close_fn = getattr(client, "close", None)
|
||||
if callable(close_fn):
|
||||
try:
|
||||
close_fn()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await VectorDBConfigManager.save_config(payload.type, payload.config)
|
||||
service = VectorDBService()
|
||||
await service.reload()
|
||||
config_data = await service.current_provider()
|
||||
stats = await service.get_all_stats()
|
||||
return success({"config": config_data, "stats": stats})
|
||||
|
||||
|
||||
__all__ = ["router_ai", "router_vector_db"]
|
||||
245
domain/ai/inference.py
Normal file
245
domain/ai/inference.py
Normal file
@@ -0,0 +1,245 @@
|
||||
import httpx
|
||||
from typing import List, Sequence, Tuple
|
||||
|
||||
from models.database import AIModel, AIProvider
|
||||
from domain.ai.service import AIProviderService
|
||||
|
||||
|
||||
provider_service = AIProviderService
|
||||
|
||||
|
||||
class MissingModelError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
async def describe_image_base64(base64_image: str, detail: str = "high") -> str:
|
||||
"""
|
||||
传入 base64 图片并返回描述文本。缺省时返回错误提示。
|
||||
"""
|
||||
try:
|
||||
model, provider = await _require_model("vision")
|
||||
if provider.api_format == "openai":
|
||||
return await _describe_with_openai(provider, model, base64_image, detail)
|
||||
return await _describe_with_gemini(provider, model, base64_image, detail)
|
||||
except MissingModelError as exc:
|
||||
return str(exc)
|
||||
except httpx.ReadTimeout:
|
||||
return "请求超时,请稍后重试。"
|
||||
except Exception as exc: # noqa: BLE001
|
||||
return f"请求失败: {exc}"
|
||||
|
||||
|
||||
async def get_text_embedding(text: str) -> List[float]:
|
||||
"""
|
||||
传入文本,返回嵌入向量。若未配置模型则抛出异常。
|
||||
"""
|
||||
model, provider = await _require_model("embedding")
|
||||
if provider.api_format == "openai":
|
||||
return await _embedding_with_openai(provider, model, text)
|
||||
return await _embedding_with_gemini(provider, model, text)
|
||||
|
||||
|
||||
async def rerank_texts(query: str, documents: Sequence[str]) -> List[float]:
|
||||
"""调用重排序模型,为一组文档返回得分。未配置时返回空列表。"""
|
||||
if not documents:
|
||||
return []
|
||||
try:
|
||||
model, provider = await _require_model("rerank")
|
||||
except MissingModelError:
|
||||
return []
|
||||
|
||||
try:
|
||||
if provider.api_format == "openai":
|
||||
return await _rerank_with_openai(provider, model, query, documents)
|
||||
return await _rerank_with_gemini(provider, model, query, documents)
|
||||
except Exception: # noqa: BLE001
|
||||
return []
|
||||
|
||||
|
||||
async def _require_model(ability: str) -> Tuple[AIModel, AIProvider]:
|
||||
model = await provider_service.get_default_model(ability)
|
||||
if not model:
|
||||
raise MissingModelError(f"未配置默认 {ability} 模型,请前往系统设置完成配置。")
|
||||
provider = getattr(model, "provider", None)
|
||||
if provider is None:
|
||||
await model.fetch_related("provider")
|
||||
provider = model.provider
|
||||
if provider is None:
|
||||
raise MissingModelError("模型缺少关联的提供商配置。")
|
||||
if not provider.base_url:
|
||||
raise MissingModelError("该提供商未设置 API 地址。")
|
||||
return model, provider
|
||||
|
||||
|
||||
def _openai_endpoint(provider: AIProvider, path: str) -> str:
|
||||
base = (provider.base_url or "").rstrip("/")
|
||||
if not base:
|
||||
raise MissingModelError("提供商 API 地址未配置。")
|
||||
return f"{base}/{path.lstrip('/')}"
|
||||
|
||||
|
||||
def _openai_headers(provider: AIProvider) -> dict:
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if provider.api_key:
|
||||
headers["Authorization"] = f"Bearer {provider.api_key}"
|
||||
return headers
|
||||
|
||||
|
||||
def _gemini_endpoint(provider: AIProvider, path: str) -> str:
|
||||
base = (provider.base_url or "").rstrip("/")
|
||||
if not base:
|
||||
raise MissingModelError("提供商 API 地址未配置。")
|
||||
url = f"{base}/{path.lstrip('/')}"
|
||||
if provider.api_key:
|
||||
connector = "&" if "?" in url else "?"
|
||||
url = f"{url}{connector}key={provider.api_key}"
|
||||
return url
|
||||
|
||||
|
||||
async def _describe_with_openai(provider: AIProvider, model: AIModel, base64_image: str, detail: str) -> str:
|
||||
url = _openai_endpoint(provider, "/chat/completions")
|
||||
payload = {
|
||||
"model": model.name,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{base64_image}",
|
||||
"detail": detail,
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": "描述这个图片"},
|
||||
],
|
||||
}
|
||||
],
|
||||
}
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
response = await client.post(url, headers=_openai_headers(provider), json=payload)
|
||||
response.raise_for_status()
|
||||
body = response.json()
|
||||
return body["choices"][0]["message"]["content"]
|
||||
|
||||
|
||||
async def _describe_with_gemini(provider: AIProvider, model: AIModel, base64_image: str, detail: str) -> str:
|
||||
detail_text = f"描述这个图片,细节等级:{detail}"
|
||||
model_name = model.name if model.name.startswith("models/") else f"models/{model.name}"
|
||||
url = _gemini_endpoint(provider, f"{model_name}:generateContent")
|
||||
payload = {
|
||||
"contents": [
|
||||
{
|
||||
"role": "user",
|
||||
"parts": [
|
||||
{
|
||||
"inline_data": {
|
||||
"mime_type": "image/jpeg",
|
||||
"data": base64_image,
|
||||
}
|
||||
},
|
||||
{"text": detail_text},
|
||||
],
|
||||
}
|
||||
]
|
||||
}
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
response = await client.post(url, json=payload)
|
||||
response.raise_for_status()
|
||||
body = response.json()
|
||||
candidates = body.get("candidates") or []
|
||||
if not candidates:
|
||||
return ""
|
||||
parts = candidates[0].get("content", {}).get("parts", [])
|
||||
text_parts = [part.get("text") for part in parts if isinstance(part, dict) and part.get("text")]
|
||||
return "\n".join(text_parts)
|
||||
|
||||
|
||||
async def _embedding_with_openai(provider: AIProvider, model: AIModel, text: str) -> List[float]:
|
||||
url = _openai_endpoint(provider, "/embeddings")
|
||||
payload = {
|
||||
"model": model.name,
|
||||
"input": text,
|
||||
}
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.post(url, headers=_openai_headers(provider), json=payload)
|
||||
response.raise_for_status()
|
||||
body = response.json()
|
||||
return body["data"][0]["embedding"]
|
||||
|
||||
|
||||
async def _embedding_with_gemini(provider: AIProvider, model: AIModel, text: str) -> List[float]:
|
||||
model_name = model.name if model.name.startswith("models/") else f"models/{model.name}"
|
||||
url = _gemini_endpoint(provider, f"{model_name}:embedContent")
|
||||
payload = {
|
||||
"model": model_name,
|
||||
"content": {
|
||||
"parts": [{"text": text}],
|
||||
},
|
||||
}
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.post(url, json=payload)
|
||||
response.raise_for_status()
|
||||
body = response.json()
|
||||
embedding = body.get("embedding") or {}
|
||||
return embedding.get("values") or []
|
||||
|
||||
|
||||
async def _rerank_with_openai(
|
||||
provider: AIProvider,
|
||||
model: AIModel,
|
||||
query: str,
|
||||
documents: Sequence[str],
|
||||
) -> List[float]:
|
||||
url = _openai_endpoint(provider, "/rerank")
|
||||
payload = {
|
||||
"model": model.name,
|
||||
"query": query,
|
||||
"documents": [
|
||||
{"id": str(idx), "text": content}
|
||||
for idx, content in enumerate(documents)
|
||||
],
|
||||
}
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.post(url, headers=_openai_headers(provider), json=payload)
|
||||
response.raise_for_status()
|
||||
body = response.json()
|
||||
results = body.get("results") or body.get("data") or []
|
||||
scores: List[float] = []
|
||||
for item in results:
|
||||
try:
|
||||
scores.append(float(item.get("score", 0.0)))
|
||||
except (TypeError, ValueError):
|
||||
scores.append(0.0)
|
||||
return scores
|
||||
|
||||
|
||||
async def _rerank_with_gemini(
|
||||
provider: AIProvider,
|
||||
model: AIModel,
|
||||
query: str,
|
||||
documents: Sequence[str],
|
||||
) -> List[float]:
|
||||
model_name = model.name if model.name.startswith("models/") else f"models/{model.name}"
|
||||
url = _gemini_endpoint(provider, f"{model_name}:rankContent")
|
||||
payload = {
|
||||
"query": {"text": query},
|
||||
"documents": [
|
||||
{"id": str(idx), "content": {"parts": [{"text": content}]}}
|
||||
for idx, content in enumerate(documents)
|
||||
],
|
||||
}
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.post(url, json=payload)
|
||||
response.raise_for_status()
|
||||
body = response.json()
|
||||
|
||||
scores: List[float] = []
|
||||
ranked = body.get("rankedDocuments") or body.get("results") or []
|
||||
for item in ranked:
|
||||
raw_score = item.get("relevanceScore") or item.get("score") or item.get("confidenceScore")
|
||||
try:
|
||||
scores.append(float(raw_score))
|
||||
except (TypeError, ValueError):
|
||||
scores.append(0.0)
|
||||
return scores
|
||||
499
domain/ai/service.py
Normal file
499
domain/ai/service.py
Normal file
@@ -0,0 +1,499 @@
|
||||
import asyncio
|
||||
import json
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import httpx
|
||||
from tortoise.exceptions import DoesNotExist
|
||||
from tortoise.transactions import in_transaction
|
||||
|
||||
from domain.config.service import ConfigService
|
||||
from models.database import AIDefaultModel, AIModel, AIProvider
|
||||
|
||||
from .types import ABILITIES, normalize_capabilities
|
||||
from .vector_providers import (
|
||||
BaseVectorProvider,
|
||||
get_provider_class,
|
||||
get_provider_entry,
|
||||
list_providers,
|
||||
)
|
||||
|
||||
DEFAULT_VECTOR_DIMENSION = 4096
|
||||
|
||||
OPENAI_EMBEDDING_DIMS = {
|
||||
"text-embedding-3-large": 3072,
|
||||
"text-embedding-3-small": 1536,
|
||||
"text-embedding-ada-002": 1536,
|
||||
}
|
||||
|
||||
|
||||
class VectorDBConfigManager:
|
||||
TYPE_KEY = "VECTOR_DB_TYPE"
|
||||
CONFIG_KEY = "VECTOR_DB_CONFIG"
|
||||
DEFAULT_TYPE = "milvus_lite"
|
||||
|
||||
@classmethod
|
||||
async def load_config(cls) -> Tuple[str, Dict[str, Any]]:
|
||||
raw_type = await ConfigService.get(cls.TYPE_KEY, cls.DEFAULT_TYPE)
|
||||
provider_type = str(raw_type or cls.DEFAULT_TYPE)
|
||||
|
||||
raw_config = await ConfigService.get(cls.CONFIG_KEY)
|
||||
config_dict: Dict[str, Any] = {}
|
||||
if isinstance(raw_config, str) and raw_config:
|
||||
try:
|
||||
config_dict = json.loads(raw_config)
|
||||
except json.JSONDecodeError:
|
||||
config_dict = {}
|
||||
elif isinstance(raw_config, dict):
|
||||
config_dict = raw_config
|
||||
return provider_type, config_dict
|
||||
|
||||
@classmethod
|
||||
async def save_config(cls, provider_type: str, config: Dict[str, Any]) -> None:
|
||||
await ConfigService.set(cls.TYPE_KEY, provider_type)
|
||||
await ConfigService.set(cls.CONFIG_KEY, json.dumps(config or {}))
|
||||
|
||||
@classmethod
|
||||
async def get_type(cls) -> str:
|
||||
provider_type, _ = await cls.load_config()
|
||||
return provider_type
|
||||
|
||||
@classmethod
|
||||
async def get_config(cls) -> Dict[str, Any]:
|
||||
_, config = await cls.load_config()
|
||||
return config
|
||||
|
||||
|
||||
def _normalize_embedding_dim(value: Any) -> Optional[int]:
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
casted = int(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
return casted if casted > 0 else None
|
||||
|
||||
|
||||
def _apply_embedding_dim_to_metadata(
|
||||
data: Dict[str, Any],
|
||||
embedding_dim: Optional[int],
|
||||
base_metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
source = base_metadata if isinstance(base_metadata, dict) else {}
|
||||
metadata: Dict[str, Any] = dict(source)
|
||||
override = data.get("metadata")
|
||||
if isinstance(override, dict) and override:
|
||||
metadata.update(override)
|
||||
if embedding_dim is None:
|
||||
metadata.pop("embedding_dimensions", None)
|
||||
else:
|
||||
metadata["embedding_dimensions"] = embedding_dim
|
||||
data["metadata"] = metadata or None
|
||||
return data
|
||||
|
||||
|
||||
def infer_openai_capabilities(model_id: str) -> Tuple[List[str], Optional[int]]:
|
||||
lower = model_id.lower()
|
||||
caps = set()
|
||||
|
||||
if any(keyword in lower for keyword in ["gpt", "chat", "turbo", "o1", "sonnet", "haiku", "thinking"]):
|
||||
caps.update({"chat", "tools"})
|
||||
|
||||
if any(keyword in lower for keyword in ["vision", "gpt-4o", "gpt-4.1", "o1", "vision-preview", "omni"]):
|
||||
caps.add("vision")
|
||||
|
||||
if any(keyword in lower for keyword in ["embed", "embedding"]):
|
||||
caps.add("embedding")
|
||||
|
||||
if "rerank" in lower or "re-rank" in lower:
|
||||
caps.add("rerank")
|
||||
|
||||
if any(keyword in lower for keyword in ["tts", "speech", "audio"]):
|
||||
caps.add("voice")
|
||||
|
||||
embedding_dim = OPENAI_EMBEDDING_DIMS.get(model_id)
|
||||
return normalize_capabilities(caps), embedding_dim
|
||||
|
||||
|
||||
def infer_gemini_capabilities(methods: Iterable[str]) -> List[str]:
|
||||
caps = set()
|
||||
for method in methods:
|
||||
m = method.lower()
|
||||
if m in {"generatecontent", "counttokens"}:
|
||||
caps.update({"chat", "tools", "vision"})
|
||||
if m == "embedcontent":
|
||||
caps.add("embedding")
|
||||
if m in {"generatespeech", "audiogeneration"}:
|
||||
caps.add("voice")
|
||||
if m == "rerank":
|
||||
caps.add("rerank")
|
||||
return normalize_capabilities(caps)
|
||||
|
||||
|
||||
def serialize_provider(provider: AIProvider) -> Dict[str, Any]:
|
||||
return {
|
||||
"id": provider.id,
|
||||
"name": provider.name,
|
||||
"identifier": provider.identifier,
|
||||
"provider_type": provider.provider_type,
|
||||
"api_format": provider.api_format,
|
||||
"base_url": provider.base_url,
|
||||
"api_key": provider.api_key,
|
||||
"logo_url": provider.logo_url,
|
||||
"extra_config": provider.extra_config or {},
|
||||
"created_at": provider.created_at,
|
||||
"updated_at": provider.updated_at,
|
||||
}
|
||||
|
||||
|
||||
def model_to_dict(model: AIModel, provider: Optional[AIProvider] = None) -> Dict[str, Any]:
|
||||
provider_obj = provider or getattr(model, "provider", None)
|
||||
provider_data = serialize_provider(provider_obj) if provider_obj else None
|
||||
return {
|
||||
"id": model.id,
|
||||
"provider_id": model.provider_id,
|
||||
"name": model.name,
|
||||
"display_name": model.display_name,
|
||||
"description": model.description,
|
||||
"capabilities": normalize_capabilities(model.capabilities),
|
||||
"context_window": model.context_window,
|
||||
"embedding_dimensions": model.embedding_dimensions,
|
||||
"metadata": model.metadata or {},
|
||||
"created_at": model.created_at,
|
||||
"updated_at": model.updated_at,
|
||||
"provider": provider_data,
|
||||
}
|
||||
|
||||
|
||||
def provider_to_dict(provider: AIProvider, models: Optional[List[AIModel]] = None) -> Dict[str, Any]:
|
||||
data = serialize_provider(provider)
|
||||
if models is not None:
|
||||
data["models"] = [model_to_dict(m, provider=provider) for m in models]
|
||||
return data
|
||||
|
||||
|
||||
class AIProviderService:
|
||||
@classmethod
|
||||
async def list_providers(cls) -> List[Dict[str, Any]]:
|
||||
providers = await AIProvider.all().order_by("id").prefetch_related("models")
|
||||
return [provider_to_dict(p, models=list(p.models)) for p in providers]
|
||||
|
||||
@classmethod
|
||||
async def get_provider(cls, provider_id: int, with_models: bool = False) -> Dict[str, Any]:
|
||||
if with_models:
|
||||
provider = await AIProvider.get(id=provider_id)
|
||||
models = await provider.models.all()
|
||||
return provider_to_dict(provider, models=models)
|
||||
provider = await AIProvider.get(id=provider_id)
|
||||
return provider_to_dict(provider)
|
||||
|
||||
@classmethod
|
||||
async def create_provider(cls, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
data = payload.copy()
|
||||
data.setdefault("extra_config", {})
|
||||
provider = await AIProvider.create(**data)
|
||||
return provider_to_dict(provider)
|
||||
|
||||
@classmethod
|
||||
async def update_provider(cls, provider_id: int, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
provider = await AIProvider.get(id=provider_id)
|
||||
for field, value in payload.items():
|
||||
setattr(provider, field, value)
|
||||
await provider.save()
|
||||
return provider_to_dict(provider)
|
||||
|
||||
@classmethod
|
||||
async def delete_provider(cls, provider_id: int) -> None:
|
||||
await AIProvider.filter(id=provider_id).delete()
|
||||
|
||||
@classmethod
|
||||
async def list_models(cls, provider_id: int) -> List[Dict[str, Any]]:
|
||||
models = await AIModel.filter(provider_id=provider_id).order_by("id").prefetch_related("provider")
|
||||
return [model_to_dict(m) for m in models]
|
||||
|
||||
@classmethod
|
||||
async def create_model(cls, provider_id: int, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
data = payload.copy()
|
||||
data["provider_id"] = provider_id
|
||||
data["capabilities"] = normalize_capabilities(data.get("capabilities"))
|
||||
embedding_dim = _normalize_embedding_dim(data.pop("embedding_dimensions", None))
|
||||
data = _apply_embedding_dim_to_metadata(data, embedding_dim)
|
||||
model = await AIModel.create(**data)
|
||||
await model.fetch_related("provider")
|
||||
return model_to_dict(model)
|
||||
|
||||
@classmethod
|
||||
async def update_model(cls, model_id: int, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
model = await AIModel.get(id=model_id)
|
||||
data = payload.copy()
|
||||
if "capabilities" in data:
|
||||
data["capabilities"] = normalize_capabilities(data.get("capabilities"))
|
||||
embedding_dim = None
|
||||
if "embedding_dimensions" in data:
|
||||
embedding_dim = _normalize_embedding_dim(data.pop("embedding_dimensions", None))
|
||||
_apply_embedding_dim_to_metadata(data, embedding_dim, base_metadata=model.metadata)
|
||||
for field, value in data.items():
|
||||
setattr(model, field, value)
|
||||
if embedding_dim is not None or ("embedding_dimensions" in payload and embedding_dim is None):
|
||||
model.embedding_dimensions = embedding_dim
|
||||
await model.save()
|
||||
await model.fetch_related("provider")
|
||||
return model_to_dict(model)
|
||||
|
||||
@classmethod
|
||||
async def delete_model(cls, model_id: int) -> None:
|
||||
await AIModel.filter(id=model_id).delete()
|
||||
|
||||
@classmethod
|
||||
async def fetch_remote_models(cls, provider_id: int) -> List[Dict[str, Any]]:
|
||||
provider = await AIProvider.get(id=provider_id)
|
||||
return await cls._get_remote_models(provider)
|
||||
|
||||
@classmethod
|
||||
async def _get_remote_models(cls, provider: AIProvider) -> List[Dict[str, Any]]:
|
||||
if not provider.base_url:
|
||||
raise ValueError("Provider base_url is required for syncing models")
|
||||
|
||||
fmt = (provider.api_format or "").lower()
|
||||
if fmt not in {"openai", "gemini"}:
|
||||
raise ValueError(f"Unsupported api_format '{provider.api_format}' for syncing models")
|
||||
|
||||
if fmt == "openai":
|
||||
return await cls._fetch_openai_models(provider)
|
||||
return await cls._fetch_gemini_models(provider)
|
||||
|
||||
@classmethod
|
||||
async def sync_models(cls, provider_id: int) -> Dict[str, int]:
|
||||
provider = await AIProvider.get(id=provider_id)
|
||||
remote_models = await cls._get_remote_models(provider)
|
||||
|
||||
created = 0
|
||||
updated = 0
|
||||
for entry in remote_models:
|
||||
defaults = entry.copy()
|
||||
model_id = defaults.pop("name")
|
||||
defaults["capabilities"] = normalize_capabilities(defaults.get("capabilities"))
|
||||
embedding_dim = _normalize_embedding_dim(defaults.pop("embedding_dimensions", None))
|
||||
defaults = _apply_embedding_dim_to_metadata(defaults, embedding_dim)
|
||||
obj, is_created = await AIModel.get_or_create(
|
||||
provider_id=provider.id,
|
||||
name=model_id,
|
||||
defaults=defaults,
|
||||
)
|
||||
if is_created:
|
||||
created += 1
|
||||
continue
|
||||
for field, value in defaults.items():
|
||||
setattr(obj, field, value)
|
||||
if embedding_dim is not None or ("embedding_dimensions" in entry and embedding_dim is None):
|
||||
obj.embedding_dimensions = embedding_dim
|
||||
await obj.save()
|
||||
updated += 1
|
||||
|
||||
return {"created": created, "updated": updated}
|
||||
|
||||
@classmethod
|
||||
async def get_default_models(cls) -> Dict[str, Optional[Dict[str, Any]]]:
|
||||
defaults = await AIDefaultModel.all().prefetch_related("model__provider")
|
||||
result: Dict[str, Optional[Dict[str, Any]]] = {ability: None for ability in ABILITIES}
|
||||
for item in defaults:
|
||||
result[item.ability] = model_to_dict(item.model, provider=item.model.provider) # type: ignore[attr-defined]
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
async def set_default_models(cls, mapping: Dict[str, Optional[int]]) -> Dict[str, Optional[Dict[str, Any]]]:
|
||||
normalized = {ability: mapping.get(ability) for ability in ABILITIES}
|
||||
async with in_transaction() as connection:
|
||||
for ability, model_id in normalized.items():
|
||||
record = await AIDefaultModel.get_or_none(ability=ability)
|
||||
if model_id:
|
||||
try:
|
||||
model = await AIModel.get(id=model_id)
|
||||
except DoesNotExist:
|
||||
raise ValueError(f"Model {model_id} not found")
|
||||
if record:
|
||||
record.model_id = model_id
|
||||
await record.save(using_db=connection)
|
||||
else:
|
||||
await AIDefaultModel.create(ability=ability, model_id=model_id)
|
||||
elif record:
|
||||
await record.delete(using_db=connection)
|
||||
return await cls.get_default_models()
|
||||
|
||||
@classmethod
|
||||
async def get_default_model(cls, ability: str) -> Optional[AIModel]:
|
||||
ability_key = ability.lower()
|
||||
if ability_key not in ABILITIES:
|
||||
return None
|
||||
record = await AIDefaultModel.get_or_none(ability=ability_key)
|
||||
if not record:
|
||||
return None
|
||||
model = await AIModel.get_or_none(id=record.model_id)
|
||||
if model:
|
||||
await model.fetch_related("provider")
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
async def _fetch_openai_models(cls, provider: AIProvider) -> List[Dict[str, Any]]:
|
||||
base_url = provider.base_url.rstrip("/")
|
||||
url = f"{base_url}/models"
|
||||
headers = {}
|
||||
if provider.api_key:
|
||||
headers["Authorization"] = f"Bearer {provider.api_key}"
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
payload = response.json()
|
||||
|
||||
data = payload.get("data", [])
|
||||
entries: List[Dict[str, Any]] = []
|
||||
for item in data:
|
||||
model_id = item.get("id")
|
||||
if not model_id:
|
||||
continue
|
||||
capabilities, embedding_dim = infer_openai_capabilities(model_id)
|
||||
entries.append({
|
||||
"name": model_id,
|
||||
"display_name": item.get("display_name"),
|
||||
"description": item.get("description"),
|
||||
"capabilities": capabilities,
|
||||
"context_window": item.get("context_window"),
|
||||
"embedding_dimensions": embedding_dim,
|
||||
"metadata": item,
|
||||
})
|
||||
return entries
|
||||
|
||||
@classmethod
|
||||
async def _fetch_gemini_models(cls, provider: AIProvider) -> List[Dict[str, Any]]:
|
||||
base_url = provider.base_url.rstrip("/")
|
||||
suffix = "/models"
|
||||
if provider.api_key:
|
||||
suffix += f"?key={provider.api_key}"
|
||||
url = f"{base_url}{suffix}"
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
payload = response.json()
|
||||
|
||||
data = payload.get("models", [])
|
||||
entries: List[Dict[str, Any]] = []
|
||||
for item in data:
|
||||
model_id = item.get("name")
|
||||
if not model_id:
|
||||
continue
|
||||
methods = item.get("supportedGenerationMethods") or []
|
||||
capabilities = infer_gemini_capabilities(methods)
|
||||
entries.append({
|
||||
"name": model_id,
|
||||
"display_name": item.get("displayName"),
|
||||
"description": item.get("description"),
|
||||
"capabilities": capabilities,
|
||||
"context_window": item.get("inputTokenLimit"),
|
||||
"embedding_dimensions": item.get("embeddingDimensions"),
|
||||
"metadata": item,
|
||||
})
|
||||
return entries
|
||||
|
||||
|
||||
class VectorDBService:
|
||||
_instance: Optional["VectorDBService"] = None
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(self, "_provider"):
|
||||
self._provider: Optional[BaseVectorProvider] = None
|
||||
self._provider_type: Optional[str] = None
|
||||
self._provider_config: Dict[str, Any] | None = None
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def _ensure_provider(self) -> BaseVectorProvider:
|
||||
if self._provider is None:
|
||||
await self.reload()
|
||||
assert self._provider is not None
|
||||
return self._provider
|
||||
|
||||
async def reload(self) -> BaseVectorProvider:
|
||||
async with self._lock:
|
||||
provider_type, provider_config = await VectorDBConfigManager.load_config()
|
||||
normalized_config = dict(provider_config or {})
|
||||
if (
|
||||
self._provider
|
||||
and self._provider_type == provider_type
|
||||
and self._provider_config == normalized_config
|
||||
):
|
||||
return self._provider
|
||||
|
||||
entry = get_provider_entry(provider_type)
|
||||
if not entry:
|
||||
raise RuntimeError(f"Unknown vector database provider: {provider_type}")
|
||||
if not entry.get("enabled", True):
|
||||
raise RuntimeError(f"Vector database provider '{provider_type}' is disabled")
|
||||
|
||||
provider_cls = get_provider_class(provider_type)
|
||||
if not provider_cls:
|
||||
raise RuntimeError(f"Provider class not found for '{provider_type}'")
|
||||
|
||||
provider = provider_cls(provider_config)
|
||||
await provider.initialize()
|
||||
|
||||
self._provider = provider
|
||||
self._provider_type = provider_type
|
||||
self._provider_config = normalized_config
|
||||
return provider
|
||||
|
||||
async def ensure_collection(self, collection_name: str, vector: bool = True, dim: int = DEFAULT_VECTOR_DIMENSION) -> None:
|
||||
provider = await self._ensure_provider()
|
||||
provider.ensure_collection(collection_name, vector, dim)
|
||||
|
||||
async def upsert_vector(self, collection_name: str, data: Dict[str, Any]) -> None:
|
||||
provider = await self._ensure_provider()
|
||||
provider.upsert_vector(collection_name, data)
|
||||
|
||||
async def delete_vector(self, collection_name: str, path: str) -> None:
|
||||
provider = await self._ensure_provider()
|
||||
provider.delete_vector(collection_name, path)
|
||||
|
||||
async def search_vectors(self, collection_name: str, query_embedding, top_k: int = 5):
|
||||
provider = await self._ensure_provider()
|
||||
return provider.search_vectors(collection_name, query_embedding, top_k)
|
||||
|
||||
async def search_by_path(self, collection_name: str, query_path: str, top_k: int = 20):
|
||||
provider = await self._ensure_provider()
|
||||
return provider.search_by_path(collection_name, query_path, top_k)
|
||||
|
||||
async def get_all_stats(self) -> Dict[str, Any]:
|
||||
provider = await self._ensure_provider()
|
||||
return provider.get_all_stats()
|
||||
|
||||
async def clear_all_data(self) -> None:
|
||||
provider = await self._ensure_provider()
|
||||
provider.clear_all_data()
|
||||
|
||||
async def current_provider(self) -> Dict[str, Any]:
|
||||
provider_type, provider_config = await VectorDBConfigManager.load_config()
|
||||
entry = get_provider_entry(provider_type) or {}
|
||||
return {
|
||||
"type": provider_type,
|
||||
"config": provider_config,
|
||||
"label": entry.get("label"),
|
||||
"enabled": entry.get("enabled", True),
|
||||
}
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AIProviderService",
|
||||
"VectorDBService",
|
||||
"VectorDBConfigManager",
|
||||
"DEFAULT_VECTOR_DIMENSION",
|
||||
"list_providers",
|
||||
"get_provider_entry",
|
||||
"get_provider_class",
|
||||
"normalize_capabilities",
|
||||
"ABILITIES",
|
||||
]
|
||||
121
domain/ai/types.py
Normal file
121
domain/ai/types.py
Normal file
@@ -0,0 +1,121 @@
|
||||
from typing import Any, Dict, Iterable, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
ABILITIES = ["chat", "vision", "embedding", "rerank", "voice", "tools"]
|
||||
|
||||
|
||||
def normalize_capabilities(items: Optional[Iterable[str]]) -> List[str]:
|
||||
if not items:
|
||||
return []
|
||||
normalized: List[str] = []
|
||||
for cap in items:
|
||||
key = str(cap).strip().lower()
|
||||
if key in ABILITIES and key not in normalized:
|
||||
normalized.append(key)
|
||||
return normalized
|
||||
|
||||
|
||||
class AIProviderBase(BaseModel):
|
||||
name: str
|
||||
identifier: str = Field(..., pattern=r"^[a-z0-9_\-\.]+$")
|
||||
provider_type: Optional[str] = None
|
||||
api_format: str
|
||||
base_url: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
logo_url: Optional[str] = None
|
||||
extra_config: Optional[dict] = None
|
||||
|
||||
@field_validator("api_format")
|
||||
@classmethod
|
||||
def normalize_format(cls, value: str) -> str:
|
||||
fmt = value.lower()
|
||||
if fmt not in {"openai", "gemini"}:
|
||||
raise ValueError("api_format must be 'openai' or 'gemini'")
|
||||
return fmt
|
||||
|
||||
|
||||
class AIProviderCreate(AIProviderBase):
|
||||
pass
|
||||
|
||||
|
||||
class AIProviderUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
provider_type: Optional[str] = None
|
||||
api_format: Optional[str] = None
|
||||
base_url: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
logo_url: Optional[str] = None
|
||||
extra_config: Optional[dict] = None
|
||||
|
||||
@field_validator("api_format")
|
||||
@classmethod
|
||||
def normalize_format(cls, value: Optional[str]) -> Optional[str]:
|
||||
if value is None:
|
||||
return value
|
||||
fmt = value.lower()
|
||||
if fmt not in {"openai", "gemini"}:
|
||||
raise ValueError("api_format must be 'openai' or 'gemini'")
|
||||
return fmt
|
||||
|
||||
|
||||
class AIModelBase(BaseModel):
|
||||
name: str
|
||||
display_name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
capabilities: Optional[List[str]] = None
|
||||
context_window: Optional[int] = None
|
||||
embedding_dimensions: Optional[int] = None
|
||||
metadata: Optional[dict] = None
|
||||
|
||||
@field_validator("capabilities")
|
||||
@classmethod
|
||||
def validate_capabilities(cls, items: Optional[List[str]]) -> Optional[List[str]]:
|
||||
if items is None:
|
||||
return None
|
||||
normalized = normalize_capabilities(items)
|
||||
invalid = set(items) - set(normalized)
|
||||
if invalid:
|
||||
raise ValueError(f"Unsupported capabilities: {', '.join(invalid)}")
|
||||
return normalized
|
||||
|
||||
|
||||
class AIModelCreate(AIModelBase):
|
||||
pass
|
||||
|
||||
|
||||
class AIModelUpdate(BaseModel):
|
||||
display_name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
capabilities: Optional[List[str]] = None
|
||||
context_window: Optional[int] = None
|
||||
embedding_dimensions: Optional[int] = None
|
||||
metadata: Optional[dict] = None
|
||||
|
||||
@field_validator("capabilities")
|
||||
@classmethod
|
||||
def validate_capabilities(cls, items: Optional[List[str]]) -> Optional[List[str]]:
|
||||
if items is None:
|
||||
return None
|
||||
normalized = normalize_capabilities(items)
|
||||
invalid = set(items) - set(normalized)
|
||||
if invalid:
|
||||
raise ValueError(f"Unsupported capabilities: {', '.join(invalid)}")
|
||||
return normalized
|
||||
|
||||
|
||||
class AIDefaultsUpdate(BaseModel):
|
||||
chat: Optional[int] = None
|
||||
vision: Optional[int] = None
|
||||
embedding: Optional[int] = None
|
||||
rerank: Optional[int] = None
|
||||
voice: Optional[int] = None
|
||||
tools: Optional[int] = None
|
||||
|
||||
def as_mapping(self) -> Dict[str, Optional[int]]:
|
||||
return {ability: getattr(self, ability) for ability in ABILITIES}
|
||||
|
||||
|
||||
class VectorDBConfigPayload(BaseModel):
|
||||
type: str = Field(..., description="向量数据库提供者类型")
|
||||
config: Dict[str, Any] = Field(default_factory=dict, description="提供者配置参数")
|
||||
65
domain/ai/vector_providers/__init__.py
Normal file
65
domain/ai/vector_providers/__init__.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from typing import Dict, List, Type
|
||||
|
||||
from .base import BaseVectorProvider
|
||||
from .milvus_lite import MilvusLiteProvider
|
||||
from .milvus_server import MilvusServerProvider
|
||||
from .qdrant import QdrantProvider
|
||||
|
||||
_PROVIDER_REGISTRY: Dict[str, Dict[str, object]] = {
|
||||
MilvusLiteProvider.type: {
|
||||
"class": MilvusLiteProvider,
|
||||
"label": MilvusLiteProvider.label,
|
||||
"description": MilvusLiteProvider.description,
|
||||
"enabled": MilvusLiteProvider.enabled,
|
||||
"config_schema": MilvusLiteProvider.config_schema,
|
||||
},
|
||||
MilvusServerProvider.type: {
|
||||
"class": MilvusServerProvider,
|
||||
"label": MilvusServerProvider.label,
|
||||
"description": MilvusServerProvider.description,
|
||||
"enabled": MilvusServerProvider.enabled,
|
||||
"config_schema": MilvusServerProvider.config_schema,
|
||||
},
|
||||
QdrantProvider.type: {
|
||||
"class": QdrantProvider,
|
||||
"label": QdrantProvider.label,
|
||||
"description": QdrantProvider.description,
|
||||
"enabled": QdrantProvider.enabled,
|
||||
"config_schema": QdrantProvider.config_schema,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def list_providers() -> List[Dict[str, object]]:
|
||||
return [
|
||||
{
|
||||
"type": type_key,
|
||||
"label": meta["label"],
|
||||
"description": meta.get("description"),
|
||||
"enabled": meta.get("enabled", True),
|
||||
"config_schema": meta.get("config_schema", []),
|
||||
}
|
||||
for type_key, meta in _PROVIDER_REGISTRY.items()
|
||||
]
|
||||
|
||||
|
||||
def get_provider_entry(provider_type: str) -> Dict[str, object] | None:
|
||||
return _PROVIDER_REGISTRY.get(provider_type)
|
||||
|
||||
|
||||
def get_provider_class(provider_type: str) -> Type[BaseVectorProvider] | None:
|
||||
entry = get_provider_entry(provider_type)
|
||||
if not entry:
|
||||
return None
|
||||
return entry.get("class") # type: ignore[return-value]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BaseVectorProvider",
|
||||
"MilvusLiteProvider",
|
||||
"MilvusServerProvider",
|
||||
"QdrantProvider",
|
||||
"list_providers",
|
||||
"get_provider_entry",
|
||||
"get_provider_class",
|
||||
]
|
||||
39
domain/ai/vector_providers/base.py
Normal file
39
domain/ai/vector_providers/base.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from typing import Any, Dict, List
|
||||
|
||||
|
||||
class BaseVectorProvider:
|
||||
"""向量数据库提供者基础类,所有实际实现需继承该类"""
|
||||
|
||||
type: str = ""
|
||||
label: str = ""
|
||||
description: str | None = None
|
||||
enabled: bool = True
|
||||
config_schema: List[Dict[str, Any]] = []
|
||||
|
||||
def __init__(self, config: Dict[str, Any] | None = None):
|
||||
self.config = config or {}
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""执行初始化逻辑,例如建立连接"""
|
||||
raise NotImplementedError
|
||||
|
||||
def ensure_collection(self, collection_name: str, vector: bool, dim: int) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def upsert_vector(self, collection_name: str, data: Dict[str, Any]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def delete_vector(self, collection_name: str, path: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def search_vectors(self, collection_name: str, query_embedding, top_k: int):
|
||||
raise NotImplementedError
|
||||
|
||||
def search_by_path(self, collection_name: str, query_path: str, top_k: int):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_all_stats(self) -> Dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
def clear_all_data(self) -> None:
|
||||
raise NotImplementedError
|
||||
276
domain/ai/vector_providers/milvus_lite.py
Normal file
276
domain/ai/vector_providers/milvus_lite.py
Normal file
@@ -0,0 +1,276 @@
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pymilvus import CollectionSchema, DataType, FieldSchema, MilvusClient
|
||||
|
||||
from .base import BaseVectorProvider
|
||||
|
||||
|
||||
class MilvusLiteProvider(BaseVectorProvider):
|
||||
type = "milvus_lite"
|
||||
label = "Milvus Lite"
|
||||
description = "Embedded Milvus Lite (local file storage)."
|
||||
enabled = True
|
||||
config_schema: List[Dict[str, Any]] = [
|
||||
{
|
||||
"key": "db_path",
|
||||
"label": "Database file path",
|
||||
"type": "text",
|
||||
"default": "data/db/milvus.db",
|
||||
"required": False,
|
||||
}
|
||||
]
|
||||
|
||||
def __init__(self, config: Dict[str, Any] | None = None):
|
||||
super().__init__(config)
|
||||
self.db_path = Path(self.config.get("db_path") or "data/db/milvus.db")
|
||||
self.client: MilvusClient | None = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
try:
|
||||
self.client = MilvusClient(str(self.db_path))
|
||||
except Exception as exc: # pragma: no cover - depends on local environment
|
||||
raise RuntimeError(f"Failed to open Milvus Lite at {self.db_path}: {exc}") from exc
|
||||
|
||||
def _get_client(self) -> MilvusClient:
|
||||
if not self.client:
|
||||
raise RuntimeError("Milvus Lite client is not initialized")
|
||||
return self.client
|
||||
|
||||
@staticmethod
|
||||
def _extract_hit_payload(hit: Any) -> tuple[Any, Any, Dict[str, Any]]:
|
||||
hit_id = getattr(hit, "id", None)
|
||||
distance = getattr(hit, "distance", None)
|
||||
payload: Dict[str, Any] = {}
|
||||
|
||||
raw: Dict[str, Any] | None = None
|
||||
if hasattr(hit, "entity"):
|
||||
raw_entity = getattr(hit, "entity")
|
||||
if hasattr(raw_entity, "to_dict"):
|
||||
raw = dict(raw_entity.to_dict())
|
||||
else:
|
||||
raw = dict(raw_entity)
|
||||
elif isinstance(hit, dict):
|
||||
raw = dict(hit)
|
||||
|
||||
if raw:
|
||||
hit_id = hit_id or raw.get("id")
|
||||
distance = distance if distance is not None else raw.get("distance")
|
||||
inner = raw.get("entity")
|
||||
if isinstance(inner, dict):
|
||||
payload = dict(inner)
|
||||
else:
|
||||
payload = {k: v for k, v in raw.items() if k not in {"id", "distance", "entity"}}
|
||||
|
||||
payload.setdefault("path", payload.get("source_path"))
|
||||
payload.setdefault("source_path", payload.get("path"))
|
||||
return hit_id, distance, payload
|
||||
|
||||
@staticmethod
|
||||
def _to_int(value: Any) -> int:
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return 0
|
||||
|
||||
def ensure_collection(self, collection_name: str, vector: bool, dim: int) -> None:
|
||||
client = self._get_client()
|
||||
if client.has_collection(collection_name):
|
||||
return
|
||||
common_fields = [
|
||||
FieldSchema(name="path", dtype=DataType.VARCHAR, max_length=512, is_primary=True, auto_id=False),
|
||||
FieldSchema(name="source_path", dtype=DataType.VARCHAR, max_length=512, is_primary=False, auto_id=False),
|
||||
]
|
||||
|
||||
if vector:
|
||||
vector_dim = dim if isinstance(dim, int) and dim > 0 else 0
|
||||
if vector_dim <= 0:
|
||||
vector_dim = 4096
|
||||
fields = [
|
||||
*common_fields,
|
||||
FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=vector_dim),
|
||||
]
|
||||
schema = CollectionSchema(fields, description="Vector collection", enable_dynamic_field=True)
|
||||
client.create_collection(collection_name, schema=schema)
|
||||
index_params = MilvusClient.prepare_index_params()
|
||||
index_params.add_index(
|
||||
field_name="embedding",
|
||||
index_type="IVF_FLAT",
|
||||
index_name="vector_index",
|
||||
metric_type="COSINE",
|
||||
params={"nlist": 64},
|
||||
)
|
||||
client.create_index(collection_name, index_params=index_params)
|
||||
else:
|
||||
schema = CollectionSchema(common_fields, description="Simple file index", enable_dynamic_field=True)
|
||||
client.create_collection(collection_name, schema=schema)
|
||||
|
||||
def upsert_vector(self, collection_name: str, data: Dict[str, Any]) -> None:
|
||||
payload = dict(data)
|
||||
payload.setdefault("source_path", payload.get("path"))
|
||||
payload.setdefault("vector_id", payload.get("path"))
|
||||
self._get_client().upsert(collection_name, data=[payload])
|
||||
|
||||
def delete_vector(self, collection_name: str, path: str) -> None:
|
||||
client = self._get_client()
|
||||
escaped = path.replace('"', '\\"')
|
||||
client.delete(collection_name, filter=f'source_path == "{escaped}"')
|
||||
|
||||
def search_vectors(self, collection_name: str, query_embedding, top_k: int):
|
||||
search_params = {"metric_type": "COSINE"}
|
||||
output_fields = [
|
||||
"path",
|
||||
"source_path",
|
||||
"chunk_id",
|
||||
"mime",
|
||||
"text",
|
||||
"start_offset",
|
||||
"end_offset",
|
||||
"type",
|
||||
"name",
|
||||
]
|
||||
raw_results = self._get_client().search(
|
||||
collection_name,
|
||||
data=[query_embedding],
|
||||
anns_field="embedding",
|
||||
search_params=search_params,
|
||||
limit=top_k,
|
||||
output_fields=output_fields,
|
||||
)
|
||||
formatted: List[List[Dict[str, Any]]] = []
|
||||
for hits in raw_results:
|
||||
bucket: List[Dict[str, Any]] = []
|
||||
for hit in hits:
|
||||
hit_id, distance, entity = self._extract_hit_payload(hit)
|
||||
bucket.append({
|
||||
"id": hit_id,
|
||||
"distance": distance,
|
||||
"entity": entity,
|
||||
})
|
||||
formatted.append(bucket)
|
||||
return formatted
|
||||
|
||||
def search_by_path(self, collection_name: str, query_path: str, top_k: int):
|
||||
if query_path:
|
||||
escaped = query_path.replace('"', '\\"')
|
||||
filter_expr = f'source_path like \"%{escaped}%\"'
|
||||
else:
|
||||
filter_expr = "source_path like '%%'"
|
||||
results = self._get_client().query(
|
||||
collection_name,
|
||||
filter=filter_expr,
|
||||
limit=top_k,
|
||||
output_fields=[
|
||||
"path",
|
||||
"source_path",
|
||||
"chunk_id",
|
||||
"mime",
|
||||
"text",
|
||||
"start_offset",
|
||||
"end_offset",
|
||||
"type",
|
||||
"name",
|
||||
],
|
||||
)
|
||||
formatted = []
|
||||
for row in results:
|
||||
entity = dict(row)
|
||||
entity.setdefault("path", entity.get("source_path"))
|
||||
formatted.append({
|
||||
"id": entity.get("path"),
|
||||
"distance": 1.0,
|
||||
"entity": entity,
|
||||
})
|
||||
return [formatted]
|
||||
|
||||
def get_all_stats(self) -> Dict[str, Any]:
|
||||
client = self._get_client()
|
||||
try:
|
||||
collection_names = client.list_collections()
|
||||
except Exception as exc:
|
||||
raise RuntimeError(f"Failed to list collections: {exc}") from exc
|
||||
|
||||
collections: List[Dict[str, Any]] = []
|
||||
total_vectors = 0
|
||||
total_estimated_memory = 0
|
||||
|
||||
for name in collection_names:
|
||||
try:
|
||||
stats = client.get_collection_stats(name) or {}
|
||||
except Exception:
|
||||
stats = {}
|
||||
row_count = self._to_int(stats.get("row_count"))
|
||||
total_vectors += row_count
|
||||
|
||||
dimension: Optional[int] = None
|
||||
is_vector_collection = False
|
||||
try:
|
||||
description = client.describe_collection(name)
|
||||
except Exception:
|
||||
description = None
|
||||
|
||||
if description:
|
||||
for field in description.get("fields", []):
|
||||
if field.get("type") == DataType.FLOAT_VECTOR:
|
||||
params = field.get("params") or {}
|
||||
dimension = self._to_int(params.get("dim")) or 4096
|
||||
is_vector_collection = True
|
||||
break
|
||||
|
||||
estimated_memory = 0
|
||||
if is_vector_collection and dimension:
|
||||
estimated_memory = row_count * dimension * 4
|
||||
total_estimated_memory += estimated_memory
|
||||
|
||||
indexes: List[Dict[str, Any]] = []
|
||||
try:
|
||||
index_names = client.list_indexes(name) or []
|
||||
except Exception:
|
||||
index_names = []
|
||||
|
||||
for index_name in index_names:
|
||||
try:
|
||||
detail = client.describe_index(name) or {}
|
||||
except Exception:
|
||||
detail = {}
|
||||
indexes.append(
|
||||
{
|
||||
"index_name": index_name,
|
||||
"index_type": detail.get("index_type"),
|
||||
"metric_type": detail.get("metric_type"),
|
||||
"indexed_rows": self._to_int(detail.get("indexed_rows")),
|
||||
"pending_index_rows": self._to_int(detail.get("pending_index_rows")),
|
||||
"state": detail.get("state"),
|
||||
}
|
||||
)
|
||||
|
||||
collections.append(
|
||||
{
|
||||
"name": name,
|
||||
"row_count": row_count,
|
||||
"dimension": dimension if is_vector_collection else None,
|
||||
"estimated_memory_bytes": estimated_memory,
|
||||
"is_vector_collection": is_vector_collection,
|
||||
"indexes": indexes,
|
||||
}
|
||||
)
|
||||
|
||||
db_file_size = None
|
||||
try:
|
||||
if self.db_path.exists():
|
||||
db_file_size = self.db_path.stat().st_size
|
||||
except OSError:
|
||||
db_file_size = None
|
||||
|
||||
return {
|
||||
"collections": collections,
|
||||
"collection_count": len(collections),
|
||||
"total_vectors": total_vectors,
|
||||
"estimated_total_memory_bytes": total_estimated_memory,
|
||||
"db_file_size_bytes": db_file_size,
|
||||
}
|
||||
|
||||
def clear_all_data(self) -> None:
|
||||
client = self._get_client()
|
||||
for collection_name in client.list_collections():
|
||||
client.drop_collection(collection_name)
|
||||
276
domain/ai/vector_providers/milvus_server.py
Normal file
276
domain/ai/vector_providers/milvus_server.py
Normal file
@@ -0,0 +1,276 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pymilvus import CollectionSchema, DataType, FieldSchema, MilvusClient
|
||||
|
||||
from .base import BaseVectorProvider
|
||||
|
||||
|
||||
class MilvusServerProvider(BaseVectorProvider):
|
||||
type = "milvus_server"
|
||||
label = "Milvus Server"
|
||||
description = "Remote Milvus instance accessed via URI."
|
||||
enabled = True
|
||||
config_schema: List[Dict[str, Any]] = [
|
||||
{
|
||||
"key": "uri",
|
||||
"label": "Server URI",
|
||||
"type": "text",
|
||||
"required": True,
|
||||
"placeholder": "http://localhost:19530",
|
||||
},
|
||||
{
|
||||
"key": "token",
|
||||
"label": "Token",
|
||||
"type": "password",
|
||||
"required": False,
|
||||
"placeholder": "user:password",
|
||||
},
|
||||
]
|
||||
|
||||
def __init__(self, config: Dict[str, Any] | None = None):
|
||||
super().__init__(config)
|
||||
self.client: MilvusClient | None = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
uri = self.config.get("uri")
|
||||
if not uri:
|
||||
raise RuntimeError("Milvus Server URI is required")
|
||||
try:
|
||||
self.client = MilvusClient(uri=uri, token=self.config.get("token"))
|
||||
except Exception as exc: # pragma: no cover - depends on remote availability
|
||||
raise RuntimeError(f"Failed to connect to Milvus Server {uri}: {exc}") from exc
|
||||
|
||||
def _get_client(self) -> MilvusClient:
|
||||
if not self.client:
|
||||
raise RuntimeError("Milvus Server client is not initialized")
|
||||
return self.client
|
||||
|
||||
@staticmethod
|
||||
def _extract_hit_payload(hit: Any) -> tuple[Any, Any, Dict[str, Any]]:
|
||||
hit_id = getattr(hit, "id", None)
|
||||
distance = getattr(hit, "distance", None)
|
||||
payload: Dict[str, Any] = {}
|
||||
|
||||
raw: Dict[str, Any] | None = None
|
||||
if hasattr(hit, "entity"):
|
||||
raw_entity = getattr(hit, "entity")
|
||||
if hasattr(raw_entity, "to_dict"):
|
||||
raw = dict(raw_entity.to_dict())
|
||||
else:
|
||||
raw = dict(raw_entity)
|
||||
elif isinstance(hit, dict):
|
||||
raw = dict(hit)
|
||||
|
||||
if raw:
|
||||
hit_id = hit_id or raw.get("id")
|
||||
distance = distance if distance is not None else raw.get("distance")
|
||||
inner = raw.get("entity")
|
||||
if isinstance(inner, dict):
|
||||
payload = dict(inner)
|
||||
else:
|
||||
payload = {k: v for k, v in raw.items() if k not in {"id", "distance", "entity"}}
|
||||
|
||||
payload.setdefault("path", payload.get("source_path"))
|
||||
payload.setdefault("source_path", payload.get("path"))
|
||||
return hit_id, distance, payload
|
||||
|
||||
@staticmethod
|
||||
def _to_int(value: Any) -> int:
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return 0
|
||||
|
||||
def ensure_collection(self, collection_name: str, vector: bool, dim: int) -> None:
|
||||
client = self._get_client()
|
||||
if client.has_collection(collection_name):
|
||||
return
|
||||
common_fields = [
|
||||
FieldSchema(name="path", dtype=DataType.VARCHAR, max_length=512, is_primary=True, auto_id=False),
|
||||
FieldSchema(name="source_path", dtype=DataType.VARCHAR, max_length=512, is_primary=False, auto_id=False),
|
||||
]
|
||||
if vector:
|
||||
vector_dim = dim if isinstance(dim, int) and dim > 0 else 0
|
||||
if vector_dim <= 0:
|
||||
vector_dim = 4096
|
||||
fields = [
|
||||
*common_fields,
|
||||
FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=vector_dim),
|
||||
]
|
||||
schema = CollectionSchema(fields, description="Vector collection", enable_dynamic_field=True)
|
||||
client.create_collection(collection_name, schema=schema)
|
||||
index_params = MilvusClient.prepare_index_params()
|
||||
index_params.add_index(
|
||||
field_name="embedding",
|
||||
index_type="IVF_FLAT",
|
||||
index_name="vector_index",
|
||||
metric_type="COSINE",
|
||||
params={"nlist": 64},
|
||||
)
|
||||
client.create_index(collection_name, index_params=index_params)
|
||||
else:
|
||||
schema = CollectionSchema(common_fields, description="Simple file index", enable_dynamic_field=True)
|
||||
client.create_collection(collection_name, schema=schema)
|
||||
|
||||
def upsert_vector(self, collection_name: str, data: Dict[str, Any]) -> None:
|
||||
payload = dict(data)
|
||||
payload.setdefault("source_path", payload.get("path"))
|
||||
payload.setdefault("vector_id", payload.get("path"))
|
||||
self._get_client().upsert(collection_name, data=[payload])
|
||||
|
||||
def delete_vector(self, collection_name: str, path: str) -> None:
|
||||
client = self._get_client()
|
||||
escaped = path.replace('"', '\\"')
|
||||
client.delete(collection_name, filter=f'source_path == "{escaped}"')
|
||||
|
||||
def search_vectors(self, collection_name: str, query_embedding, top_k: int):
|
||||
search_params = {"metric_type": "COSINE"}
|
||||
output_fields = [
|
||||
"path",
|
||||
"source_path",
|
||||
"chunk_id",
|
||||
"mime",
|
||||
"text",
|
||||
"start_offset",
|
||||
"end_offset",
|
||||
"type",
|
||||
"name",
|
||||
]
|
||||
raw_results = self._get_client().search(
|
||||
collection_name,
|
||||
data=[query_embedding],
|
||||
anns_field="embedding",
|
||||
search_params=search_params,
|
||||
limit=top_k,
|
||||
output_fields=output_fields,
|
||||
)
|
||||
formatted: List[List[Dict[str, Any]]] = []
|
||||
for hits in raw_results:
|
||||
bucket: List[Dict[str, Any]] = []
|
||||
for hit in hits:
|
||||
hit_id, distance, entity = self._extract_hit_payload(hit)
|
||||
bucket.append({
|
||||
"id": hit_id,
|
||||
"distance": distance,
|
||||
"entity": entity,
|
||||
})
|
||||
formatted.append(bucket)
|
||||
return formatted
|
||||
|
||||
def search_by_path(self, collection_name: str, query_path: str, top_k: int):
|
||||
if query_path:
|
||||
escaped = query_path.replace('"', '\\"')
|
||||
filter_expr = f'source_path like \"%{escaped}%\"'
|
||||
else:
|
||||
filter_expr = "source_path like '%%'"
|
||||
results = self._get_client().query(
|
||||
collection_name,
|
||||
filter=filter_expr,
|
||||
limit=top_k,
|
||||
output_fields=[
|
||||
"path",
|
||||
"source_path",
|
||||
"chunk_id",
|
||||
"mime",
|
||||
"text",
|
||||
"start_offset",
|
||||
"end_offset",
|
||||
"type",
|
||||
"name",
|
||||
],
|
||||
)
|
||||
formatted = []
|
||||
for row in results:
|
||||
entity = dict(row)
|
||||
entity.setdefault("path", entity.get("source_path"))
|
||||
formatted.append({
|
||||
"id": entity.get("path"),
|
||||
"distance": 1.0,
|
||||
"entity": entity,
|
||||
})
|
||||
return [formatted]
|
||||
|
||||
def get_all_stats(self) -> Dict[str, Any]:
|
||||
client = self._get_client()
|
||||
try:
|
||||
collection_names = client.list_collections()
|
||||
except Exception as exc:
|
||||
raise RuntimeError(f"Failed to list collections: {exc}") from exc
|
||||
|
||||
collections: List[Dict[str, Any]] = []
|
||||
total_vectors = 0
|
||||
total_estimated_memory = 0
|
||||
|
||||
for name in collection_names:
|
||||
try:
|
||||
stats = client.get_collection_stats(name) or {}
|
||||
except Exception:
|
||||
stats = {}
|
||||
row_count = self._to_int(stats.get("row_count"))
|
||||
total_vectors += row_count
|
||||
|
||||
dimension: Optional[int] = None
|
||||
is_vector_collection = False
|
||||
try:
|
||||
description = client.describe_collection(name)
|
||||
except Exception:
|
||||
description = None
|
||||
|
||||
if description:
|
||||
for field in description.get("fields", []):
|
||||
if field.get("type") == DataType.FLOAT_VECTOR:
|
||||
params = field.get("params") or {}
|
||||
dimension = self._to_int(params.get("dim")) or 4096
|
||||
is_vector_collection = True
|
||||
break
|
||||
|
||||
estimated_memory = 0
|
||||
if is_vector_collection and dimension:
|
||||
estimated_memory = row_count * dimension * 4
|
||||
total_estimated_memory += estimated_memory
|
||||
|
||||
indexes: List[Dict[str, Any]] = []
|
||||
try:
|
||||
index_names = client.list_indexes(name) or []
|
||||
except Exception:
|
||||
index_names = []
|
||||
|
||||
for index_name in index_names:
|
||||
try:
|
||||
detail = client.describe_index(name) or {}
|
||||
except Exception:
|
||||
detail = {}
|
||||
indexes.append(
|
||||
{
|
||||
"index_name": index_name,
|
||||
"index_type": detail.get("index_type"),
|
||||
"metric_type": detail.get("metric_type"),
|
||||
"indexed_rows": self._to_int(detail.get("indexed_rows")),
|
||||
"pending_index_rows": self._to_int(detail.get("pending_index_rows")),
|
||||
"state": detail.get("state"),
|
||||
}
|
||||
)
|
||||
|
||||
collections.append(
|
||||
{
|
||||
"name": name,
|
||||
"row_count": row_count,
|
||||
"dimension": dimension if is_vector_collection else None,
|
||||
"estimated_memory_bytes": estimated_memory,
|
||||
"is_vector_collection": is_vector_collection,
|
||||
"indexes": indexes,
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"collections": collections,
|
||||
"collection_count": len(collections),
|
||||
"total_vectors": total_vectors,
|
||||
"estimated_total_memory_bytes": total_estimated_memory,
|
||||
"db_file_size_bytes": None,
|
||||
}
|
||||
|
||||
def clear_all_data(self) -> None:
|
||||
client = self._get_client()
|
||||
for collection_name in client.list_collections():
|
||||
client.drop_collection(collection_name)
|
||||
271
domain/ai/vector_providers/qdrant.py
Normal file
271
domain/ai/vector_providers/qdrant.py
Normal file
@@ -0,0 +1,271 @@
|
||||
from typing import Any, Dict, List, Optional, Sequence
|
||||
from uuid import NAMESPACE_URL, uuid5
|
||||
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.http import models as qmodels
|
||||
|
||||
from .base import BaseVectorProvider
|
||||
|
||||
|
||||
class QdrantProvider(BaseVectorProvider):
|
||||
type = "qdrant"
|
||||
label = "Qdrant"
|
||||
description = "Qdrant vector database (HTTP API)."
|
||||
enabled = True
|
||||
config_schema: List[Dict[str, Any]] = [
|
||||
{
|
||||
"key": "url",
|
||||
"label": "Server URL",
|
||||
"type": "text",
|
||||
"required": True,
|
||||
"placeholder": "http://localhost:6333",
|
||||
},
|
||||
{
|
||||
"key": "api_key",
|
||||
"label": "API Key",
|
||||
"type": "password",
|
||||
"required": False,
|
||||
},
|
||||
]
|
||||
|
||||
def __init__(self, config: Dict[str, Any] | None = None):
|
||||
super().__init__(config)
|
||||
self.client: Optional[QdrantClient] = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
url = (self.config.get("url") or "").strip()
|
||||
if not url:
|
||||
raise RuntimeError("Qdrant URL is required")
|
||||
|
||||
api_key = (self.config.get("api_key") or None) or None
|
||||
try:
|
||||
client = QdrantClient(url=url, api_key=api_key)
|
||||
client.get_collections()
|
||||
self.client = client
|
||||
except Exception as exc: # pragma: no cover - 依赖外部服务
|
||||
raise RuntimeError(f"Failed to connect to Qdrant at {url}: {exc}") from exc
|
||||
|
||||
def _get_client(self) -> QdrantClient:
|
||||
if not self.client:
|
||||
raise RuntimeError("Qdrant client is not initialized")
|
||||
return self.client
|
||||
|
||||
@staticmethod
|
||||
def _vector_params(vector: bool, dim: int) -> qmodels.VectorParams:
|
||||
size = dim if vector and isinstance(dim, int) and dim > 0 else 1
|
||||
return qmodels.VectorParams(size=size, distance=qmodels.Distance.COSINE)
|
||||
|
||||
def _ensure_payload_indexes(self, client: QdrantClient, collection_name: str) -> None:
|
||||
for field in ("path", "source_path"):
|
||||
try:
|
||||
client.create_payload_index(
|
||||
collection_name=collection_name,
|
||||
field_name=field,
|
||||
field_schema="keyword",
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - 依赖外部服务
|
||||
message = str(exc).lower()
|
||||
if "already exists" in message or "index exists" in message:
|
||||
continue
|
||||
raise
|
||||
|
||||
def ensure_collection(self, collection_name: str, vector: bool, dim: int) -> None:
|
||||
client = self._get_client()
|
||||
try:
|
||||
exists = client.collection_exists(collection_name)
|
||||
except Exception as exc: # pragma: no cover - 依赖外部服务
|
||||
raise RuntimeError(f"Failed to check Qdrant collection '{collection_name}': {exc}") from exc
|
||||
|
||||
if exists:
|
||||
try:
|
||||
self._ensure_payload_indexes(client, collection_name)
|
||||
except Exception:
|
||||
pass
|
||||
return
|
||||
|
||||
vectors_config = self._vector_params(vector, dim)
|
||||
try:
|
||||
client.create_collection(collection_name=collection_name, vectors_config=vectors_config)
|
||||
except Exception as exc: # pragma: no cover
|
||||
if "already exists" in str(exc).lower():
|
||||
try:
|
||||
self._ensure_payload_indexes(client, collection_name)
|
||||
except Exception:
|
||||
pass
|
||||
return
|
||||
raise RuntimeError(f"Failed to create Qdrant collection '{collection_name}': {exc}") from exc
|
||||
|
||||
try:
|
||||
self._ensure_payload_indexes(client, collection_name)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _point_id(uid: str) -> str:
|
||||
return str(uuid5(NAMESPACE_URL, uid))
|
||||
|
||||
def _prepare_point(self, data: Dict[str, Any]) -> qmodels.PointStruct:
|
||||
uid = data.get("path")
|
||||
if not uid:
|
||||
raise ValueError("Qdrant upsert requires 'path' in data")
|
||||
|
||||
embedding = data.get("embedding")
|
||||
if embedding is None:
|
||||
vector = [0.0]
|
||||
else:
|
||||
vector = [float(x) for x in embedding]
|
||||
|
||||
payload = {k: v for k, v in data.items() if k != "embedding"}
|
||||
payload.setdefault("vector_id", uid)
|
||||
source_path = payload.get("source_path") or payload.get("path")
|
||||
payload["path"] = source_path
|
||||
return qmodels.PointStruct(id=self._point_id(str(uid)), vector=vector, payload=payload)
|
||||
|
||||
def upsert_vector(self, collection_name: str, data: Dict[str, Any]) -> None:
|
||||
client = self._get_client()
|
||||
point = self._prepare_point(data)
|
||||
client.upsert(collection_name=collection_name, wait=True, points=[point])
|
||||
|
||||
def delete_vector(self, collection_name: str, path: str) -> None:
|
||||
client = self._get_client()
|
||||
condition = qmodels.FieldCondition(
|
||||
key="path",
|
||||
match=qmodels.MatchValue(value=path),
|
||||
)
|
||||
flt = qmodels.Filter(must=[condition])
|
||||
selector = qmodels.FilterSelector(filter=flt)
|
||||
client.delete(collection_name=collection_name, points_selector=selector, wait=True)
|
||||
|
||||
def _format_search_results(self, points: Sequence[qmodels.ScoredPoint]):
|
||||
return [
|
||||
{
|
||||
"id": point.id,
|
||||
"distance": point.score,
|
||||
"entity": point.payload or {},
|
||||
}
|
||||
for point in points
|
||||
]
|
||||
|
||||
def search_vectors(self, collection_name: str, query_embedding, top_k: int):
|
||||
client = self._get_client()
|
||||
vector = [float(x) for x in query_embedding]
|
||||
points = client.search(
|
||||
collection_name=collection_name,
|
||||
query_vector=vector,
|
||||
limit=top_k,
|
||||
with_payload=True,
|
||||
)
|
||||
return [self._format_search_results(points)]
|
||||
|
||||
def search_by_path(self, collection_name: str, query_path: str, top_k: int):
|
||||
client = self._get_client()
|
||||
results: List[Dict[str, Any]] = []
|
||||
offset: Optional[str | int] = None
|
||||
remaining = max(top_k, 1)
|
||||
|
||||
while len(results) < top_k:
|
||||
batch_size = min(max(remaining * 2, 10), 200)
|
||||
records, next_offset = client.scroll(
|
||||
collection_name=collection_name,
|
||||
limit=batch_size,
|
||||
offset=offset,
|
||||
with_payload=True,
|
||||
)
|
||||
if not records:
|
||||
break
|
||||
|
||||
for record in records:
|
||||
payload = record.payload or {}
|
||||
path = payload.get("path")
|
||||
if query_path and path and query_path not in path:
|
||||
continue
|
||||
results.append({"id": record.id, "distance": 1.0, "entity": payload})
|
||||
if len(results) >= top_k:
|
||||
break
|
||||
|
||||
if next_offset is None or len(results) >= top_k:
|
||||
break
|
||||
offset = next_offset
|
||||
remaining = top_k - len(results)
|
||||
|
||||
return [results]
|
||||
|
||||
def _extract_vector_config(self, vectors) -> Optional[qmodels.VectorParams]:
|
||||
if isinstance(vectors, qmodels.VectorParams):
|
||||
return vectors
|
||||
if isinstance(vectors, dict):
|
||||
for value in vectors.values():
|
||||
if isinstance(value, qmodels.VectorParams):
|
||||
return value
|
||||
return None
|
||||
|
||||
def get_all_stats(self) -> Dict[str, Any]:
|
||||
client = self._get_client()
|
||||
try:
|
||||
response = client.get_collections()
|
||||
except Exception as exc: # pragma: no cover
|
||||
raise RuntimeError(f"Failed to list Qdrant collections: {exc}") from exc
|
||||
|
||||
collections: List[Dict[str, Any]] = []
|
||||
total_vectors = 0
|
||||
total_estimated_memory = 0
|
||||
|
||||
for description in response.collections or []:
|
||||
name = description.name
|
||||
try:
|
||||
info = client.get_collection(name)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
row_count = int(info.points_count or 0)
|
||||
total_vectors += row_count
|
||||
|
||||
vector_params = self._extract_vector_config(info.config.params.vectors if info.config and info.config.params else None)
|
||||
dimension = int(vector_params.size) if vector_params and vector_params.size else None
|
||||
estimated_memory = row_count * dimension * 4 if dimension else 0
|
||||
total_estimated_memory += estimated_memory
|
||||
distance = str(vector_params.distance) if vector_params and vector_params.distance else None
|
||||
|
||||
indexed_rows = int(info.indexed_vectors_count or 0)
|
||||
pending_rows = max(row_count - indexed_rows, 0)
|
||||
|
||||
collections.append(
|
||||
{
|
||||
"name": name,
|
||||
"row_count": row_count,
|
||||
"dimension": dimension,
|
||||
"estimated_memory_bytes": estimated_memory,
|
||||
"is_vector_collection": dimension is not None and dimension > 1,
|
||||
"indexes": [
|
||||
{
|
||||
"index_name": "hnsw",
|
||||
"index_type": "HNSW",
|
||||
"metric_type": distance,
|
||||
"indexed_rows": indexed_rows,
|
||||
"pending_index_rows": pending_rows,
|
||||
"state": info.status,
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"collections": collections,
|
||||
"collection_count": len(collections),
|
||||
"total_vectors": total_vectors,
|
||||
"estimated_total_memory_bytes": total_estimated_memory,
|
||||
"db_file_size_bytes": None,
|
||||
}
|
||||
|
||||
def clear_all_data(self) -> None:
|
||||
client = self._get_client()
|
||||
try:
|
||||
response = client.get_collections()
|
||||
except Exception as exc: # pragma: no cover
|
||||
raise RuntimeError(f"Failed to list Qdrant collections: {exc}") from exc
|
||||
|
||||
for description in response.collections or []:
|
||||
try:
|
||||
client.delete_collection(description.name)
|
||||
except Exception:
|
||||
continue
|
||||
5
domain/audit/__init__.py
Normal file
5
domain/audit/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from domain.audit.decorator import audit
|
||||
from domain.audit.types import AuditAction
|
||||
from domain.audit.api import router
|
||||
|
||||
__all__ = ["audit", "AuditAction", "router"]
|
||||
68
domain/audit/api.py
Normal file
68
domain/audit/api.py
Normal file
@@ -0,0 +1,68 @@
|
||||
from datetime import datetime, timezone
|
||||
from typing import Annotated, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
|
||||
from api import response
|
||||
from domain.audit.service import AuditService
|
||||
from domain.audit.types import AuditAction
|
||||
from domain.auth.service import get_current_active_user
|
||||
from domain.auth.types import User
|
||||
|
||||
CurrentUser = Annotated[User, Depends(get_current_active_user)]
|
||||
|
||||
router = APIRouter(prefix="/api/audit", tags=["Audit"])
|
||||
|
||||
|
||||
def _parse_iso(value: Optional[str], field: str):
|
||||
if not value:
|
||||
return None
|
||||
try:
|
||||
normalized = value.replace("Z", "+00:00")
|
||||
dt = datetime.fromisoformat(normalized)
|
||||
if dt.tzinfo:
|
||||
dt = dt.astimezone(timezone.utc).replace(tzinfo=None)
|
||||
return dt
|
||||
except ValueError as exc: # noqa: BLE001
|
||||
raise HTTPException(status_code=400, detail=f"invalid {field}") from exc
|
||||
|
||||
|
||||
@router.get("/logs")
|
||||
async def list_audit_logs(
|
||||
current_user: CurrentUser,
|
||||
page_num: int = Query(1, ge=1, alias="page", description="页码"),
|
||||
page_size: int = Query(20, ge=1, le=200, description="每页条数"),
|
||||
action: AuditAction | None = Query(None, description="操作类型"),
|
||||
success: bool | None = Query(None, description="是否成功"),
|
||||
username: str | None = Query(None, description="用户名模糊匹配"),
|
||||
path: str | None = Query(None, description="路径模糊匹配"),
|
||||
start_time: str | None = Query(None, description="开始时间 (ISO 8601)"),
|
||||
end_time: str | None = Query(None, description="结束时间 (ISO 8601)"),
|
||||
):
|
||||
start_dt = _parse_iso(start_time, "start_time")
|
||||
end_dt = _parse_iso(end_time, "end_time")
|
||||
items, total = await AuditService.list_logs(
|
||||
page=page_num,
|
||||
page_size=page_size,
|
||||
action=str(action) if action else None,
|
||||
success=success,
|
||||
username=username,
|
||||
path=path,
|
||||
start_time=start_dt,
|
||||
end_time=end_dt,
|
||||
)
|
||||
return response.success(response.page(items, total, page_num, page_size))
|
||||
|
||||
|
||||
@router.delete("/logs")
|
||||
async def clear_audit_logs(
|
||||
current_user: CurrentUser,
|
||||
start_time: str | None = Query(None, description="开始时间 (ISO 8601)"),
|
||||
end_time: str | None = Query(None, description="结束时间 (ISO 8601)"),
|
||||
):
|
||||
start_dt = _parse_iso(start_time, "start_time")
|
||||
end_dt = _parse_iso(end_time, "end_time")
|
||||
if start_dt is None and end_dt is None:
|
||||
raise HTTPException(status_code=400, detail="start_time 或 end_time 至少提供一个")
|
||||
deleted_count = await AuditService.clear_logs(start_time=start_dt, end_time=end_dt)
|
||||
return response.success({"deleted_count": deleted_count})
|
||||
199
domain/audit/decorator.py
Normal file
199
domain/audit/decorator.py
Normal file
@@ -0,0 +1,199 @@
|
||||
import inspect
|
||||
import time
|
||||
from functools import wraps
|
||||
from typing import Any, Dict, Mapping, Optional
|
||||
|
||||
import jwt
|
||||
from fastapi import Request
|
||||
from jwt.exceptions import InvalidTokenError
|
||||
|
||||
from domain.audit.service import AuditService
|
||||
from domain.audit.types import AuditAction
|
||||
from domain.auth.service import ALGORITHM
|
||||
from domain.config.service import ConfigService
|
||||
from models.database import UserAccount
|
||||
|
||||
|
||||
def _extract_request(bound_args: Mapping[str, Any]) -> Request | None:
|
||||
for value in bound_args.values():
|
||||
if isinstance(value, Request):
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
async def _resolve_user(request: Request | None, user_obj: Any | None) -> tuple[Optional[int], Optional[str]]:
|
||||
user_id: int | None = None
|
||||
username: str | None = None
|
||||
|
||||
if request:
|
||||
auth_header = request.headers.get("authorization") or request.headers.get("Authorization")
|
||||
if auth_header and auth_header.lower().startswith("bearer "):
|
||||
token = auth_header.split(" ", 1)[1]
|
||||
try:
|
||||
payload = jwt.decode(token, await ConfigService.get_secret_key("SECRET_KEY"), algorithms=[ALGORITHM])
|
||||
username = payload.get("sub") or payload.get("username")
|
||||
if username:
|
||||
user = await UserAccount.get_or_none(username=username)
|
||||
user_id = user.id if user else None
|
||||
except (InvalidTokenError, Exception):
|
||||
pass
|
||||
|
||||
if user_id is None and username is None and user_obj is not None:
|
||||
user_id = getattr(user_obj, "id", None) or getattr(user_obj, "user_id", None)
|
||||
username = getattr(user_obj, "username", None) or getattr(user_obj, "name", None)
|
||||
if isinstance(user_obj, dict):
|
||||
user_id = user_obj.get("id", user_obj.get("user_id", user_id))
|
||||
username = user_obj.get("username", user_obj.get("name", username))
|
||||
|
||||
return user_id, username
|
||||
|
||||
|
||||
def _extract_body_fields(bound_args: Mapping[str, Any], body_fields: list[str] | None, redact_fields: list[str] | None):
|
||||
if not body_fields:
|
||||
return None
|
||||
body: Dict[str, Any] = {}
|
||||
redacts = set(redact_fields or [])
|
||||
for value in bound_args.values():
|
||||
data: Optional[Dict[str, Any]] = None
|
||||
if hasattr(value, "model_dump"):
|
||||
try:
|
||||
data = value.model_dump()
|
||||
except Exception:
|
||||
data = None
|
||||
elif hasattr(value, "dict"):
|
||||
try:
|
||||
data = value.dict()
|
||||
except Exception:
|
||||
data = None
|
||||
elif isinstance(value, dict):
|
||||
data = value
|
||||
elif hasattr(value, "__dict__"):
|
||||
data = dict(value.__dict__)
|
||||
if not isinstance(data, dict):
|
||||
continue
|
||||
for field in body_fields:
|
||||
if field in data and field not in body:
|
||||
body[field] = data[field]
|
||||
if not body:
|
||||
return None
|
||||
for field in redacts:
|
||||
if field in body:
|
||||
body[field] = "<redacted>"
|
||||
return body
|
||||
|
||||
|
||||
def _build_request_params(request: Request | None) -> Dict[str, Any] | None:
|
||||
if not request:
|
||||
return None
|
||||
params: Dict[str, Any] = {}
|
||||
query = dict(request.query_params)
|
||||
if query:
|
||||
params["query"] = query
|
||||
path_params = dict(request.path_params or {})
|
||||
if path_params:
|
||||
params["path"] = path_params
|
||||
return params or None
|
||||
|
||||
|
||||
def _get_client_ip(request: Request | None) -> str | None:
|
||||
if not request:
|
||||
return None
|
||||
x_real_ip = request.headers.get("x-real-ip") or request.headers.get("X-Real-IP")
|
||||
if x_real_ip:
|
||||
ip = x_real_ip.strip()
|
||||
if ip:
|
||||
return ip
|
||||
x_forwarded_for = request.headers.get("x-forwarded-for") or request.headers.get("X-Forwarded-For")
|
||||
if x_forwarded_for:
|
||||
for part in x_forwarded_for.split(","):
|
||||
ip = part.strip()
|
||||
if ip and ip.lower() != "unknown":
|
||||
return ip
|
||||
return request.client.host if request.client else None
|
||||
|
||||
|
||||
def _status_code_from_response(response: Any) -> int:
|
||||
if hasattr(response, "status_code"):
|
||||
try:
|
||||
return int(getattr(response, "status_code"))
|
||||
except Exception:
|
||||
pass
|
||||
return 200
|
||||
|
||||
|
||||
def audit(
|
||||
*,
|
||||
action: AuditAction,
|
||||
description: str | None = None,
|
||||
body_fields: list[str] | None = None,
|
||||
redact_fields: list[str] | None = None,
|
||||
user_kw: str = "current_user",
|
||||
):
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
bound = inspect.signature(func).bind_partial(*args, **kwargs)
|
||||
bound.apply_defaults()
|
||||
request = _extract_request(bound.arguments)
|
||||
start = time.perf_counter()
|
||||
user_info = bound.arguments.get(user_kw)
|
||||
user_id, username = await _resolve_user(request, user_info)
|
||||
request_params = _build_request_params(request)
|
||||
request_body = _extract_body_fields(bound.arguments, body_fields, redact_fields)
|
||||
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
if inspect.isawaitable(result):
|
||||
result = await result
|
||||
status_code = _status_code_from_response(result)
|
||||
success = True
|
||||
error = None
|
||||
except Exception as exc: # noqa: BLE001
|
||||
status_code = getattr(exc, "status_code", 500)
|
||||
success = False
|
||||
error = str(exc)
|
||||
duration_ms = round((time.perf_counter() - start) * 1000, 2)
|
||||
try:
|
||||
await AuditService.log(
|
||||
action=action,
|
||||
description=description,
|
||||
user_id=user_id,
|
||||
username=username,
|
||||
client_ip=_get_client_ip(request),
|
||||
method=request.method if request else "",
|
||||
path=request.url.path if request else func.__name__,
|
||||
status_code=status_code,
|
||||
duration_ms=duration_ms,
|
||||
success=success,
|
||||
request_params=request_params,
|
||||
request_body=request_body,
|
||||
error=error,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
raise
|
||||
|
||||
duration_ms = round((time.perf_counter() - start) * 1000, 2)
|
||||
try:
|
||||
await AuditService.log(
|
||||
action=action,
|
||||
description=description,
|
||||
user_id=user_id,
|
||||
username=username,
|
||||
client_ip=_get_client_ip(request),
|
||||
method=request.method if request else "",
|
||||
path=request.url.path if request else func.__name__,
|
||||
status_code=status_code,
|
||||
duration_ms=duration_ms,
|
||||
success=success,
|
||||
request_params=request_params,
|
||||
request_body=request_body,
|
||||
error=error,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
124
domain/audit/service.py
Normal file
124
domain/audit/service.py
Normal file
@@ -0,0 +1,124 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from models.database import AuditLog
|
||||
|
||||
from domain.audit.types import AuditAction
|
||||
|
||||
|
||||
class AuditService:
|
||||
@classmethod
|
||||
async def log(
|
||||
cls,
|
||||
*,
|
||||
action: AuditAction | str,
|
||||
description: Optional[str],
|
||||
user_id: Optional[int],
|
||||
username: Optional[str],
|
||||
client_ip: Optional[str],
|
||||
method: str,
|
||||
path: str,
|
||||
status_code: int,
|
||||
duration_ms: Optional[float],
|
||||
success: bool,
|
||||
request_params: Optional[Dict[str, Any]] = None,
|
||||
request_body: Optional[Dict[str, Any]] = None,
|
||||
error: Optional[str] = None,
|
||||
) -> None:
|
||||
await AuditLog.create(
|
||||
action=str(action),
|
||||
description=description,
|
||||
user_id=user_id,
|
||||
username=username,
|
||||
client_ip=client_ip,
|
||||
method=method,
|
||||
path=path,
|
||||
status_code=status_code,
|
||||
duration_ms=duration_ms,
|
||||
success=success,
|
||||
request_params=request_params,
|
||||
request_body=request_body,
|
||||
error=error,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _serialize(cls, log: AuditLog) -> Dict[str, Any]:
|
||||
return {
|
||||
"id": log.id,
|
||||
"created_at": log.created_at.isoformat() if log.created_at else None,
|
||||
"action": log.action,
|
||||
"description": log.description,
|
||||
"user_id": log.user_id,
|
||||
"username": log.username,
|
||||
"client_ip": log.client_ip,
|
||||
"method": log.method,
|
||||
"path": log.path,
|
||||
"status_code": log.status_code,
|
||||
"duration_ms": log.duration_ms,
|
||||
"success": log.success,
|
||||
"request_params": log.request_params,
|
||||
"request_body": log.request_body,
|
||||
"error": log.error,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _apply_filters(
|
||||
cls,
|
||||
*,
|
||||
action: str | None = None,
|
||||
success: bool | None = None,
|
||||
username: str | None = None,
|
||||
path: str | None = None,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
):
|
||||
qs = AuditLog.all()
|
||||
if action:
|
||||
qs = qs.filter(action=action)
|
||||
if success is not None:
|
||||
qs = qs.filter(success=success)
|
||||
if username:
|
||||
qs = qs.filter(username__icontains=username)
|
||||
if path:
|
||||
qs = qs.filter(path__icontains=path)
|
||||
if start_time:
|
||||
qs = qs.filter(created_at__gte=start_time)
|
||||
if end_time:
|
||||
qs = qs.filter(created_at__lte=end_time)
|
||||
return qs
|
||||
|
||||
@classmethod
|
||||
async def list_logs(
|
||||
cls,
|
||||
*,
|
||||
page: int,
|
||||
page_size: int,
|
||||
action: str | None = None,
|
||||
success: bool | None = None,
|
||||
username: str | None = None,
|
||||
path: str | None = None,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
):
|
||||
qs = cls._apply_filters(
|
||||
action=action,
|
||||
success=success,
|
||||
username=username,
|
||||
path=path,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
total = await qs.count()
|
||||
offset = (page - 1) * page_size
|
||||
items = await qs.order_by("-created_at").offset(offset).limit(page_size)
|
||||
return [cls._serialize(log) for log in items], total
|
||||
|
||||
@classmethod
|
||||
async def clear_logs(
|
||||
cls,
|
||||
*,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
) -> int:
|
||||
qs = cls._apply_filters(start_time=start_time, end_time=end_time)
|
||||
deleted_count = await qs.delete()
|
||||
return deleted_count
|
||||
16
domain/audit/types.py
Normal file
16
domain/audit/types.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class AuditAction(StrEnum):
|
||||
LOGIN = "login"
|
||||
LOGOUT = "logout"
|
||||
REGISTER = "register"
|
||||
READ = "read"
|
||||
CREATE = "create"
|
||||
UPDATE = "update"
|
||||
DELETE = "delete"
|
||||
RESET_PASSWORD = "reset_password"
|
||||
SHARE = "share"
|
||||
DOWNLOAD = "download"
|
||||
UPLOAD = "upload"
|
||||
OTHER = "other"
|
||||
90
domain/auth/api.py
Normal file
90
domain/auth/api.py
Normal file
@@ -0,0 +1,90 @@
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
|
||||
from api.response import success
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.auth.service import AuthService, get_current_active_user
|
||||
from domain.auth.types import (
|
||||
PasswordResetConfirm,
|
||||
PasswordResetRequest,
|
||||
RegisterRequest,
|
||||
Token,
|
||||
UpdateMeRequest,
|
||||
User,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/api/auth", tags=["auth"])
|
||||
|
||||
|
||||
@router.post("/register", summary="注册第一个管理员用户")
|
||||
@audit(
|
||||
action=AuditAction.REGISTER,
|
||||
description="注册管理员",
|
||||
body_fields=["username", "email", "full_name"],
|
||||
redact_fields=["password"],
|
||||
)
|
||||
async def register(request: Request, data: RegisterRequest):
|
||||
user = await AuthService.register_user(data)
|
||||
return success({"username": user.username}, msg="初始用户注册成功")
|
||||
|
||||
|
||||
@router.post("/login")
|
||||
@audit(action=AuditAction.LOGIN, description="用户登录", body_fields=["username"], redact_fields=["password"])
|
||||
async def login_for_access_token(
|
||||
request: Request,
|
||||
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
||||
) -> Token:
|
||||
return await AuthService.login(form_data)
|
||||
|
||||
|
||||
@router.get("/me", summary="获取当前登录用户信息")
|
||||
@audit(action=AuditAction.READ, description="获取当前用户信息")
|
||||
async def get_me(
|
||||
request: Request, current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
profile = AuthService.get_profile(current_user)
|
||||
return success(profile)
|
||||
|
||||
|
||||
@router.put("/me", summary="更新当前登录用户信息")
|
||||
@audit(
|
||||
action=AuditAction.UPDATE,
|
||||
description="更新当前用户信息",
|
||||
body_fields=["email", "full_name"],
|
||||
redact_fields=["old_password", "new_password"],
|
||||
)
|
||||
async def update_me(
|
||||
request: Request,
|
||||
payload: UpdateMeRequest,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
profile = await AuthService.update_me(payload, current_user)
|
||||
return success(profile)
|
||||
|
||||
|
||||
@router.post("/password-reset/request", summary="请求密码重置邮件")
|
||||
@audit(action=AuditAction.RESET_PASSWORD, description="请求密码重置邮件", body_fields=["email"])
|
||||
async def password_reset_request_endpoint(request: Request, payload: PasswordResetRequest):
|
||||
await AuthService.request_password_reset(payload)
|
||||
return success(msg="如果邮箱存在,将发送重置邮件")
|
||||
|
||||
|
||||
@router.get("/password-reset/verify", summary="校验密码重置令牌")
|
||||
@audit(action=AuditAction.RESET_PASSWORD, description="校验密码重置令牌", redact_fields=["token"])
|
||||
async def password_reset_verify(request: Request, token: str):
|
||||
user = await AuthService.verify_password_reset_token(token)
|
||||
return success({"username": user.username, "email": user.email})
|
||||
|
||||
|
||||
@router.post("/password-reset/confirm", summary="使用令牌重置密码")
|
||||
@audit(
|
||||
action=AuditAction.RESET_PASSWORD,
|
||||
description="重置密码",
|
||||
body_fields=["token"],
|
||||
redact_fields=["token", "password"],
|
||||
)
|
||||
async def password_reset_confirm(request: Request, payload: PasswordResetConfirm):
|
||||
await AuthService.reset_password_with_token(payload)
|
||||
return success(msg="密码已重置")
|
||||
365
domain/auth/service.py
Normal file
365
domain/auth/service.py
Normal file
@@ -0,0 +1,365 @@
|
||||
import asyncio
|
||||
import hashlib
|
||||
import secrets
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Annotated
|
||||
|
||||
import bcrypt
|
||||
import jwt
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
||||
from jwt.exceptions import InvalidTokenError
|
||||
|
||||
from domain.auth.types import (
|
||||
PasswordResetConfirm,
|
||||
PasswordResetRequest,
|
||||
RegisterRequest,
|
||||
Token,
|
||||
TokenData,
|
||||
UpdateMeRequest,
|
||||
User,
|
||||
UserInDB,
|
||||
)
|
||||
from models.database import UserAccount
|
||||
from domain.config.service import ConfigService
|
||||
|
||||
ALGORITHM = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 * 365
|
||||
PASSWORD_RESET_TOKEN_EXPIRE_MINUTES = 10
|
||||
|
||||
|
||||
def _now() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PasswordResetEntry:
|
||||
user_id: int
|
||||
email: str
|
||||
username: str
|
||||
expires_at: datetime
|
||||
used: bool = False
|
||||
|
||||
|
||||
class PasswordResetStore:
|
||||
_tokens: dict[str, PasswordResetEntry] = {}
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
@classmethod
|
||||
def _cleanup(cls):
|
||||
now = _now()
|
||||
for token, record in list(cls._tokens.items()):
|
||||
if record.used or record.expires_at < now:
|
||||
cls._tokens.pop(token, None)
|
||||
|
||||
@classmethod
|
||||
async def create(cls, user: UserAccount) -> str:
|
||||
async with cls._lock:
|
||||
cls._cleanup()
|
||||
for key, record in list(cls._tokens.items()):
|
||||
if record.user_id == user.id:
|
||||
cls._tokens.pop(key, None)
|
||||
token = secrets.token_urlsafe(32)
|
||||
expires_at = _now() + timedelta(minutes=PASSWORD_RESET_TOKEN_EXPIRE_MINUTES)
|
||||
cls._tokens[token] = PasswordResetEntry(
|
||||
user_id=user.id,
|
||||
email=user.email or "",
|
||||
username=user.username,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
return token
|
||||
|
||||
@classmethod
|
||||
async def get(cls, token: str) -> PasswordResetEntry | None:
|
||||
async with cls._lock:
|
||||
cls._cleanup()
|
||||
record = cls._tokens.get(token)
|
||||
if not record or record.used:
|
||||
return None
|
||||
return record
|
||||
|
||||
@classmethod
|
||||
async def mark_used(cls, token: str) -> None:
|
||||
async with cls._lock:
|
||||
record = cls._tokens.get(token)
|
||||
if record:
|
||||
record.used = True
|
||||
cls._cleanup()
|
||||
|
||||
@classmethod
|
||||
async def invalidate_user(cls, user_id: int, except_token: str | None = None) -> None:
|
||||
async with cls._lock:
|
||||
for key, record in list(cls._tokens.items()):
|
||||
if record.user_id == user_id and key != except_token:
|
||||
cls._tokens.pop(key, None)
|
||||
cls._cleanup()
|
||||
|
||||
|
||||
class AuthService:
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login")
|
||||
algorithm = ALGORITHM
|
||||
access_token_expire_minutes = ACCESS_TOKEN_EXPIRE_MINUTES
|
||||
password_reset_token_expire_minutes = PASSWORD_RESET_TOKEN_EXPIRE_MINUTES
|
||||
|
||||
@staticmethod
|
||||
def _to_bytes(value: str) -> bytes:
|
||||
return value.encode("utf-8")
|
||||
|
||||
@classmethod
|
||||
async def get_secret_key(cls) -> str:
|
||||
return await ConfigService.get_secret_key("SECRET_KEY", None)
|
||||
|
||||
@classmethod
|
||||
def _normalize_email(cls, email: str | None) -> str:
|
||||
return (email or "").strip().lower()
|
||||
|
||||
@classmethod
|
||||
def verify_password(cls, plain_password: str, hashed_password: str) -> bool:
|
||||
try:
|
||||
return bcrypt.checkpw(cls._to_bytes(plain_password), hashed_password.encode("utf-8"))
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_password_hash(cls, password: str) -> str:
|
||||
encoded = cls._to_bytes(password)
|
||||
if len(encoded) > 72:
|
||||
raise HTTPException(status_code=400, detail="密码过长")
|
||||
return bcrypt.hashpw(encoded, bcrypt.gensalt()).decode("utf-8")
|
||||
|
||||
@classmethod
|
||||
async def get_user_db(cls, username_or_email: str) -> UserInDB | None:
|
||||
user = await UserAccount.get_or_none(username=username_or_email)
|
||||
if not user:
|
||||
user = await UserAccount.get_or_none(email=username_or_email)
|
||||
if user:
|
||||
return UserInDB(
|
||||
id=user.id,
|
||||
username=user.username,
|
||||
email=user.email,
|
||||
full_name=user.full_name,
|
||||
disabled=user.disabled,
|
||||
hashed_password=user.hashed_password,
|
||||
)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
async def authenticate_user_db(cls, username_or_email: str, password: str) -> UserInDB | None:
|
||||
user = await cls.get_user_db(username_or_email)
|
||||
if not user:
|
||||
return None
|
||||
if not cls.verify_password(password, user.hashed_password):
|
||||
return None
|
||||
return user
|
||||
|
||||
@classmethod
|
||||
async def has_users(cls) -> bool:
|
||||
user_count = await UserAccount.all().count()
|
||||
return user_count > 0
|
||||
|
||||
@classmethod
|
||||
async def register_user(cls, payload: RegisterRequest):
|
||||
if await cls.has_users():
|
||||
raise HTTPException(status_code=403, detail="系统已初始化,不允许注册新用户")
|
||||
exists = await UserAccount.get_or_none(username=payload.username)
|
||||
if exists:
|
||||
raise HTTPException(status_code=400, detail="用户名已存在")
|
||||
hashed = cls.get_password_hash(payload.password)
|
||||
user = await UserAccount.create(
|
||||
username=payload.username,
|
||||
email=payload.email,
|
||||
full_name=payload.full_name,
|
||||
hashed_password=hashed,
|
||||
disabled=False,
|
||||
)
|
||||
return user
|
||||
|
||||
@classmethod
|
||||
async def create_access_token(cls, data: dict, expires_delta: timedelta | None = None):
|
||||
to_encode = data.copy()
|
||||
if "sub" not in to_encode and "username" in to_encode:
|
||||
to_encode["sub"] = to_encode["username"]
|
||||
expire = _now() + (expires_delta or timedelta(minutes=15))
|
||||
to_encode.update({"exp": expire})
|
||||
secret_key = await cls.get_secret_key()
|
||||
encoded_jwt = jwt.encode(to_encode, secret_key, algorithm=cls.algorithm)
|
||||
return encoded_jwt
|
||||
|
||||
@classmethod
|
||||
async def login(cls, form: OAuth2PasswordRequestForm) -> Token:
|
||||
user = await cls.authenticate_user_db(form.username, form.password)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="用户名或密码错误",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
access_token_expires = timedelta(minutes=cls.access_token_expire_minutes)
|
||||
access_token = await cls.create_access_token(
|
||||
data={"sub": user.username}, expires_delta=access_token_expires
|
||||
)
|
||||
return Token(access_token=access_token, token_type="bearer")
|
||||
|
||||
@classmethod
|
||||
def _build_profile(cls, user: User | UserInDB | UserAccount) -> dict:
|
||||
email = cls._normalize_email(getattr(user, "email", None))
|
||||
md5_hash = hashlib.md5(email.encode("utf-8")).hexdigest()
|
||||
gravatar_url = f"https://cn.cravatar.com/avatar/{md5_hash}?s=64&d=identicon"
|
||||
return {
|
||||
"id": user.id,
|
||||
"username": user.username,
|
||||
"email": getattr(user, "email", None),
|
||||
"full_name": getattr(user, "full_name", None),
|
||||
"gravatar_url": gravatar_url,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_profile(cls, user: User | UserInDB | UserAccount) -> dict:
|
||||
return cls._build_profile(user)
|
||||
|
||||
@classmethod
|
||||
async def update_me(cls, payload: UpdateMeRequest, current_user: User) -> dict:
|
||||
db_user = await UserAccount.get_or_none(id=current_user.id)
|
||||
if not db_user:
|
||||
raise HTTPException(status_code=404, detail="用户不存在")
|
||||
|
||||
if payload.email is not None:
|
||||
exists = (
|
||||
await UserAccount.filter(email=payload.email)
|
||||
.exclude(id=db_user.id)
|
||||
.exists()
|
||||
)
|
||||
if exists:
|
||||
raise HTTPException(status_code=400, detail="邮箱已被占用")
|
||||
db_user.email = payload.email
|
||||
|
||||
if payload.full_name is not None:
|
||||
db_user.full_name = payload.full_name
|
||||
|
||||
if payload.new_password:
|
||||
if not payload.old_password:
|
||||
raise HTTPException(status_code=400, detail="请提供原密码")
|
||||
if not cls.verify_password(payload.old_password, db_user.hashed_password):
|
||||
raise HTTPException(status_code=400, detail="原密码错误")
|
||||
db_user.hashed_password = cls.get_password_hash(payload.new_password)
|
||||
|
||||
await db_user.save()
|
||||
return cls._build_profile(db_user)
|
||||
|
||||
@classmethod
|
||||
async def request_password_reset(cls, payload: PasswordResetRequest) -> bool:
|
||||
normalized = cls._normalize_email(payload.email)
|
||||
if not normalized:
|
||||
return False
|
||||
user = await UserAccount.get_or_none(email=normalized)
|
||||
if not user or not user.email:
|
||||
return False
|
||||
|
||||
token = await PasswordResetStore.create(user)
|
||||
try:
|
||||
await cls._send_password_reset_email(user, token)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
await PasswordResetStore.mark_used(token)
|
||||
await PasswordResetStore.invalidate_user(user.id)
|
||||
raise HTTPException(status_code=500, detail="邮件发送失败") from exc
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
async def verify_password_reset_token(cls, token: str) -> UserAccount:
|
||||
record = await PasswordResetStore.get(token)
|
||||
if not record:
|
||||
raise HTTPException(status_code=400, detail="重置链接无效")
|
||||
user = await UserAccount.get_or_none(id=record.user_id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=400, detail="重置链接无效")
|
||||
if record.expires_at < _now():
|
||||
await PasswordResetStore.mark_used(token)
|
||||
raise HTTPException(status_code=400, detail="重置链接已过期")
|
||||
return user
|
||||
|
||||
@classmethod
|
||||
async def reset_password_with_token(cls, payload: PasswordResetConfirm) -> None:
|
||||
record = await PasswordResetStore.get(payload.token)
|
||||
if not record:
|
||||
raise HTTPException(status_code=400, detail="重置链接无效")
|
||||
if record.expires_at < _now():
|
||||
await PasswordResetStore.mark_used(payload.token)
|
||||
raise HTTPException(status_code=400, detail="重置链接已过期")
|
||||
|
||||
user = await UserAccount.get_or_none(id=record.user_id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=400, detail="重置链接无效")
|
||||
user.hashed_password = cls.get_password_hash(payload.password)
|
||||
await user.save(update_fields=["hashed_password"])
|
||||
await PasswordResetStore.mark_used(payload.token)
|
||||
await PasswordResetStore.invalidate_user(user.id)
|
||||
|
||||
@classmethod
|
||||
async def get_current_user(cls, token: str):
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
try:
|
||||
secret_key = await cls.get_secret_key()
|
||||
payload = jwt.decode(token, secret_key, algorithms=[cls.algorithm])
|
||||
username = payload.get("sub")
|
||||
if username is None:
|
||||
raise credentials_exception
|
||||
token_data = TokenData(username=username)
|
||||
except InvalidTokenError:
|
||||
raise credentials_exception
|
||||
user = await cls.get_user_db(token_data.username)
|
||||
if user is None:
|
||||
raise credentials_exception
|
||||
return user
|
||||
|
||||
@classmethod
|
||||
async def get_current_active_user(cls, current_user: User):
|
||||
if current_user.disabled:
|
||||
raise HTTPException(status_code=400, detail="Inactive user")
|
||||
return current_user
|
||||
|
||||
@classmethod
|
||||
async def _send_password_reset_email(cls, user: UserAccount, token: str) -> None:
|
||||
from domain.email.service import EmailService
|
||||
|
||||
app_domain = await ConfigService.get("APP_DOMAIN", None)
|
||||
base_url = (app_domain or "http://localhost:5173").rstrip("/")
|
||||
reset_link = f"{base_url}/reset-password?token={token}"
|
||||
await EmailService.enqueue_email(
|
||||
recipients=[user.email],
|
||||
subject="Foxel 密码重置",
|
||||
template="password_reset",
|
||||
context={
|
||||
"username": user.username,
|
||||
"reset_link": reset_link,
|
||||
"expire_minutes": cls.password_reset_token_expire_minutes,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def _current_user_dep(token: Annotated[str, Depends(AuthService.oauth2_scheme)]):
|
||||
return await AuthService.get_current_user(token)
|
||||
|
||||
|
||||
async def _current_active_user_dep(
|
||||
current_user: Annotated[User, Depends(_current_user_dep)],
|
||||
):
|
||||
return await AuthService.get_current_active_user(current_user)
|
||||
|
||||
|
||||
# 方便依赖注入与外部使用
|
||||
get_current_user = _current_user_dep
|
||||
get_current_active_user = _current_active_user_dep
|
||||
authenticate_user_db = AuthService.authenticate_user_db
|
||||
create_access_token = AuthService.create_access_token
|
||||
register_user = AuthService.register_user
|
||||
request_password_reset = AuthService.request_password_reset
|
||||
verify_password_reset_token = AuthService.verify_password_reset_token
|
||||
reset_password_with_token = AuthService.reset_password_with_token
|
||||
has_users = AuthService.has_users
|
||||
verify_password = AuthService.verify_password
|
||||
get_password_hash = AuthService.get_password_hash
|
||||
45
domain/auth/types.py
Normal file
45
domain/auth/types.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Token(BaseModel):
|
||||
access_token: str
|
||||
token_type: str
|
||||
|
||||
|
||||
class TokenData(BaseModel):
|
||||
username: str | None = None
|
||||
|
||||
|
||||
class User(BaseModel):
|
||||
id: int
|
||||
username: str
|
||||
email: str | None = None
|
||||
full_name: str | None = None
|
||||
disabled: bool | None = None
|
||||
|
||||
|
||||
class UserInDB(User):
|
||||
hashed_password: str
|
||||
|
||||
|
||||
class RegisterRequest(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
email: str | None = None
|
||||
full_name: str | None = None
|
||||
|
||||
|
||||
class UpdateMeRequest(BaseModel):
|
||||
email: str | None = None
|
||||
full_name: str | None = None
|
||||
old_password: str | None = None
|
||||
new_password: str | None = None
|
||||
|
||||
|
||||
class PasswordResetRequest(BaseModel):
|
||||
email: str
|
||||
|
||||
|
||||
class PasswordResetConfirm(BaseModel):
|
||||
token: str
|
||||
password: str
|
||||
1
domain/backup/__init__.py
Normal file
1
domain/backup/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
30
domain/backup/api.py
Normal file
30
domain/backup/api.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, File, Request, UploadFile
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.auth.service import get_current_active_user
|
||||
from domain.backup.service import BackupService
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api/backup",
|
||||
tags=["Backup & Restore"],
|
||||
dependencies=[Depends(get_current_active_user)],
|
||||
)
|
||||
|
||||
|
||||
@router.get("/export", summary="导出全站数据")
|
||||
@audit(action=AuditAction.DOWNLOAD, description="导出备份")
|
||||
async def export_backup(request: Request):
|
||||
data = await BackupService.export_data()
|
||||
timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
headers = {"Content-Disposition": f"attachment; filename=foxel_backup_{timestamp}.json"}
|
||||
return JSONResponse(content=data.model_dump(), headers=headers)
|
||||
|
||||
|
||||
@router.post("/import", summary="导入数据")
|
||||
@audit(action=AuditAction.UPLOAD, description="导入备份")
|
||||
async def import_backup(request: Request, file: UploadFile = File(...)):
|
||||
await BackupService.import_from_bytes(file.filename, await file.read())
|
||||
return {"message": "数据导入成功。"}
|
||||
203
domain/backup/service.py
Normal file
203
domain/backup/service.py
Normal file
@@ -0,0 +1,203 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import HTTPException
|
||||
from tortoise.transactions import in_transaction
|
||||
|
||||
from domain.backup.types import BackupData
|
||||
from domain.config.service import VERSION
|
||||
from models.database import (
|
||||
AIDefaultModel,
|
||||
AIModel,
|
||||
AIProvider,
|
||||
AutomationTask,
|
||||
Configuration,
|
||||
Plugin,
|
||||
ShareLink,
|
||||
StorageAdapter,
|
||||
UserAccount,
|
||||
)
|
||||
|
||||
|
||||
class BackupService:
|
||||
@classmethod
|
||||
async def export_data(cls) -> BackupData:
|
||||
async with in_transaction():
|
||||
adapters = await StorageAdapter.all().values()
|
||||
users = await UserAccount.all().values()
|
||||
tasks = await AutomationTask.all().values()
|
||||
shares = await ShareLink.all().values()
|
||||
configs = await Configuration.all().values()
|
||||
providers = await AIProvider.all().values()
|
||||
models = await AIModel.all().values()
|
||||
default_models = await AIDefaultModel.all().values()
|
||||
plugins = await Plugin.all().values()
|
||||
|
||||
share_links = cls._serialize_datetime_fields(
|
||||
shares, ["created_at", "expires_at"]
|
||||
)
|
||||
ai_providers = cls._serialize_datetime_fields(
|
||||
providers, ["created_at", "updated_at"]
|
||||
)
|
||||
ai_models = cls._serialize_datetime_fields(
|
||||
models, ["created_at", "updated_at"]
|
||||
)
|
||||
ai_default_models = cls._serialize_datetime_fields(
|
||||
default_models, ["created_at", "updated_at"]
|
||||
)
|
||||
plugin_items = cls._serialize_datetime_fields(
|
||||
plugins, ["created_at", "updated_at"]
|
||||
)
|
||||
|
||||
return BackupData(
|
||||
version=VERSION,
|
||||
storage_adapters=list(adapters),
|
||||
user_accounts=list(users),
|
||||
automation_tasks=list(tasks),
|
||||
share_links=share_links,
|
||||
configurations=list(configs),
|
||||
ai_providers=ai_providers,
|
||||
ai_models=ai_models,
|
||||
ai_default_models=ai_default_models,
|
||||
plugins=plugin_items,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def import_from_bytes(cls, filename: str, content: bytes) -> None:
|
||||
if not filename.endswith(".json"):
|
||||
raise HTTPException(status_code=400, detail="无效的文件类型, 请上传 .json 文件")
|
||||
try:
|
||||
raw_data = json.loads(content)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=400, detail="无法解析JSON文件")
|
||||
await cls.import_data(BackupData(**raw_data))
|
||||
|
||||
@classmethod
|
||||
async def import_data(cls, payload: BackupData) -> None:
|
||||
async with in_transaction() as conn:
|
||||
await ShareLink.all().using_db(conn).delete()
|
||||
await AutomationTask.all().using_db(conn).delete()
|
||||
await StorageAdapter.all().using_db(conn).delete()
|
||||
await UserAccount.all().using_db(conn).delete()
|
||||
await Configuration.all().using_db(conn).delete()
|
||||
await AIDefaultModel.all().using_db(conn).delete()
|
||||
await AIModel.all().using_db(conn).delete()
|
||||
await AIProvider.all().using_db(conn).delete()
|
||||
await Plugin.all().using_db(conn).delete()
|
||||
|
||||
if payload.configurations:
|
||||
await Configuration.bulk_create(
|
||||
[Configuration(**config) for config in payload.configurations],
|
||||
using_db=conn,
|
||||
)
|
||||
|
||||
if payload.user_accounts:
|
||||
await UserAccount.bulk_create(
|
||||
[UserAccount(**user) for user in payload.user_accounts],
|
||||
using_db=conn,
|
||||
)
|
||||
|
||||
if payload.storage_adapters:
|
||||
await StorageAdapter.bulk_create(
|
||||
[StorageAdapter(**adapter) for adapter in payload.storage_adapters],
|
||||
using_db=conn,
|
||||
)
|
||||
|
||||
if payload.automation_tasks:
|
||||
await AutomationTask.bulk_create(
|
||||
[AutomationTask(**task) for task in payload.automation_tasks],
|
||||
using_db=conn,
|
||||
)
|
||||
|
||||
if payload.share_links:
|
||||
await ShareLink.bulk_create(
|
||||
[
|
||||
ShareLink(**share)
|
||||
for share in cls._parse_datetime_fields(
|
||||
payload.share_links, ["created_at", "expires_at"]
|
||||
)
|
||||
],
|
||||
using_db=conn,
|
||||
)
|
||||
|
||||
if payload.ai_providers:
|
||||
await AIProvider.bulk_create(
|
||||
[
|
||||
AIProvider(**item)
|
||||
for item in cls._parse_datetime_fields(
|
||||
payload.ai_providers, ["created_at", "updated_at"]
|
||||
)
|
||||
],
|
||||
using_db=conn,
|
||||
)
|
||||
|
||||
if payload.ai_models:
|
||||
await AIModel.bulk_create(
|
||||
[
|
||||
AIModel(**item)
|
||||
for item in cls._parse_datetime_fields(
|
||||
payload.ai_models, ["created_at", "updated_at"]
|
||||
)
|
||||
],
|
||||
using_db=conn,
|
||||
)
|
||||
|
||||
if payload.ai_default_models:
|
||||
await AIDefaultModel.bulk_create(
|
||||
[
|
||||
AIDefaultModel(**item)
|
||||
for item in cls._parse_datetime_fields(
|
||||
payload.ai_default_models, ["created_at", "updated_at"]
|
||||
)
|
||||
],
|
||||
using_db=conn,
|
||||
)
|
||||
|
||||
if payload.plugins:
|
||||
await Plugin.bulk_create(
|
||||
[
|
||||
Plugin(**item)
|
||||
for item in cls._parse_datetime_fields(
|
||||
payload.plugins, ["created_at", "updated_at"]
|
||||
)
|
||||
],
|
||||
using_db=conn,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _serialize_datetime_fields(
|
||||
records: list[dict], fields: list[str]
|
||||
) -> list[dict]:
|
||||
serialized: list[dict] = []
|
||||
for record in records:
|
||||
item = dict(record)
|
||||
for field in fields:
|
||||
value = item.get(field)
|
||||
if isinstance(value, datetime):
|
||||
item[field] = value.isoformat()
|
||||
serialized.append(item)
|
||||
return serialized
|
||||
|
||||
@staticmethod
|
||||
def _parse_datetime_fields(
|
||||
records: list[dict], fields: list[str]
|
||||
) -> list[dict]:
|
||||
parsed: list[dict] = []
|
||||
for record in records:
|
||||
item = dict(record)
|
||||
for field in fields:
|
||||
value = item.get(field)
|
||||
if isinstance(value, str):
|
||||
item[field] = BackupService._from_iso(value)
|
||||
parsed.append(item)
|
||||
return parsed
|
||||
|
||||
@staticmethod
|
||||
def _from_iso(value: str) -> datetime | None:
|
||||
if not value:
|
||||
return None
|
||||
normalized = value.replace("Z", "+00:00")
|
||||
try:
|
||||
return datetime.fromisoformat(normalized)
|
||||
except ValueError as exc: # noqa: BLE001
|
||||
raise HTTPException(status_code=400, detail="无效的日期格式") from exc
|
||||
16
domain/backup/types.py
Normal file
16
domain/backup/types.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class BackupData(BaseModel):
|
||||
version: str | None = None
|
||||
storage_adapters: list[dict[str, Any]] = Field(default_factory=list)
|
||||
user_accounts: list[dict[str, Any]] = Field(default_factory=list)
|
||||
automation_tasks: list[dict[str, Any]] = Field(default_factory=list)
|
||||
share_links: list[dict[str, Any]] = Field(default_factory=list)
|
||||
configurations: list[dict[str, Any]] = Field(default_factory=list)
|
||||
ai_providers: list[dict[str, Any]] = Field(default_factory=list)
|
||||
ai_models: list[dict[str, Any]] = Field(default_factory=list)
|
||||
ai_default_models: list[dict[str, Any]] = Field(default_factory=list)
|
||||
plugins: list[dict[str, Any]] = Field(default_factory=list)
|
||||
59
domain/config/api.py
Normal file
59
domain/config/api.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Form, Request
|
||||
|
||||
from api.response import success
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.auth.service import get_current_active_user
|
||||
from domain.auth.types import User
|
||||
from domain.config.service import ConfigService
|
||||
from domain.config.types import ConfigItem
|
||||
|
||||
router = APIRouter(prefix="/api/config", tags=["config"])
|
||||
|
||||
|
||||
@router.get("/")
|
||||
@audit(action=AuditAction.READ, description="获取配置")
|
||||
async def get_config(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
key: str,
|
||||
):
|
||||
value = await ConfigService.get(key)
|
||||
return success(ConfigItem(key=key, value=value).model_dump())
|
||||
|
||||
|
||||
@router.post("/")
|
||||
@audit(action=AuditAction.UPDATE, description="设置配置", body_fields=["key", "value"])
|
||||
async def set_config(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
key: str = Form(...),
|
||||
value: str = Form(""),
|
||||
):
|
||||
await ConfigService.set(key, value)
|
||||
return success(ConfigItem(key=key, value=value).model_dump())
|
||||
|
||||
|
||||
@router.get("/all")
|
||||
@audit(action=AuditAction.READ, description="获取全部配置")
|
||||
async def get_all_config(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
configs = await ConfigService.get_all()
|
||||
return success(configs)
|
||||
|
||||
|
||||
@router.get("/status")
|
||||
@audit(action=AuditAction.READ, description="获取系统状态")
|
||||
async def get_system_status(request: Request):
|
||||
status_data = await ConfigService.get_system_status()
|
||||
return success(status_data.model_dump())
|
||||
|
||||
|
||||
@router.get("/latest-version")
|
||||
@audit(action=AuditAction.READ, description="获取最新版本")
|
||||
async def get_latest_version(request: Request):
|
||||
info = await ConfigService.get_latest_version()
|
||||
return success(info.model_dump())
|
||||
111
domain/config/service.py
Normal file
111
domain/config/service.py
Normal file
@@ -0,0 +1,111 @@
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import httpx
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from domain.config.types import LatestVersionInfo, SystemStatus
|
||||
from models.database import Configuration, UserAccount
|
||||
|
||||
load_dotenv(dotenv_path=".env")
|
||||
|
||||
VERSION = "v1.5.2"
|
||||
|
||||
|
||||
class ConfigService:
|
||||
_cache: Dict[str, Any] = {}
|
||||
_latest_version_cache: Dict[str, Any] = {"timestamp": 0.0, "data": None}
|
||||
|
||||
@classmethod
|
||||
async def get(cls, key: str, default: Optional[Any] = None) -> Any:
|
||||
if key in cls._cache:
|
||||
return cls._cache[key]
|
||||
try:
|
||||
config = await Configuration.get_or_none(key=key)
|
||||
if config:
|
||||
cls._cache[key] = config.value
|
||||
return config.value
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
env_value = os.getenv(key)
|
||||
if env_value is not None:
|
||||
cls._cache[key] = env_value
|
||||
return env_value
|
||||
return default
|
||||
|
||||
@classmethod
|
||||
async def get_secret_key(cls, key: str, default: Optional[Any] = None) -> bytes:
|
||||
value = await cls.get(key, default)
|
||||
if isinstance(value, bytes):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
return value.encode("utf-8")
|
||||
if value is None:
|
||||
raise ValueError(f"Secret key '{key}' not found in config or environment.")
|
||||
return str(value).encode("utf-8")
|
||||
|
||||
@classmethod
|
||||
async def set(cls, key: str, value: Any):
|
||||
obj, _ = await Configuration.get_or_create(key=key, defaults={"value": value})
|
||||
obj.value = value
|
||||
await obj.save()
|
||||
cls._cache[key] = value
|
||||
|
||||
@classmethod
|
||||
async def get_all(cls) -> Dict[str, Any]:
|
||||
try:
|
||||
configs = await Configuration.all()
|
||||
result = {}
|
||||
for config in configs:
|
||||
result[config.key] = config.value
|
||||
cls._cache[config.key] = config.value
|
||||
return result
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def clear_cache(cls):
|
||||
cls._cache.clear()
|
||||
|
||||
@classmethod
|
||||
async def get_system_status(cls) -> SystemStatus:
|
||||
logo = await cls.get("APP_LOGO", "/logo.svg")
|
||||
favicon = await cls.get("APP_FAVICON", logo)
|
||||
user_count = await UserAccount.all().count()
|
||||
return SystemStatus(
|
||||
version=VERSION,
|
||||
title=await cls.get("APP_NAME", "Foxel"),
|
||||
logo=logo,
|
||||
favicon=favicon,
|
||||
is_initialized=user_count > 0,
|
||||
app_domain=await cls.get("APP_DOMAIN"),
|
||||
file_domain=await cls.get("FILE_DOMAIN"),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def get_latest_version(cls) -> LatestVersionInfo:
|
||||
current_time = time.time()
|
||||
cache = cls._latest_version_cache
|
||||
if current_time - cache["timestamp"] < 3600 and cache["data"]:
|
||||
return cache["data"]
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.get(
|
||||
"https://api.github.com/repos/DrizzleTime/Foxel/releases/latest",
|
||||
follow_redirects=True,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
version_info = LatestVersionInfo(
|
||||
latest_version=data.get("tag_name"),
|
||||
body=data.get("body"),
|
||||
)
|
||||
cache["timestamp"] = current_time
|
||||
cache["data"] = version_info
|
||||
return version_info
|
||||
except httpx.RequestError:
|
||||
if cache["data"]:
|
||||
return cache["data"]
|
||||
return LatestVersionInfo()
|
||||
23
domain/config/types.py
Normal file
23
domain/config/types.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ConfigItem(BaseModel):
|
||||
key: str
|
||||
value: Optional[Any] = None
|
||||
|
||||
|
||||
class SystemStatus(BaseModel):
|
||||
version: str
|
||||
title: str
|
||||
logo: str
|
||||
favicon: str
|
||||
is_initialized: bool
|
||||
app_domain: Optional[str] = None
|
||||
file_domain: Optional[str] = None
|
||||
|
||||
|
||||
class LatestVersionInfo(BaseModel):
|
||||
latest_version: Optional[str] = None
|
||||
body: Optional[str] = None
|
||||
92
domain/email/api.py
Normal file
92
domain/email/api.py
Normal file
@@ -0,0 +1,92 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
|
||||
from api.response import success
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.auth.service import get_current_active_user
|
||||
from domain.auth.types import User
|
||||
from domain.email.service import EmailService, EmailTemplateRenderer
|
||||
from domain.email.types import (
|
||||
EmailTemplatePreviewPayload,
|
||||
EmailTemplateUpdate,
|
||||
EmailTestRequest,
|
||||
)
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/email", tags=["email"])
|
||||
|
||||
|
||||
@router.post("/test")
|
||||
@audit(action=AuditAction.CREATE, description="发送测试邮件")
|
||||
async def trigger_test_email(
|
||||
request: Request,
|
||||
payload: EmailTestRequest,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
):
|
||||
try:
|
||||
task = await EmailService.enqueue_email(
|
||||
recipients=[str(payload.to)],
|
||||
subject=payload.subject,
|
||||
template=payload.template,
|
||||
context=payload.context,
|
||||
)
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc))
|
||||
return success({"task_id": task.id})
|
||||
|
||||
|
||||
@router.get("/templates")
|
||||
@audit(action=AuditAction.READ, description="获取邮件模板列表")
|
||||
async def list_email_templates(
|
||||
request: Request,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
):
|
||||
templates = await EmailTemplateRenderer.list_templates()
|
||||
return success({"templates": templates})
|
||||
|
||||
|
||||
@router.get("/templates/{name}")
|
||||
@audit(action=AuditAction.READ, description="查看邮件模板")
|
||||
async def get_email_template(
|
||||
request: Request,
|
||||
name: str,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
):
|
||||
try:
|
||||
content = await EmailTemplateRenderer.load(name)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc))
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="模板不存在")
|
||||
return success({"name": name, "content": content})
|
||||
|
||||
|
||||
@router.post("/templates/{name}")
|
||||
@audit(action=AuditAction.UPDATE, description="更新邮件模板")
|
||||
async def update_email_template(
|
||||
request: Request,
|
||||
name: str,
|
||||
payload: EmailTemplateUpdate,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
):
|
||||
try:
|
||||
await EmailTemplateRenderer.save(name, payload.content)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc))
|
||||
return success({"name": name})
|
||||
|
||||
|
||||
@router.post("/templates/{name}/preview")
|
||||
@audit(action=AuditAction.READ, description="预览邮件模板")
|
||||
async def preview_email_template(
|
||||
request: Request,
|
||||
name: str,
|
||||
payload: EmailTemplatePreviewPayload,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
):
|
||||
try:
|
||||
html = await EmailTemplateRenderer.render(name, payload.context)
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="模板不存在")
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc))
|
||||
return success({"html": html})
|
||||
151
domain/email/service.py
Normal file
151
domain/email/service.py
Normal file
@@ -0,0 +1,151 @@
|
||||
import asyncio
|
||||
import re
|
||||
import smtplib
|
||||
from email.message import EmailMessage
|
||||
from email.utils import formataddr
|
||||
from pathlib import Path
|
||||
from string import Template
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from domain.config.service import ConfigService
|
||||
from domain.email.types import EmailConfig, EmailSecurity, EmailSendPayload
|
||||
|
||||
|
||||
class EmailTemplateRenderer:
|
||||
ROOT = Path("templates/email")
|
||||
|
||||
@classmethod
|
||||
def _resolve_path(cls, template_name: str) -> Path:
|
||||
if not re.fullmatch(r"[A-Za-z0-9_\-]+", template_name):
|
||||
raise ValueError("Invalid template name")
|
||||
return cls.ROOT / f"{template_name}.html"
|
||||
|
||||
@classmethod
|
||||
async def list_templates(cls) -> list[str]:
|
||||
cls.ROOT.mkdir(parents=True, exist_ok=True)
|
||||
return sorted(
|
||||
path.stem for path in cls.ROOT.glob("*.html") if path.is_file()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def load(cls, template_name: str) -> str:
|
||||
path = cls._resolve_path(template_name)
|
||||
if not path.is_file():
|
||||
raise FileNotFoundError(f"Email template '{template_name}' not found")
|
||||
return await asyncio.to_thread(path.read_text, encoding="utf-8")
|
||||
|
||||
@classmethod
|
||||
async def save(cls, template_name: str, content: str) -> None:
|
||||
path = cls._resolve_path(template_name)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
await asyncio.to_thread(path.write_text, content, encoding="utf-8")
|
||||
|
||||
@classmethod
|
||||
async def render(cls, template_name: str, context: Dict[str, Any]) -> str:
|
||||
raw = await cls.load(template_name)
|
||||
context = {k: str(v) for k, v in (context or {}).items()}
|
||||
return Template(raw).safe_substitute(context)
|
||||
|
||||
|
||||
class EmailService:
|
||||
CONFIG_KEY = "EMAIL_CONFIG"
|
||||
|
||||
@classmethod
|
||||
async def _load_config(cls) -> EmailConfig:
|
||||
raw_config = await ConfigService.get(cls.CONFIG_KEY)
|
||||
return EmailConfig.parse_config(raw_config)
|
||||
|
||||
@staticmethod
|
||||
def _html_to_text(html: str) -> str:
|
||||
stripped = re.sub(r"<[^>]+>", " ", html)
|
||||
return " ".join(stripped.split())
|
||||
|
||||
@classmethod
|
||||
async def _deliver(cls, config: EmailConfig, payload: EmailSendPayload, html_body: str):
|
||||
message = EmailMessage()
|
||||
message["Subject"] = payload.subject
|
||||
message["From"] = formataddr(
|
||||
(config.sender_name or str(config.sender_email), str(config.sender_email))
|
||||
)
|
||||
message["To"] = ", ".join([str(addr) for addr in payload.recipients])
|
||||
|
||||
plain_body = cls._html_to_text(html_body)
|
||||
message.set_content(plain_body or html_body)
|
||||
message.add_alternative(html_body, subtype="html")
|
||||
|
||||
await asyncio.to_thread(cls._deliver_sync, config, message)
|
||||
|
||||
@staticmethod
|
||||
def _deliver_sync(config: EmailConfig, message: EmailMessage):
|
||||
if config.security == EmailSecurity.SSL:
|
||||
smtp: smtplib.SMTP = smtplib.SMTP_SSL(
|
||||
config.host, config.port, timeout=config.timeout
|
||||
)
|
||||
else:
|
||||
smtp = smtplib.SMTP(config.host, config.port, timeout=config.timeout)
|
||||
|
||||
try:
|
||||
if config.security == EmailSecurity.STARTTLS:
|
||||
smtp.starttls()
|
||||
if config.username and config.password:
|
||||
smtp.login(config.username, config.password)
|
||||
smtp.send_message(message)
|
||||
finally:
|
||||
try:
|
||||
smtp.quit()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
async def enqueue_email(
|
||||
cls,
|
||||
recipients: List[str],
|
||||
subject: str,
|
||||
template: str,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
from domain.tasks.task_queue import TaskProgress, task_queue_service
|
||||
|
||||
payload = EmailSendPayload(
|
||||
recipients=recipients,
|
||||
subject=subject,
|
||||
template=template,
|
||||
context=context or {},
|
||||
)
|
||||
|
||||
task = await task_queue_service.add_task(
|
||||
"send_email",
|
||||
payload.model_dump(mode="json"),
|
||||
)
|
||||
|
||||
await task_queue_service.update_progress(
|
||||
task.id,
|
||||
TaskProgress(stage="queued", percent=0.0, detail="Waiting to send"),
|
||||
)
|
||||
return task
|
||||
|
||||
@classmethod
|
||||
async def send_from_task(cls, task_id: str, data: Dict[str, Any]):
|
||||
from domain.tasks.task_queue import TaskProgress, task_queue_service
|
||||
|
||||
payload = EmailSendPayload(**data)
|
||||
|
||||
await task_queue_service.update_progress(
|
||||
task_id,
|
||||
TaskProgress(stage="preparing", percent=10.0, detail="Rendering template"),
|
||||
)
|
||||
|
||||
config = await cls._load_config()
|
||||
html_body = await EmailTemplateRenderer.render(payload.template, payload.context)
|
||||
|
||||
await task_queue_service.update_progress(
|
||||
task_id,
|
||||
TaskProgress(stage="sending", percent=60.0, detail="Sending message"),
|
||||
)
|
||||
|
||||
await cls._deliver(config, payload, html_body)
|
||||
|
||||
await task_queue_service.update_progress(
|
||||
task_id,
|
||||
TaskProgress(stage="completed", percent=100.0, detail="Email sent"),
|
||||
)
|
||||
63
domain/email/types.py
Normal file
63
domain/email/types.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import json
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, EmailStr, Field, ValidationError
|
||||
|
||||
|
||||
class EmailSecurity(str, Enum):
|
||||
NONE = "none"
|
||||
SSL = "ssl"
|
||||
STARTTLS = "starttls"
|
||||
|
||||
|
||||
class EmailConfig(BaseModel):
|
||||
host: str
|
||||
port: int = Field(..., gt=0)
|
||||
username: Optional[str] = None
|
||||
password: Optional[str] = None
|
||||
sender_email: EmailStr
|
||||
sender_name: Optional[str] = None
|
||||
security: EmailSecurity = EmailSecurity.NONE
|
||||
timeout: float = Field(default=30.0, gt=0.0)
|
||||
|
||||
@classmethod
|
||||
def parse_config(cls, raw_config: Any) -> "EmailConfig":
|
||||
"""接受字符串或 dict 配置并解析为 EmailConfig。"""
|
||||
if raw_config is None:
|
||||
raise ValueError("Email configuration not found")
|
||||
|
||||
if isinstance(raw_config, str):
|
||||
raw_config = raw_config.strip()
|
||||
data: Any = json.loads(raw_config) if raw_config else {}
|
||||
elif isinstance(raw_config, dict):
|
||||
data = raw_config
|
||||
else:
|
||||
raise ValueError("Invalid email configuration format")
|
||||
|
||||
try:
|
||||
return cls(**data)
|
||||
except ValidationError as exc:
|
||||
raise ValueError(f"Invalid email configuration: {exc}") from exc
|
||||
|
||||
|
||||
class EmailSendPayload(BaseModel):
|
||||
recipients: List[EmailStr] = Field(..., min_length=1)
|
||||
subject: str = Field(..., min_length=1)
|
||||
template: str = Field(..., min_length=1)
|
||||
context: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class EmailTestRequest(BaseModel):
|
||||
to: EmailStr
|
||||
subject: str = Field(..., min_length=1)
|
||||
template: str = Field(default="test", min_length=1)
|
||||
context: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class EmailTemplateUpdate(BaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
class EmailTemplatePreviewPayload(BaseModel):
|
||||
context: Dict[str, Any] = Field(default_factory=dict)
|
||||
42
domain/offline_downloads/api.py
Normal file
42
domain/offline_downloads/api.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
|
||||
from api.response import success
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.auth.service import get_current_active_user
|
||||
from domain.auth.types import User
|
||||
from domain.offline_downloads.service import OfflineDownloadService
|
||||
from domain.offline_downloads.types import OfflineDownloadCreate
|
||||
|
||||
CurrentUser = Annotated[User, Depends(get_current_active_user)]
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api/offline-downloads",
|
||||
tags=["OfflineDownloads"],
|
||||
)
|
||||
|
||||
|
||||
@router.post("/")
|
||||
@audit(
|
||||
action=AuditAction.CREATE,
|
||||
description="创建离线下载任务",
|
||||
body_fields=["url", "dest_dir", "filename"],
|
||||
)
|
||||
async def create_offline_download(request: Request, payload: OfflineDownloadCreate, current_user: CurrentUser):
|
||||
data = await OfflineDownloadService.create_download(payload, current_user)
|
||||
return success(data)
|
||||
|
||||
|
||||
@router.get("/")
|
||||
@audit(action=AuditAction.READ, description="获取离线下载列表")
|
||||
async def list_offline_downloads(request: Request, current_user: CurrentUser):
|
||||
data = OfflineDownloadService.list_downloads()
|
||||
return success(data)
|
||||
|
||||
|
||||
@router.get("/{task_id}")
|
||||
@audit(action=AuditAction.READ, description="获取离线下载详情")
|
||||
async def get_offline_download(task_id: str, request: Request, current_user: CurrentUser):
|
||||
data = OfflineDownloadService.get_download(task_id)
|
||||
return success(data)
|
||||
252
domain/offline_downloads/service.py
Normal file
252
domain/offline_downloads/service.py
Normal file
@@ -0,0 +1,252 @@
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Annotated, AsyncIterator
|
||||
|
||||
import aiofiles
|
||||
import aiohttp
|
||||
from fastapi import Depends, HTTPException
|
||||
|
||||
from domain.auth.service import get_current_active_user
|
||||
from domain.auth.types import User
|
||||
from domain.offline_downloads.types import OfflineDownloadCreate
|
||||
from domain.virtual_fs.service import VirtualFSService
|
||||
from domain.tasks.task_queue import Task, TaskProgress, task_queue_service
|
||||
|
||||
|
||||
class OfflineDownloadService:
|
||||
current_user_dep = Annotated[User, Depends(get_current_active_user)]
|
||||
temp_root = Path("data/tmp/offline_downloads")
|
||||
|
||||
@classmethod
|
||||
async def create_download(cls, payload: OfflineDownloadCreate, current_user: User) -> dict:
|
||||
await cls._ensure_destination(payload.dest_dir)
|
||||
task = await task_queue_service.add_task(
|
||||
"offline_http_download",
|
||||
{
|
||||
"url": str(payload.url),
|
||||
"dest_dir": payload.dest_dir,
|
||||
"filename": payload.filename,
|
||||
},
|
||||
)
|
||||
|
||||
await task_queue_service.update_progress(
|
||||
task.id,
|
||||
TaskProgress(
|
||||
stage="queued",
|
||||
percent=0.0,
|
||||
bytes_total=None,
|
||||
bytes_done=0,
|
||||
detail="Waiting to start",
|
||||
),
|
||||
)
|
||||
|
||||
return {"task_id": task.id}
|
||||
|
||||
@classmethod
|
||||
def list_downloads(cls) -> list[dict]:
|
||||
tasks = [t for t in task_queue_service.get_all_tasks() if t.name == "offline_http_download"]
|
||||
return [t.dict() for t in tasks]
|
||||
|
||||
@classmethod
|
||||
def get_download(cls, task_id: str) -> dict:
|
||||
task = task_queue_service.get_task(task_id)
|
||||
if not task or task.name != "offline_http_download":
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
return task.dict()
|
||||
|
||||
@classmethod
|
||||
async def run_http_download(cls, task: Task):
|
||||
params = task.task_info
|
||||
url = params.get("url")
|
||||
dest_dir = params.get("dest_dir")
|
||||
filename = params.get("filename")
|
||||
|
||||
if not url or not dest_dir or not filename:
|
||||
raise ValueError("Missing required parameters for offline download")
|
||||
|
||||
cls.temp_root.mkdir(parents=True, exist_ok=True)
|
||||
temp_dir = cls.temp_root / task.id
|
||||
temp_dir.mkdir(parents=True, exist_ok=True)
|
||||
temp_file = temp_dir / "payload"
|
||||
|
||||
bytes_total: int | None = None
|
||||
bytes_done = 0
|
||||
last_update = time.monotonic()
|
||||
|
||||
await task_queue_service.update_progress(
|
||||
task.id,
|
||||
TaskProgress(
|
||||
stage="downloading",
|
||||
percent=0.0,
|
||||
bytes_total=None,
|
||||
bytes_done=0,
|
||||
detail="HTTP downloading",
|
||||
),
|
||||
)
|
||||
|
||||
async def report_download(delta: int, total: int | None):
|
||||
nonlocal bytes_done, bytes_total, last_update
|
||||
if total is not None:
|
||||
bytes_total = total
|
||||
bytes_done += delta
|
||||
now = time.monotonic()
|
||||
if delta and now - last_update < 0.5:
|
||||
return
|
||||
last_update = now
|
||||
percent = None
|
||||
total_for_display = bytes_total if bytes_total is not None else None
|
||||
if bytes_total:
|
||||
percent = min(100.0, round(bytes_done / bytes_total * 100, 2))
|
||||
await task_queue_service.update_progress(
|
||||
task.id,
|
||||
TaskProgress(
|
||||
stage="downloading",
|
||||
percent=percent,
|
||||
bytes_total=total_for_display,
|
||||
bytes_done=bytes_done,
|
||||
detail="HTTP downloading",
|
||||
),
|
||||
)
|
||||
|
||||
timeout = aiohttp.ClientTimeout(total=None, connect=30)
|
||||
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
async with session.get(url) as resp:
|
||||
if resp.status != 200:
|
||||
raise ValueError(f"HTTP {resp.status} for {url}")
|
||||
content_length = resp.headers.get("Content-Length")
|
||||
total_size = int(content_length) if content_length else None
|
||||
bytes_done = 0
|
||||
async with aiofiles.open(temp_file, "wb") as f:
|
||||
async for chunk in resp.content.iter_chunked(512 * 1024):
|
||||
if not chunk:
|
||||
continue
|
||||
await f.write(chunk)
|
||||
await report_download(len(chunk), total_size)
|
||||
await report_download(0, total_size)
|
||||
|
||||
file_size = os.path.getsize(temp_file)
|
||||
bytes_done_transfer = 0
|
||||
|
||||
async def report_transfer(delta: int):
|
||||
nonlocal bytes_done_transfer
|
||||
bytes_done_transfer += delta
|
||||
percent = min(100.0, round(bytes_done_transfer / file_size * 100, 2)) if file_size else None
|
||||
await task_queue_service.update_progress(
|
||||
task.id,
|
||||
TaskProgress(
|
||||
stage="transferring",
|
||||
percent=percent,
|
||||
bytes_total=file_size or None,
|
||||
bytes_done=bytes_done_transfer,
|
||||
detail="Saving to storage",
|
||||
),
|
||||
)
|
||||
|
||||
async def chunk_iter() -> AsyncIterator[bytes]:
|
||||
async for chunk in cls._iter_file(temp_file, 512 * 1024, report_transfer):
|
||||
yield chunk
|
||||
|
||||
final_path, resolved_name = await cls._allocate_destination(dest_dir, filename)
|
||||
|
||||
await task_queue_service.update_progress(
|
||||
task.id,
|
||||
TaskProgress(
|
||||
stage="transferring",
|
||||
percent=0.0,
|
||||
bytes_total=file_size or None,
|
||||
bytes_done=0,
|
||||
detail="Saving to storage",
|
||||
),
|
||||
)
|
||||
|
||||
await VirtualFSService.write_file_stream(final_path, chunk_iter())
|
||||
|
||||
await task_queue_service.update_progress(
|
||||
task.id,
|
||||
TaskProgress(
|
||||
stage="completed",
|
||||
percent=100.0,
|
||||
bytes_total=file_size or None,
|
||||
bytes_done=file_size,
|
||||
detail="Completed",
|
||||
),
|
||||
)
|
||||
await task_queue_service.update_meta(task.id, {"final_path": final_path, "filename": resolved_name})
|
||||
|
||||
try:
|
||||
os.remove(temp_file)
|
||||
temp_dir.rmdir()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return final_path
|
||||
|
||||
@classmethod
|
||||
async def _ensure_destination(cls, dest_dir: str) -> None:
|
||||
try:
|
||||
is_dir = await VirtualFSService.path_is_directory(dest_dir)
|
||||
except HTTPException:
|
||||
is_dir = False
|
||||
if not is_dir:
|
||||
raise HTTPException(400, detail="Destination directory not found")
|
||||
|
||||
@staticmethod
|
||||
def _normalize_path(path: str) -> str:
|
||||
if not path:
|
||||
return "/"
|
||||
if not path.startswith("/"):
|
||||
path = "/" + path
|
||||
if len(path) > 1 and path.endswith("/"):
|
||||
path = path.rstrip("/")
|
||||
return path or "/"
|
||||
|
||||
@staticmethod
|
||||
async def _path_exists(full_path: str) -> bool:
|
||||
try:
|
||||
await VirtualFSService.stat_file(full_path)
|
||||
return True
|
||||
except FileNotFoundError:
|
||||
return False
|
||||
except HTTPException as exc: # noqa: PERF203
|
||||
if exc.status_code == 404:
|
||||
return False
|
||||
raise
|
||||
|
||||
@classmethod
|
||||
async def _allocate_destination(cls, dest_dir: str, filename: str) -> tuple[str, str]:
|
||||
dest_dir = cls._normalize_path(dest_dir)
|
||||
stem, suffix = cls._split_filename(filename)
|
||||
candidate = filename
|
||||
base = "" if dest_dir == "/" else dest_dir
|
||||
attempt = 0
|
||||
while await cls._path_exists(f"{base}/{candidate}" if base else f"/{candidate}"):
|
||||
attempt += 1
|
||||
if stem:
|
||||
candidate = f"{stem} ({attempt}){suffix}"
|
||||
else:
|
||||
candidate = f"file ({attempt}){suffix}" if suffix else f"file ({attempt})"
|
||||
full_path = f"{base}/{candidate}" if base else f"/{candidate}"
|
||||
return full_path, candidate
|
||||
|
||||
@staticmethod
|
||||
def _split_filename(filename: str) -> tuple[str, str]:
|
||||
if not filename:
|
||||
return "", ""
|
||||
if filename.startswith(".") and filename.count(".") == 1:
|
||||
return filename, ""
|
||||
if "." not in filename:
|
||||
return filename, ""
|
||||
stem, ext = filename.rsplit(".", 1)
|
||||
return stem, f".{ext}"
|
||||
|
||||
@staticmethod
|
||||
async def _iter_file(path: Path, chunk_size: int, report_cb) -> AsyncIterator[bytes]:
|
||||
async with aiofiles.open(path, "rb") as f:
|
||||
while True:
|
||||
chunk = await f.read(chunk_size)
|
||||
if not chunk:
|
||||
break
|
||||
await report_cb(len(chunk))
|
||||
yield chunk
|
||||
7
domain/offline_downloads/types.py
Normal file
7
domain/offline_downloads/types.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from pydantic import BaseModel, HttpUrl, Field
|
||||
|
||||
|
||||
class OfflineDownloadCreate(BaseModel):
|
||||
url: HttpUrl
|
||||
dest_dir: str = Field(..., min_length=1)
|
||||
filename: str = Field(..., min_length=1)
|
||||
1
domain/plugins/__init__.py
Normal file
1
domain/plugins/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
76
domain/plugins/api.py
Normal file
76
domain/plugins/api.py
Normal file
@@ -0,0 +1,76 @@
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Body, Request
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.plugins.service import PluginService
|
||||
from domain.plugins.routes import video_player as video_player_routes
|
||||
from domain.plugins.types import PluginCreate, PluginManifestUpdate, PluginOut
|
||||
|
||||
router = APIRouter(prefix="/api/plugins", tags=["plugins"])
|
||||
router.include_router(video_player_routes.router)
|
||||
|
||||
|
||||
@router.post("", response_model=PluginOut)
|
||||
@audit(
|
||||
action=AuditAction.CREATE,
|
||||
description="创建插件",
|
||||
body_fields=["url", "enabled"],
|
||||
)
|
||||
async def create_plugin(request: Request, payload: PluginCreate):
|
||||
return await PluginService.create(payload)
|
||||
|
||||
|
||||
@router.get("", response_model=List[PluginOut])
|
||||
@audit(action=AuditAction.READ, description="获取插件列表")
|
||||
async def list_plugins(request: Request):
|
||||
return await PluginService.list_plugins()
|
||||
|
||||
|
||||
@router.delete("/{plugin_id}")
|
||||
@audit(action=AuditAction.DELETE, description="删除插件")
|
||||
async def delete_plugin(request: Request, plugin_id: int):
|
||||
await PluginService.delete(plugin_id)
|
||||
return {"code": 0, "msg": "ok"}
|
||||
|
||||
|
||||
@router.put("/{plugin_id}", response_model=PluginOut)
|
||||
@audit(
|
||||
action=AuditAction.UPDATE,
|
||||
description="更新插件",
|
||||
body_fields=["url", "enabled"],
|
||||
)
|
||||
async def update_plugin(request: Request, plugin_id: int, payload: PluginCreate):
|
||||
return await PluginService.update(plugin_id, payload)
|
||||
|
||||
|
||||
@router.post("/{plugin_id}/metadata", response_model=PluginOut)
|
||||
@audit(
|
||||
action=AuditAction.UPDATE,
|
||||
description="更新插件 manifest",
|
||||
body_fields=[
|
||||
"key",
|
||||
"name",
|
||||
"version",
|
||||
"open_app",
|
||||
"supported_exts",
|
||||
"default_bounds",
|
||||
"default_maximized",
|
||||
"icon",
|
||||
"description",
|
||||
"author",
|
||||
"website",
|
||||
"github",
|
||||
],
|
||||
)
|
||||
async def update_manifest(
|
||||
request: Request, plugin_id: int, manifest: PluginManifestUpdate = Body(...)
|
||||
):
|
||||
return await PluginService.update_manifest(plugin_id, manifest)
|
||||
|
||||
|
||||
@router.get("/{plugin_id}/bundle.js")
|
||||
async def get_bundle(request: Request, plugin_id: int):
|
||||
path = await PluginService.get_bundle_path(plugin_id)
|
||||
return FileResponse(path, media_type="application/javascript", headers={"Cache-Control": "no-store"})
|
||||
2
domain/plugins/routes/__init__.py
Normal file
2
domain/plugins/routes/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
"""插件专属服务端路由集合。"""
|
||||
|
||||
142
domain/plugins/routes/video_player.py
Normal file
142
domain/plugins/routes/video_player.py
Normal file
@@ -0,0 +1,142 @@
|
||||
import json
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
|
||||
from api.response import success
|
||||
from domain.auth.service import get_current_active_user
|
||||
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/video-player",
|
||||
tags=["plugins"],
|
||||
dependencies=[Depends(get_current_active_user)],
|
||||
)
|
||||
|
||||
DATA_ROOT = Path("data/.video")
|
||||
|
||||
|
||||
def _read_json(path: Path) -> Dict[str, Any]:
|
||||
return json.loads(path.read_text(encoding="utf-8"))
|
||||
|
||||
|
||||
def _file_mtime_iso(path: Path) -> str:
|
||||
try:
|
||||
ts = path.stat().st_mtime
|
||||
except FileNotFoundError:
|
||||
return ""
|
||||
return datetime.fromtimestamp(ts, tz=UTC).isoformat()
|
||||
|
||||
|
||||
def _extract_title(payload: Dict[str, Any]) -> str:
|
||||
detail = (payload.get("tmdb") or {}).get("detail") or {}
|
||||
if payload.get("type") == "tv":
|
||||
return str(detail.get("name") or detail.get("original_name") or "")
|
||||
return str(detail.get("title") or detail.get("original_title") or "")
|
||||
|
||||
|
||||
def _extract_year(payload: Dict[str, Any]) -> Optional[str]:
|
||||
detail = (payload.get("tmdb") or {}).get("detail") or {}
|
||||
value = detail.get("first_air_date") if payload.get("type") == "tv" else detail.get("release_date")
|
||||
if not value or not isinstance(value, str):
|
||||
return None
|
||||
return value[:4] if len(value) >= 4 else value
|
||||
|
||||
|
||||
def _extract_genres(payload: Dict[str, Any]) -> List[str]:
|
||||
detail = (payload.get("tmdb") or {}).get("detail") or {}
|
||||
genres = detail.get("genres") or []
|
||||
out: List[str] = []
|
||||
if isinstance(genres, list):
|
||||
for g in genres:
|
||||
if isinstance(g, dict) and g.get("name"):
|
||||
out.append(str(g["name"]))
|
||||
return out
|
||||
|
||||
|
||||
def _summarize(item_id: str, payload: Dict[str, Any], mtime_iso: str) -> Dict[str, Any]:
|
||||
detail = (payload.get("tmdb") or {}).get("detail") or {}
|
||||
media_type = payload.get("type") or "unknown"
|
||||
episodes = payload.get("episodes") or []
|
||||
seasons = {e.get("season") for e in episodes if isinstance(e, dict) and e.get("season") is not None}
|
||||
|
||||
return {
|
||||
"id": item_id,
|
||||
"type": media_type,
|
||||
"title": _extract_title(payload),
|
||||
"year": _extract_year(payload),
|
||||
"overview": detail.get("overview"),
|
||||
"poster_path": detail.get("poster_path"),
|
||||
"backdrop_path": detail.get("backdrop_path"),
|
||||
"genres": _extract_genres(payload),
|
||||
"tmdb_id": (payload.get("tmdb") or {}).get("id"),
|
||||
"source_path": payload.get("source_path"),
|
||||
"scraped_at": payload.get("scraped_at"),
|
||||
"updated_at": mtime_iso,
|
||||
"episodes_count": len(episodes) if isinstance(episodes, list) else 0,
|
||||
"seasons_count": len(seasons),
|
||||
"vote_average": detail.get("vote_average"),
|
||||
"vote_count": detail.get("vote_count"),
|
||||
}
|
||||
|
||||
|
||||
def _iter_library_files() -> List[tuple[str, Path]]:
|
||||
files: List[tuple[str, Path]] = []
|
||||
for sub in ("tv", "movie"):
|
||||
folder = DATA_ROOT / sub
|
||||
if not folder.exists():
|
||||
continue
|
||||
for p in folder.glob("*.json"):
|
||||
if not p.is_file():
|
||||
continue
|
||||
files.append((sub, p))
|
||||
return files
|
||||
|
||||
|
||||
@router.get("/library")
|
||||
async def list_library(
|
||||
q: str | None = Query(None, description="搜索关键字(标题/简介)"),
|
||||
media_type: str | None = Query(None, alias="type", description="tv 或 movie"),
|
||||
):
|
||||
items: List[Dict[str, Any]] = []
|
||||
keyword = (q or "").strip().lower()
|
||||
type_filter = (media_type or "").strip().lower()
|
||||
if type_filter and type_filter not in {"tv", "movie"}:
|
||||
raise HTTPException(status_code=400, detail="type must be tv or movie")
|
||||
|
||||
for _sub, path in _iter_library_files():
|
||||
item_id = path.stem
|
||||
try:
|
||||
payload = _read_json(path)
|
||||
except Exception:
|
||||
continue
|
||||
if type_filter and str(payload.get("type") or "").lower() != type_filter:
|
||||
continue
|
||||
summary = _summarize(item_id, payload, _file_mtime_iso(path))
|
||||
if keyword:
|
||||
haystack = f"{summary.get('title') or ''} {summary.get('overview') or ''}".lower()
|
||||
if keyword not in haystack:
|
||||
continue
|
||||
items.append(summary)
|
||||
|
||||
items.sort(key=lambda x: x.get("updated_at") or "", reverse=True)
|
||||
return success(items)
|
||||
|
||||
|
||||
@router.get("/library/{item_id}")
|
||||
async def get_library_item(item_id: str):
|
||||
candidates = [
|
||||
DATA_ROOT / "tv" / f"{item_id}.json",
|
||||
DATA_ROOT / "movie" / f"{item_id}.json",
|
||||
]
|
||||
path = next((p for p in candidates if p.exists()), None)
|
||||
if not path:
|
||||
raise HTTPException(status_code=404, detail="Item not found")
|
||||
|
||||
payload = _read_json(path)
|
||||
payload["id"] = item_id
|
||||
payload["updated_at"] = _file_mtime_iso(path)
|
||||
return success(payload)
|
||||
|
||||
138
domain/plugins/service.py
Normal file
138
domain/plugins/service.py
Normal file
@@ -0,0 +1,138 @@
|
||||
import contextlib
|
||||
import re
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import aiofiles
|
||||
import httpx
|
||||
from fastapi import HTTPException
|
||||
|
||||
from domain.plugins.types import PluginCreate, PluginManifestUpdate, PluginOut
|
||||
from models.database import Plugin
|
||||
|
||||
|
||||
class PluginService:
|
||||
_plugins_root = Path("data/plugins")
|
||||
|
||||
@classmethod
|
||||
def _folder_name(cls, rec: Plugin) -> str:
|
||||
if rec.key:
|
||||
safe = re.sub(r"[^A-Za-z0-9_.-]", "_", rec.key)
|
||||
return safe or str(rec.id)
|
||||
return str(rec.id)
|
||||
|
||||
@classmethod
|
||||
def _bundle_dir_from_rec(cls, rec: Plugin) -> Path:
|
||||
return cls._plugins_root / cls._folder_name(rec) / "current"
|
||||
|
||||
@classmethod
|
||||
def _bundle_path_from_rec(cls, rec: Plugin) -> Path:
|
||||
return cls._bundle_dir_from_rec(rec) / "index.js"
|
||||
|
||||
@classmethod
|
||||
async def _download_bundle(cls, rec: Plugin, url: str) -> None:
|
||||
dest_dir = cls._bundle_dir_from_rec(rec)
|
||||
dest_dir.mkdir(parents=True, exist_ok=True)
|
||||
dest_path = cls._bundle_path_from_rec(rec)
|
||||
tmp_path = dest_path.with_suffix(".tmp")
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
|
||||
async with client.stream("GET", url) as resp:
|
||||
resp.raise_for_status()
|
||||
async with aiofiles.open(tmp_path, "wb") as f:
|
||||
async for chunk in resp.aiter_bytes(chunk_size=65536):
|
||||
if not chunk:
|
||||
continue
|
||||
await f.write(chunk)
|
||||
tmp_path.replace(dest_path)
|
||||
except Exception:
|
||||
with contextlib.suppress(Exception):
|
||||
if tmp_path.exists():
|
||||
tmp_path.unlink()
|
||||
raise
|
||||
|
||||
@classmethod
|
||||
async def _ensure_bundle(cls, plugin_id: int) -> Path:
|
||||
rec = await cls._get_or_404(plugin_id)
|
||||
bundle_path = cls._bundle_path_from_rec(rec)
|
||||
if bundle_path.exists():
|
||||
return bundle_path
|
||||
|
||||
legacy = cls._plugins_root / str(rec.id) / "current" / "index.js"
|
||||
if legacy.exists():
|
||||
return legacy
|
||||
|
||||
raise HTTPException(status_code=404, detail="Plugin bundle not found")
|
||||
|
||||
@classmethod
|
||||
async def get_bundle_path(cls, plugin_id: int) -> Path:
|
||||
return await cls._ensure_bundle(plugin_id)
|
||||
|
||||
@classmethod
|
||||
async def create(cls, payload: PluginCreate) -> PluginOut:
|
||||
rec = await Plugin.create(**payload.model_dump())
|
||||
try:
|
||||
await cls._download_bundle(rec, rec.url)
|
||||
except Exception as exc:
|
||||
with contextlib.suppress(Exception):
|
||||
await rec.delete()
|
||||
raise HTTPException(status_code=400, detail=f"Failed to fetch plugin: {exc}")
|
||||
return PluginOut.model_validate(rec)
|
||||
|
||||
@classmethod
|
||||
async def list_plugins(cls) -> list[PluginOut]:
|
||||
rows = await Plugin.all().order_by("-id")
|
||||
return [PluginOut.model_validate(r) for r in rows]
|
||||
|
||||
@classmethod
|
||||
async def _get_or_404(cls, plugin_id: int) -> Plugin:
|
||||
rec = await Plugin.get_or_none(id=plugin_id)
|
||||
if not rec:
|
||||
raise HTTPException(status_code=404, detail="Plugin not found")
|
||||
return rec
|
||||
|
||||
@classmethod
|
||||
async def delete(cls, plugin_id: int) -> None:
|
||||
rec = await cls._get_or_404(plugin_id)
|
||||
await rec.delete()
|
||||
with contextlib.suppress(Exception):
|
||||
dirs = {cls._bundle_dir_from_rec(rec).parent, cls._plugins_root / str(rec.id)}
|
||||
for plugin_dir in dirs:
|
||||
if plugin_dir.exists():
|
||||
shutil.rmtree(plugin_dir)
|
||||
|
||||
@classmethod
|
||||
async def update(cls, plugin_id: int, payload: PluginCreate) -> PluginOut:
|
||||
rec = await cls._get_or_404(plugin_id)
|
||||
url_changed = rec.url != payload.url
|
||||
if url_changed:
|
||||
try:
|
||||
await cls._download_bundle(rec, payload.url)
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=400, detail=f"Failed to fetch plugin: {exc}")
|
||||
rec.url = payload.url
|
||||
rec.enabled = payload.enabled
|
||||
await rec.save()
|
||||
return PluginOut.model_validate(rec)
|
||||
|
||||
@classmethod
|
||||
async def update_manifest(
|
||||
cls, plugin_id: int, manifest: PluginManifestUpdate
|
||||
) -> PluginOut:
|
||||
rec = await cls._get_or_404(plugin_id)
|
||||
old_dir = cls._bundle_dir_from_rec(rec).parent
|
||||
updates = manifest.model_dump(exclude_none=True)
|
||||
if updates:
|
||||
for key, value in updates.items():
|
||||
setattr(rec, key, value)
|
||||
await rec.save()
|
||||
new_dir = cls._bundle_dir_from_rec(rec).parent
|
||||
if rec.key and new_dir != old_dir:
|
||||
candidate_dir = old_dir if old_dir.exists() else (cls._plugins_root / str(rec.id))
|
||||
if candidate_dir.exists():
|
||||
new_dir.parent.mkdir(parents=True, exist_ok=True)
|
||||
with contextlib.suppress(Exception):
|
||||
if new_dir.exists():
|
||||
shutil.rmtree(new_dir)
|
||||
shutil.move(str(candidate_dir), str(new_dir))
|
||||
return PluginOut.model_validate(rec)
|
||||
57
domain/plugins/types.py
Normal file
57
domain/plugins/types.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import AliasChoices, BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class PluginCreate(BaseModel):
|
||||
url: str = Field(min_length=1)
|
||||
enabled: bool = True
|
||||
|
||||
|
||||
class PluginManifestUpdate(BaseModel):
|
||||
model_config = ConfigDict(populate_by_name=True, extra="ignore")
|
||||
|
||||
key: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
version: Optional[str] = None
|
||||
open_app: Optional[bool] = Field(
|
||||
default=None,
|
||||
validation_alias=AliasChoices("open_app", "openApp"),
|
||||
)
|
||||
supported_exts: Optional[List[str]] = Field(
|
||||
default=None,
|
||||
validation_alias=AliasChoices("supported_exts", "supportedExts"),
|
||||
)
|
||||
default_bounds: Optional[Dict[str, Any]] = Field(
|
||||
default=None,
|
||||
validation_alias=AliasChoices("default_bounds", "defaultBounds"),
|
||||
)
|
||||
default_maximized: Optional[bool] = Field(
|
||||
default=None,
|
||||
validation_alias=AliasChoices("default_maximized", "defaultMaximized"),
|
||||
)
|
||||
icon: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
author: Optional[str] = None
|
||||
website: Optional[str] = None
|
||||
github: Optional[str] = None
|
||||
|
||||
|
||||
class PluginOut(BaseModel):
|
||||
id: int
|
||||
url: str
|
||||
enabled: bool
|
||||
open_app: bool = False
|
||||
key: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
version: Optional[str] = None
|
||||
supported_exts: Optional[List[str]] = None
|
||||
default_bounds: Optional[Dict[str, Any]] = None
|
||||
default_maximized: Optional[bool] = None
|
||||
icon: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
author: Optional[str] = None
|
||||
website: Optional[str] = None
|
||||
github: Optional[str] = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
89
domain/processors/api.py
Normal file
89
domain/processors/api.py
Normal file
@@ -0,0 +1,89 @@
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, Request
|
||||
|
||||
from api.response import success
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.auth.service import get_current_active_user
|
||||
from domain.auth.types import User
|
||||
from domain.processors.service import ProcessorService
|
||||
from domain.processors.types import (
|
||||
ProcessDirectoryRequest,
|
||||
ProcessRequest,
|
||||
UpdateSourceRequest,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/api/processors", tags=["processors"])
|
||||
|
||||
|
||||
@router.get("")
|
||||
@audit(action=AuditAction.READ, description="获取处理器列表")
|
||||
async def list_processors(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
data = ProcessorService.list_processors()
|
||||
return success(data)
|
||||
|
||||
|
||||
@router.post("/process")
|
||||
@audit(
|
||||
action=AuditAction.CREATE,
|
||||
description="处理单个文件",
|
||||
body_fields=["path", "processor_type", "save_to", "overwrite"],
|
||||
)
|
||||
async def process_file_with_processor(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
req: ProcessRequest = Body(...),
|
||||
):
|
||||
data = await ProcessorService.process_file(req)
|
||||
return success(data)
|
||||
|
||||
|
||||
@router.post("/process-directory")
|
||||
@audit(
|
||||
action=AuditAction.CREATE,
|
||||
description="批量处理目录",
|
||||
body_fields=["path", "processor_type", "overwrite", "max_depth", "suffix"],
|
||||
)
|
||||
async def process_directory_with_processor(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
req: ProcessDirectoryRequest = Body(...),
|
||||
):
|
||||
data = await ProcessorService.process_directory(req)
|
||||
return success(data)
|
||||
|
||||
|
||||
@router.get("/source/{processor_type}")
|
||||
@audit(action=AuditAction.READ, description="获取处理器源码")
|
||||
async def get_processor_source(
|
||||
request: Request,
|
||||
processor_type: str,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
data = await ProcessorService.get_source(processor_type)
|
||||
return success(data)
|
||||
|
||||
|
||||
@router.put("/source/{processor_type}")
|
||||
@audit(action=AuditAction.UPDATE, description="更新处理器源码")
|
||||
async def update_processor_source(
|
||||
request: Request,
|
||||
processor_type: str,
|
||||
req: UpdateSourceRequest,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
data = await ProcessorService.update_source(processor_type, req)
|
||||
return success(data)
|
||||
|
||||
|
||||
@router.post("/reload")
|
||||
@audit(action=AuditAction.UPDATE, description="重载处理器模块")
|
||||
async def reload_processor_modules(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
data = ProcessorService.reload()
|
||||
return success(data)
|
||||
@@ -6,9 +6,11 @@ class BaseProcessor(Protocol):
|
||||
supported_exts: list
|
||||
config_schema: list
|
||||
produces_file: bool
|
||||
supports_directory: bool
|
||||
requires_input_bytes: bool
|
||||
|
||||
async def process(self, input_bytes: bytes, path: str, config: Dict[str, Any]) -> bytes:
|
||||
"""处理文件内容并返回处理后的内容"""
|
||||
async def process(self, input_bytes: bytes, path: str, config: Dict[str, Any]) -> Any:
|
||||
"""处理文件内容/路径并返回结果。produces_file=True 时应返回 bytes/Response。"""
|
||||
...
|
||||
|
||||
# 约定:每个处理器需定义
|
||||
1
domain/processors/builtin/__init__.py
Normal file
1
domain/processors/builtin/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# 内置处理器包
|
||||
@@ -1,9 +1,11 @@
|
||||
from .base import BaseProcessor
|
||||
from typing import Dict, Any
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from io import BytesIO
|
||||
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from fastapi.responses import Response
|
||||
from services.logging import LogService
|
||||
|
||||
from ..base import BaseProcessor
|
||||
|
||||
|
||||
class ImageWatermarkProcessor:
|
||||
name = "图片水印"
|
||||
@@ -26,10 +28,11 @@ class ImageWatermarkProcessor:
|
||||
]
|
||||
produces_file = True
|
||||
|
||||
async def process(self, input_bytes: bytes,path: str, config: Dict[str, Any]) -> Response:
|
||||
async def process(self, input_bytes: bytes, path: str, config: Dict[str, Any]) -> Response:
|
||||
text = config.get("text", "")
|
||||
position = config.get("position", "bottom-right")
|
||||
font_size = int(config.get("font_size", 24))
|
||||
|
||||
img = Image.open(BytesIO(input_bytes)).convert("RGBA")
|
||||
watermark = Image.new("RGBA", img.size)
|
||||
draw = ImageDraw.Draw(watermark)
|
||||
@@ -37,29 +40,29 @@ class ImageWatermarkProcessor:
|
||||
font = ImageFont.truetype("arial.ttf", font_size)
|
||||
except Exception:
|
||||
font = ImageFont.load_default()
|
||||
|
||||
w, h = img.size
|
||||
try:
|
||||
text_w, text_h = font.getsize(text)
|
||||
except AttributeError:
|
||||
bbox = draw.textbbox((0, 0), text, font=font)
|
||||
text_w, text_h = bbox[2] - bbox[0], bbox[3] - bbox[1]
|
||||
|
||||
if position == "bottom-right":
|
||||
xy = (w - text_w - 10, h - text_h - 10)
|
||||
elif position == "top-left":
|
||||
xy = (10, 10)
|
||||
else:
|
||||
xy = (w // 2 - text_w // 2, h // 2 - text_h // 2)
|
||||
|
||||
draw.text(xy, text, font=font, fill=(255, 255, 255, 128))
|
||||
out = Image.alpha_composite(img, watermark)
|
||||
buf = BytesIO()
|
||||
out.convert("RGB").save(buf, format="JPEG")
|
||||
await LogService.info(
|
||||
"processor:image_watermark",
|
||||
f"Watermarked image {path}",
|
||||
details={"path": path, "config": config},
|
||||
)
|
||||
|
||||
return Response(content=buf.getvalue(), media_type="image/jpeg")
|
||||
|
||||
|
||||
PROCESSOR_TYPE = "image_watermark"
|
||||
PROCESSOR_NAME = ImageWatermarkProcessor.name
|
||||
SUPPORTED_EXTS = ImageWatermarkProcessor.supported_exts
|
||||
233
domain/processors/builtin/vector_index.py
Normal file
233
domain/processors/builtin/vector_index.py
Normal file
@@ -0,0 +1,233 @@
|
||||
import base64
|
||||
import mimetypes
|
||||
import os
|
||||
from io import BytesIO
|
||||
from typing import Dict, Any, List, Tuple
|
||||
|
||||
from fastapi.responses import Response
|
||||
from PIL import Image
|
||||
|
||||
from ..base import BaseProcessor
|
||||
from domain.ai.inference import describe_image_base64, get_text_embedding, provider_service
|
||||
from domain.ai.service import VectorDBService, DEFAULT_VECTOR_DIMENSION
|
||||
|
||||
|
||||
CHUNK_SIZE = 800
|
||||
CHUNK_OVERLAP = 200
|
||||
MAX_IMAGE_EDGE = 1600
|
||||
JPEG_QUALITY = 85
|
||||
|
||||
|
||||
def _chunk_text(content: str, chunk_size: int = CHUNK_SIZE, overlap: int = CHUNK_OVERLAP) -> List[Tuple[int, str, int, int]]:
|
||||
"""按固定窗口拆分文本,返回(chunk_id, chunk_text, start, end)。"""
|
||||
if chunk_size <= 0:
|
||||
chunk_size = CHUNK_SIZE
|
||||
if overlap >= chunk_size:
|
||||
overlap = max(chunk_size // 4, 1)
|
||||
|
||||
chunks: List[Tuple[int, str, int, int]] = []
|
||||
step = chunk_size - overlap
|
||||
idx = 0
|
||||
start = 0
|
||||
length = len(content)
|
||||
|
||||
while start < length:
|
||||
end = min(length, start + chunk_size)
|
||||
chunk = content[start:end].strip()
|
||||
if chunk:
|
||||
chunks.append((idx, chunk, start, end))
|
||||
idx += 1
|
||||
if end >= length:
|
||||
break
|
||||
start += step
|
||||
return chunks
|
||||
|
||||
|
||||
def _guess_mime(path: str) -> str:
|
||||
mime, _ = mimetypes.guess_type(path)
|
||||
return mime or "application/octet-stream"
|
||||
|
||||
|
||||
def _chunk_key(path: str, chunk_id: str) -> str:
|
||||
return f"{path}#chunk={chunk_id}"
|
||||
|
||||
|
||||
def _compress_image_for_embedding(input_bytes: bytes) -> Tuple[bytes, Dict[str, Any] | None]:
|
||||
"""压缩图片,降低发送到视觉模型的体积。"""
|
||||
if Image is None:
|
||||
return input_bytes, None
|
||||
|
||||
try:
|
||||
with Image.open(BytesIO(input_bytes)) as img:
|
||||
img = img.convert("RGB")
|
||||
width, height = img.size
|
||||
longest_edge = max(width, height)
|
||||
scale = 1.0
|
||||
if longest_edge > MAX_IMAGE_EDGE:
|
||||
scale = MAX_IMAGE_EDGE / float(longest_edge)
|
||||
new_size = (max(int(width * scale), 1), max(int(height * scale), 1))
|
||||
resample_mode = getattr(getattr(Image, "Resampling", Image), "LANCZOS")
|
||||
img = img.resize(new_size, resample=resample_mode)
|
||||
|
||||
buffer = BytesIO()
|
||||
img.save(buffer, format="JPEG", quality=JPEG_QUALITY, optimize=True)
|
||||
compressed = buffer.getvalue()
|
||||
|
||||
if len(compressed) < len(input_bytes):
|
||||
return compressed, {
|
||||
"original_bytes": len(input_bytes),
|
||||
"compressed_bytes": len(compressed),
|
||||
"scaled": scale < 1.0,
|
||||
"width": img.width,
|
||||
"height": img.height,
|
||||
}
|
||||
except Exception: # pragma: no cover - 任意图像处理异常时回退
|
||||
return input_bytes, None
|
||||
|
||||
return input_bytes, None
|
||||
|
||||
|
||||
class VectorIndexProcessor:
|
||||
name = "向量索引"
|
||||
supported_exts: List[str] = [] # 留空表示不限扩展名
|
||||
config_schema = [
|
||||
{
|
||||
"key": "action", "label": "操作", "type": "select", "required": True, "default": "create",
|
||||
"options": [
|
||||
{"value": "create", "label": "创建索引"},
|
||||
{"value": "destroy", "label": "销毁索引"},
|
||||
]
|
||||
},
|
||||
{
|
||||
"key": "index_type", "label": "索引类型", "type": "select", "required": True, "default": "vector",
|
||||
"options": [
|
||||
{"value": "vector", "label": "向量索引"},
|
||||
{"value": "simple", "label": "普通索引"},
|
||||
]
|
||||
}
|
||||
]
|
||||
produces_file = False
|
||||
|
||||
async def process(self, input_bytes: bytes, path: str, config: Dict[str, Any]) -> Response:
|
||||
action = config.get("action", "create")
|
||||
index_type = config.get("index_type", "vector")
|
||||
vector_db = VectorDBService()
|
||||
collection_name = "vector_collection"
|
||||
|
||||
if action == "destroy":
|
||||
await vector_db.delete_vector(collection_name, path)
|
||||
return Response(content=f"文件 {path} 的 {index_type} 索引已销毁", media_type="text/plain")
|
||||
|
||||
mime_type = _guess_mime(path)
|
||||
|
||||
if index_type == "simple":
|
||||
await vector_db.ensure_collection(collection_name, vector=False)
|
||||
await vector_db.delete_vector(collection_name, path)
|
||||
await vector_db.upsert_vector(collection_name, {
|
||||
"path": path,
|
||||
"source_path": path,
|
||||
"chunk_id": "filename",
|
||||
"mime": mime_type,
|
||||
"type": "filename",
|
||||
"name": os.path.basename(path),
|
||||
})
|
||||
return Response(content=f"文件 {path} 的普通索引已创建", media_type="text/plain")
|
||||
|
||||
file_ext = path.split('.')[-1].lower()
|
||||
details: Dict[str, Any] = {"path": path, "action": "create", "index_type": "vector"}
|
||||
|
||||
embedding_model = await provider_service.get_default_model("embedding")
|
||||
vector_dim = DEFAULT_VECTOR_DIMENSION
|
||||
if embedding_model and getattr(embedding_model, "embedding_dimensions", None):
|
||||
try:
|
||||
vector_dim = int(embedding_model.embedding_dimensions)
|
||||
except (TypeError, ValueError):
|
||||
vector_dim = DEFAULT_VECTOR_DIMENSION
|
||||
if vector_dim <= 0:
|
||||
vector_dim = DEFAULT_VECTOR_DIMENSION
|
||||
|
||||
await vector_db.ensure_collection(collection_name, vector=True, dim=vector_dim)
|
||||
await vector_db.delete_vector(collection_name, path)
|
||||
|
||||
if file_ext in ["jpg", "jpeg", "png", "bmp"]:
|
||||
processed_bytes, compression = _compress_image_for_embedding(input_bytes)
|
||||
base64_image = base64.b64encode(processed_bytes).decode("utf-8")
|
||||
description = await describe_image_base64(base64_image)
|
||||
embedding = await get_text_embedding(description)
|
||||
image_mime = "image/jpeg" if compression else mime_type
|
||||
await vector_db.upsert_vector(collection_name, {
|
||||
"path": _chunk_key(path, "image"),
|
||||
"source_path": path,
|
||||
"chunk_id": "image",
|
||||
"embedding": embedding,
|
||||
"text": description,
|
||||
"mime": image_mime,
|
||||
"type": "image",
|
||||
})
|
||||
details["description"] = description
|
||||
if compression:
|
||||
details["image_compression"] = compression
|
||||
return Response(content=f"图片已索引,描述:{description}", media_type="text/plain")
|
||||
|
||||
if file_ext in ["txt", "md"]:
|
||||
try:
|
||||
text = input_bytes.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
return Response(content="文本文件解码失败", status_code=400)
|
||||
|
||||
chunks = _chunk_text(text)
|
||||
if not chunks:
|
||||
await vector_db.upsert_vector(collection_name, {
|
||||
"path": _chunk_key(path, "0"),
|
||||
"source_path": path,
|
||||
"chunk_id": "0",
|
||||
"embedding": await get_text_embedding(text or path),
|
||||
"text": text,
|
||||
"mime": mime_type,
|
||||
"type": "text",
|
||||
"start_offset": 0,
|
||||
"end_offset": len(text),
|
||||
})
|
||||
details["chunks"] = 1
|
||||
return Response(content="文本文件已索引", media_type="text/plain")
|
||||
|
||||
chunk_count = 0
|
||||
for chunk_id, chunk_text, start, end in chunks:
|
||||
embedding = await get_text_embedding(chunk_text)
|
||||
await vector_db.upsert_vector(collection_name, {
|
||||
"path": _chunk_key(path, str(chunk_id)),
|
||||
"source_path": path,
|
||||
"chunk_id": str(chunk_id),
|
||||
"embedding": embedding,
|
||||
"text": chunk_text,
|
||||
"mime": mime_type,
|
||||
"type": "text",
|
||||
"start_offset": start,
|
||||
"end_offset": end,
|
||||
})
|
||||
chunk_count += 1
|
||||
|
||||
details["chunks"] = chunk_count
|
||||
sample = chunks[0][1]
|
||||
details["sample"] = sample[:120]
|
||||
return Response(content="文本文件已索引", media_type="text/plain")
|
||||
|
||||
# 其他类型暂未支持向量索引,回退为文件名索引
|
||||
await vector_db.delete_vector(collection_name, path)
|
||||
await vector_db.upsert_vector(collection_name, {
|
||||
"path": _chunk_key(path, "fallback"),
|
||||
"source_path": path,
|
||||
"chunk_id": "filename",
|
||||
"mime": mime_type,
|
||||
"type": "filename",
|
||||
"name": os.path.basename(path),
|
||||
"embedding": [0.0] * vector_dim,
|
||||
})
|
||||
return Response(content="暂不支持该类型的向量索引,已创建文件名索引", media_type="text/plain")
|
||||
|
||||
|
||||
PROCESSOR_TYPE = "vector_index"
|
||||
PROCESSOR_NAME = VectorIndexProcessor.name
|
||||
SUPPORTED_EXTS = VectorIndexProcessor.supported_exts
|
||||
CONFIG_SCHEMA = VectorIndexProcessor.config_schema
|
||||
def PROCESSOR_FACTORY(): return VectorIndexProcessor()
|
||||
396
domain/processors/builtin/video_library.py
Normal file
396
domain/processors/builtin/video_library.py
Normal file
@@ -0,0 +1,396 @@
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import httpx
|
||||
|
||||
from domain.virtual_fs.service import VirtualFSService
|
||||
from domain.virtual_fs.thumbnail import VIDEO_EXT, is_video_filename
|
||||
|
||||
|
||||
DATA_ROOT = Path("data/.video")
|
||||
TMDB_BASE_URL = "https://api.themoviedb.org/3"
|
||||
|
||||
|
||||
def _sha1(text: str) -> str:
|
||||
return hashlib.sha1(text.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def _store_path(media_type: str, source_path: str) -> Path:
|
||||
subdir = "tv" if media_type == "tv" else "movie"
|
||||
return DATA_ROOT / subdir / f"{_sha1(source_path)}.json"
|
||||
|
||||
|
||||
def _write_json(path: Path, payload: dict) -> None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
|
||||
|
||||
_CLEAN_TAGS_RE = re.compile(
|
||||
r"\b("
|
||||
r"2160p|1080p|720p|480p|4k|hdr|dv|dolby|atmos|"
|
||||
r"x264|x265|h264|h265|hevc|av1|aac|dts|flac|"
|
||||
r"bluray|bdrip|web[- ]?dl|webrip|dvdrip|remux|proper|repack"
|
||||
r")\b",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def _clean_query_name(raw: str) -> str:
|
||||
name = raw
|
||||
name = name.replace(".", " ").replace("_", " ")
|
||||
name = re.sub(r"\[[^\]]*\]", " ", name)
|
||||
name = re.sub(r"\([^\)]*\)", " ", name)
|
||||
name = _CLEAN_TAGS_RE.sub(" ", name)
|
||||
name = re.sub(r"\s+", " ", name).strip()
|
||||
return name
|
||||
|
||||
|
||||
def _guess_name_from_path(path: str, is_dir: bool) -> str:
|
||||
norm = path.rstrip("/") if is_dir else path
|
||||
p = Path(norm)
|
||||
raw = p.name if is_dir else p.stem
|
||||
return _clean_query_name(raw)
|
||||
|
||||
|
||||
def _as_bool(value: Any, default: bool) -> bool:
|
||||
if value is None:
|
||||
return default
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if isinstance(value, int):
|
||||
return value != 0
|
||||
if isinstance(value, str):
|
||||
v = value.strip().lower()
|
||||
if v in {"1", "true", "yes", "y", "on"}:
|
||||
return True
|
||||
if v in {"0", "false", "no", "n", "off"}:
|
||||
return False
|
||||
return default
|
||||
|
||||
|
||||
_SXXEYY_RE = re.compile(r"[Ss](\d{1,2})\s*[.\-_ ]*\s*[Ee](\d{1,3})")
|
||||
_X_RE = re.compile(r"(\d{1,2})x(\d{1,3})", re.IGNORECASE)
|
||||
_CN_EP_RE = re.compile(r"第\s*(\d{1,3})\s*[集话]")
|
||||
_CN_SEASON_RE = re.compile(r"第\s*(\d{1,2})\s*季")
|
||||
_SEASON_WORD_RE = re.compile(r"Season\s*(\d{1,2})", re.IGNORECASE)
|
||||
_S_RE = re.compile(r"[Ss](\d{1,2})")
|
||||
|
||||
|
||||
def _parse_season_episode(rel_path: str) -> Tuple[Optional[int], Optional[int]]:
|
||||
stem = Path(rel_path).stem
|
||||
|
||||
m = _SXXEYY_RE.search(stem) or _SXXEYY_RE.search(rel_path)
|
||||
if m:
|
||||
return int(m.group(1)), int(m.group(2))
|
||||
|
||||
m = _X_RE.search(stem)
|
||||
if m:
|
||||
return int(m.group(1)), int(m.group(2))
|
||||
|
||||
m = _CN_EP_RE.search(stem)
|
||||
if m:
|
||||
episode = int(m.group(1))
|
||||
season = None
|
||||
for part in reversed(Path(rel_path).parts[:-1]):
|
||||
sm = _CN_SEASON_RE.search(part) or _SEASON_WORD_RE.search(part) or _S_RE.search(part)
|
||||
if sm:
|
||||
season = int(sm.group(1))
|
||||
break
|
||||
return season or 1, episode
|
||||
|
||||
m = re.match(r"^(\d{1,3})(?!\d)", stem)
|
||||
if m:
|
||||
episode = int(m.group(1))
|
||||
season = None
|
||||
for part in reversed(Path(rel_path).parts[:-1]):
|
||||
sm = _CN_SEASON_RE.search(part) or _SEASON_WORD_RE.search(part) or _S_RE.search(part)
|
||||
if sm:
|
||||
season = int(sm.group(1))
|
||||
break
|
||||
return season or 1, episode
|
||||
|
||||
return None, None
|
||||
|
||||
|
||||
class TMDBClient:
|
||||
def __init__(self, access_token: str | None, api_key: str | None):
|
||||
self._access_token = access_token
|
||||
self._api_key = api_key
|
||||
|
||||
@classmethod
|
||||
def from_env(cls) -> "TMDBClient":
|
||||
access_token = os.getenv("TMDB_ACCESS_TOKEN")
|
||||
api_key = os.getenv("TMDB_API_KEY")
|
||||
if not access_token and not api_key:
|
||||
raise RuntimeError("缺少 TMDB_ACCESS_TOKEN 或 TMDB_API_KEY")
|
||||
return cls(access_token=access_token, api_key=api_key)
|
||||
|
||||
def _headers(self) -> dict:
|
||||
headers = {"Accept": "application/json"}
|
||||
if self._access_token:
|
||||
headers["Authorization"] = f"Bearer {self._access_token}"
|
||||
return headers
|
||||
|
||||
def _merge_params(self, params: dict) -> dict:
|
||||
merged = dict(params or {})
|
||||
if self._api_key:
|
||||
merged.setdefault("api_key", self._api_key)
|
||||
return merged
|
||||
|
||||
async def get(self, path: str, params: dict) -> dict:
|
||||
url = f"{TMDB_BASE_URL}{path}"
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
resp = await client.get(url, headers=self._headers(), params=self._merge_params(params))
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
class VideoLibraryProcessor:
|
||||
name = "影视入库"
|
||||
supported_exts = sorted(VIDEO_EXT)
|
||||
config_schema = [
|
||||
{
|
||||
"key": "name",
|
||||
"label": "手动名称(可选)",
|
||||
"type": "string",
|
||||
"required": False,
|
||||
"placeholder": "留空则从路径提取",
|
||||
},
|
||||
{
|
||||
"key": "language",
|
||||
"label": "语言",
|
||||
"type": "string",
|
||||
"required": False,
|
||||
"default": "zh-CN",
|
||||
},
|
||||
{
|
||||
"key": "include_episodes",
|
||||
"label": "电视剧:保存每集",
|
||||
"type": "select",
|
||||
"required": False,
|
||||
"default": 1,
|
||||
"options": [
|
||||
{"label": "是", "value": 1},
|
||||
{"label": "否", "value": 0},
|
||||
],
|
||||
},
|
||||
]
|
||||
produces_file = False
|
||||
supports_directory = True
|
||||
requires_input_bytes = False
|
||||
|
||||
async def process(self, input_bytes: bytes, path: str, config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
tmdb = TMDBClient.from_env()
|
||||
is_dir = await VirtualFSService.path_is_directory(path)
|
||||
language = str(config.get("language") or "zh-CN")
|
||||
manual_name = str(config.get("name") or "").strip()
|
||||
query_name = manual_name or _guess_name_from_path(path, is_dir=is_dir)
|
||||
scraped_at = datetime.now(UTC).isoformat()
|
||||
|
||||
if is_dir:
|
||||
payload, saved_to = await self._process_tv_dir(tmdb, path, query_name, language, scraped_at, config)
|
||||
return {
|
||||
"ok": True,
|
||||
"type": "tv",
|
||||
"path": path,
|
||||
"tmdb_id": payload.get("tmdb", {}).get("id"),
|
||||
"saved_to": str(saved_to),
|
||||
}
|
||||
|
||||
payload, saved_to = await self._process_movie_file(tmdb, path, query_name, language, scraped_at)
|
||||
return {
|
||||
"ok": True,
|
||||
"type": "movie",
|
||||
"path": path,
|
||||
"tmdb_id": payload.get("tmdb", {}).get("id"),
|
||||
"saved_to": str(saved_to),
|
||||
}
|
||||
|
||||
async def _process_movie_file(
|
||||
self,
|
||||
tmdb: TMDBClient,
|
||||
path: str,
|
||||
query_name: str,
|
||||
language: str,
|
||||
scraped_at: str,
|
||||
) -> Tuple[dict, Path]:
|
||||
search = await tmdb.get("/search/movie", {"query": query_name, "language": language})
|
||||
results = search.get("results") or []
|
||||
if not results:
|
||||
raise RuntimeError(f"未找到电影条目:{query_name}")
|
||||
|
||||
chosen = results[0] or {}
|
||||
movie_id = chosen.get("id")
|
||||
if not movie_id:
|
||||
raise RuntimeError("TMDB 搜索结果缺少 id")
|
||||
|
||||
detail = await tmdb.get(
|
||||
f"/movie/{movie_id}",
|
||||
{
|
||||
"language": language,
|
||||
"append_to_response": "credits,images,external_ids,videos",
|
||||
},
|
||||
)
|
||||
|
||||
payload = {
|
||||
"type": "movie",
|
||||
"source_path": path,
|
||||
"query": {"name": query_name, "language": language},
|
||||
"scraped_at": scraped_at,
|
||||
"tmdb": {
|
||||
"id": movie_id,
|
||||
"search": {"page": search.get("page"), "total_results": search.get("total_results"), "results": results[:5]},
|
||||
"detail": detail,
|
||||
},
|
||||
}
|
||||
saved_to = _store_path("movie", path)
|
||||
_write_json(saved_to, payload)
|
||||
return payload, saved_to
|
||||
|
||||
async def _process_tv_dir(
|
||||
self,
|
||||
tmdb: TMDBClient,
|
||||
path: str,
|
||||
query_name: str,
|
||||
language: str,
|
||||
scraped_at: str,
|
||||
config: Dict[str, Any],
|
||||
) -> Tuple[dict, Path]:
|
||||
search = await tmdb.get("/search/tv", {"query": query_name, "language": language})
|
||||
results = search.get("results") or []
|
||||
if not results:
|
||||
raise RuntimeError(f"未找到电视剧条目:{query_name}")
|
||||
|
||||
chosen = results[0] or {}
|
||||
tv_id = chosen.get("id")
|
||||
if not tv_id:
|
||||
raise RuntimeError("TMDB 搜索结果缺少 id")
|
||||
|
||||
detail = await tmdb.get(
|
||||
f"/tv/{tv_id}",
|
||||
{
|
||||
"language": language,
|
||||
"append_to_response": "credits,images,external_ids,videos",
|
||||
},
|
||||
)
|
||||
|
||||
include_episodes = _as_bool(config.get("include_episodes"), True)
|
||||
episodes: List[dict] = []
|
||||
seasons_detail: Dict[str, Any] = {}
|
||||
if include_episodes:
|
||||
episodes = await self._collect_episode_files(path)
|
||||
seasons = sorted({ep["season"] for ep in episodes if ep.get("season") is not None})
|
||||
for season in seasons:
|
||||
seasons_detail[str(season)] = await tmdb.get(
|
||||
f"/tv/{tv_id}/season/{int(season)}",
|
||||
{"language": language},
|
||||
)
|
||||
self._attach_tmdb_episode_detail(episodes, seasons_detail)
|
||||
|
||||
payload = {
|
||||
"type": "tv",
|
||||
"source_path": path,
|
||||
"query": {"name": query_name, "language": language},
|
||||
"scraped_at": scraped_at,
|
||||
"tmdb": {
|
||||
"id": tv_id,
|
||||
"search": {"page": search.get("page"), "total_results": search.get("total_results"), "results": results[:5]},
|
||||
"detail": detail,
|
||||
"seasons": seasons_detail,
|
||||
},
|
||||
"episodes": episodes,
|
||||
}
|
||||
|
||||
saved_to = _store_path("tv", path)
|
||||
_write_json(saved_to, payload)
|
||||
return payload, saved_to
|
||||
|
||||
async def _collect_episode_files(self, dir_path: str) -> List[dict]:
|
||||
adapter_instance, adapter_model, root, rel = await VirtualFSService.resolve_adapter_and_rel(dir_path)
|
||||
rel = rel.rstrip("/")
|
||||
list_dir = await VirtualFSService._ensure_method(adapter_instance, "list_dir")
|
||||
|
||||
stack: List[str] = [rel]
|
||||
page_size = 200
|
||||
out: List[dict] = []
|
||||
|
||||
while stack:
|
||||
current_rel = stack.pop()
|
||||
page = 1
|
||||
while True:
|
||||
entries, total = await list_dir(root, current_rel, page, page_size, "name", "asc")
|
||||
entries = entries or []
|
||||
if not entries and (total or 0) == 0:
|
||||
break
|
||||
|
||||
for entry in entries:
|
||||
name = entry.get("name")
|
||||
if not name:
|
||||
continue
|
||||
child_rel = VirtualFSService._join_rel(current_rel, name)
|
||||
if entry.get("is_dir"):
|
||||
stack.append(child_rel.rstrip("/"))
|
||||
continue
|
||||
if not is_video_filename(name):
|
||||
continue
|
||||
|
||||
absolute_path = VirtualFSService._build_absolute_path(adapter_model.path, child_rel)
|
||||
rel_in_show = child_rel
|
||||
if rel and child_rel.startswith(rel.rstrip("/") + "/"):
|
||||
rel_in_show = child_rel[len(rel.rstrip("/")) + 1 :]
|
||||
|
||||
season, episode = _parse_season_episode(rel_in_show)
|
||||
out.append(
|
||||
{
|
||||
"path": absolute_path,
|
||||
"rel": rel_in_show,
|
||||
"name": name,
|
||||
"size": entry.get("size"),
|
||||
"mtime": entry.get("mtime"),
|
||||
"season": season,
|
||||
"episode": episode,
|
||||
}
|
||||
)
|
||||
|
||||
if total is None or page * page_size >= total:
|
||||
break
|
||||
page += 1
|
||||
|
||||
return out
|
||||
|
||||
def _attach_tmdb_episode_detail(self, episodes: List[dict], seasons_detail: Dict[str, Any]) -> None:
|
||||
episode_maps: Dict[str, Dict[int, Any]] = {}
|
||||
for season_str, season_payload in (seasons_detail or {}).items():
|
||||
items = (season_payload or {}).get("episodes") or []
|
||||
m: Dict[int, Any] = {}
|
||||
for item in items:
|
||||
try:
|
||||
number = int(item.get("episode_number"))
|
||||
except Exception:
|
||||
continue
|
||||
m[number] = item
|
||||
episode_maps[season_str] = m
|
||||
|
||||
for ep in episodes:
|
||||
season = ep.get("season")
|
||||
episode = ep.get("episode")
|
||||
if season is None or episode is None:
|
||||
continue
|
||||
m = episode_maps.get(str(season))
|
||||
if not m:
|
||||
continue
|
||||
detail = m.get(int(episode))
|
||||
if detail:
|
||||
ep["tmdb_episode"] = detail
|
||||
|
||||
|
||||
PROCESSOR_TYPE = "video_library"
|
||||
PROCESSOR_NAME = VideoLibraryProcessor.name
|
||||
SUPPORTED_EXTS = VideoLibraryProcessor.supported_exts
|
||||
CONFIG_SCHEMA = VideoLibraryProcessor.config_schema
|
||||
PROCESSOR_FACTORY = lambda: VideoLibraryProcessor()
|
||||
145
domain/processors/registry.py
Normal file
145
domain/processors/registry.py
Normal file
@@ -0,0 +1,145 @@
|
||||
import inspect
|
||||
import pkgutil
|
||||
from importlib import import_module, reload
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
from domain.processors.base import BaseProcessor
|
||||
|
||||
ProcessorFactory = Callable[[], BaseProcessor]
|
||||
TYPE_MAP: Dict[str, ProcessorFactory] = {}
|
||||
CONFIG_SCHEMAS: Dict[str, dict] = {}
|
||||
MODULE_MAP: Dict[str, ModuleType] = {}
|
||||
LAST_DISCOVERY_ERRORS: list[str] = []
|
||||
|
||||
|
||||
def discover_processors(force_reload: bool = False) -> list[str]:
|
||||
"""扫描并缓存可用的处理器模块。"""
|
||||
from domain.processors import builtin as processors_pkg
|
||||
|
||||
TYPE_MAP.clear()
|
||||
CONFIG_SCHEMAS.clear()
|
||||
MODULE_MAP.clear()
|
||||
|
||||
global LAST_DISCOVERY_ERRORS
|
||||
LAST_DISCOVERY_ERRORS = []
|
||||
|
||||
for modinfo in pkgutil.iter_modules(processors_pkg.__path__):
|
||||
if modinfo.name.startswith("_"):
|
||||
continue
|
||||
|
||||
full_name = f"{processors_pkg.__name__}.{modinfo.name}"
|
||||
try:
|
||||
module = import_module(full_name)
|
||||
if force_reload:
|
||||
module = reload(module)
|
||||
except Exception as exc:
|
||||
LAST_DISCOVERY_ERRORS.append(f"Failed to import {full_name}: {exc}")
|
||||
continue
|
||||
|
||||
processor_type = getattr(module, "PROCESSOR_TYPE", None)
|
||||
processor_name = getattr(module, "PROCESSOR_NAME", None)
|
||||
supported_exts = getattr(module, "SUPPORTED_EXTS", None)
|
||||
schema = getattr(module, "CONFIG_SCHEMA", None)
|
||||
factory = getattr(module, "PROCESSOR_FACTORY", None)
|
||||
|
||||
if not processor_type:
|
||||
continue
|
||||
|
||||
if factory is None:
|
||||
for attr in module.__dict__.values():
|
||||
if inspect.isclass(attr) and attr.__name__.endswith("Processor"):
|
||||
|
||||
def _mk(cls=attr):
|
||||
return lambda: cls()
|
||||
|
||||
factory = _mk()
|
||||
break
|
||||
|
||||
if not callable(factory):
|
||||
LAST_DISCOVERY_ERRORS.append(f"Processor {full_name} missing factory")
|
||||
continue
|
||||
|
||||
try:
|
||||
sample = factory()
|
||||
except Exception as exc:
|
||||
LAST_DISCOVERY_ERRORS.append(f"Failed to instantiate processor {processor_type}: {exc}")
|
||||
continue
|
||||
|
||||
TYPE_MAP[processor_type] = factory
|
||||
MODULE_MAP[processor_type] = module
|
||||
|
||||
produces_file = getattr(module, "produces_file", None)
|
||||
if produces_file is None and hasattr(sample, "produces_file"):
|
||||
produces_file = getattr(sample, "produces_file")
|
||||
|
||||
supports_directory = getattr(module, "supports_directory", None)
|
||||
if supports_directory is None and hasattr(sample, "supports_directory"):
|
||||
supports_directory = getattr(sample, "supports_directory")
|
||||
|
||||
module_file = getattr(module, "__file__", None)
|
||||
module_path: Optional[str] = None
|
||||
if module_file:
|
||||
try:
|
||||
module_path = str(Path(module_file).resolve())
|
||||
except Exception:
|
||||
module_path = module_file
|
||||
|
||||
if isinstance(supported_exts, list):
|
||||
normalized_exts = [str(ext) for ext in supported_exts]
|
||||
elif supported_exts:
|
||||
normalized_exts = [str(supported_exts)]
|
||||
else:
|
||||
normalized_exts = []
|
||||
|
||||
if not normalized_exts and hasattr(sample, "supported_exts"):
|
||||
sample_exts = getattr(sample, "supported_exts") or []
|
||||
if isinstance(sample_exts, list):
|
||||
normalized_exts = [str(ext) for ext in sample_exts]
|
||||
|
||||
if isinstance(schema, list):
|
||||
CONFIG_SCHEMAS[processor_type] = {
|
||||
"type": processor_type,
|
||||
"name": processor_name or processor_type,
|
||||
"supported_exts": normalized_exts,
|
||||
"config_schema": schema,
|
||||
"produces_file": produces_file if produces_file is not None else False,
|
||||
"supports_directory": supports_directory if supports_directory is not None else False,
|
||||
"module_path": module_path,
|
||||
}
|
||||
|
||||
return LAST_DISCOVERY_ERRORS
|
||||
|
||||
|
||||
def get_config_schemas() -> Dict[str, dict]:
|
||||
return CONFIG_SCHEMAS
|
||||
|
||||
|
||||
def get_config_schema(processor_type: str):
|
||||
return CONFIG_SCHEMAS.get(processor_type)
|
||||
|
||||
|
||||
def get(processor_type: str) -> BaseProcessor | None:
|
||||
factory = TYPE_MAP.get(processor_type)
|
||||
if factory:
|
||||
return factory()
|
||||
return None
|
||||
|
||||
|
||||
def get_module_path(processor_type: str) -> Optional[str]:
|
||||
meta = CONFIG_SCHEMAS.get(processor_type)
|
||||
if not meta:
|
||||
return None
|
||||
return meta.get("module_path")
|
||||
|
||||
|
||||
def get_last_discovery_errors() -> list[str]:
|
||||
return LAST_DISCOVERY_ERRORS
|
||||
|
||||
|
||||
def reload_processors() -> list[str]:
|
||||
return discover_processors(force_reload=True)
|
||||
|
||||
|
||||
discover_processors()
|
||||
223
domain/processors/service.py
Normal file
223
domain/processors/service.py
Normal file
@@ -0,0 +1,223 @@
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
from fastapi import HTTPException
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from domain.processors.registry import (
|
||||
get,
|
||||
get_config_schema,
|
||||
get_config_schemas,
|
||||
get_module_path,
|
||||
reload_processors,
|
||||
)
|
||||
from domain.processors.types import (
|
||||
ProcessDirectoryRequest,
|
||||
ProcessRequest,
|
||||
UpdateSourceRequest,
|
||||
)
|
||||
from domain.virtual_fs.service import VirtualFSService
|
||||
from domain.tasks.task_queue import task_queue_service
|
||||
|
||||
|
||||
class ProcessorService:
|
||||
@classmethod
|
||||
def get_processor(cls, processor_type: str):
|
||||
return get(processor_type)
|
||||
|
||||
@classmethod
|
||||
def list_processors(cls):
|
||||
schemas = get_config_schemas()
|
||||
out = []
|
||||
for t, meta in schemas.items():
|
||||
out.append({
|
||||
"type": meta["type"],
|
||||
"name": meta["name"],
|
||||
"supported_exts": meta.get("supported_exts", []),
|
||||
"config_schema": meta["config_schema"],
|
||||
"produces_file": meta.get("produces_file", False),
|
||||
"supports_directory": meta.get("supports_directory", False),
|
||||
"module_path": meta.get("module_path"),
|
||||
})
|
||||
return out
|
||||
|
||||
@classmethod
|
||||
async def process_file(cls, req: ProcessRequest):
|
||||
processor = cls.get_processor(req.processor_type)
|
||||
if not processor:
|
||||
raise HTTPException(404, detail="Processor not found")
|
||||
|
||||
is_dir = await VirtualFSService.path_is_directory(req.path)
|
||||
supports_directory = bool(getattr(processor, "supports_directory", False))
|
||||
if is_dir and not supports_directory and not req.overwrite:
|
||||
raise HTTPException(400, detail="Directory processing requires overwrite")
|
||||
|
||||
save_to = None if is_dir else (req.path if req.overwrite else req.save_to)
|
||||
task = await task_queue_service.add_task(
|
||||
"process_file",
|
||||
{
|
||||
"path": req.path,
|
||||
"processor_type": req.processor_type,
|
||||
"config": req.config,
|
||||
"save_to": save_to,
|
||||
"overwrite": req.overwrite,
|
||||
},
|
||||
)
|
||||
return {"task_id": task.id}
|
||||
|
||||
@classmethod
|
||||
async def process_directory(cls, req: ProcessDirectoryRequest):
|
||||
if req.max_depth is not None and req.max_depth < 0:
|
||||
raise HTTPException(400, detail="max_depth must be >= 0")
|
||||
|
||||
is_dir = await VirtualFSService.path_is_directory(req.path)
|
||||
if not is_dir:
|
||||
raise HTTPException(400, detail="Path must be a directory")
|
||||
|
||||
schema = get_config_schema(req.processor_type)
|
||||
_processor = get(req.processor_type)
|
||||
if not schema or not _processor:
|
||||
raise HTTPException(404, detail="Processor not found")
|
||||
|
||||
produces_file = bool(schema.get("produces_file"))
|
||||
raw_suffix = req.suffix if req.suffix is not None else None
|
||||
if raw_suffix is not None and raw_suffix.strip() == "":
|
||||
raw_suffix = None
|
||||
suffix = raw_suffix
|
||||
overwrite = req.overwrite
|
||||
|
||||
if produces_file:
|
||||
if not overwrite and not suffix:
|
||||
raise HTTPException(400, detail="Suffix is required when not overwriting files")
|
||||
else:
|
||||
overwrite = False
|
||||
suffix = None
|
||||
|
||||
supported_exts = schema.get("supported_exts") or []
|
||||
allowed_exts = {
|
||||
ext.lower().lstrip('.')
|
||||
for ext in supported_exts
|
||||
if isinstance(ext, str)
|
||||
}
|
||||
|
||||
def matches_extension(file_rel: str) -> bool:
|
||||
if not allowed_exts:
|
||||
return True
|
||||
if '.' not in file_rel:
|
||||
return '' in allowed_exts
|
||||
ext = file_rel.rsplit('.', 1)[-1].lower()
|
||||
return ext in allowed_exts or f'.{ext}' in allowed_exts
|
||||
|
||||
adapter_instance, adapter_model, root, rel = await VirtualFSService.resolve_adapter_and_rel(req.path)
|
||||
rel = rel.rstrip('/')
|
||||
|
||||
list_dir = getattr(adapter_instance, "list_dir", None)
|
||||
if not callable(list_dir):
|
||||
raise HTTPException(501, detail="Adapter does not implement list_dir")
|
||||
|
||||
def build_absolute_path(mount_path: str, rel_path: str) -> str:
|
||||
rel_norm = rel_path.lstrip('/')
|
||||
mount_norm = mount_path.rstrip('/')
|
||||
if not mount_norm:
|
||||
return '/' + rel_norm if rel_norm else '/'
|
||||
return f"{mount_norm}/{rel_norm}" if rel_norm else mount_norm
|
||||
|
||||
def apply_suffix(path_str: str, suffix_str: str) -> str:
|
||||
path_obj = Path(path_str)
|
||||
name = path_obj.name
|
||||
if not name:
|
||||
return path_str
|
||||
if '.' in name:
|
||||
base, ext = name.rsplit('.', 1)
|
||||
new_name = f"{base}{suffix_str}.{ext}"
|
||||
else:
|
||||
new_name = f"{name}{suffix_str}"
|
||||
return str(path_obj.with_name(new_name))
|
||||
|
||||
scheduled_tasks: List[str] = []
|
||||
stack: List[Tuple[str, int]] = [(rel, 0)]
|
||||
page_size = 200
|
||||
|
||||
while stack:
|
||||
current_rel, depth = stack.pop()
|
||||
page = 1
|
||||
while True:
|
||||
entries, total = await list_dir(root, current_rel, page, page_size, "name", "asc")
|
||||
entries = entries or []
|
||||
if not entries and (total or 0) == 0:
|
||||
break
|
||||
|
||||
for entry in entries:
|
||||
name = entry.get("name")
|
||||
if not name:
|
||||
continue
|
||||
child_rel = f"{current_rel}/{name}" if current_rel else name
|
||||
if entry.get("is_dir"):
|
||||
if req.max_depth is None or depth < req.max_depth:
|
||||
stack.append((child_rel.rstrip('/'), depth + 1))
|
||||
continue
|
||||
if not matches_extension(child_rel):
|
||||
continue
|
||||
absolute_path = build_absolute_path(adapter_model.path, child_rel)
|
||||
save_to = None
|
||||
if produces_file and not overwrite and suffix:
|
||||
save_to = apply_suffix(absolute_path, suffix)
|
||||
task = await task_queue_service.add_task(
|
||||
"process_file",
|
||||
{
|
||||
"path": absolute_path,
|
||||
"processor_type": req.processor_type,
|
||||
"config": req.config,
|
||||
"save_to": save_to,
|
||||
"overwrite": overwrite,
|
||||
},
|
||||
)
|
||||
scheduled_tasks.append(task.id)
|
||||
|
||||
if total is None or page * page_size >= total:
|
||||
break
|
||||
page += 1
|
||||
|
||||
return {
|
||||
"task_ids": scheduled_tasks,
|
||||
"scheduled": len(scheduled_tasks),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
async def get_source(cls, processor_type: str):
|
||||
module_path = get_module_path(processor_type)
|
||||
if not module_path:
|
||||
raise HTTPException(404, detail="Processor not found")
|
||||
path_obj = Path(module_path)
|
||||
if not path_obj.exists():
|
||||
raise HTTPException(404, detail="Processor source not found")
|
||||
try:
|
||||
content = await run_in_threadpool(path_obj.read_text, encoding='utf-8')
|
||||
except Exception as exc:
|
||||
raise HTTPException(500, detail=f"Failed to read source: {exc}")
|
||||
return {"source": content, "module_path": str(path_obj)}
|
||||
|
||||
@classmethod
|
||||
async def update_source(cls, processor_type: str, req: UpdateSourceRequest):
|
||||
module_path = get_module_path(processor_type)
|
||||
if not module_path:
|
||||
raise HTTPException(404, detail="Processor not found")
|
||||
path_obj = Path(module_path)
|
||||
if not path_obj.exists():
|
||||
raise HTTPException(404, detail="Processor source not found")
|
||||
try:
|
||||
await run_in_threadpool(path_obj.write_text, req.source, encoding='utf-8')
|
||||
except Exception as exc:
|
||||
raise HTTPException(500, detail=f"Failed to write source: {exc}")
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def reload(cls):
|
||||
errors = reload_processors()
|
||||
if errors:
|
||||
raise HTTPException(500, detail="; ".join(errors))
|
||||
return True
|
||||
|
||||
|
||||
get_processor = ProcessorService.get_processor
|
||||
list_processors = ProcessorService.list_processors
|
||||
reload_processor_modules = ProcessorService.reload
|
||||
24
domain/processors/types.py
Normal file
24
domain/processors/types.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ProcessRequest(BaseModel):
|
||||
path: str
|
||||
processor_type: str
|
||||
config: Dict[str, Any]
|
||||
save_to: Optional[str] = None
|
||||
overwrite: bool = False
|
||||
|
||||
|
||||
class ProcessDirectoryRequest(BaseModel):
|
||||
path: str
|
||||
processor_type: str
|
||||
config: Dict[str, Any]
|
||||
overwrite: bool = True
|
||||
max_depth: Optional[int] = None
|
||||
suffix: Optional[str] = None
|
||||
|
||||
|
||||
class UpdateSourceRequest(BaseModel):
|
||||
source: str
|
||||
129
domain/share/api.py
Normal file
129
domain/share/api.py
Normal file
@@ -0,0 +1,129 @@
|
||||
from typing import Annotated, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
|
||||
from api.response import success
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.auth.service import get_current_active_user
|
||||
from domain.auth.types import User
|
||||
from domain.share.service import ShareService
|
||||
from domain.share.types import (
|
||||
ShareCreate,
|
||||
ShareInfo,
|
||||
ShareInfoWithPassword,
|
||||
SharePassword,
|
||||
)
|
||||
from models.database import UserAccount
|
||||
|
||||
public_router = APIRouter(prefix="/api/s", tags=["Share - Public"])
|
||||
router = APIRouter(prefix="/api/shares", tags=["Share - Management"])
|
||||
|
||||
|
||||
@router.post("", response_model=ShareInfoWithPassword)
|
||||
@audit(
|
||||
action=AuditAction.SHARE,
|
||||
description="创建分享链接",
|
||||
body_fields=["name", "paths", "expires_in_days", "access_type"],
|
||||
)
|
||||
async def create_share(
|
||||
request: Request,
|
||||
payload: ShareCreate,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
user_account = await UserAccount.get(id=current_user.id)
|
||||
share = await ShareService.create_share_link(
|
||||
user=user_account,
|
||||
name=payload.name,
|
||||
paths=payload.paths,
|
||||
expires_in_days=payload.expires_in_days,
|
||||
access_type=payload.access_type,
|
||||
password=payload.password,
|
||||
)
|
||||
share_info = ShareInfo.from_orm(share).model_dump()
|
||||
if payload.access_type == "password" and payload.password:
|
||||
share_info["password"] = payload.password
|
||||
return share_info
|
||||
|
||||
|
||||
@router.get("", response_model=List[ShareInfo])
|
||||
@audit(action=AuditAction.READ, description="获取我的分享列表")
|
||||
async def get_my_shares(
|
||||
request: Request, current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
user_account = await UserAccount.get(id=current_user.id)
|
||||
shares = await ShareService.get_user_shares(user=user_account)
|
||||
return [ShareInfo.from_orm(s) for s in shares]
|
||||
|
||||
|
||||
@router.delete("/expired")
|
||||
@audit(action=AuditAction.DELETE, description="删除已过期分享")
|
||||
async def delete_expired_shares(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
user_account = await UserAccount.get(id=current_user.id)
|
||||
deleted_count = await ShareService.delete_expired_shares(user=user_account)
|
||||
return success({"deleted_count": deleted_count})
|
||||
|
||||
|
||||
@router.delete("/{share_id}")
|
||||
@audit(action=AuditAction.DELETE, description="删除分享链接")
|
||||
async def delete_share(
|
||||
share_id: int,
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
user_account = await UserAccount.get(id=current_user.id)
|
||||
await ShareService.delete_share_link(user=user_account, share_id=share_id)
|
||||
return success(msg="分享已取消")
|
||||
|
||||
|
||||
@public_router.post("/{token}/verify")
|
||||
@audit(
|
||||
action=AuditAction.SHARE,
|
||||
description="校验分享密码",
|
||||
body_fields=["password"],
|
||||
redact_fields=["password"],
|
||||
)
|
||||
async def verify_password(request: Request, token: str, payload: SharePassword):
|
||||
await ShareService.verify_share_password(token, payload.password)
|
||||
return success(msg="验证成功")
|
||||
|
||||
|
||||
@public_router.get("/{token}/ls")
|
||||
@audit(action=AuditAction.SHARE, description="浏览分享内容")
|
||||
async def list_share_content(
|
||||
request: Request, token: str, path: str = "/", password: Optional[str] = None
|
||||
):
|
||||
share = await ShareService.ensure_share_access(token, password)
|
||||
content = await ShareService.get_shared_item_details(share, path)
|
||||
return success(
|
||||
{
|
||||
"path": path,
|
||||
"entries": content.get("items", []),
|
||||
"pagination": {
|
||||
"total": content.get("total", 0),
|
||||
"page": content.get("page", 1),
|
||||
"page_size": content.get("page_size", 1),
|
||||
"pages": content.get("pages", 1),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@public_router.get("/{token}")
|
||||
@audit(action=AuditAction.SHARE, description="获取分享信息")
|
||||
async def get_share_info(request: Request, token: str):
|
||||
share = await ShareService.get_share_by_token(token)
|
||||
return success(ShareInfo.from_orm(share))
|
||||
|
||||
|
||||
@public_router.get("/{token}/download")
|
||||
@audit(action=AuditAction.DOWNLOAD, description="下载分享文件")
|
||||
async def download_shared_file(
|
||||
token: str,
|
||||
path: str,
|
||||
request: Request,
|
||||
password: Optional[str] = None,
|
||||
):
|
||||
return await ShareService.stream_shared_file(token, path, request.headers.get("Range"), password)
|
||||
187
domain/share/service.py
Normal file
187
domain/share/service.py
Normal file
@@ -0,0 +1,187 @@
|
||||
import secrets
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import List, Optional
|
||||
from urllib.parse import quote
|
||||
|
||||
import bcrypt
|
||||
from fastapi import HTTPException, status
|
||||
from fastapi.responses import Response
|
||||
|
||||
from domain.virtual_fs.service import VirtualFSService
|
||||
from models.database import ShareLink, UserAccount
|
||||
|
||||
|
||||
class ShareService:
|
||||
@classmethod
|
||||
def _hash_password(cls, password: str) -> str:
|
||||
return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
|
||||
|
||||
@classmethod
|
||||
def _verify_password(cls, plain_password: str, hashed_password: str) -> bool:
|
||||
return bcrypt.checkpw(plain_password.encode("utf-8"), hashed_password.encode("utf-8"))
|
||||
|
||||
@classmethod
|
||||
def _calc_expires_at(cls, expires_in_days: Optional[int]) -> Optional[datetime]:
|
||||
if expires_in_days is None or expires_in_days <= 0:
|
||||
return None
|
||||
return datetime.now(timezone.utc) + timedelta(days=expires_in_days)
|
||||
|
||||
@classmethod
|
||||
def _ensure_password_if_needed(cls, share: ShareLink, password: Optional[str]) -> None:
|
||||
if share.access_type != "password":
|
||||
return
|
||||
if not password:
|
||||
raise HTTPException(status_code=401, detail="需要密码")
|
||||
if not share.hashed_password:
|
||||
raise HTTPException(status_code=403, detail="密码错误")
|
||||
if not cls._verify_password(password, share.hashed_password):
|
||||
raise HTTPException(status_code=403, detail="密码错误")
|
||||
|
||||
@classmethod
|
||||
async def create_share_link(
|
||||
cls,
|
||||
user: UserAccount,
|
||||
name: str,
|
||||
paths: List[str],
|
||||
expires_in_days: Optional[int] = 7,
|
||||
access_type: str = "public",
|
||||
password: Optional[str] = None,
|
||||
) -> ShareLink:
|
||||
if not paths:
|
||||
raise HTTPException(status_code=400, detail="分享路径不能为空")
|
||||
|
||||
if access_type == "password" and not password:
|
||||
raise HTTPException(status_code=400, detail="密码不能为空")
|
||||
|
||||
token = secrets.token_urlsafe(16)
|
||||
expires_at = cls._calc_expires_at(expires_in_days)
|
||||
|
||||
hashed_password = None
|
||||
if access_type == "password" and password:
|
||||
hashed_password = cls._hash_password(password)
|
||||
|
||||
share = await ShareLink.create(
|
||||
token=token,
|
||||
name=name,
|
||||
paths=paths,
|
||||
user=user,
|
||||
expires_at=expires_at,
|
||||
access_type=access_type,
|
||||
hashed_password=hashed_password,
|
||||
)
|
||||
return share
|
||||
|
||||
@classmethod
|
||||
async def get_share_by_token(cls, token: str) -> ShareLink:
|
||||
share = await ShareLink.get_or_none(token=token).prefetch_related("user")
|
||||
if not share:
|
||||
raise HTTPException(status_code=404, detail="分享链接不存在")
|
||||
|
||||
if share.expires_at and share.expires_at < datetime.now(timezone.utc):
|
||||
raise HTTPException(status_code=410, detail="分享链接已过期")
|
||||
|
||||
return share
|
||||
|
||||
@classmethod
|
||||
async def verify_share_password(cls, token: str, password: str) -> ShareLink:
|
||||
share = await cls.get_share_by_token(token)
|
||||
if share.access_type != "password":
|
||||
raise HTTPException(status_code=400, detail="此分享不需要密码")
|
||||
cls._ensure_password_if_needed(share, password)
|
||||
return share
|
||||
|
||||
@classmethod
|
||||
async def ensure_share_access(cls, token: str, password: Optional[str]) -> ShareLink:
|
||||
share = await cls.get_share_by_token(token)
|
||||
cls._ensure_password_if_needed(share, password)
|
||||
return share
|
||||
|
||||
@classmethod
|
||||
async def get_user_shares(cls, user: UserAccount) -> List[ShareLink]:
|
||||
return await ShareLink.filter(user=user).order_by("-created_at")
|
||||
|
||||
@classmethod
|
||||
async def delete_share_link(cls, user: UserAccount, share_id: int) -> None:
|
||||
share = await ShareLink.get_or_none(id=share_id, user_id=user.id)
|
||||
if not share:
|
||||
raise HTTPException(status_code=404, detail="分享链接不存在")
|
||||
await share.delete()
|
||||
|
||||
@classmethod
|
||||
async def delete_expired_shares(cls, user: UserAccount) -> int:
|
||||
now = datetime.now(timezone.utc)
|
||||
deleted_count = await ShareLink.filter(user=user, expires_at__lte=now).delete()
|
||||
return deleted_count
|
||||
|
||||
@classmethod
|
||||
async def get_shared_item_details(cls, share: ShareLink, sub_path: str = ""):
|
||||
if not share.paths:
|
||||
raise HTTPException(status_code=404, detail="分享内容为空")
|
||||
|
||||
base_shared_path = share.paths[0]
|
||||
|
||||
if sub_path and sub_path != "/":
|
||||
full_path = f"{base_shared_path.rstrip('/')}/{sub_path.lstrip('/')}".rstrip("/")
|
||||
if not full_path.startswith(base_shared_path):
|
||||
raise HTTPException(status_code=403, detail="无权访问此路径")
|
||||
try:
|
||||
return await VirtualFSService.list_virtual_dir(full_path)
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="目录未找到")
|
||||
|
||||
try:
|
||||
stat = await VirtualFSService.stat_file(base_shared_path)
|
||||
if stat.get("is_dir"):
|
||||
return await VirtualFSService.list_virtual_dir(base_shared_path)
|
||||
|
||||
stat["name"] = base_shared_path.split("/")[-1]
|
||||
return {"items": [stat], "total": 1, "page": 1, "page_size": 1, "pages": 1}
|
||||
except HTTPException as e:
|
||||
if "Path is a directory" in str(e.detail) or "Not a file" in str(e.detail):
|
||||
return await VirtualFSService.list_virtual_dir(base_shared_path)
|
||||
raise e
|
||||
|
||||
@classmethod
|
||||
async def stream_shared_file(
|
||||
cls,
|
||||
token: str,
|
||||
path: str,
|
||||
range_header: str | None,
|
||||
password: Optional[str] = None,
|
||||
) -> Response:
|
||||
if not path or path == "/" or ".." in path.split("/"):
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="无效的文件路径")
|
||||
|
||||
share = await cls.ensure_share_access(token, password)
|
||||
if not share.paths:
|
||||
raise HTTPException(status_code=404, detail="分享的源文件不存在")
|
||||
base_shared_path = share.paths[0]
|
||||
|
||||
is_dir = False
|
||||
try:
|
||||
stat = await VirtualFSService.stat_file(base_shared_path)
|
||||
if stat and stat.get("is_dir"):
|
||||
is_dir = True
|
||||
except HTTPException as e:
|
||||
if "Path is a directory" in str(e.detail) or "Not a file" in str(e.detail):
|
||||
is_dir = True
|
||||
elif e.status_code == 404:
|
||||
raise HTTPException(status_code=404, detail="分享的源文件不存在")
|
||||
else:
|
||||
raise
|
||||
|
||||
if is_dir:
|
||||
full_virtual_path = f"{base_shared_path.rstrip('/')}/{path.lstrip('/')}"
|
||||
if not full_virtual_path.startswith(base_shared_path):
|
||||
raise HTTPException(status_code=403, detail="无权访问此路径")
|
||||
else:
|
||||
shared_filename = base_shared_path.split("/")[-1]
|
||||
request_filename = path.lstrip("/")
|
||||
if shared_filename != request_filename:
|
||||
raise HTTPException(status_code=403, detail="无权访问此路径")
|
||||
full_virtual_path = base_shared_path
|
||||
|
||||
response = await VirtualFSService.stream_file(full_virtual_path, range_header)
|
||||
filename = full_virtual_path.split("/")[-1]
|
||||
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{quote(filename)}"
|
||||
return response
|
||||
43
domain/share/types.py
Normal file
43
domain/share/types.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from models.database import ShareLink
|
||||
|
||||
|
||||
class ShareCreate(BaseModel):
|
||||
name: str
|
||||
paths: List[str]
|
||||
expires_in_days: Optional[int] = 7
|
||||
access_type: str = "public"
|
||||
password: Optional[str] = None
|
||||
|
||||
|
||||
class SharePassword(BaseModel):
|
||||
password: str
|
||||
|
||||
|
||||
class ShareInfo(BaseModel):
|
||||
id: int
|
||||
token: str
|
||||
name: str
|
||||
paths: List[str]
|
||||
created_at: str
|
||||
expires_at: Optional[str] = None
|
||||
access_type: str
|
||||
|
||||
@classmethod
|
||||
def from_orm(cls, obj: ShareLink):
|
||||
return cls(
|
||||
id=obj.id,
|
||||
token=obj.token,
|
||||
name=obj.name,
|
||||
paths=obj.paths,
|
||||
created_at=obj.created_at.isoformat(),
|
||||
expires_at=obj.expires_at.isoformat() if obj.expires_at else None,
|
||||
access_type=obj.access_type,
|
||||
)
|
||||
|
||||
|
||||
class ShareInfoWithPassword(ShareInfo):
|
||||
password: Optional[str] = None
|
||||
112
domain/tasks/api.py
Normal file
112
domain/tasks/api.py
Normal file
@@ -0,0 +1,112 @@
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
|
||||
from api.response import success
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.auth.service import get_current_active_user
|
||||
from domain.tasks.service import TaskService
|
||||
from domain.tasks.types import (
|
||||
AutomationTaskCreate,
|
||||
AutomationTaskUpdate,
|
||||
TaskQueueSettings,
|
||||
)
|
||||
|
||||
CurrentUser = TaskService.current_user_dep
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api/tasks",
|
||||
tags=["Tasks"],
|
||||
dependencies=[Depends(get_current_active_user)],
|
||||
responses={404: {"description": "Not found"}},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/queue")
|
||||
@audit(action=AuditAction.READ, description="获取任务队列状态")
|
||||
async def get_task_queue_status(request: Request, current_user: CurrentUser):
|
||||
payload = TaskService.get_queue_tasks()
|
||||
return success(payload)
|
||||
|
||||
|
||||
@router.get("/queue/settings")
|
||||
@audit(action=AuditAction.READ, description="获取任务队列设置")
|
||||
async def get_task_queue_settings(request: Request, current_user: CurrentUser):
|
||||
payload = TaskService.get_queue_settings()
|
||||
return success(payload.model_dump())
|
||||
|
||||
|
||||
@router.post("/queue/settings")
|
||||
@audit(
|
||||
action=AuditAction.UPDATE,
|
||||
description="更新任务队列设置",
|
||||
body_fields=["concurrency"],
|
||||
)
|
||||
async def update_task_queue_settings(request: Request, settings: TaskQueueSettings, current_user: CurrentUser):
|
||||
payload = await TaskService.update_queue_settings(settings, getattr(current_user, "id", None))
|
||||
return success(payload.model_dump())
|
||||
|
||||
|
||||
@router.get("/queue/{task_id}")
|
||||
@audit(action=AuditAction.READ, description="获取队列任务状态")
|
||||
async def get_task_status(task_id: str, request: Request, current_user: CurrentUser):
|
||||
payload = TaskService.get_queue_task(task_id)
|
||||
return success(payload)
|
||||
|
||||
|
||||
@router.post("/")
|
||||
@audit(
|
||||
action=AuditAction.CREATE,
|
||||
description="创建自动化任务",
|
||||
body_fields=[
|
||||
"name",
|
||||
"event",
|
||||
"path_pattern",
|
||||
"filename_regex",
|
||||
"processor_type",
|
||||
"processor_config",
|
||||
"enabled",
|
||||
],
|
||||
user_kw="user",
|
||||
)
|
||||
async def create_task(request: Request, task_in: AutomationTaskCreate, user: CurrentUser):
|
||||
task = await TaskService.create_task(task_in, user)
|
||||
return success(task)
|
||||
|
||||
|
||||
@router.get("/{task_id}")
|
||||
@audit(action=AuditAction.READ, description="获取自动化任务详情")
|
||||
async def get_task(task_id: int, request: Request, current_user: CurrentUser):
|
||||
task = await TaskService.get_task(task_id)
|
||||
return success(task)
|
||||
|
||||
|
||||
@router.get("/")
|
||||
@audit(action=AuditAction.READ, description="获取自动化任务列表")
|
||||
async def list_tasks(request: Request, current_user: CurrentUser):
|
||||
tasks = await TaskService.list_tasks()
|
||||
return success(tasks)
|
||||
|
||||
|
||||
@router.put("/{task_id}")
|
||||
@audit(
|
||||
action=AuditAction.UPDATE,
|
||||
description="更新自动化任务",
|
||||
body_fields=[
|
||||
"name",
|
||||
"event",
|
||||
"path_pattern",
|
||||
"filename_regex",
|
||||
"processor_type",
|
||||
"processor_config",
|
||||
"enabled",
|
||||
],
|
||||
)
|
||||
async def update_task(request: Request, current_user: CurrentUser, task_id: int, task_in: AutomationTaskUpdate):
|
||||
task = await TaskService.update_task(task_id, task_in, current_user)
|
||||
return success(task)
|
||||
|
||||
|
||||
@router.delete("/{task_id}")
|
||||
@audit(action=AuditAction.DELETE, description="删除自动化任务", user_kw="user")
|
||||
async def delete_task(task_id: int, request: Request, user: CurrentUser):
|
||||
await TaskService.delete_task(task_id, user)
|
||||
return success(msg="Task deleted")
|
||||
109
domain/tasks/service.py
Normal file
109
domain/tasks/service.py
Normal file
@@ -0,0 +1,109 @@
|
||||
import re
|
||||
from typing import Annotated, Any, Dict, Optional
|
||||
|
||||
from fastapi import Depends, HTTPException
|
||||
|
||||
from domain.auth.service import get_current_active_user
|
||||
from domain.auth.types import User
|
||||
from domain.config.service import ConfigService
|
||||
from domain.tasks.types import (
|
||||
AutomationTaskCreate,
|
||||
AutomationTaskUpdate,
|
||||
TaskQueueSettings,
|
||||
TaskQueueSettingsResponse,
|
||||
)
|
||||
from models.database import AutomationTask
|
||||
from domain.tasks.task_queue import task_queue_service
|
||||
|
||||
|
||||
class TaskService:
|
||||
current_user_dep = Annotated[User, Depends(get_current_active_user)]
|
||||
|
||||
@classmethod
|
||||
def get_queue_tasks(cls) -> list[dict[str, Any]]:
|
||||
tasks = task_queue_service.get_all_tasks()
|
||||
return [task.dict() for task in tasks]
|
||||
|
||||
@classmethod
|
||||
def get_queue_settings(cls) -> TaskQueueSettingsResponse:
|
||||
return TaskQueueSettingsResponse(
|
||||
concurrency=task_queue_service.get_concurrency(),
|
||||
active_workers=task_queue_service.get_active_worker_count(),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def update_queue_settings(cls, settings: TaskQueueSettings, user_id: Optional[int]) -> TaskQueueSettingsResponse:
|
||||
await task_queue_service.set_concurrency(settings.concurrency)
|
||||
await ConfigService.set("TASK_QUEUE_CONCURRENCY", str(task_queue_service.get_concurrency()))
|
||||
return cls.get_queue_settings()
|
||||
|
||||
@classmethod
|
||||
def get_queue_task(cls, task_id: str) -> dict[str, Any]:
|
||||
task = task_queue_service.get_task(task_id)
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
return task.dict()
|
||||
|
||||
@classmethod
|
||||
async def create_task(cls, payload: AutomationTaskCreate, user: Optional[User]) -> AutomationTask:
|
||||
task = await AutomationTask.create(**payload.model_dump())
|
||||
return task
|
||||
|
||||
@classmethod
|
||||
async def get_task(cls, task_id: int) -> AutomationTask:
|
||||
task = await AutomationTask.get_or_none(id=task_id)
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail=f"Task {task_id} not found")
|
||||
return task
|
||||
|
||||
@classmethod
|
||||
async def list_tasks(cls) -> list[AutomationTask]:
|
||||
tasks = await AutomationTask.all()
|
||||
return tasks
|
||||
|
||||
@classmethod
|
||||
async def update_task(cls, task_id: int, payload: AutomationTaskUpdate, current_user: User) -> AutomationTask:
|
||||
task = await AutomationTask.get_or_none(id=task_id)
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail=f"Task {task_id} not found")
|
||||
update_data = payload.model_dump(exclude_unset=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(task, key, value)
|
||||
await task.save()
|
||||
return task
|
||||
|
||||
@classmethod
|
||||
async def delete_task(cls, task_id: int, user: Optional[User]) -> None:
|
||||
deleted_count = await AutomationTask.filter(id=task_id).delete()
|
||||
if not deleted_count:
|
||||
raise HTTPException(status_code=404, detail=f"Task {task_id} not found")
|
||||
|
||||
@classmethod
|
||||
async def trigger_tasks(cls, event: str, path: str):
|
||||
tasks = await AutomationTask.filter(event=event, enabled=True)
|
||||
for task in tasks:
|
||||
if cls.match(task, path):
|
||||
await cls.execute(task, path)
|
||||
|
||||
@classmethod
|
||||
def match(cls, task: AutomationTask, path: str) -> bool:
|
||||
if task.path_pattern and not path.startswith(task.path_pattern):
|
||||
return False
|
||||
if task.filename_regex:
|
||||
filename = path.split("/")[-1]
|
||||
if not re.match(task.filename_regex, filename):
|
||||
return False
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
async def execute(cls, task: AutomationTask, path: str):
|
||||
await task_queue_service.add_task(
|
||||
task.processor_type,
|
||||
{
|
||||
"task_id": task.id,
|
||||
"path": path,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
task_service = TaskService
|
||||
214
domain/tasks/task_queue.py
Normal file
214
domain/tasks/task_queue.py
Normal file
@@ -0,0 +1,214 @@
|
||||
import asyncio
|
||||
from typing import Dict, Any
|
||||
from pydantic import BaseModel, Field
|
||||
import uuid
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class TaskStatus(str, Enum):
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
SUCCESS = "success"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class TaskProgress(BaseModel):
|
||||
stage: str | None = None
|
||||
percent: float | None = None
|
||||
bytes_total: int | None = None
|
||||
bytes_done: int | None = None
|
||||
detail: str | None = None
|
||||
|
||||
|
||||
class Task(BaseModel):
|
||||
id: str = Field(default_factory=lambda: uuid.uuid4().hex)
|
||||
name: str
|
||||
status: TaskStatus = TaskStatus.PENDING
|
||||
result: Any = None
|
||||
error: str | None = None
|
||||
task_info: Dict[str, Any] = {}
|
||||
progress: TaskProgress | None = None
|
||||
meta: Dict[str, Any] | None = None
|
||||
|
||||
|
||||
_SENTINEL = object()
|
||||
|
||||
|
||||
class TaskQueueService:
|
||||
def __init__(self):
|
||||
self._queue: asyncio.Queue[Task | object] = asyncio.Queue()
|
||||
self._tasks: Dict[str, Task] = {}
|
||||
self._worker_tasks: list[asyncio.Task] = []
|
||||
self._concurrency: int = 1
|
||||
self._worker_seq: int = 0
|
||||
|
||||
async def add_task(self, name: str, task_info: Dict[str, Any]) -> Task:
|
||||
task = Task(name=name, task_info=task_info)
|
||||
self._tasks[task.id] = task
|
||||
await self._queue.put(task)
|
||||
return task
|
||||
|
||||
def get_task(self, task_id: str) -> Task | None:
|
||||
return self._tasks.get(task_id)
|
||||
|
||||
def get_all_tasks(self) -> list[Task]:
|
||||
return list(self._tasks.values())
|
||||
|
||||
async def update_progress(self, task_id: str, progress: TaskProgress | Dict[str, Any]):
|
||||
task = self._tasks.get(task_id)
|
||||
if not task:
|
||||
return
|
||||
if isinstance(progress, TaskProgress):
|
||||
task.progress = progress
|
||||
else:
|
||||
task.progress = TaskProgress(**progress)
|
||||
|
||||
async def update_meta(self, task_id: str, meta: Dict[str, Any]):
|
||||
task = self._tasks.get(task_id)
|
||||
if not task:
|
||||
return
|
||||
task.meta = (task.meta or {}) | meta
|
||||
|
||||
async def _execute_task(self, task: Task):
|
||||
task.status = TaskStatus.RUNNING
|
||||
|
||||
try:
|
||||
# Local import to avoid circular dependency during module load.
|
||||
from domain.virtual_fs.service import VirtualFSService
|
||||
|
||||
if task.name == "process_file":
|
||||
params = task.task_info
|
||||
result = await VirtualFSService.process_file(
|
||||
path=params["path"],
|
||||
processor_type=params["processor_type"],
|
||||
config=params["config"],
|
||||
save_to=params.get("save_to"),
|
||||
overwrite=params.get("overwrite", False),
|
||||
)
|
||||
task.result = result
|
||||
elif task.name == "automation_task" or self._is_processor_task(task.name):
|
||||
from models.database import AutomationTask
|
||||
from domain.processors.service import get_processor
|
||||
|
||||
params = task.task_info
|
||||
auto_task = await AutomationTask.get(id=params["task_id"])
|
||||
path = params["path"]
|
||||
|
||||
processor_type = auto_task.processor_type if task.name == "automation_task" else task.name
|
||||
processor = get_processor(processor_type)
|
||||
if not processor:
|
||||
raise ValueError(f"Processor {processor_type} not found for task {auto_task.id}")
|
||||
|
||||
if processor_type != auto_task.processor_type:
|
||||
processor_type = auto_task.processor_type
|
||||
processor = get_processor(processor_type)
|
||||
if not processor:
|
||||
raise ValueError(f"Processor {processor_type} not found for task {auto_task.id}")
|
||||
|
||||
requires_input_bytes = bool(getattr(processor, "requires_input_bytes", True))
|
||||
file_content = b""
|
||||
if requires_input_bytes:
|
||||
file_content = await VirtualFSService.read_file(path)
|
||||
result = await processor.process(file_content, path, auto_task.processor_config)
|
||||
|
||||
save_to = auto_task.processor_config.get("save_to")
|
||||
if save_to and getattr(processor, "produces_file", False):
|
||||
await VirtualFSService.write_file(save_to, result)
|
||||
task.result = "Automation task completed"
|
||||
elif task.name == "offline_http_download":
|
||||
from domain.offline_downloads.service import OfflineDownloadService
|
||||
|
||||
result_path = await OfflineDownloadService.run_http_download(task)
|
||||
task.result = {"path": result_path}
|
||||
elif task.name == "cross_mount_transfer":
|
||||
result = await VirtualFSService.run_cross_mount_transfer_task(task)
|
||||
task.result = result
|
||||
elif task.name == "send_email":
|
||||
from domain.email.service import EmailService
|
||||
await EmailService.send_from_task(task.id, task.task_info)
|
||||
task.result = "Email sent"
|
||||
else:
|
||||
raise ValueError(f"Unknown task name: {task.name}")
|
||||
|
||||
task.status = TaskStatus.SUCCESS
|
||||
|
||||
except Exception as e:
|
||||
task.status = TaskStatus.FAILED
|
||||
task.error = str(e)
|
||||
|
||||
def _cleanup_workers(self):
|
||||
self._worker_tasks = [task for task in self._worker_tasks if not task.done()]
|
||||
|
||||
def _is_processor_task(self, task_name: str) -> bool:
|
||||
try:
|
||||
from domain.processors.service import get_processor
|
||||
|
||||
return get_processor(task_name) is not None
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def _ensure_worker_count(self):
|
||||
self._cleanup_workers()
|
||||
current = len(self._worker_tasks)
|
||||
if current < self._concurrency:
|
||||
for _ in range(self._concurrency - current):
|
||||
self._worker_seq += 1
|
||||
worker_id = self._worker_seq
|
||||
worker_task = asyncio.create_task(self._worker_loop(worker_id))
|
||||
self._worker_tasks.append(worker_task)
|
||||
elif current > self._concurrency:
|
||||
for _ in range(current - self._concurrency):
|
||||
await self._queue.put(_SENTINEL)
|
||||
|
||||
async def _worker_loop(self, worker_id: int):
|
||||
current_task = asyncio.current_task()
|
||||
try:
|
||||
while True:
|
||||
job = await self._queue.get()
|
||||
if job is _SENTINEL:
|
||||
self._queue.task_done()
|
||||
break
|
||||
try:
|
||||
await self._execute_task(job)
|
||||
except Exception as e:
|
||||
pass
|
||||
finally:
|
||||
self._queue.task_done()
|
||||
finally:
|
||||
if current_task in self._worker_tasks:
|
||||
self._worker_tasks.remove(current_task) # type: ignore[arg-type]
|
||||
|
||||
async def start_worker(self, concurrency: int | None = None):
|
||||
if concurrency is None:
|
||||
from domain.config.service import ConfigService
|
||||
|
||||
stored_value = await ConfigService.get("TASK_QUEUE_CONCURRENCY", self._concurrency)
|
||||
try:
|
||||
concurrency = int(stored_value)
|
||||
except (TypeError, ValueError):
|
||||
concurrency = self._concurrency
|
||||
await self.set_concurrency(concurrency)
|
||||
|
||||
async def set_concurrency(self, value: int):
|
||||
value = max(1, int(value))
|
||||
if value != self._concurrency:
|
||||
self._concurrency = value
|
||||
await self._ensure_worker_count()
|
||||
|
||||
async def stop_worker(self):
|
||||
self._cleanup_workers()
|
||||
for _ in range(len(self._worker_tasks)):
|
||||
await self._queue.put(_SENTINEL)
|
||||
if self._worker_tasks:
|
||||
await asyncio.gather(*self._worker_tasks, return_exceptions=True)
|
||||
self._worker_tasks.clear()
|
||||
|
||||
def get_concurrency(self) -> int:
|
||||
return self._concurrency
|
||||
|
||||
def get_active_worker_count(self) -> int:
|
||||
self._cleanup_workers()
|
||||
return len(self._worker_tasks)
|
||||
|
||||
|
||||
task_queue_service = TaskQueueService()
|
||||
@@ -1,5 +1,6 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, Dict, Any
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class AutomationTaskBase(BaseModel):
|
||||
@@ -29,3 +30,11 @@ class AutomationTaskRead(AutomationTaskBase):
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class TaskQueueSettings(BaseModel):
|
||||
concurrency: int = Field(..., ge=1, description="Desired number of concurrent task workers")
|
||||
|
||||
|
||||
class TaskQueueSettingsResponse(TaskQueueSettings):
|
||||
active_workers: int = Field(..., ge=0, description="Currently running worker count")
|
||||
189
domain/virtual_fs/api.py
Normal file
189
domain/virtual_fs/api.py
Normal file
@@ -0,0 +1,189 @@
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, File, Query, Request, UploadFile
|
||||
|
||||
from api.response import success
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.auth.service import get_current_active_user
|
||||
from domain.auth.types import User
|
||||
from domain.virtual_fs.service import VirtualFSService
|
||||
from domain.virtual_fs.types import MkdirRequest, MoveRequest
|
||||
|
||||
router = APIRouter(prefix="/api/fs", tags=["virtual-fs"])
|
||||
|
||||
|
||||
@router.get("/file/{full_path:path}")
|
||||
@audit(action=AuditAction.DOWNLOAD, description="获取文件")
|
||||
async def get_file(
|
||||
full_path: str,
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
return await VirtualFSService.serve_file(full_path, request.headers.get("Range"))
|
||||
|
||||
|
||||
@router.get("/thumb/{full_path:path}")
|
||||
@audit(action=AuditAction.READ, description="获取缩略图")
|
||||
async def get_thumb(
|
||||
full_path: str,
|
||||
request: Request,
|
||||
w: int = Query(256, ge=8, le=1024),
|
||||
h: int = Query(256, ge=8, le=1024),
|
||||
fit: str = Query("cover"),
|
||||
):
|
||||
return await VirtualFSService.get_thumbnail(full_path, w, h, fit)
|
||||
|
||||
|
||||
@router.get("/stream/{full_path:path}")
|
||||
@audit(action=AuditAction.DOWNLOAD, description="流式读取文件")
|
||||
async def stream_endpoint(
|
||||
full_path: str,
|
||||
request: Request,
|
||||
):
|
||||
return await VirtualFSService.stream_response(full_path, request.headers.get("Range"))
|
||||
|
||||
|
||||
@router.get("/temp-link/{full_path:path}")
|
||||
@audit(action=AuditAction.SHARE, description="创建临时链接")
|
||||
async def get_temp_link(
|
||||
full_path: str,
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
expires_in: int = Query(3600, description="有效时间(秒), 0或负数表示永久"),
|
||||
):
|
||||
data = await VirtualFSService.create_temp_link(full_path, expires_in)
|
||||
return success(data)
|
||||
|
||||
|
||||
@router.get("/public/{token}")
|
||||
@audit(action=AuditAction.DOWNLOAD, description="访问临时链接文件")
|
||||
async def access_public_file(
|
||||
token: str,
|
||||
request: Request,
|
||||
):
|
||||
return await VirtualFSService.access_public_file(token, request.headers.get("Range"))
|
||||
|
||||
|
||||
@router.get("/stat/{full_path:path}")
|
||||
@audit(action=AuditAction.READ, description="查看文件信息")
|
||||
async def get_file_stat(
|
||||
full_path: str,
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
stat = await VirtualFSService.stat(full_path)
|
||||
return success(stat)
|
||||
|
||||
|
||||
@router.post("/file/{full_path:path}")
|
||||
@audit(action=AuditAction.UPLOAD, description="上传文件")
|
||||
async def put_file(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
full_path: str,
|
||||
file: UploadFile = File(...),
|
||||
):
|
||||
data = await file.read()
|
||||
result = await VirtualFSService.write_uploaded_file(full_path, data)
|
||||
return success(result)
|
||||
|
||||
|
||||
@router.post("/mkdir")
|
||||
@audit(action=AuditAction.CREATE, description="创建目录", body_fields=["path"])
|
||||
async def api_mkdir(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
body: MkdirRequest,
|
||||
):
|
||||
result = await VirtualFSService.mkdir(body.path)
|
||||
return success(result)
|
||||
|
||||
|
||||
@router.post("/move")
|
||||
@audit(action=AuditAction.UPDATE, description="移动路径", body_fields=["src", "dst"])
|
||||
async def api_move(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
body: MoveRequest,
|
||||
overwrite: bool = Query(False, description="是否允许覆盖已存在目标"),
|
||||
):
|
||||
result = await VirtualFSService.move(body.src, body.dst, overwrite)
|
||||
return success(result)
|
||||
|
||||
|
||||
@router.post("/rename")
|
||||
@audit(action=AuditAction.UPDATE, description="重命名路径", body_fields=["src", "dst"])
|
||||
async def api_rename(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
body: MoveRequest,
|
||||
overwrite: bool = Query(False, description="是否允许覆盖已存在目标"),
|
||||
):
|
||||
result = await VirtualFSService.rename(body.src, body.dst, overwrite)
|
||||
return success(result)
|
||||
|
||||
|
||||
@router.post("/copy")
|
||||
@audit(action=AuditAction.CREATE, description="复制路径", body_fields=["src", "dst"])
|
||||
async def api_copy(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
body: MoveRequest,
|
||||
overwrite: bool = Query(False, description="是否覆盖已存在目标"),
|
||||
):
|
||||
result = await VirtualFSService.copy(body.src, body.dst, overwrite)
|
||||
return success(result)
|
||||
|
||||
|
||||
@router.post("/upload/{full_path:path}")
|
||||
@audit(action=AuditAction.UPLOAD, description="流式上传文件")
|
||||
async def upload_stream(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
full_path: str,
|
||||
file: UploadFile = File(...),
|
||||
overwrite: bool = Query(True, description="是否覆盖已存在文件"),
|
||||
chunk_size: int = Query(1024 * 1024, ge=8 * 1024, le=8 * 1024 * 1024, description="单次读取块大小"),
|
||||
):
|
||||
result = await VirtualFSService.upload_stream_from_upload_file(full_path, file, chunk_size, overwrite)
|
||||
return success(result)
|
||||
|
||||
|
||||
@router.get("/{full_path:path}")
|
||||
@audit(action=AuditAction.READ, description="浏览目录")
|
||||
async def browse_fs(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
full_path: str,
|
||||
page_num: int = Query(1, alias="page", ge=1, description="页码"),
|
||||
page_size: int = Query(50, ge=1, le=500, description="每页条数"),
|
||||
sort_by: str = Query("name", description="按字段排序: name, size, mtime"),
|
||||
sort_order: str = Query("asc", description="排序顺序: asc, desc"),
|
||||
):
|
||||
data = await VirtualFSService.list_directory(full_path, page_num, page_size, sort_by, sort_order)
|
||||
return success(data)
|
||||
|
||||
|
||||
@router.delete("/{full_path:path}")
|
||||
@audit(action=AuditAction.DELETE, description="删除路径")
|
||||
async def api_delete(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
full_path: str,
|
||||
):
|
||||
result = await VirtualFSService.delete(full_path)
|
||||
return success(result)
|
||||
|
||||
|
||||
@router.get("/")
|
||||
@audit(action=AuditAction.READ, description="浏览根目录")
|
||||
async def root_listing(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
page_num: int = Query(1, alias="page", ge=1, description="页码"),
|
||||
page_size: int = Query(50, ge=1, le=500, description="每页条数"),
|
||||
sort_by: str = Query("name", description="按字段排序: name, size, mtime"),
|
||||
sort_order: str = Query("asc", description="排序顺序: asc, desc"),
|
||||
):
|
||||
data = await VirtualFSService.list_directory("/", page_num, page_size, sort_by, sort_order)
|
||||
return success(data)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user