mirror of
https://github.com/DrizzleTime/Foxel.git
synced 2026-06-21 07:13:55 +08:00
Initial commit
This commit is contained in:
27
.dockerignore
Normal file
27
.dockerignore
Normal file
@@ -0,0 +1,27 @@
|
||||
# Git
|
||||
.git
|
||||
.gitignore
|
||||
|
||||
# Python
|
||||
__pycache__/
|
||||
*.pyc
|
||||
*.pyo
|
||||
*.pyd
|
||||
.Python
|
||||
.venv/
|
||||
.env/
|
||||
|
||||
# Database
|
||||
*.sqlite3
|
||||
*.sqlite3-shm
|
||||
*.sqlite3-wal
|
||||
*.db
|
||||
*.db.lock
|
||||
|
||||
# Node
|
||||
web/node_modules/
|
||||
web/dist/
|
||||
|
||||
# IDEs and editors
|
||||
.idea/
|
||||
.vscode/
|
||||
1
.github/FUNDING.yml
vendored
Normal file
1
.github/FUNDING.yml
vendored
Normal file
@@ -0,0 +1 @@
|
||||
custom: https://foxel.cc/sponsor
|
||||
75
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
Normal file
75
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
Normal file
@@ -0,0 +1,75 @@
|
||||
name: Bug Report / 缺陷报告
|
||||
description: Report reproducible defects with clear context / 请提供可复现的缺陷信息
|
||||
title: "[Bug] "
|
||||
labels:
|
||||
- bug
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
Thanks for helping us improve Foxel! / 感谢你帮助改进 Foxel!
|
||||
Please confirm the checklist below before filing. / 在提交前请确认以下事项。
|
||||
- type: checkboxes
|
||||
id: validations
|
||||
attributes:
|
||||
label: Pre-flight Check / 提交前检查
|
||||
options:
|
||||
- label: I searched existing issues and docs / 我已搜索现有 Issue 与文档
|
||||
required: true
|
||||
- label: This is not a question or feature request / 这不是问题咨询或功能需求
|
||||
required: true
|
||||
- type: textarea
|
||||
id: summary
|
||||
attributes:
|
||||
label: Bug Summary / 缺陷摘要
|
||||
description: Briefly describe what is wrong / 简要说明出现了什么问题
|
||||
placeholder: e.g. Upload fails with 500 error / 例如:上传时报 500 错误
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: steps
|
||||
attributes:
|
||||
label: Steps to Reproduce / 复现步骤
|
||||
description: List numbered steps to trigger the bug / 列出触发问题的步骤
|
||||
placeholder: |
|
||||
1. ...
|
||||
2. ...
|
||||
3. ...
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: expected
|
||||
attributes:
|
||||
label: Expected Behavior / 预期行为
|
||||
description: What should happen instead? / 期望看到什么结果?
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: actual
|
||||
attributes:
|
||||
label: Actual Behavior / 实际行为
|
||||
description: What actually happens? Include messages or screenshots / 实际发生了什么?可附报错或截图
|
||||
validations:
|
||||
required: true
|
||||
- type: input
|
||||
id: version
|
||||
attributes:
|
||||
label: Version / 版本信息
|
||||
description: Git commit, tag, or build number / 提供 Git 提交、标签或构建号
|
||||
validations:
|
||||
required: false
|
||||
- type: textarea
|
||||
id: environment
|
||||
attributes:
|
||||
label: Environment / 运行环境
|
||||
description: OS, browser, API server config, etc. / 操作系统、浏览器、服务端配置等
|
||||
validations:
|
||||
required: false
|
||||
- type: textarea
|
||||
id: logs
|
||||
attributes:
|
||||
label: Logs & Attachments / 日志与附件
|
||||
description: Paste relevant logs, stack traces, screenshots / 粘贴相关日志、堆栈或截图
|
||||
render: shell
|
||||
validations:
|
||||
required: false
|
||||
56
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
Normal file
56
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
Normal file
@@ -0,0 +1,56 @@
|
||||
name: Feature Request / 功能需求
|
||||
description: Suggest enhancements or new capabilities / 提出改进或新增能力
|
||||
title: "[Feature] "
|
||||
labels:
|
||||
- enhancement
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
Tell us about your idea! / 欢迎分享你的想法!
|
||||
Please complete the sections below so we can evaluate it quickly. / 请完整填写以下信息,便于快速评估。
|
||||
- type: checkboxes
|
||||
id: prechecks
|
||||
attributes:
|
||||
label: Pre-flight Check / 提交前检查
|
||||
options:
|
||||
- label: I searched existing issues and roadmap / 我已搜索现有 Issue 与路线图
|
||||
required: true
|
||||
- label: This is not a bug report or question / 这不是缺陷或问题咨询
|
||||
required: true
|
||||
- type: textarea
|
||||
id: summary
|
||||
attributes:
|
||||
label: Feature Summary / 功能概述
|
||||
description: What do you want to build? / 希望新增什么能力?
|
||||
placeholder: e.g. Support sharing download links / 例如:支持分享下载链接
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: motivation
|
||||
attributes:
|
||||
label: Motivation / 背景与价值
|
||||
description: Why is this feature important? Who benefits? / 为什么重要?受益者是谁?
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: scope
|
||||
attributes:
|
||||
label: Proposed Solution / 建议方案
|
||||
description: Outline how the feature might work, including API or UI hints / 描述可能的实现方式,包含 API 或 UI 提示
|
||||
validations:
|
||||
required: false
|
||||
- type: textarea
|
||||
id: alternatives
|
||||
attributes:
|
||||
label: Alternatives / 可选方案
|
||||
description: List any alternatives considered / 如有考虑过其他方案请列出
|
||||
validations:
|
||||
required: false
|
||||
- type: textarea
|
||||
id: extra
|
||||
attributes:
|
||||
label: Additional Context / 补充信息
|
||||
description: Diagrams, sketches, links, constraints, etc. / 可附上草图、链接或约束
|
||||
validations:
|
||||
required: false
|
||||
42
.github/ISSUE_TEMPLATE/question.yml
vendored
Normal file
42
.github/ISSUE_TEMPLATE/question.yml
vendored
Normal file
@@ -0,0 +1,42 @@
|
||||
name: Question / 问题咨询
|
||||
description: Ask about usage, configuration, or clarification / 用于使用、配置或澄清问题
|
||||
title: "[Question] "
|
||||
labels:
|
||||
- question
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
Need help? You're in the right place. / 需要帮助?请按以下提示填写。
|
||||
Check the docs before filing. / 提交前请先查阅文档。
|
||||
- type: checkboxes
|
||||
id: prechecks
|
||||
attributes:
|
||||
label: Pre-flight Check / 提交前检查
|
||||
options:
|
||||
- label: I searched existing issues and discussions / 我已搜索现有 Issue 和讨论
|
||||
required: true
|
||||
- label: I read the relevant documentation / 我已阅读相关文档
|
||||
required: true
|
||||
- type: textarea
|
||||
id: question
|
||||
attributes:
|
||||
label: Question Details / 问题详情
|
||||
description: What do you need help with? Be specific. / 具体说明需要帮助的内容
|
||||
placeholder: Describe the scenario, expectation, and blockers / 说明场景、期望结果与阻碍
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: tried
|
||||
attributes:
|
||||
label: What You Tried / 已尝试方案
|
||||
description: List commands, configs, or steps attempted / 列出尝试过的命令、配置或步骤
|
||||
validations:
|
||||
required: false
|
||||
- type: textarea
|
||||
id: context
|
||||
attributes:
|
||||
label: Additional Context / 补充信息
|
||||
description: Environment details, logs, screenshots / 可补充运行环境、日志或截图
|
||||
validations:
|
||||
required: false
|
||||
16
.github/dependabot.yml
vendored
Normal file
16
.github/dependabot.yml
vendored
Normal file
@@ -0,0 +1,16 @@
|
||||
version: 2
|
||||
updates:
|
||||
- package-ecosystem: "github-actions"
|
||||
directory: "/"
|
||||
schedule:
|
||||
interval: "monthly"
|
||||
|
||||
- package-ecosystem: "bun"
|
||||
directory: "/web"
|
||||
schedule:
|
||||
interval: "monthly"
|
||||
|
||||
- package-ecosystem: "uv"
|
||||
directory: "/"
|
||||
schedule:
|
||||
interval: "monthly"
|
||||
22
.github/release-drafter.yml
vendored
Normal file
22
.github/release-drafter.yml
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
name-template: 'v$RESOLVED_VERSION'
|
||||
tag-template: 'v$RESOLVED_VERSION'
|
||||
categories:
|
||||
- title: '🚀 Features'
|
||||
labels:
|
||||
- 'feat'
|
||||
- title: '🐛 Bug Fixes'
|
||||
labels:
|
||||
- 'fix'
|
||||
- title: '📦 Code Refactoring'
|
||||
labels:
|
||||
- 'refactor'
|
||||
- title: '📄 Documentation'
|
||||
labels:
|
||||
- 'docs'
|
||||
- title: '🧰 Maintenance'
|
||||
label: 'chore'
|
||||
change-template: '- $TITLE @$AUTHOR (#$NUMBER)'
|
||||
template: |
|
||||
## Changes
|
||||
|
||||
$CHANGES
|
||||
51
.github/workflows/docker-clean.yml
vendored
Normal file
51
.github/workflows/docker-clean.yml
vendored
Normal file
@@ -0,0 +1,51 @@
|
||||
name: Clean dangling Docker images
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
docker-clean:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
|
||||
steps:
|
||||
- name: Delete untagged GHCR versions
|
||||
shell: bash
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
OWNER="${GITHUB_REPOSITORY_OWNER}"
|
||||
PACKAGE="$(echo "${GITHUB_REPOSITORY##*/}" | tr '[:upper:]' '[:lower:]')"
|
||||
|
||||
OWNER_TYPE="$(gh api "/users/${OWNER}" -q '.type')"
|
||||
if [[ "${OWNER_TYPE}" == "Organization" ]]; then
|
||||
SCOPE="orgs/${OWNER}"
|
||||
else
|
||||
SCOPE="users/${OWNER}"
|
||||
fi
|
||||
|
||||
BASE_PATH="/${SCOPE}/packages/container/${PACKAGE}"
|
||||
|
||||
if ! gh api "${BASE_PATH}" >/dev/null 2>&1; then
|
||||
echo "Package ghcr.io/${OWNER}/${PACKAGE} not found or accessible. Nothing to clean."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
mapfile -t VERSION_IDS < <(gh api --paginate "${BASE_PATH}/versions?per_page=100" \
|
||||
-q '.[] | select(.metadata.container.tags | length == 0) | .id')
|
||||
|
||||
if [[ ${#VERSION_IDS[@]} -eq 0 ]]; then
|
||||
echo "No untagged versions to delete."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "Deleting ${#VERSION_IDS[@]} untagged versions from ghcr.io/${OWNER}/${PACKAGE}..."
|
||||
for id in "${VERSION_IDS[@]}"; do
|
||||
gh api -X DELETE "${BASE_PATH}/versions/${id}" >/dev/null
|
||||
echo "Deleted version ${id}"
|
||||
done
|
||||
|
||||
echo "Cleanup complete."
|
||||
53
.github/workflows/docker.yml
vendored
Normal file
53
.github/workflows/docker.yml
vendored
Normal file
@@ -0,0 +1,53 @@
|
||||
name: Build and Push Docker image
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
tags:
|
||||
- 'v*.*.*'
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
with:
|
||||
driver: docker-container
|
||||
|
||||
- name: Lowercase repo name
|
||||
run: echo "REPO_LC=${GITHUB_REPOSITORY,,}" >> $GITHUB_ENV
|
||||
|
||||
- name: Set up Docker tags
|
||||
id: meta
|
||||
run: |
|
||||
if [[ $GITHUB_REF == refs/tags/* ]]; then
|
||||
VERSION=${GITHUB_REF#refs/tags/}
|
||||
echo "DOCKER_TAGS=ghcr.io/${REPO_LC}:${VERSION},ghcr.io/${REPO_LC}:latest" >> $GITHUB_ENV
|
||||
else
|
||||
echo "DOCKER_TAGS=ghcr.io/${REPO_LC}:dev" >> $GITHUB_ENV
|
||||
fi
|
||||
|
||||
- name: Log in to GitHub Container Registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Build and push Docker image (multi arch)
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: ${{ env.DOCKER_TAGS }}
|
||||
17
.github/workflows/release-drafter.yml
vendored
Normal file
17
.github/workflows/release-drafter.yml
vendored
Normal file
@@ -0,0 +1,17 @@
|
||||
name: Release Drafter
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
update_release_draft:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
steps:
|
||||
- uses: release-drafter/release-drafter@v6
|
||||
with:
|
||||
config-name: release-drafter.yml
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
34
.gitignore
vendored
Normal file
34
.gitignore
vendored
Normal file
@@ -0,0 +1,34 @@
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
.venv/
|
||||
.vscode/
|
||||
data/
|
||||
.env
|
||||
AGENTS.md
|
||||
|
||||
# Logs
|
||||
/web/logs
|
||||
*.log
|
||||
npm-debug.log*
|
||||
yarn-debug.log*
|
||||
yarn-error.log*
|
||||
pnpm-debug.log*
|
||||
lerna-debug.log*
|
||||
|
||||
/web/node_modules
|
||||
/web/dist
|
||||
/web/dist-ssr
|
||||
/web/*.local
|
||||
|
||||
# Editor directories and files
|
||||
.vscode/*
|
||||
!.vscode/extensions.json
|
||||
.idea
|
||||
.DS_Store
|
||||
*.suo
|
||||
*.ntvs*
|
||||
*.njsproj
|
||||
*.sln
|
||||
*.sw?
|
||||
1
.python-version
Normal file
1
.python-version
Normal file
@@ -0,0 +1 @@
|
||||
3.14
|
||||
192
CONTRIBUTING.md
Normal file
192
CONTRIBUTING.md
Normal file
@@ -0,0 +1,192 @@
|
||||
<div align="right">
|
||||
<b>English</b> | <a href="./CONTRIBUTING_zh.md">简体中文</a>
|
||||
</div>
|
||||
|
||||
# Contributing to Foxel
|
||||
|
||||
We appreciate every minute you spend helping Foxel improve. This guide explains the contribution workflow so you can get started quickly.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [How to Contribute](#how-to-contribute)
|
||||
- [🐛 Report Bugs](#-report-bugs)
|
||||
- [✨ Suggest Features](#-suggest-features)
|
||||
- [🛠️ Contribute Code](#️-contribute-code)
|
||||
- [Development Environment](#development-environment)
|
||||
- [Prerequisites](#prerequisites)
|
||||
- [Backend (FastAPI)](#backend-fastapi)
|
||||
- [Frontend (React + Vite)](#frontend-react--vite)
|
||||
- [Contribution Guidelines](#contribution-guidelines)
|
||||
- [Storage Adapters](#storage-adapters)
|
||||
- [Frontend Apps](#frontend-apps)
|
||||
- [Submission Rules](#submission-rules)
|
||||
- [Git Branching](#git-branching)
|
||||
- [Commit Message Format](#commit-message-format)
|
||||
- [Pull Request Flow](#pull-request-flow)
|
||||
|
||||
---
|
||||
|
||||
## How to Contribute
|
||||
|
||||
### 🐛 Report Bugs
|
||||
|
||||
If you discover a bug, open a ticket via [GitHub Issues](https://github.com/DrizzleTime/Foxel/issues) and include:
|
||||
|
||||
- **A clear title** that summarises the problem.
|
||||
- **Reproduction steps** with enough detail to trigger the bug.
|
||||
- **Expected vs actual behaviour** to highlight the gap.
|
||||
- **Environment details** such as operating system, browser version, and the Foxel build you used.
|
||||
|
||||
### ✨ Suggest Features
|
||||
|
||||
To propose a new capability or an improvement, create an Issue and choose the "Feature Request" template. Document:
|
||||
|
||||
- **Problem statement** – what pain point will the feature solve?
|
||||
- **Proposed solution** – how you expect it to work.
|
||||
- **Supporting material** – screenshots, references, or related links if helpful.
|
||||
|
||||
### 🛠️ Contribute Code
|
||||
|
||||
Follow the development setup below before opening a pull request. Keep changes focused and small so they are easier to review.
|
||||
|
||||
## Development Environment
|
||||
|
||||
### Prerequisites
|
||||
|
||||
Install the following tooling first:
|
||||
|
||||
- **Git** for version control
|
||||
- **Python** 3.13 or newer
|
||||
- **Bun** for frontend package management and scripts
|
||||
|
||||
### Backend (FastAPI)
|
||||
|
||||
1. **Clone the repository**
|
||||
|
||||
```bash
|
||||
git clone https://github.com/DrizzleTime/foxel.git
|
||||
cd Foxel
|
||||
```
|
||||
|
||||
2. **Create and activate a virtual environment**
|
||||
|
||||
`uv` is recommended for performance and reproducibility:
|
||||
|
||||
```bash
|
||||
uv venv
|
||||
source .venv/bin/activate
|
||||
# On Windows: .venv\Scripts\activate
|
||||
```
|
||||
|
||||
3. **Install dependencies**
|
||||
|
||||
```bash
|
||||
uv sync
|
||||
```
|
||||
|
||||
4. **Prepare local resources**
|
||||
|
||||
- Create the data directory:
|
||||
|
||||
```bash
|
||||
mkdir -p data/db
|
||||
```
|
||||
|
||||
Ensure the application user can read and write to `data/db`.
|
||||
|
||||
- Create an `.env` file in the project root and provide the required secrets. Replace the sample values with your own random strings:
|
||||
|
||||
```dotenv
|
||||
SECRET_KEY=EnsRhL9NFPxgFVc+7t96/y70DIOR+9SpntcIqQa90TU=
|
||||
TEMP_LINK_SECRET_KEY=EnsRhL9NFPxgFVc+7t96/y70DIOR+9SpntcIqQa90TU=
|
||||
```
|
||||
|
||||
5. **Start the development server**
|
||||
|
||||
```bash
|
||||
uvicorn main:app --reload --host 0.0.0.0 --port 8000
|
||||
```
|
||||
|
||||
The API is available at `http://localhost:8000`, and the interactive docs live at `http://localhost:8000/docs`.
|
||||
|
||||
### Frontend (React + Vite)
|
||||
|
||||
1. **Enter the frontend directory**
|
||||
|
||||
```bash
|
||||
cd web
|
||||
```
|
||||
|
||||
2. **Install dependencies**
|
||||
|
||||
```bash
|
||||
bun install
|
||||
```
|
||||
|
||||
3. **Run the dev server**
|
||||
|
||||
```bash
|
||||
bun run dev
|
||||
```
|
||||
|
||||
The Vite dev server runs at `http://localhost:5173` and proxies `/api` requests to the backend.
|
||||
|
||||
## Contribution Guidelines
|
||||
|
||||
### Storage Adapters
|
||||
|
||||
Storage adapters integrate new storage providers (for example S3, FTP, or Alist).
|
||||
|
||||
1. Create a new module under [`domain/adapters/providers/`](domain/adapters/providers/) (for example `my_new_adapter.py`).
|
||||
2. Implement a class that inherits from [`domain.adapters.providers.base.BaseAdapter`](domain/adapters/providers/base.py) and provide concrete implementations for the abstract methods such as `list_dir`, `get_meta`, `upload`, and `download`.
|
||||
|
||||
### Frontend Apps
|
||||
|
||||
Frontend apps enable in-browser previews or editors for specific file types.
|
||||
|
||||
1. Add a new folder in [`web/src/apps/`](web/src/apps/) for your app and expose a React component.
|
||||
2. Implement the `FoxelApp` interface defined in [`web/src/apps/types.ts`](web/src/apps/types.ts).
|
||||
3. Register the app in [`web/src/apps/registry.ts`](web/src/apps/registry.ts) and declare the MIME types or extensions it supports.
|
||||
|
||||
## Submission Rules
|
||||
|
||||
### Git Branching
|
||||
|
||||
Start your work from the latest `main` branch and push feature changes on a dedicated branch.
|
||||
|
||||
### Commit Message Format
|
||||
|
||||
We follow the [Conventional Commits](https://www.conventionalcommits.org/) specification to drive release tooling.
|
||||
|
||||
```
|
||||
<type>(<scope>): <subject>
|
||||
<BLANK LINE>
|
||||
<body>
|
||||
<BLANK LINE>
|
||||
<footer>
|
||||
```
|
||||
|
||||
- **type**: e.g. `feat`, `fix`, `docs`, `style`, `refactor`, `test`, `chore`.
|
||||
- **scope** (optional): the area impacted by the change, such as `adapter`, `ui`, or `api`.
|
||||
- **subject**: a concise summary written in the imperative mood.
|
||||
|
||||
**Examples:**
|
||||
|
||||
```
|
||||
feat(adapter): add support for Alist storage
|
||||
```
|
||||
|
||||
```
|
||||
fix(ui): correct display issue in file list view
|
||||
```
|
||||
|
||||
### Pull Request Flow
|
||||
|
||||
1. Fork the repository and clone it locally.
|
||||
2. Create and switch to your feature branch.
|
||||
3. Implement the change and run relevant checks.
|
||||
4. Push the branch to your fork.
|
||||
5. Open a pull request against `main` in the Foxel repository.
|
||||
6. Explain the change set, its motivation, and reference related Issues in the PR description.
|
||||
|
||||
Maintainers will review your pull request as soon as possible.
|
||||
202
CONTRIBUTING_zh.md
Normal file
202
CONTRIBUTING_zh.md
Normal file
@@ -0,0 +1,202 @@
|
||||
<div align="right">
|
||||
<a href="./CONTRIBUTING.md">English</a> | <b>简体中文</b>
|
||||
</div>
|
||||
|
||||
# Contributing to Foxel
|
||||
|
||||
🎉 首先,非常感谢您愿意花时间为 Foxel 做出贡献!
|
||||
|
||||
我们热烈欢迎各种形式的贡献。无论是报告 Bug、提出新功能建议、完善文档,还是直接提交代码,都将对项目产生积极的影响。
|
||||
|
||||
本指南将帮助您顺利地参与到项目中来。
|
||||
|
||||
## 目录
|
||||
|
||||
- [如何贡献](#如何贡献)
|
||||
- [🐛 报告 Bug](#-报告-bug)
|
||||
- [✨ 提交功能建议](#-提交功能建议)
|
||||
- [🛠️ 贡献代码](#️-贡献代码)
|
||||
- [开发环境搭建](#开发环境搭建)
|
||||
- [依赖准备](#依赖准备)
|
||||
- [后端 (FastAPI)](#后端-fastapi)
|
||||
- [前端 (React + Vite)](#前端-react--vite)
|
||||
- [代码贡献指南](#代码贡献指南)
|
||||
- [贡献存储适配器 (Adapter)](#贡献存储适配器-adapter)
|
||||
- [贡献前端应用 (App)](#贡献前端应用-app)
|
||||
- [提交规范](#提交规范)
|
||||
- [Git 分支管理](#git-分支管理)
|
||||
- [Commit Message 格式](#commit-message-格式)
|
||||
- [Pull Request 流程](#pull-request-流程)
|
||||
|
||||
---
|
||||
|
||||
## 如何贡献
|
||||
|
||||
### 🐛 报告 Bug
|
||||
|
||||
如果您在使用的过程中发现了 Bug,请通过 [GitHub Issues](https://github.com/DrizzleTime/Foxel/issues) 来报告。请在报告中提供以下信息:
|
||||
|
||||
- **清晰的标题**:简明扼要地描述问题。
|
||||
- **复现步骤**:详细说明如何一步步重现该 Bug。
|
||||
- **期望行为** vs **实际行为**:描述您预期的结果和实际发生的情况。
|
||||
- **环境信息**:例如操作系统、浏览器版本、Foxel 版本等。
|
||||
|
||||
### ✨ 提交功能建议
|
||||
|
||||
我们欢迎任何关于新功能或改进的建议。请通过 [GitHub Issues](https://github.com/DrizzleTime/Foxel/issues) 创建一个 "Feature Request",并详细阐述您的想法:
|
||||
|
||||
- **问题描述**:说明该功能要解决什么问题。
|
||||
- **方案设想**:描述您希望该功能如何工作。
|
||||
- **相关信息**:提供任何有助于理解您想法的截图、链接或参考。
|
||||
|
||||
### 🛠️ 贡献代码
|
||||
|
||||
如果您希望直接贡献代码,请参考下面的开发和提交流程。
|
||||
|
||||
## 开发环境搭建
|
||||
|
||||
### 依赖准备
|
||||
|
||||
- **Git**: 用于版本控制。
|
||||
- **Python**: >= 3.13
|
||||
- **Bun**: 用于前端包管理和脚本运行。
|
||||
|
||||
### 后端 (FastAPI)
|
||||
|
||||
后端 API 服务基于 Python 和 FastAPI 构建。
|
||||
|
||||
1. **克隆仓库**
|
||||
|
||||
```bash
|
||||
git clone https://github.com/DrizzleTime/foxel.git
|
||||
cd Foxel
|
||||
```
|
||||
|
||||
2. **创建并激活 Python 虚拟环境**
|
||||
|
||||
我们推荐使用 `uv` 来管理虚拟环境,以获得最佳性能。
|
||||
|
||||
```bash
|
||||
uv venv
|
||||
source .venv/bin/activate
|
||||
# On Windows: .venv\Scripts\activate
|
||||
```
|
||||
|
||||
3. **安装依赖**
|
||||
|
||||
```bash
|
||||
uv sync
|
||||
```
|
||||
|
||||
4. **初始化环境**
|
||||
|
||||
在启动服务前,请进行以下准备:
|
||||
|
||||
- **创建数据目录**:
|
||||
在项目根目录执行 `mkdir -p data/db`。这将创建用于存放数据库等文件的目录。
|
||||
> [!IMPORTANT]
|
||||
> 请确保应用拥有对 `data/db` 目录的读写权限。
|
||||
|
||||
- **创建 `.env` 配置文件**:
|
||||
在项目根目录创建名为 `.env` 的文件,并填入以下内容。这些密钥用于保障应用安全,您可以按需修改。
|
||||
|
||||
```dotenv
|
||||
SECRET_KEY=EnsRhL9NFPxgFVc+7t96/y70DIOR+9SpntcIqQa90TU=
|
||||
TEMP_LINK_SECRET_KEY=EnsRhL9NFPxgFVc+7t96/y70DIOR+9SpntcIqQa90TU=
|
||||
```
|
||||
|
||||
5. **启动开发服务器**
|
||||
|
||||
```bash
|
||||
uvicorn main:app --reload --host 0.0.0.0 --port 8000
|
||||
```
|
||||
|
||||
API 服务将在 `http://localhost:8000` 上运行,您可以通过 `http://localhost:8000/docs` 访问自动生成的 API 文档。
|
||||
|
||||
### 前端 (React + Vite)
|
||||
|
||||
前端应用使用 React, Vite, 和 TypeScript 构建。
|
||||
|
||||
1. **进入前端目录**
|
||||
|
||||
```bash
|
||||
cd web
|
||||
```
|
||||
|
||||
2. **安装依赖**
|
||||
|
||||
```bash
|
||||
bun install
|
||||
```
|
||||
|
||||
3. **启动开发服务器**
|
||||
|
||||
```bash
|
||||
bun run dev
|
||||
```
|
||||
|
||||
前端开发服务器将在 `http://localhost:5173` 运行。它已经配置了代理,会自动将 `/api` 请求转发到后端服务。
|
||||
|
||||
## 代码贡献指南
|
||||
|
||||
### 贡献存储适配器 (Adapter)
|
||||
|
||||
存储适配器是 Foxel 的核心扩展点,用于接入不同的存储后端 (如 S3, FTP, Alist 等)。
|
||||
|
||||
1. **创建适配器文件**: 在 [`domain/adapters/providers/`](domain/adapters/providers/) 目录下,创建一个新文件,例如 `my_new_adapter.py`。
|
||||
2. **实现适配器类**:
|
||||
- 创建一个类,继承自 [`domain.adapters.providers.base.BaseAdapter`](domain/adapters/providers/base.py)。
|
||||
- 实现 `BaseAdapter` 中定义的所有抽象方法,如 `list_dir`, `get_meta`, `upload`, `download` 等。请仔细阅读基类中的文档注释以理解每个方法的作用和参数。
|
||||
|
||||
### 贡献前端应用 (App)
|
||||
|
||||
前端应用允许用户在浏览器中直接预览或编辑特定类型的文件。
|
||||
|
||||
1. **创建应用组件**: 在 [`web/src/apps/`](web/src/apps/) 目录下,为您的应用创建一个新的文件夹,并在其中创建 React 组件。
|
||||
2. **定义应用类型**: 您的应用需要实现 [`web/src/apps/types.ts`](web/src/apps/types.ts) 中定义的 `FoxelApp` 接口。
|
||||
3. **注册应用**: 在 [`web/src/apps/registry.ts`](web/src/apps/registry.ts) 中,导入您的应用组件,并将其添加到 `APP_REGISTRY`。在注册时,您需要指定该应用可以处理的文件类型(通过 MIME Type 或文件扩展名)。
|
||||
|
||||
## 提交规范
|
||||
|
||||
### Git 分支管理
|
||||
|
||||
- 从最新的 `main` 分支创建您的特性分支。
|
||||
|
||||
### Commit Message 格式
|
||||
|
||||
我们遵循 [Conventional Commits](https://www.conventionalcommits.org/) 规范。这有助于自动化生成更新日志和版本管理。
|
||||
|
||||
Commit Message 格式如下:
|
||||
|
||||
```
|
||||
<type>(<scope>): <subject>
|
||||
<BLANK LINE>
|
||||
<body>
|
||||
<BLANK LINE>
|
||||
<footer>
|
||||
```
|
||||
|
||||
- **type**: `feat`, `fix`, `docs`, `style`, `refactor`, `test`, `chore` 等。
|
||||
- **scope**: (可选) 本次提交影响的范围,例如 `adapter`, `ui`, `api`。
|
||||
- **subject**: 简明扼要的描述。
|
||||
|
||||
**示例:**
|
||||
|
||||
```
|
||||
feat(adapter): Add support for Alist storage
|
||||
```
|
||||
|
||||
```
|
||||
fix(ui): Correct display issue in file list view
|
||||
```
|
||||
|
||||
### Pull Request 流程
|
||||
|
||||
1. Fork 仓库并克隆到本地。
|
||||
2. 创建并切换到您的特性分支。
|
||||
3. 完成代码编写和测试。
|
||||
4. 将您的分支推送到您的 Fork 仓库。
|
||||
5. 在 Foxel 主仓库创建一个 Pull Request,目标分支为 `main`。
|
||||
6. 在 PR 描述中清晰地说明您的更改内容、目的和任何相关的 Issue 编号。
|
||||
|
||||
项目维护者会尽快审查您的 PR。感谢您的耐心和贡献!
|
||||
45
Dockerfile
Normal file
45
Dockerfile
Normal file
@@ -0,0 +1,45 @@
|
||||
FROM oven/bun:1.2-slim AS frontend-builder
|
||||
|
||||
WORKDIR /app/web
|
||||
|
||||
COPY web/package.json web/bun.lock ./
|
||||
RUN bun install --frozen-lockfile
|
||||
|
||||
COPY web/ ./
|
||||
|
||||
RUN bun run build
|
||||
|
||||
FROM python:3.14-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends ffmpeg curl ca-certificates \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN pip install uv
|
||||
COPY pyproject.toml uv.lock ./
|
||||
RUN uv pip install --system . gunicorn \
|
||||
&& rm -rf /root/.cache
|
||||
|
||||
RUN curl -L https://github.com/DrizzleTime/FoxelUpgrade/archive/refs/heads/main.tar.gz -o /tmp/migrate.tgz \
|
||||
&& mkdir -p /app/migrate \
|
||||
&& tar -xzf /tmp/migrate.tgz --strip-components=1 -C /app/migrate \
|
||||
&& rm -rf /tmp/migrate.tgz
|
||||
|
||||
COPY --from=frontend-builder /app/web/dist /app/web/dist
|
||||
|
||||
COPY . .
|
||||
|
||||
RUN mkdir -p data/db data/mount && \
|
||||
chmod 777 data/db data/mount && \
|
||||
chmod +x setup/foxel_cli.py && \
|
||||
ln -sf /app/setup/foxel_cli.py /usr/local/bin/foxel && \
|
||||
rm -rf /var/log/apt /var/cache/apt/archives
|
||||
|
||||
EXPOSE 80
|
||||
|
||||
COPY entrypoint.sh /entrypoint.sh
|
||||
RUN chmod +x /entrypoint.sh
|
||||
|
||||
CMD ["/entrypoint.sh"]
|
||||
9
LICENSE
Normal file
9
LICENSE
Normal file
@@ -0,0 +1,9 @@
|
||||
MIT License
|
||||
|
||||
Copyright © 2025 Foxel
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
89
README.md
Normal file
89
README.md
Normal file
@@ -0,0 +1,89 @@
|
||||
<div align="right">
|
||||
<b>English</b> | <a href="./README_zh.md">简体中文</a>
|
||||
</div>
|
||||
|
||||
<div align="center">
|
||||
|
||||
# Foxel
|
||||
|
||||
**A highly extensible private cloud storage solution for individuals and teams, featuring AI-powered semantic search.**
|
||||
|
||||

|
||||

|
||||

|
||||
|
||||

|
||||
|
||||
---
|
||||
<blockquote>
|
||||
<em><strong>The ocean of data is boundless, let the eye of insight guide the voyage, yet its intricate connections lie deep, not fully discernible from the surface.</strong></em>
|
||||
</blockquote>
|
||||
<img src="https://foxel.cc/image/ad-min-en.png" alt="UI Screenshot">
|
||||
</div>
|
||||
|
||||
## 👀 Online Demo
|
||||
|
||||
> [https://demo.foxel.cc](https://demo.foxel.cc)
|
||||
>
|
||||
> Account/Password: `admin` / `admin`
|
||||
|
||||
## ✨ Core Features
|
||||
|
||||
- **Unified File Management**: Centralize management of files distributed across different storage backends.
|
||||
- **Pluggable Storage Backends**: Utilizes an extensible adapter pattern to easily integrate various storage types.
|
||||
- **Semantic Search**: Supports natural language search for content within unstructured data like images and documents.
|
||||
- **Built-in File Preview**: Preview images, videos, PDFs, Office documents, text, and code files directly without downloading.
|
||||
- **Permissions and Sharing**: Supports public or private sharing links for easy file distribution.
|
||||
- **Task Processing Center**: Supports asynchronous task processing, such as file indexing and data backups, without impacting the main application.
|
||||
|
||||
## 🚀 Quick Start
|
||||
|
||||
Using Docker Compose is the most recommended way to start Foxel.
|
||||
|
||||
1. **Create Data Directories**
|
||||
|
||||
Create a `data` folder for persistent data:
|
||||
|
||||
```bash
|
||||
mkdir -p data/db
|
||||
mkdir -p data/mount
|
||||
chmod 777 data/db data/mount
|
||||
```
|
||||
|
||||
2. **Download Docker Compose File**
|
||||
|
||||
```bash
|
||||
curl -L -O https://github.com/DrizzleTime/Foxel/raw/main/compose.yaml
|
||||
```
|
||||
|
||||
After downloading, it is **strongly recommended** to modify the environment variables in the `compose.yaml` file to ensure security:
|
||||
|
||||
- Modify `SECRET_KEY` and `TEMP_LINK_SECRET_KEY`: Replace the default keys with randomly generated strong keys.
|
||||
|
||||
3. **Start the Services**
|
||||
|
||||
```bash
|
||||
docker-compose up -d
|
||||
```
|
||||
|
||||
4. **Access the Application**
|
||||
|
||||
Once the services are running, open the page in your browser.
|
||||
|
||||
> On the first launch, please follow the setup guide to initialize the administrator account.
|
||||
|
||||
## 🤝 How to Contribute
|
||||
|
||||
We welcome contributions from the community! Whether it's submitting bugs, suggesting new features, or contributing code directly.
|
||||
|
||||
Before you start, please read our [`CONTRIBUTING.md`](CONTRIBUTING.md) file, which explains the development environment and submission process. A Simplified Chinese translation is available in [`CONTRIBUTING_zh.md`](CONTRIBUTING_zh.md).
|
||||
|
||||
## 🌐 Community
|
||||
|
||||
Join our community on [Telegram](https://t.me/+thDsBfyqJxZkNTU1) to discuss with developers and other users!
|
||||
|
||||
You can also join our WeChat group for more real-time communication and support. Please scan the QR code below to join:
|
||||
|
||||
<img src="https://foxel.cc/image/wechat.png" alt="WeChat Group QR Code" width="180">
|
||||
|
||||
> If the QR code is invalid, please add WeChat ID **drizzle2001**, and we will invite you to the group.
|
||||
89
README_zh.md
Normal file
89
README_zh.md
Normal file
@@ -0,0 +1,89 @@
|
||||
<div align="right">
|
||||
<a href="./README.md">English</a> | <b>简体中文</b>
|
||||
</div>
|
||||
|
||||
<div align="center">
|
||||
|
||||
# Foxel
|
||||
|
||||
**一个面向个人和团队的、高度可扩展的私有云盘解决方案,支持 AI 语义搜索。**
|
||||
|
||||

|
||||

|
||||

|
||||
|
||||

|
||||
|
||||
---
|
||||
<blockquote>
|
||||
<em><strong>数据之洋浩瀚无涯,当以洞察之目引航,然其脉络深隐,非表象所能尽窥。</strong></em><br>
|
||||
</blockquote>
|
||||
<img src="https://foxel.cc/image/ad-min-zh.png" alt="UI Screenshot">
|
||||
</div>
|
||||
|
||||
## 👀 在线体验
|
||||
|
||||
> [https://demo.foxel.cc](https://demo.foxel.cc)
|
||||
>
|
||||
> 账号/密码:`admin` / `admin`
|
||||
|
||||
## ✨ 核心功能
|
||||
|
||||
- **统一文件管理**:集中管理分布于不同存储后端的文件。
|
||||
- **插件化存储后端**:采用可扩展的适配器模式,方便集成多种存储类型。
|
||||
- **语义搜索**:支持自然语言描述搜索图片、文档等非结构化数据内容。
|
||||
- **内置文件预览**:可直接预览图片、视频、PDF、Office 文档及文本、代码文件,无需下载。
|
||||
- **权限与分享**:支持公开或私密分享链接,便于文件共享。
|
||||
- **任务处理中心**:支持异步任务处理,如文件索引和数据备份,不影响主应用运行。
|
||||
|
||||
## 🚀 快速开始
|
||||
|
||||
使用 Docker Compose 是启动 Foxel 最推荐的方式。
|
||||
|
||||
1. **创建数据目录**
|
||||
|
||||
新建 `data` 文件夹用于持久化数据:
|
||||
|
||||
```bash
|
||||
mkdir -p data/db
|
||||
mkdir -p data/mount
|
||||
chmod 777 data/db data/mount
|
||||
```
|
||||
|
||||
2. **下载 Docker Compose 文件**
|
||||
|
||||
```bash
|
||||
curl -L -O https://github.com/DrizzleTime/Foxel/raw/main/compose.yaml
|
||||
```
|
||||
|
||||
下载完成后,**强烈建议**修改 `compose.yaml` 文件中的环境变量以确保安全:
|
||||
|
||||
- 修改 `SECRET_KEY` 和 `TEMP_LINK_SECRET_KEY`:将默认的密钥替换为随机生成的强密钥
|
||||
|
||||
3. **启动服务**
|
||||
|
||||
```bash
|
||||
docker-compose up -d
|
||||
```
|
||||
|
||||
4. **访问应用**
|
||||
|
||||
服务启动后,在浏览器中打开页面。
|
||||
|
||||
> 首次启动,请根据引导页面完成管理员账号的初始化设置。
|
||||
|
||||
## 🤝 如何贡献
|
||||
|
||||
我们非常欢迎来自社区的贡献!无论是提交 Bug、建议新功能还是直接贡献代码。
|
||||
|
||||
在开始之前,请先阅读我们的 [`CONTRIBUTING_zh.md`](CONTRIBUTING_zh.md) 文件,它会指导你如何设置开发环境以及提交流程。
|
||||
|
||||
## 🌐 社区
|
||||
|
||||
加入我们的交流社区:[Telegram 群组](https://t.me/+thDsBfyqJxZkNTU1),与开发者和用户一起讨论!
|
||||
|
||||
你也可以加入我们的微信群,获取更多实时交流与支持。请扫描下方二维码加入:
|
||||
|
||||
<img src="https://foxel.cc/image/wechat.png" alt="微信群二维码" width="180">
|
||||
|
||||
> 如果二维码失效,请添加微信号 **drizzle2001**,我们会邀请你加入群聊。
|
||||
16
api/response.py
Normal file
16
api/response.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
def success(data: Any = None, msg: str = "ok", code: int = 0):
|
||||
"""标准成功响应包装。"""
|
||||
return {"code": code, "msg": msg, "data": data}
|
||||
|
||||
|
||||
def page(items: list[Any], total: int, page: int, page_size: int):
|
||||
"""统一分页数据结构。"""
|
||||
pages = (total + page_size - 1) // page_size if page_size else 0
|
||||
return {"items": items, "total": total, "page": page, "page_size": page_size, "pages": pages}
|
||||
|
||||
|
||||
def error(msg: str, code: int = 1, data: Optional[Any] = None):
|
||||
return {"code": code, "msg": msg, "data": data}
|
||||
46
api/routers.py
Normal file
46
api/routers.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from fastapi import FastAPI
|
||||
|
||||
from domain.adapters import api as adapters
|
||||
from domain.auth import api as auth
|
||||
from domain.backup import api as backup
|
||||
from domain.config import api as config
|
||||
from domain.email import api as email
|
||||
from domain.offline_downloads import api as offline_downloads
|
||||
from domain.plugins import api as plugins
|
||||
from domain.processors import api as processors
|
||||
from domain.share import api as share
|
||||
from domain.tasks import api as tasks
|
||||
from domain.ai import api as ai
|
||||
from domain.agent import api as agent
|
||||
from domain.virtual_fs import api as virtual_fs
|
||||
from domain.virtual_fs.mapping import s3_api, webdav_api
|
||||
from domain.virtual_fs.search import search_api
|
||||
from domain.audit import api as audit
|
||||
from domain.permission import api as permission
|
||||
from domain.user import api as user
|
||||
from domain.role import api as role
|
||||
|
||||
|
||||
def include_routers(app: FastAPI):
|
||||
app.include_router(adapters.router)
|
||||
app.include_router(search_api.router)
|
||||
app.include_router(virtual_fs.router)
|
||||
app.include_router(auth.router)
|
||||
app.include_router(config.router)
|
||||
app.include_router(processors.router)
|
||||
app.include_router(tasks.router)
|
||||
app.include_router(share.router)
|
||||
app.include_router(share.public_router)
|
||||
app.include_router(backup.router)
|
||||
app.include_router(ai.router_vector_db)
|
||||
app.include_router(ai.router_ai)
|
||||
app.include_router(agent.router)
|
||||
app.include_router(plugins.router)
|
||||
app.include_router(webdav_api.router)
|
||||
app.include_router(s3_api.router)
|
||||
app.include_router(offline_downloads.router)
|
||||
app.include_router(email.router)
|
||||
app.include_router(audit.router)
|
||||
app.include_router(permission.router)
|
||||
app.include_router(user.router)
|
||||
app.include_router(role.router)
|
||||
22
compose.yaml
Normal file
22
compose.yaml
Normal file
@@ -0,0 +1,22 @@
|
||||
services:
|
||||
foxel:
|
||||
image: ghcr.io/drizzletime/foxel:latest
|
||||
#image: ghcr.nju.edu.cn/drizzletime/foxel:latest # 国内用户可以用此镜像命令
|
||||
container_name: foxel
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
- "${FOXEL_HOST_PORT:-8088}:${FOXEL_PORT:-80}"
|
||||
environment:
|
||||
- TZ=Asia/Shanghai
|
||||
- FOXEL_PORT=${FOXEL_PORT:-80}
|
||||
- SECRET_KEY=EnsRhL9NFPxgFVc+7t96/y70DIOR+9SpntcIqQa90TU=
|
||||
- TEMP_LINK_SECRET_KEY=EnsRhL9NFPxgFVc+7t96/y70DIOR+9SpntcIqQa90TU=
|
||||
volumes:
|
||||
- ./data:/app/data
|
||||
pull_policy: always
|
||||
networks:
|
||||
- foxel-network
|
||||
|
||||
networks:
|
||||
foxel-network:
|
||||
driver: bridge
|
||||
3
db/__init__.py
Normal file
3
db/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .session import init_db, close_db
|
||||
|
||||
__all__ = ["init_db", "close_db"]
|
||||
22
db/session.py
Normal file
22
db/session.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from tortoise import Tortoise
|
||||
|
||||
from domain.adapters import runtime_registry
|
||||
|
||||
TORTOISE_ORM = {
|
||||
"connections": {"default": "sqlite://data/db/db.sqlite3"},
|
||||
"apps": {
|
||||
"models": {
|
||||
"models": ["models.database"],
|
||||
"default_connection": "default",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
async def init_db():
|
||||
await Tortoise.init(config=TORTOISE_ORM)
|
||||
await Tortoise.generate_schemas()
|
||||
await runtime_registry.refresh()
|
||||
|
||||
|
||||
async def close_db():
|
||||
await Tortoise.close_connections()
|
||||
7
domain/__init__.py
Normal file
7
domain/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
domain:业务域层
|
||||
|
||||
约定:跨包只从各子包 `__init__.py` 导入公开 API。
|
||||
"""
|
||||
|
||||
__all__: list[str] = []
|
||||
24
domain/adapters/__init__.py
Normal file
24
domain/adapters/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from .providers import BaseAdapter
|
||||
from .registry import (
|
||||
RuntimeRegistry,
|
||||
discover_adapters,
|
||||
get_config_schema,
|
||||
get_config_schemas,
|
||||
normalize_adapter_type,
|
||||
runtime_registry,
|
||||
)
|
||||
from .service import AdapterService
|
||||
from .types import AdapterCreate, AdapterOut
|
||||
|
||||
__all__ = [
|
||||
"BaseAdapter",
|
||||
"RuntimeRegistry",
|
||||
"discover_adapters",
|
||||
"get_config_schema",
|
||||
"get_config_schemas",
|
||||
"normalize_adapter_type",
|
||||
"runtime_registry",
|
||||
"AdapterService",
|
||||
"AdapterCreate",
|
||||
"AdapterOut",
|
||||
]
|
||||
92
domain/adapters/api.py
Normal file
92
domain/adapters/api.py
Normal file
@@ -0,0 +1,92 @@
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
|
||||
from api.response import success
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.auth import User, get_current_active_user
|
||||
from domain.permission import require_system_permission
|
||||
from domain.permission.types import AdapterPermission
|
||||
from .service import AdapterService
|
||||
from .types import AdapterCreate
|
||||
|
||||
router = APIRouter(prefix="/api/adapters", tags=["adapters"])
|
||||
|
||||
|
||||
@router.post("")
|
||||
@audit(
|
||||
action=AuditAction.CREATE,
|
||||
description="创建存储适配器",
|
||||
body_fields=["name", "type", "path", "sub_path", "enabled"],
|
||||
)
|
||||
@require_system_permission(AdapterPermission.CREATE)
|
||||
async def create_adapter(
|
||||
request: Request,
|
||||
data: AdapterCreate,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
adapter = await AdapterService.create_adapter(data, current_user)
|
||||
return success(adapter)
|
||||
|
||||
|
||||
@router.get("")
|
||||
@audit(action=AuditAction.READ, description="获取适配器列表")
|
||||
@require_system_permission(AdapterPermission.LIST)
|
||||
async def list_adapters(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
adapters = await AdapterService.list_adapters()
|
||||
return success(adapters)
|
||||
|
||||
|
||||
@router.get("/available")
|
||||
@audit(action=AuditAction.READ, description="获取可用适配器类型")
|
||||
@require_system_permission(AdapterPermission.LIST)
|
||||
async def available_adapter_types(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
data = await AdapterService.available_adapter_types()
|
||||
return success(data)
|
||||
|
||||
|
||||
@router.get("/{adapter_id}")
|
||||
@audit(action=AuditAction.READ, description="获取适配器详情")
|
||||
@require_system_permission(AdapterPermission.LIST)
|
||||
async def get_adapter(
|
||||
request: Request,
|
||||
adapter_id: int,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
adapter = await AdapterService.get_adapter(adapter_id)
|
||||
return success(adapter)
|
||||
|
||||
|
||||
@router.put("/{adapter_id}")
|
||||
@audit(
|
||||
action=AuditAction.UPDATE,
|
||||
description="更新存储适配器",
|
||||
body_fields=["name", "type", "path", "sub_path", "enabled"],
|
||||
)
|
||||
@require_system_permission(AdapterPermission.EDIT)
|
||||
async def update_adapter(
|
||||
request: Request,
|
||||
adapter_id: int,
|
||||
data: AdapterCreate,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
adapter = await AdapterService.update_adapter(adapter_id, data, current_user)
|
||||
return success(adapter)
|
||||
|
||||
|
||||
@router.delete("/{adapter_id}")
|
||||
@audit(action=AuditAction.DELETE, description="删除存储适配器")
|
||||
@require_system_permission(AdapterPermission.DELETE)
|
||||
async def delete_adapter(
|
||||
request: Request,
|
||||
adapter_id: int,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
result = await AdapterService.delete_adapter(adapter_id, current_user)
|
||||
return success(result)
|
||||
3
domain/adapters/providers/__init__.py
Normal file
3
domain/adapters/providers/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .base import BaseAdapter
|
||||
|
||||
__all__ = ["BaseAdapter"]
|
||||
515
domain/adapters/providers/alist.py
Normal file
515
domain/adapters/providers/alist.py
Normal file
@@ -0,0 +1,515 @@
|
||||
import asyncio
|
||||
import mimetypes
|
||||
import re
|
||||
import tempfile
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, AsyncIterator, Dict, List, Tuple
|
||||
from urllib.parse import quote, urljoin
|
||||
|
||||
import httpx
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import Response, StreamingResponse
|
||||
|
||||
from models import StorageAdapter
|
||||
|
||||
|
||||
def _normalize_fs_path(path: str) -> str:
|
||||
path = (path or "").replace("\\", "/").strip()
|
||||
if not path or path == "/":
|
||||
return "/"
|
||||
if not path.startswith("/"):
|
||||
path = "/" + path
|
||||
path = re.sub(r"/{2,}", "/", path)
|
||||
if path != "/" and path.endswith("/"):
|
||||
path = path.rstrip("/")
|
||||
return path or "/"
|
||||
|
||||
|
||||
def _join_fs_path(base: str, rel: str) -> str:
|
||||
base = _normalize_fs_path(base)
|
||||
rel = (rel or "").replace("\\", "/").lstrip("/")
|
||||
if not rel:
|
||||
return base
|
||||
if base == "/":
|
||||
return "/" + rel
|
||||
return f"{base}/{rel}"
|
||||
|
||||
|
||||
def _split_parent_and_name(path: str) -> Tuple[str, str]:
|
||||
path = _normalize_fs_path(path)
|
||||
if path == "/":
|
||||
return "/", ""
|
||||
parent, _, name = path.rpartition("/")
|
||||
if not parent:
|
||||
parent = "/"
|
||||
return parent, name
|
||||
|
||||
|
||||
def _parse_iso_to_epoch(value: str | None) -> int:
|
||||
if not value:
|
||||
return 0
|
||||
text = str(value).strip()
|
||||
if not text:
|
||||
return 0
|
||||
try:
|
||||
if text.endswith("Z"):
|
||||
text = text[:-1] + "+00:00"
|
||||
m = re.match(r"^(.*?)(\.\d+)([+-]\d\d:\d\d)?$", text)
|
||||
if m:
|
||||
head, frac, tz = m.group(1), m.group(2), m.group(3) or ""
|
||||
digits = frac[1:]
|
||||
if len(digits) > 6:
|
||||
frac = "." + digits[:6]
|
||||
text = head + frac + tz
|
||||
dt = datetime.fromisoformat(text)
|
||||
if dt.tzinfo is None:
|
||||
dt = dt.replace(tzinfo=timezone.utc)
|
||||
return int(dt.timestamp())
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
|
||||
class AListApiAdapterBase:
|
||||
def __init__(self, record: StorageAdapter, *, product_name: str):
|
||||
self.record = record
|
||||
self.product_name = product_name
|
||||
|
||||
cfg = record.config or {}
|
||||
self.base_url: str = str(cfg.get("base_url", "")).rstrip("/")
|
||||
if not self.base_url.startswith("http"):
|
||||
raise ValueError(f"{product_name} requires base_url http/https")
|
||||
self.username: str = str(cfg.get("username") or "")
|
||||
self.password: str = str(cfg.get("password") or "")
|
||||
if (self.username and not self.password) or (self.password and not self.username):
|
||||
raise ValueError(f"{product_name} requires both username and password")
|
||||
self.use_auth: bool = bool(self.username and self.password)
|
||||
|
||||
self.timeout: float = float(cfg.get("timeout", 30))
|
||||
self.root_path: str = _normalize_fs_path(str(cfg.get("root") or "/"))
|
||||
self.enable_redirect_307: bool = bool(cfg.get("enable_direct_download_307"))
|
||||
|
||||
self._token: str | None = None
|
||||
self._login_lock = asyncio.Lock()
|
||||
|
||||
def get_effective_root(self, sub_path: str | None) -> str:
|
||||
base = _normalize_fs_path(self.root_path)
|
||||
if sub_path:
|
||||
return _join_fs_path(base, sub_path)
|
||||
return base
|
||||
|
||||
async def _ensure_token(self) -> str:
|
||||
if not self.use_auth:
|
||||
return ""
|
||||
if self._token:
|
||||
return self._token
|
||||
async with self._login_lock:
|
||||
if self._token:
|
||||
return self._token
|
||||
self._token = await self._login()
|
||||
return self._token
|
||||
|
||||
async def _login(self) -> str:
|
||||
url = self.base_url + "/api/auth/login"
|
||||
body = {"username": self.username, "password": self.password}
|
||||
async with httpx.AsyncClient(timeout=self.timeout, follow_redirects=True) as client:
|
||||
resp = await client.post(url, json=body)
|
||||
resp.raise_for_status()
|
||||
payload = resp.json()
|
||||
if not isinstance(payload, dict):
|
||||
raise HTTPException(502, detail=f"{self.product_name} login: invalid response")
|
||||
code = payload.get("code")
|
||||
if code not in (0, 200):
|
||||
raise HTTPException(502, detail=f"{self.product_name} login failed: {payload.get('message')}")
|
||||
data = payload.get("data") or {}
|
||||
token = (data.get("token") if isinstance(data, dict) else None) or ""
|
||||
token = str(token).strip()
|
||||
if not token:
|
||||
raise HTTPException(502, detail=f"{self.product_name} login: missing token")
|
||||
return token
|
||||
|
||||
async def _api_json(
|
||||
self,
|
||||
method: str,
|
||||
endpoint: str,
|
||||
*,
|
||||
json: Dict[str, Any] | None = None,
|
||||
headers: Dict[str, str] | None = None,
|
||||
retry: bool = True,
|
||||
files: Any = None,
|
||||
) -> Any:
|
||||
token = await self._ensure_token()
|
||||
url = self.base_url + endpoint
|
||||
req_headers: Dict[str, str] = {}
|
||||
if token:
|
||||
req_headers["Authorization"] = token
|
||||
if headers:
|
||||
req_headers.update(headers)
|
||||
async with httpx.AsyncClient(timeout=self.timeout, follow_redirects=True) as client:
|
||||
resp = await client.request(method, url, json=json, headers=req_headers, files=files)
|
||||
if resp.status_code == 401 and retry and self.use_auth:
|
||||
self._token = None
|
||||
return await self._api_json(method, endpoint, json=json, headers=headers, retry=False, files=files)
|
||||
resp.raise_for_status()
|
||||
payload = resp.json()
|
||||
if not isinstance(payload, dict):
|
||||
raise HTTPException(502, detail=f"{self.product_name} api: invalid response")
|
||||
|
||||
code = payload.get("code")
|
||||
if code in (0, 200):
|
||||
return payload.get("data")
|
||||
if code in (401, 403) and retry and self.use_auth:
|
||||
self._token = None
|
||||
return await self._api_json(method, endpoint, json=json, headers=headers, retry=False, files=files)
|
||||
if code == 404:
|
||||
raise FileNotFoundError(json.get("path") if json else "")
|
||||
msg = payload.get("message") or payload.get("msg") or ""
|
||||
raise HTTPException(502, detail=f"{self.product_name} api error code={code} msg={msg}")
|
||||
|
||||
def _abs_url(self, url: str) -> str:
|
||||
u = (url or "").strip()
|
||||
if not u:
|
||||
return ""
|
||||
if u.startswith("http://") or u.startswith("https://"):
|
||||
return u
|
||||
return urljoin(self.base_url.rstrip("/") + "/", u.lstrip("/"))
|
||||
|
||||
async def _fs_list(self, path: str) -> Dict[str, Any]:
|
||||
body = {"path": path, "password": "", "page": 1, "per_page": 0, "refresh": False}
|
||||
data = await self._api_json("POST", "/api/fs/list", json=body)
|
||||
return data or {}
|
||||
|
||||
async def _fs_get(self, path: str) -> Dict[str, Any]:
|
||||
body = {"path": path, "password": "", "page": 1, "per_page": 0, "refresh": False}
|
||||
data = await self._api_json("POST", "/api/fs/get", json=body)
|
||||
return data or {}
|
||||
|
||||
async def list_dir(
|
||||
self,
|
||||
root: str,
|
||||
rel: str,
|
||||
page_num: int = 1,
|
||||
page_size: int = 50,
|
||||
sort_by: str = "name",
|
||||
sort_order: str = "asc",
|
||||
) -> Tuple[List[Dict], int]:
|
||||
path = _join_fs_path(root, rel)
|
||||
data = await self._fs_list(path)
|
||||
content = data.get("content") or []
|
||||
if not isinstance(content, list):
|
||||
raise HTTPException(502, detail=f"{self.product_name} list_dir: invalid content")
|
||||
|
||||
entries: List[Dict] = []
|
||||
for it in content:
|
||||
if not isinstance(it, dict):
|
||||
continue
|
||||
name = str(it.get("name") or "")
|
||||
if not name:
|
||||
continue
|
||||
is_dir = bool(it.get("is_dir"))
|
||||
size = int(it.get("size") or 0) if not is_dir else 0
|
||||
mtime = _parse_iso_to_epoch(it.get("modified"))
|
||||
entries.append(
|
||||
{
|
||||
"name": name,
|
||||
"is_dir": is_dir,
|
||||
"size": size,
|
||||
"mtime": mtime,
|
||||
"type": "dir" if is_dir else "file",
|
||||
}
|
||||
)
|
||||
|
||||
reverse = sort_order.lower() == "desc"
|
||||
|
||||
def get_sort_key(item: Dict) -> Tuple:
|
||||
key = (not item.get("is_dir"),)
|
||||
f = sort_by.lower()
|
||||
if f == "name":
|
||||
key += (str(item.get("name", "")).lower(),)
|
||||
elif f == "size":
|
||||
key += (int(item.get("size", 0)),)
|
||||
elif f == "mtime":
|
||||
key += (int(item.get("mtime", 0)),)
|
||||
else:
|
||||
key += (str(item.get("name", "")).lower(),)
|
||||
return key
|
||||
|
||||
entries.sort(key=get_sort_key, reverse=reverse)
|
||||
total = len(entries)
|
||||
start = (page_num - 1) * page_size
|
||||
end = start + page_size
|
||||
return entries[start:end], total
|
||||
|
||||
async def stat_file(self, root: str, rel: str):
|
||||
path = _join_fs_path(root, rel)
|
||||
data = await self._fs_get(path)
|
||||
if not data:
|
||||
raise FileNotFoundError(rel)
|
||||
is_dir = bool(data.get("is_dir"))
|
||||
name = str(data.get("name") or (rel.rstrip("/").split("/")[-1] if rel else ""))
|
||||
size = int(data.get("size") or 0) if not is_dir else 0
|
||||
mtime = _parse_iso_to_epoch(data.get("modified"))
|
||||
info = {
|
||||
"name": name,
|
||||
"is_dir": is_dir,
|
||||
"size": size,
|
||||
"mtime": mtime,
|
||||
"type": "dir" if is_dir else "file",
|
||||
"path": path,
|
||||
}
|
||||
return info
|
||||
|
||||
async def stat_path(self, root: str, rel: str):
|
||||
try:
|
||||
info = await self.stat_file(root, rel)
|
||||
return {"exists": True, "is_dir": bool(info.get("is_dir")), "path": info.get("path")}
|
||||
except FileNotFoundError:
|
||||
return {"exists": False, "is_dir": None, "path": _join_fs_path(root, rel)}
|
||||
|
||||
async def exists(self, root: str, rel: str) -> bool:
|
||||
try:
|
||||
await self.stat_file(root, rel)
|
||||
return True
|
||||
except FileNotFoundError:
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def get_direct_download_response(self, root: str, rel: str):
|
||||
if not self.enable_redirect_307:
|
||||
return None
|
||||
data = await self._fs_get(_join_fs_path(root, rel))
|
||||
if not data:
|
||||
raise FileNotFoundError(rel)
|
||||
if bool(data.get("is_dir")):
|
||||
raise IsADirectoryError(rel)
|
||||
raw_url = self._abs_url(str(data.get("raw_url") or ""))
|
||||
if not raw_url:
|
||||
return None
|
||||
return Response(status_code=307, headers={"Location": raw_url})
|
||||
|
||||
async def _get_raw_url_and_meta(self, root: str, rel: str) -> Tuple[str, int, str]:
|
||||
data = await self._fs_get(_join_fs_path(root, rel))
|
||||
if not data:
|
||||
raise FileNotFoundError(rel)
|
||||
if bool(data.get("is_dir")):
|
||||
raise IsADirectoryError(rel)
|
||||
raw_url = self._abs_url(str(data.get("raw_url") or ""))
|
||||
if not raw_url:
|
||||
raise HTTPException(502, detail=f"{self.product_name} missing raw_url")
|
||||
size = int(data.get("size") or 0)
|
||||
name = str(data.get("name") or "")
|
||||
return raw_url, size, name
|
||||
|
||||
async def read_file(self, root: str, rel: str) -> bytes:
|
||||
raw_url, _, _ = await self._get_raw_url_and_meta(root, rel)
|
||||
async with httpx.AsyncClient(timeout=self.timeout, follow_redirects=True) as client:
|
||||
resp = await client.get(raw_url)
|
||||
resp.raise_for_status()
|
||||
return resp.content
|
||||
|
||||
async def stream_file(self, root: str, rel: str, range_header: str | None):
|
||||
raw_url, file_size, name = await self._get_raw_url_and_meta(root, rel)
|
||||
mime, _ = mimetypes.guess_type(name or rel)
|
||||
content_type = mime or "application/octet-stream"
|
||||
|
||||
start = 0
|
||||
end = max(file_size - 1, 0)
|
||||
status = 200
|
||||
headers = {
|
||||
"Accept-Ranges": "bytes",
|
||||
"Content-Type": content_type,
|
||||
}
|
||||
if file_size >= 0:
|
||||
headers["Content-Length"] = str(file_size)
|
||||
|
||||
if range_header and range_header.startswith("bytes="):
|
||||
try:
|
||||
part = range_header.removeprefix("bytes=")
|
||||
s, e = part.split("-", 1)
|
||||
if s.strip():
|
||||
start = int(s)
|
||||
if e.strip():
|
||||
end = int(e)
|
||||
if file_size and start >= file_size:
|
||||
raise HTTPException(416, detail="Requested Range Not Satisfiable")
|
||||
if file_size and end >= file_size:
|
||||
end = file_size - 1
|
||||
status = 206
|
||||
except ValueError:
|
||||
raise HTTPException(400, detail="Invalid Range header")
|
||||
headers["Content-Range"] = f"bytes {start}-{end}/{file_size}"
|
||||
headers["Content-Length"] = str(end - start + 1)
|
||||
|
||||
async def agen():
|
||||
async with httpx.AsyncClient(timeout=self.timeout, follow_redirects=True) as client:
|
||||
req_headers = {"Range": f"bytes={start}-{end}"} if status == 206 else {}
|
||||
async with client.stream("GET", raw_url, headers=req_headers) as resp:
|
||||
resp.raise_for_status()
|
||||
async for chunk in resp.aiter_bytes():
|
||||
if chunk:
|
||||
yield chunk
|
||||
|
||||
return StreamingResponse(agen(), status_code=status, headers=headers, media_type=content_type)
|
||||
|
||||
async def _upload_file(self, full_path: str, file_path: Path) -> Any:
|
||||
token = await self._ensure_token()
|
||||
headers = {"File-Path": quote(full_path, safe="/")}
|
||||
if token:
|
||||
headers["Authorization"] = token
|
||||
with file_path.open("rb") as f:
|
||||
files = {"file": (file_path.name, f, "application/octet-stream")}
|
||||
async with httpx.AsyncClient(timeout=self.timeout, follow_redirects=True) as client:
|
||||
resp = await client.put(self.base_url + "/api/fs/form", headers=headers, files=files)
|
||||
resp.raise_for_status()
|
||||
payload = resp.json()
|
||||
if not isinstance(payload, dict):
|
||||
raise HTTPException(502, detail=f"{self.product_name} upload: invalid response")
|
||||
code = payload.get("code")
|
||||
if code not in (0, 200):
|
||||
msg = payload.get("message") or payload.get("msg") or ""
|
||||
raise HTTPException(502, detail=f"{self.product_name} upload failed: {msg}")
|
||||
return payload.get("data")
|
||||
|
||||
async def write_file(self, root: str, rel: str, data: bytes):
|
||||
full_path = _join_fs_path(root, rel)
|
||||
suffix = Path(rel).suffix
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tf:
|
||||
tf.write(data)
|
||||
tmp_path = Path(tf.name)
|
||||
try:
|
||||
await self._upload_file(full_path, tmp_path)
|
||||
finally:
|
||||
try:
|
||||
tmp_path.unlink(missing_ok=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def write_upload_file(self, root: str, rel: str, file_obj, filename: str | None, file_size: int | None = None, content_type: str | None = None):
|
||||
full_path = _join_fs_path(root, rel)
|
||||
token = await self._ensure_token()
|
||||
headers = {"File-Path": quote(full_path, safe="/")}
|
||||
if token:
|
||||
headers["Authorization"] = token
|
||||
name = filename or Path(rel).name or "file"
|
||||
mime = content_type or "application/octet-stream"
|
||||
files = {"file": (name, file_obj, mime)}
|
||||
async with httpx.AsyncClient(timeout=self.timeout, follow_redirects=True) as client:
|
||||
resp = await client.put(self.base_url + "/api/fs/form", headers=headers, files=files)
|
||||
resp.raise_for_status()
|
||||
payload = resp.json()
|
||||
if not isinstance(payload, dict):
|
||||
raise HTTPException(502, detail=f"{self.product_name} upload: invalid response")
|
||||
code = payload.get("code")
|
||||
if code not in (0, 200):
|
||||
msg = payload.get("message") or payload.get("msg") or ""
|
||||
raise HTTPException(502, detail=f"{self.product_name} upload failed: {msg}")
|
||||
data = payload.get("data")
|
||||
if isinstance(data, dict) and file_size is not None and "size" not in data:
|
||||
data["size"] = file_size
|
||||
return data
|
||||
|
||||
async def write_file_stream(self, root: str, rel: str, data_iter: AsyncIterator[bytes]):
|
||||
full_path = _join_fs_path(root, rel)
|
||||
suffix = Path(rel).suffix
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tf:
|
||||
tmp_path = Path(tf.name)
|
||||
size = 0
|
||||
try:
|
||||
with tmp_path.open("wb") as f:
|
||||
async for chunk in data_iter:
|
||||
if not chunk:
|
||||
continue
|
||||
f.write(chunk)
|
||||
size += len(chunk)
|
||||
await self._upload_file(full_path, tmp_path)
|
||||
return size
|
||||
finally:
|
||||
try:
|
||||
tmp_path.unlink(missing_ok=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def mkdir(self, root: str, rel: str):
|
||||
path = _join_fs_path(root, rel)
|
||||
await self._api_json("POST", "/api/fs/mkdir", json={"path": path})
|
||||
|
||||
async def delete(self, root: str, rel: str):
|
||||
path = _join_fs_path(root, rel)
|
||||
parent, name = _split_parent_and_name(path)
|
||||
if not name:
|
||||
return
|
||||
await self._api_json("POST", "/api/fs/remove", json={"dir": parent, "names": [name]})
|
||||
|
||||
async def move(self, root: str, src_rel: str, dst_rel: str):
|
||||
src_path = _join_fs_path(root, src_rel)
|
||||
dst_path = _join_fs_path(root, dst_rel)
|
||||
src_dir, src_name = _split_parent_and_name(src_path)
|
||||
dst_dir, dst_name = _split_parent_and_name(dst_path)
|
||||
if not src_name or not dst_name:
|
||||
raise HTTPException(400, detail="Invalid move path")
|
||||
|
||||
if src_dir == dst_dir:
|
||||
if src_name == dst_name:
|
||||
return
|
||||
await self._api_json("POST", "/api/fs/rename", json={"path": src_path, "name": dst_name})
|
||||
return
|
||||
|
||||
await self._api_json("POST", "/api/fs/move", json={"src_dir": src_dir, "dst_dir": dst_dir, "names": [src_name]})
|
||||
if src_name != dst_name:
|
||||
await self._api_json("POST", "/api/fs/rename", json={"path": _join_fs_path(dst_dir, src_name), "name": dst_name})
|
||||
|
||||
async def rename(self, root: str, src_rel: str, dst_rel: str):
|
||||
await self.move(root, src_rel, dst_rel)
|
||||
|
||||
async def copy(self, root: str, src_rel: str, dst_rel: str, overwrite: bool = False):
|
||||
src_path = _join_fs_path(root, src_rel)
|
||||
dst_path = _join_fs_path(root, dst_rel)
|
||||
src_dir, src_name = _split_parent_and_name(src_path)
|
||||
dst_dir, dst_name = _split_parent_and_name(dst_path)
|
||||
if not src_name or not dst_name:
|
||||
raise HTTPException(400, detail="Invalid copy path")
|
||||
|
||||
src_info = await self._fs_get(src_path)
|
||||
if not src_info:
|
||||
raise FileNotFoundError(src_rel)
|
||||
|
||||
if src_name != dst_name and not bool(src_info.get("is_dir")):
|
||||
raw_url, _, _ = await self._get_raw_url_and_meta(root, src_rel)
|
||||
async with httpx.AsyncClient(timeout=self.timeout, follow_redirects=True) as client:
|
||||
async with client.stream("GET", raw_url) as resp:
|
||||
resp.raise_for_status()
|
||||
|
||||
async def gen():
|
||||
async for chunk in resp.aiter_bytes():
|
||||
if chunk:
|
||||
yield chunk
|
||||
|
||||
await self.write_file_stream(root, dst_rel, gen())
|
||||
return
|
||||
|
||||
await self._api_json("POST", "/api/fs/copy", json={"src_dir": src_dir, "dst_dir": dst_dir, "names": [src_name]})
|
||||
if src_name != dst_name:
|
||||
await self._api_json("POST", "/api/fs/rename", json={"path": _join_fs_path(dst_dir, src_name), "name": dst_name})
|
||||
|
||||
|
||||
class AListAdapter(AListApiAdapterBase):
|
||||
def __init__(self, record: StorageAdapter):
|
||||
super().__init__(record, product_name="AList")
|
||||
|
||||
|
||||
class OpenListAdapter(AListApiAdapterBase):
|
||||
def __init__(self, record: StorageAdapter):
|
||||
super().__init__(record, product_name="OpenList")
|
||||
|
||||
|
||||
ADAPTER_TYPES = {"alist": AListAdapter, "openlist": OpenListAdapter}
|
||||
|
||||
CONFIG_SCHEMA = [
|
||||
{"key": "base_url", "label": "基础地址", "type": "string", "required": True, "placeholder": "http://127.0.0.1:5244"},
|
||||
{"key": "username", "label": "用户名", "type": "string", "required": False, "placeholder": "留空则匿名访问"},
|
||||
{"key": "password", "label": "密码", "type": "password", "required": False, "placeholder": "留空则匿名访问"},
|
||||
{"key": "root", "label": "根目录", "type": "string", "required": False, "default": "/"},
|
||||
{"key": "timeout", "label": "超时(秒)", "type": "number", "required": False, "default": 30},
|
||||
{"key": "enable_direct_download_307", "label": "启用 307 直链下载", "type": "boolean", "default": False},
|
||||
]
|
||||
23
domain/adapters/providers/base.py
Normal file
23
domain/adapters/providers/base.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from typing import List, Dict, Protocol, runtime_checkable, Tuple, AsyncIterator
|
||||
from models import StorageAdapter
|
||||
|
||||
# 约定:任意新适配器模块需定义:
|
||||
# ADAPTER_TYPE: str
|
||||
# CONFIG_SCHEMA: List[Dict]
|
||||
# ADAPTER_FACTORY: Callable[[StorageAdapter], BaseAdapter] (可省略, 会自动寻找 *Adapter 类)
|
||||
|
||||
@runtime_checkable
|
||||
class BaseAdapter(Protocol):
|
||||
record: StorageAdapter
|
||||
async def list_dir(self, root: str, rel: str, page_num: int = 1, page_size: int = 50, sort_by: str = "name", sort_order: str = "asc") -> Tuple[List[Dict], int]: ...
|
||||
async def read_file(self, root: str, rel: str) -> bytes: ...
|
||||
async def write_file(self, root: str, rel: str, data: bytes): ...
|
||||
async def write_file_stream(self, root: str, rel: str, data_iter: AsyncIterator[bytes]): ...
|
||||
async def mkdir(self, root: str, rel: str): ...
|
||||
async def delete(self, root: str, rel: str): ...
|
||||
async def move(self, root: str, src_rel: str, dst_rel: str): ...
|
||||
async def rename(self, root: str, src_rel: str, dst_rel: str): ...
|
||||
async def copy(self, root: str, src_rel: str, dst_rel: str, overwrite: bool = False): ...
|
||||
async def stream_file(self, root: str, rel: str, range_header: str | None): ...
|
||||
async def stat_file(self, root: str, rel: str): ...
|
||||
def get_effective_root(self, sub_path: str | None) -> str: ...
|
||||
471
domain/adapters/providers/dropbox.py
Normal file
471
domain/adapters/providers/dropbox.py
Normal file
@@ -0,0 +1,471 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import mimetypes
|
||||
import re
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import AsyncIterator, Dict, List, Tuple
|
||||
|
||||
import httpx
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import Response, StreamingResponse
|
||||
|
||||
from models import StorageAdapter
|
||||
|
||||
DROPBOX_OAUTH_URL = "https://api.dropboxapi.com/oauth2/token"
|
||||
DROPBOX_API_URL = "https://api.dropboxapi.com/2"
|
||||
DROPBOX_CONTENT_URL = "https://content.dropboxapi.com/2"
|
||||
|
||||
|
||||
def _normalize_dbx_path(path: str | None) -> str:
|
||||
path = (path or "").replace("\\", "/").strip()
|
||||
if not path or path == "/":
|
||||
return ""
|
||||
if not path.startswith("/"):
|
||||
path = "/" + path
|
||||
path = re.sub(r"/{2,}", "/", path)
|
||||
if path.endswith("/"):
|
||||
path = path.rstrip("/")
|
||||
return path
|
||||
|
||||
|
||||
def _join_dbx_path(base: str, rel: str) -> str:
|
||||
base = _normalize_dbx_path(base)
|
||||
rel = (rel or "").replace("\\", "/").strip("/")
|
||||
if not rel:
|
||||
return base
|
||||
if not base:
|
||||
return "/" + rel
|
||||
return f"{base}/{rel}"
|
||||
|
||||
|
||||
def _parse_iso_to_epoch(value: str | None) -> int:
|
||||
if not value:
|
||||
return 0
|
||||
text = str(value).strip()
|
||||
if not text:
|
||||
return 0
|
||||
try:
|
||||
if text.endswith("Z"):
|
||||
text = text[:-1] + "+00:00"
|
||||
dt = datetime.fromisoformat(text)
|
||||
if dt.tzinfo is None:
|
||||
dt = dt.replace(tzinfo=timezone.utc)
|
||||
return int(dt.timestamp())
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
|
||||
class DropboxAdapter:
|
||||
def __init__(self, record: StorageAdapter):
|
||||
self.record = record
|
||||
cfg = record.config or {}
|
||||
|
||||
self.app_key: str = str(cfg.get("app_key") or "").strip()
|
||||
self.app_secret: str = str(cfg.get("app_secret") or "").strip()
|
||||
self.refresh_token: str = str(cfg.get("refresh_token") or "").strip()
|
||||
self.root_path: str = _normalize_dbx_path(str(cfg.get("root") or "/"))
|
||||
self.enable_redirect_307: bool = bool(cfg.get("enable_direct_download_307"))
|
||||
self.timeout: float = float(cfg.get("timeout", 60))
|
||||
|
||||
if not (self.app_key and self.app_secret and self.refresh_token):
|
||||
raise ValueError("Dropbox 适配器需要 app_key, app_secret, refresh_token")
|
||||
|
||||
self._access_token: str | None = None
|
||||
self._token_expiry: datetime | None = None
|
||||
self._token_lock = asyncio.Lock()
|
||||
|
||||
def get_effective_root(self, sub_path: str | None) -> str:
|
||||
base = _normalize_dbx_path(self.root_path)
|
||||
if sub_path:
|
||||
return _join_dbx_path(base, sub_path)
|
||||
return base
|
||||
|
||||
async def _get_access_token(self) -> str:
|
||||
if self._access_token and self._token_expiry and datetime.now(timezone.utc) < self._token_expiry:
|
||||
return self._access_token
|
||||
|
||||
async with self._token_lock:
|
||||
if self._access_token and self._token_expiry and datetime.now(timezone.utc) < self._token_expiry:
|
||||
return self._access_token
|
||||
|
||||
basic = base64.b64encode(f"{self.app_key}:{self.app_secret}".encode("utf-8")).decode("ascii")
|
||||
headers = {"Authorization": f"Basic {basic}"}
|
||||
data = {"grant_type": "refresh_token", "refresh_token": self.refresh_token}
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
resp = await client.post(DROPBOX_OAUTH_URL, data=data, headers=headers)
|
||||
resp.raise_for_status()
|
||||
|
||||
payload = resp.json()
|
||||
token = str(payload.get("access_token") or "").strip()
|
||||
if not token:
|
||||
raise HTTPException(502, detail="Dropbox oauth: missing access_token")
|
||||
expires_in = int(payload.get("expires_in") or 3600)
|
||||
self._access_token = token
|
||||
self._token_expiry = datetime.now(timezone.utc) + timedelta(seconds=max(60, expires_in - 300))
|
||||
return token
|
||||
|
||||
async def _api_json(self, endpoint: str, body: Dict) -> httpx.Response:
|
||||
token = await self._get_access_token()
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
return await client.post(f"{DROPBOX_API_URL}{endpoint}", json=body, headers=headers)
|
||||
|
||||
async def _content_request(
|
||||
self,
|
||||
endpoint: str,
|
||||
api_arg: Dict,
|
||||
*,
|
||||
content: bytes | None = None,
|
||||
data_iter: AsyncIterator[bytes] | None = None,
|
||||
extra_headers: Dict[str, str] | None = None,
|
||||
) -> httpx.Response:
|
||||
token = await self._get_access_token()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Dropbox-API-Arg": json.dumps(api_arg, separators=(",", ":"), ensure_ascii=False),
|
||||
}
|
||||
if extra_headers:
|
||||
headers.update(extra_headers)
|
||||
|
||||
if data_iter is None:
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
return await client.post(f"{DROPBOX_CONTENT_URL}{endpoint}", headers=headers, content=content or b"")
|
||||
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
return await client.post(f"{DROPBOX_CONTENT_URL}{endpoint}", headers=headers, content=data_iter)
|
||||
|
||||
@staticmethod
|
||||
def _raise_dbx_error(resp: httpx.Response, *, rel: str):
|
||||
try:
|
||||
payload = resp.json()
|
||||
except Exception:
|
||||
payload = None
|
||||
summary = ""
|
||||
if isinstance(payload, dict):
|
||||
summary = str(payload.get("error_summary") or "")
|
||||
if "not_found" in summary:
|
||||
raise FileNotFoundError(rel)
|
||||
if "conflict" in summary or "already_exists" in summary:
|
||||
raise FileExistsError(rel)
|
||||
if "is_folder" in summary:
|
||||
raise IsADirectoryError(rel)
|
||||
if "not_folder" in summary:
|
||||
raise NotADirectoryError(rel)
|
||||
raise HTTPException(502, detail=f"Dropbox API error: {summary or resp.text}")
|
||||
|
||||
def _format_entry(self, entry: Dict) -> Dict:
|
||||
tag = entry.get(".tag")
|
||||
is_dir = tag == "folder"
|
||||
mtime = _parse_iso_to_epoch(entry.get("server_modified") if not is_dir else None)
|
||||
return {
|
||||
"name": entry.get("name") or "",
|
||||
"is_dir": is_dir,
|
||||
"size": 0 if is_dir else int(entry.get("size") or 0),
|
||||
"mtime": mtime,
|
||||
"type": "dir" if is_dir else "file",
|
||||
}
|
||||
|
||||
async def list_dir(
|
||||
self,
|
||||
root: str,
|
||||
rel: str,
|
||||
page_num: int = 1,
|
||||
page_size: int = 50,
|
||||
sort_by: str = "name",
|
||||
sort_order: str = "asc",
|
||||
) -> Tuple[List[Dict], int]:
|
||||
path = _join_dbx_path(root, rel)
|
||||
body = {"path": path, "recursive": False, "include_deleted": False, "limit": 2000}
|
||||
resp = await self._api_json("/files/list_folder", body)
|
||||
if resp.status_code == 409:
|
||||
try:
|
||||
payload = resp.json()
|
||||
except Exception:
|
||||
payload = None
|
||||
summary = str((payload or {}).get("error_summary") or "")
|
||||
if "not_found" in summary:
|
||||
return [], 0
|
||||
self._raise_dbx_error(resp, rel=rel)
|
||||
resp.raise_for_status()
|
||||
payload = resp.json()
|
||||
|
||||
all_entries: List[Dict] = []
|
||||
all_entries.extend(payload.get("entries") or [])
|
||||
cursor = payload.get("cursor")
|
||||
has_more = bool(payload.get("has_more"))
|
||||
while has_more and cursor:
|
||||
resp2 = await self._api_json("/files/list_folder/continue", {"cursor": cursor})
|
||||
resp2.raise_for_status()
|
||||
p2 = resp2.json()
|
||||
all_entries.extend(p2.get("entries") or [])
|
||||
cursor = p2.get("cursor")
|
||||
has_more = bool(p2.get("has_more"))
|
||||
|
||||
items = [self._format_entry(e) for e in all_entries if isinstance(e, dict)]
|
||||
|
||||
reverse = sort_order.lower() == "desc"
|
||||
|
||||
def get_sort_key(item):
|
||||
key = (not item["is_dir"],)
|
||||
f = sort_by.lower()
|
||||
if f == "name":
|
||||
key += (item["name"].lower(),)
|
||||
elif f == "size":
|
||||
key += (item["size"],)
|
||||
elif f == "mtime":
|
||||
key += (item["mtime"],)
|
||||
else:
|
||||
key += (item["name"].lower(),)
|
||||
return key
|
||||
|
||||
items.sort(key=get_sort_key, reverse=reverse)
|
||||
|
||||
total = len(items)
|
||||
start = (page_num - 1) * page_size
|
||||
end = start + page_size
|
||||
return items[start:end], total
|
||||
|
||||
async def stat_file(self, root: str, rel: str):
|
||||
path = _join_dbx_path(root, rel)
|
||||
resp = await self._api_json("/files/get_metadata", {"path": path, "include_deleted": False})
|
||||
if resp.status_code == 409:
|
||||
self._raise_dbx_error(resp, rel=rel)
|
||||
resp.raise_for_status()
|
||||
meta = resp.json()
|
||||
if not isinstance(meta, dict):
|
||||
raise HTTPException(502, detail="Dropbox metadata: invalid response")
|
||||
return self._format_entry(meta)
|
||||
|
||||
async def exists(self, root: str, rel: str) -> bool:
|
||||
try:
|
||||
await self.stat_file(root, rel)
|
||||
return True
|
||||
except FileNotFoundError:
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def read_file(self, root: str, rel: str) -> bytes:
|
||||
path = _join_dbx_path(root, rel)
|
||||
resp = await self._content_request("/files/download", {"path": path})
|
||||
if resp.status_code == 409:
|
||||
self._raise_dbx_error(resp, rel=rel)
|
||||
resp.raise_for_status()
|
||||
return resp.content
|
||||
|
||||
async def write_file(self, root: str, rel: str, data: bytes):
|
||||
path = _join_dbx_path(root, rel)
|
||||
arg = {
|
||||
"path": path,
|
||||
"mode": "overwrite",
|
||||
"autorename": False,
|
||||
"mute": False,
|
||||
"strict_conflict": False,
|
||||
}
|
||||
resp = await self._content_request(
|
||||
"/files/upload",
|
||||
arg,
|
||||
content=data,
|
||||
extra_headers={"Content-Type": "application/octet-stream"},
|
||||
)
|
||||
if resp.status_code == 409:
|
||||
self._raise_dbx_error(resp, rel=rel)
|
||||
resp.raise_for_status()
|
||||
return True
|
||||
|
||||
async def write_file_stream(self, root: str, rel: str, data_iter: AsyncIterator[bytes]):
|
||||
path = _join_dbx_path(root, rel)
|
||||
|
||||
size = 0
|
||||
session_id: str | None = None
|
||||
offset = 0
|
||||
|
||||
async for chunk in data_iter:
|
||||
if not chunk:
|
||||
continue
|
||||
if session_id is None:
|
||||
resp = await self._content_request(
|
||||
"/files/upload_session_start",
|
||||
{"close": False},
|
||||
content=chunk,
|
||||
extra_headers={"Content-Type": "application/octet-stream"},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
payload = resp.json()
|
||||
session_id = str(payload.get("session_id") or "")
|
||||
if not session_id:
|
||||
raise HTTPException(502, detail="Dropbox upload_session_start: missing session_id")
|
||||
offset += len(chunk)
|
||||
size += len(chunk)
|
||||
continue
|
||||
|
||||
arg = {"cursor": {"session_id": session_id, "offset": offset}, "close": False}
|
||||
resp = await self._content_request(
|
||||
"/files/upload_session_append_v2",
|
||||
arg,
|
||||
content=chunk,
|
||||
extra_headers={"Content-Type": "application/octet-stream"},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
offset += len(chunk)
|
||||
size += len(chunk)
|
||||
|
||||
if session_id is None:
|
||||
await self.write_file(root, rel, b"")
|
||||
return 0
|
||||
|
||||
finish_arg = {
|
||||
"cursor": {"session_id": session_id, "offset": offset},
|
||||
"commit": {
|
||||
"path": path,
|
||||
"mode": "overwrite",
|
||||
"autorename": False,
|
||||
"mute": False,
|
||||
"strict_conflict": False,
|
||||
},
|
||||
}
|
||||
resp = await self._content_request(
|
||||
"/files/upload_session_finish",
|
||||
finish_arg,
|
||||
content=b"",
|
||||
extra_headers={"Content-Type": "application/octet-stream"},
|
||||
)
|
||||
if resp.status_code == 409:
|
||||
self._raise_dbx_error(resp, rel=rel)
|
||||
resp.raise_for_status()
|
||||
return size
|
||||
|
||||
async def mkdir(self, root: str, rel: str):
|
||||
path = _join_dbx_path(root, rel)
|
||||
resp = await self._api_json("/files/create_folder_v2", {"path": path, "autorename": False})
|
||||
if resp.status_code == 409:
|
||||
self._raise_dbx_error(resp, rel=rel)
|
||||
resp.raise_for_status()
|
||||
return True
|
||||
|
||||
async def delete(self, root: str, rel: str):
|
||||
path = _join_dbx_path(root, rel)
|
||||
resp = await self._api_json("/files/delete_v2", {"path": path})
|
||||
if resp.status_code == 409:
|
||||
try:
|
||||
payload = resp.json()
|
||||
except Exception:
|
||||
payload = None
|
||||
summary = str((payload or {}).get("error_summary") or "")
|
||||
if "not_found" in summary:
|
||||
return
|
||||
self._raise_dbx_error(resp, rel=rel)
|
||||
resp.raise_for_status()
|
||||
return True
|
||||
|
||||
async def move(self, root: str, src_rel: str, dst_rel: str):
|
||||
src = _join_dbx_path(root, src_rel)
|
||||
dst = _join_dbx_path(root, dst_rel)
|
||||
resp = await self._api_json(
|
||||
"/files/move_v2",
|
||||
{"from_path": src, "to_path": dst, "autorename": False, "allow_shared_folder": True},
|
||||
)
|
||||
if resp.status_code == 409:
|
||||
self._raise_dbx_error(resp, rel=src_rel)
|
||||
resp.raise_for_status()
|
||||
return True
|
||||
|
||||
async def rename(self, root: str, src_rel: str, dst_rel: str):
|
||||
return await self.move(root, src_rel, dst_rel)
|
||||
|
||||
async def copy(self, root: str, src_rel: str, dst_rel: str, overwrite: bool = False):
|
||||
src = _join_dbx_path(root, src_rel)
|
||||
dst = _join_dbx_path(root, dst_rel)
|
||||
resp = await self._api_json(
|
||||
"/files/copy_v2",
|
||||
{"from_path": src, "to_path": dst, "autorename": False, "allow_shared_folder": True},
|
||||
)
|
||||
if resp.status_code == 409:
|
||||
self._raise_dbx_error(resp, rel=dst_rel if overwrite else dst_rel)
|
||||
resp.raise_for_status()
|
||||
return True
|
||||
|
||||
async def get_direct_download_response(self, root: str, rel: str):
|
||||
if not self.enable_redirect_307:
|
||||
return None
|
||||
|
||||
path = _join_dbx_path(root, rel)
|
||||
resp = await self._api_json("/files/get_temporary_link", {"path": path})
|
||||
if resp.status_code == 409:
|
||||
self._raise_dbx_error(resp, rel=rel)
|
||||
resp.raise_for_status()
|
||||
payload = resp.json()
|
||||
link = (payload.get("link") if isinstance(payload, dict) else None) or ""
|
||||
link = str(link).strip()
|
||||
if not link:
|
||||
return None
|
||||
return Response(status_code=307, headers={"Location": link})
|
||||
|
||||
async def stream_file(self, root: str, rel: str, range_header: str | None):
|
||||
path = _join_dbx_path(root, rel)
|
||||
token = await self._get_access_token()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Dropbox-API-Arg": json.dumps({"path": path}, separators=(",", ":"), ensure_ascii=False),
|
||||
}
|
||||
if range_header:
|
||||
headers["Range"] = range_header
|
||||
|
||||
client = httpx.AsyncClient(timeout=None)
|
||||
stream_cm = client.stream("POST", f"{DROPBOX_CONTENT_URL}/files/download", headers=headers)
|
||||
try:
|
||||
resp = await stream_cm.__aenter__()
|
||||
except Exception:
|
||||
await client.aclose()
|
||||
raise
|
||||
|
||||
if resp.status_code == 409:
|
||||
try:
|
||||
content = await resp.aread()
|
||||
_ = content
|
||||
finally:
|
||||
await stream_cm.__aexit__(None, None, None)
|
||||
await client.aclose()
|
||||
self._raise_dbx_error(resp, rel=rel)
|
||||
|
||||
if resp.status_code >= 400:
|
||||
try:
|
||||
await resp.aread()
|
||||
finally:
|
||||
await stream_cm.__aexit__(None, None, None)
|
||||
await client.aclose()
|
||||
resp.raise_for_status()
|
||||
|
||||
content_type = resp.headers.get("Content-Type") or (mimetypes.guess_type(rel)[0] or "application/octet-stream")
|
||||
out_headers = {}
|
||||
for key in ("Accept-Ranges", "Content-Range", "Content-Length"):
|
||||
value = resp.headers.get(key)
|
||||
if value:
|
||||
out_headers[key] = value
|
||||
|
||||
async def iterator():
|
||||
try:
|
||||
async for chunk in resp.aiter_bytes():
|
||||
if chunk:
|
||||
yield chunk
|
||||
finally:
|
||||
await stream_cm.__aexit__(None, None, None)
|
||||
await client.aclose()
|
||||
|
||||
return StreamingResponse(iterator(), status_code=resp.status_code, headers=out_headers, media_type=content_type)
|
||||
|
||||
|
||||
ADAPTER_TYPE = "dropbox"
|
||||
CONFIG_SCHEMA = [
|
||||
{"key": "app_key", "label": "App Key", "type": "string", "required": True},
|
||||
{"key": "app_secret", "label": "App Secret", "type": "password", "required": True},
|
||||
{"key": "refresh_token", "label": "Refresh Token", "type": "password", "required": True},
|
||||
{"key": "root", "label": "Root Path", "type": "string", "required": False, "default": "/", "placeholder": "/ or /Apps/Foxel"},
|
||||
{"key": "timeout", "label": "超时(秒)", "type": "number", "required": False, "default": 60},
|
||||
{"key": "enable_direct_download_307", "label": "Enable 307 redirect download", "type": "boolean", "default": False},
|
||||
]
|
||||
|
||||
|
||||
def ADAPTER_FACTORY(rec): return DropboxAdapter(rec)
|
||||
|
||||
435
domain/adapters/providers/foxel.py
Normal file
435
domain/adapters/providers/foxel.py
Normal file
@@ -0,0 +1,435 @@
|
||||
import asyncio
|
||||
import mimetypes
|
||||
import re
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any, AsyncIterator, Dict, List, Tuple
|
||||
from urllib.parse import quote
|
||||
|
||||
import httpx
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from models import StorageAdapter
|
||||
|
||||
|
||||
def _normalize_fs_path(path: str) -> str:
|
||||
path = (path or "").replace("\\", "/").strip()
|
||||
if not path or path == "/":
|
||||
return "/"
|
||||
if not path.startswith("/"):
|
||||
path = "/" + path
|
||||
path = re.sub(r"/{2,}", "/", path)
|
||||
if path != "/" and path.endswith("/"):
|
||||
path = path.rstrip("/")
|
||||
return path or "/"
|
||||
|
||||
|
||||
def _join_fs_path(base: str, rel: str | None) -> str:
|
||||
base = _normalize_fs_path(base)
|
||||
rel_norm = (rel or "").replace("\\", "/").strip().lstrip("/")
|
||||
if not rel_norm:
|
||||
return base
|
||||
if base == "/":
|
||||
return "/" + rel_norm
|
||||
return f"{base}/{rel_norm}"
|
||||
|
||||
|
||||
def _unwrap_success(payload: Any, *, context: str) -> Any:
|
||||
if not isinstance(payload, dict):
|
||||
return payload
|
||||
if "data" not in payload:
|
||||
return payload
|
||||
code = payload.get("code")
|
||||
if code not in (None, 0, 200):
|
||||
msg = payload.get("msg") or payload.get("message") or ""
|
||||
raise HTTPException(502, detail=f"Foxel 上游错误({context}): {msg}")
|
||||
return payload.get("data")
|
||||
|
||||
|
||||
class FoxelAdapter:
|
||||
def __init__(self, record: StorageAdapter):
|
||||
self.record = record
|
||||
cfg = record.config or {}
|
||||
|
||||
self.base_url: str = str(cfg.get("base_url", "")).rstrip("/")
|
||||
if not self.base_url.startswith("http"):
|
||||
raise ValueError("foxel requires base_url http/https")
|
||||
|
||||
self.username: str = str(cfg.get("username") or "")
|
||||
self.password: str = str(cfg.get("password") or "")
|
||||
if not self.username or not self.password:
|
||||
raise ValueError("foxel requires username and password")
|
||||
|
||||
self.timeout: float = float(cfg.get("timeout", 15))
|
||||
self.root_path: str = _normalize_fs_path(str(cfg.get("root") or "/"))
|
||||
|
||||
self._token: str | None = None
|
||||
self._login_lock = asyncio.Lock()
|
||||
|
||||
def get_effective_root(self, sub_path: str | None) -> str:
|
||||
base = _normalize_fs_path(self.root_path)
|
||||
if sub_path:
|
||||
return _join_fs_path(base, sub_path)
|
||||
return base
|
||||
|
||||
async def _login(self) -> str:
|
||||
url = self.base_url + "/api/auth/login"
|
||||
body = {"username": self.username, "password": self.password}
|
||||
async with httpx.AsyncClient(timeout=self.timeout, follow_redirects=True) as client:
|
||||
resp = await client.post(url, data=body)
|
||||
resp.raise_for_status()
|
||||
payload = resp.json()
|
||||
if not isinstance(payload, dict):
|
||||
raise HTTPException(502, detail="Foxel 登录响应异常")
|
||||
token = payload.get("access_token")
|
||||
if not token:
|
||||
raise HTTPException(502, detail="Foxel 登录失败: 缺少 access_token")
|
||||
return str(token)
|
||||
|
||||
async def _ensure_token(self) -> str:
|
||||
if self._token:
|
||||
return self._token
|
||||
async with self._login_lock:
|
||||
if self._token:
|
||||
return self._token
|
||||
self._token = await self._login()
|
||||
return self._token
|
||||
|
||||
async def _request_json(self, method: str, path: str, *, params: dict | None = None, json: Any = None) -> Any:
|
||||
url = self.base_url + path
|
||||
for attempt in range(2):
|
||||
token = await self._ensure_token()
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
async with httpx.AsyncClient(timeout=self.timeout, follow_redirects=True) as client:
|
||||
resp = await client.request(method, url, headers=headers, params=params, json=json)
|
||||
if resp.status_code == 401 and attempt == 0:
|
||||
self._token = None
|
||||
continue
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
raise HTTPException(502, detail="Foxel 上游请求失败")
|
||||
|
||||
@staticmethod
|
||||
def _encode_path(full_path: str) -> str:
|
||||
return quote(full_path.lstrip("/"), safe="/")
|
||||
|
||||
def _browse_path(self, full_path: str) -> str:
|
||||
full_path = _normalize_fs_path(full_path)
|
||||
if full_path == "/":
|
||||
return "/api/fs/"
|
||||
return "/api/fs/" + self._encode_path(full_path)
|
||||
|
||||
def _stat_path(self, full_path: str) -> str:
|
||||
full_path = _normalize_fs_path(full_path)
|
||||
if full_path == "/":
|
||||
return "/api/fs/stat/"
|
||||
return "/api/fs/stat/" + self._encode_path(full_path)
|
||||
|
||||
def _file_path(self, full_path: str) -> str:
|
||||
full_path = _normalize_fs_path(full_path)
|
||||
if full_path == "/":
|
||||
return "/api/fs/file/"
|
||||
return "/api/fs/file/" + self._encode_path(full_path)
|
||||
|
||||
def _stream_path(self, full_path: str) -> str:
|
||||
full_path = _normalize_fs_path(full_path)
|
||||
if full_path == "/":
|
||||
return "/api/fs/stream/"
|
||||
return "/api/fs/stream/" + self._encode_path(full_path)
|
||||
|
||||
async def list_dir(
|
||||
self,
|
||||
root: str,
|
||||
rel: str,
|
||||
page_num: int = 1,
|
||||
page_size: int = 50,
|
||||
sort_by: str = "name",
|
||||
sort_order: str = "asc",
|
||||
) -> Tuple[List[Dict], int]:
|
||||
rel = (rel or "").strip("/")
|
||||
full_path = _join_fs_path(root, rel)
|
||||
payload = await self._request_json(
|
||||
"GET",
|
||||
self._browse_path(full_path),
|
||||
params={
|
||||
"page": page_num,
|
||||
"page_size": page_size,
|
||||
"sort_by": sort_by,
|
||||
"sort_order": sort_order,
|
||||
},
|
||||
)
|
||||
data = _unwrap_success(payload, context="list_dir")
|
||||
if not isinstance(data, dict):
|
||||
raise HTTPException(502, detail="Foxel 浏览响应异常")
|
||||
entries = data.get("entries") or []
|
||||
pagination = data.get("pagination") or {}
|
||||
total = pagination.get("total")
|
||||
try:
|
||||
total_int = int(total) if total is not None else len(entries)
|
||||
except Exception:
|
||||
total_int = len(entries)
|
||||
if not isinstance(entries, list):
|
||||
entries = []
|
||||
return entries, total_int
|
||||
|
||||
async def stat_file(self, root: str, rel: str):
|
||||
rel = (rel or "").strip("/")
|
||||
full_path = _join_fs_path(root, rel)
|
||||
payload = await self._request_json("GET", self._stat_path(full_path))
|
||||
data = _unwrap_success(payload, context="stat_file")
|
||||
if not isinstance(data, dict):
|
||||
raise HTTPException(502, detail="Foxel stat 响应异常")
|
||||
return data
|
||||
|
||||
async def exists(self, root: str, rel: str) -> bool:
|
||||
rel = (rel or "").strip("/")
|
||||
full_path = _join_fs_path(root, rel)
|
||||
url = self.base_url + self._stat_path(full_path)
|
||||
for attempt in range(2):
|
||||
token = await self._ensure_token()
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
async with httpx.AsyncClient(timeout=self.timeout, follow_redirects=True) as client:
|
||||
resp = await client.get(url, headers=headers)
|
||||
if resp.status_code == 401 and attempt == 0:
|
||||
self._token = None
|
||||
continue
|
||||
return resp.status_code == 200
|
||||
return False
|
||||
|
||||
async def read_file(self, root: str, rel: str) -> bytes:
|
||||
rel = (rel or "").lstrip("/")
|
||||
full_path = _join_fs_path(root, rel)
|
||||
url = self.base_url + self._file_path(full_path)
|
||||
for attempt in range(2):
|
||||
token = await self._ensure_token()
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
async with httpx.AsyncClient(timeout=self.timeout, follow_redirects=True) as client:
|
||||
resp = await client.get(url, headers=headers)
|
||||
if resp.status_code == 401 and attempt == 0:
|
||||
self._token = None
|
||||
continue
|
||||
if resp.status_code == 404:
|
||||
raise FileNotFoundError(rel)
|
||||
resp.raise_for_status()
|
||||
return resp.content
|
||||
raise HTTPException(502, detail="Foxel 读取失败")
|
||||
|
||||
async def _upload_file_path(self, full_path: str, file_path: Path) -> None:
|
||||
url = self.base_url + self._file_path(full_path)
|
||||
filename = Path(full_path).name or file_path.name
|
||||
for attempt in range(2):
|
||||
token = await self._ensure_token()
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
with file_path.open("rb") as f:
|
||||
files = {"file": (filename, f, "application/octet-stream")}
|
||||
async with httpx.AsyncClient(timeout=self.timeout, follow_redirects=True) as client:
|
||||
resp = await client.post(url, headers=headers, files=files)
|
||||
if resp.status_code == 401 and attempt == 0:
|
||||
self._token = None
|
||||
continue
|
||||
resp.raise_for_status()
|
||||
return
|
||||
raise HTTPException(502, detail="Foxel 上传失败")
|
||||
|
||||
async def write_file(self, root: str, rel: str, data: bytes):
|
||||
rel = (rel or "").lstrip("/")
|
||||
full_path = _join_fs_path(root, rel)
|
||||
url = self.base_url + self._file_path(full_path)
|
||||
filename = Path(rel).name or "file"
|
||||
for attempt in range(2):
|
||||
token = await self._ensure_token()
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
files = {"file": (filename, data, "application/octet-stream")}
|
||||
async with httpx.AsyncClient(timeout=self.timeout, follow_redirects=True) as client:
|
||||
resp = await client.post(url, headers=headers, files=files)
|
||||
if resp.status_code == 401 and attempt == 0:
|
||||
self._token = None
|
||||
continue
|
||||
resp.raise_for_status()
|
||||
return True
|
||||
raise HTTPException(502, detail="Foxel 写入失败")
|
||||
|
||||
async def write_upload_file(self, root: str, rel: str, file_obj, filename: str | None, file_size: int | None = None, content_type: str | None = None):
|
||||
rel = (rel or "").lstrip("/")
|
||||
full_path = _join_fs_path(root, rel)
|
||||
url = self.base_url + self._file_path(full_path)
|
||||
name = filename or Path(rel).name or "file"
|
||||
mime = content_type or "application/octet-stream"
|
||||
for attempt in range(2):
|
||||
try:
|
||||
if callable(getattr(file_obj, "seek", None)):
|
||||
file_obj.seek(0)
|
||||
except Exception:
|
||||
pass
|
||||
token = await self._ensure_token()
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
files = {"file": (name, file_obj, mime)}
|
||||
async with httpx.AsyncClient(timeout=self.timeout, follow_redirects=True) as client:
|
||||
resp = await client.post(url, headers=headers, files=files)
|
||||
if resp.status_code == 401 and attempt == 0:
|
||||
self._token = None
|
||||
continue
|
||||
resp.raise_for_status()
|
||||
return {"size": file_size or 0}
|
||||
raise HTTPException(502, detail="Foxel 上传失败")
|
||||
|
||||
async def write_file_stream(self, root: str, rel: str, data_iter: AsyncIterator[bytes]):
|
||||
rel = (rel or "").lstrip("/")
|
||||
full_path = _join_fs_path(root, rel)
|
||||
suffix = Path(rel).suffix
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tf:
|
||||
tmp_path = Path(tf.name)
|
||||
|
||||
size = 0
|
||||
try:
|
||||
with tmp_path.open("wb") as f:
|
||||
async for chunk in data_iter:
|
||||
if not chunk:
|
||||
continue
|
||||
f.write(chunk)
|
||||
size += len(chunk)
|
||||
await self._upload_file_path(full_path, tmp_path)
|
||||
return size
|
||||
finally:
|
||||
try:
|
||||
tmp_path.unlink(missing_ok=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def mkdir(self, root: str, rel: str):
|
||||
rel = (rel or "").strip("/")
|
||||
full_path = _join_fs_path(root, rel)
|
||||
payload = await self._request_json("POST", "/api/fs/mkdir", json={"path": full_path})
|
||||
_unwrap_success(payload, context="mkdir")
|
||||
return True
|
||||
|
||||
async def delete(self, root: str, rel: str):
|
||||
rel = (rel or "").strip("/")
|
||||
full_path = _join_fs_path(root, rel)
|
||||
url = self.base_url + self._browse_path(full_path)
|
||||
for attempt in range(2):
|
||||
token = await self._ensure_token()
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
async with httpx.AsyncClient(timeout=self.timeout, follow_redirects=True) as client:
|
||||
resp = await client.delete(url, headers=headers)
|
||||
if resp.status_code == 401 and attempt == 0:
|
||||
self._token = None
|
||||
continue
|
||||
if resp.status_code == 404:
|
||||
return
|
||||
resp.raise_for_status()
|
||||
return
|
||||
raise HTTPException(502, detail="Foxel 删除失败")
|
||||
|
||||
async def move(self, root: str, src_rel: str, dst_rel: str):
|
||||
src_path = _join_fs_path(root, (src_rel or "").lstrip("/"))
|
||||
dst_path = _join_fs_path(root, (dst_rel or "").lstrip("/"))
|
||||
payload = await self._request_json("POST", "/api/fs/move", json={"src": src_path, "dst": dst_path})
|
||||
_unwrap_success(payload, context="move")
|
||||
return True
|
||||
|
||||
async def rename(self, root: str, src_rel: str, dst_rel: str):
|
||||
src_path = _join_fs_path(root, (src_rel or "").lstrip("/"))
|
||||
dst_path = _join_fs_path(root, (dst_rel or "").lstrip("/"))
|
||||
payload = await self._request_json("POST", "/api/fs/rename", json={"src": src_path, "dst": dst_path})
|
||||
_unwrap_success(payload, context="rename")
|
||||
return True
|
||||
|
||||
async def copy(self, root: str, src_rel: str, dst_rel: str, overwrite: bool = False):
|
||||
src_path = _join_fs_path(root, (src_rel or "").lstrip("/"))
|
||||
dst_path = _join_fs_path(root, (dst_rel or "").lstrip("/"))
|
||||
payload = await self._request_json(
|
||||
"POST",
|
||||
"/api/fs/copy",
|
||||
json={"src": src_path, "dst": dst_path},
|
||||
params={"overwrite": overwrite},
|
||||
)
|
||||
_unwrap_success(payload, context="copy")
|
||||
return True
|
||||
|
||||
async def stream_file(self, root: str, rel: str, range_header: str | None):
|
||||
rel = (rel or "").lstrip("/")
|
||||
full_path = _join_fs_path(root, rel)
|
||||
url = self.base_url + self._stream_path(full_path)
|
||||
|
||||
headers = {}
|
||||
if range_header:
|
||||
headers["Range"] = range_header
|
||||
|
||||
for attempt in range(2):
|
||||
token = await self._ensure_token()
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
client = httpx.AsyncClient(timeout=None, follow_redirects=True)
|
||||
stream_cm = client.stream("GET", url, headers=headers)
|
||||
try:
|
||||
resp = await stream_cm.__aenter__()
|
||||
except Exception:
|
||||
await client.aclose()
|
||||
raise
|
||||
|
||||
if resp.status_code == 401 and attempt == 0:
|
||||
try:
|
||||
await resp.aread()
|
||||
finally:
|
||||
await stream_cm.__aexit__(None, None, None)
|
||||
await client.aclose()
|
||||
self._token = None
|
||||
continue
|
||||
|
||||
if resp.status_code == 404:
|
||||
try:
|
||||
await resp.aread()
|
||||
finally:
|
||||
await stream_cm.__aexit__(None, None, None)
|
||||
await client.aclose()
|
||||
raise FileNotFoundError(rel)
|
||||
|
||||
if resp.status_code >= 400:
|
||||
try:
|
||||
await resp.aread()
|
||||
finally:
|
||||
await stream_cm.__aexit__(None, None, None)
|
||||
await client.aclose()
|
||||
resp.raise_for_status()
|
||||
|
||||
content_type = resp.headers.get("Content-Type") or (
|
||||
mimetypes.guess_type(rel)[0] or "application/octet-stream"
|
||||
)
|
||||
out_headers = {}
|
||||
for key in ("Accept-Ranges", "Content-Range", "Content-Length"):
|
||||
value = resp.headers.get(key)
|
||||
if value:
|
||||
out_headers[key] = value
|
||||
|
||||
async def iterator():
|
||||
try:
|
||||
async for chunk in resp.aiter_bytes():
|
||||
if chunk:
|
||||
yield chunk
|
||||
finally:
|
||||
await stream_cm.__aexit__(None, None, None)
|
||||
await client.aclose()
|
||||
|
||||
return StreamingResponse(
|
||||
iterator(),
|
||||
status_code=resp.status_code,
|
||||
headers=out_headers,
|
||||
media_type=content_type,
|
||||
)
|
||||
|
||||
raise HTTPException(502, detail="Foxel 流式读取失败")
|
||||
|
||||
|
||||
ADAPTER_TYPE = "foxel"
|
||||
CONFIG_SCHEMA = [
|
||||
{"key": "base_url", "label": "节点地址", "type": "string", "required": True, "placeholder": "http://127.0.0.1:8000"},
|
||||
{"key": "username", "label": "用户名", "type": "string", "required": True},
|
||||
{"key": "password", "label": "密码", "type": "password", "required": True},
|
||||
{"key": "root", "label": "远端根目录", "type": "string", "required": False, "default": "/", "placeholder": "/ 或 /drive"},
|
||||
{"key": "timeout", "label": "超时(秒)", "type": "number", "required": False, "default": 60},
|
||||
]
|
||||
|
||||
|
||||
def ADAPTER_FACTORY(rec: StorageAdapter):
|
||||
return FoxelAdapter(rec)
|
||||
645
domain/adapters/providers/ftp.py
Normal file
645
domain/adapters/providers/ftp.py
Normal file
@@ -0,0 +1,645 @@
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Dict, Tuple, AsyncIterator, Optional
|
||||
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from ftplib import FTP, error_perm
|
||||
import mimetypes
|
||||
|
||||
from models import StorageAdapter
|
||||
|
||||
|
||||
def _join_remote(root: str, rel: str) -> str:
|
||||
root = (root or "/").rstrip("/") or "/"
|
||||
rel = (rel or "").lstrip("/")
|
||||
if not rel:
|
||||
return root
|
||||
return f"{root}/{rel}"
|
||||
|
||||
|
||||
def _parse_mlst_line(line: str) -> Dict[str, str]:
|
||||
out: Dict[str, str] = {}
|
||||
try:
|
||||
facts, _, name = line.partition(" ")
|
||||
for part in facts.split(";"):
|
||||
if not part or "=" not in part:
|
||||
continue
|
||||
k, v = part.split("=", 1)
|
||||
out[k.strip().lower()] = v.strip()
|
||||
if name:
|
||||
out["name"] = name.strip()
|
||||
except Exception:
|
||||
pass
|
||||
return out
|
||||
|
||||
|
||||
def _parse_modify_to_epoch(mod: str) -> int:
|
||||
# Formats we may see: YYYYMMDDHHMMSS or YYYYMMDDHHMMSS(.sss)
|
||||
try:
|
||||
mod = mod.strip()
|
||||
mod = mod.split(".")[0]
|
||||
if len(mod) >= 14:
|
||||
y = int(mod[0:4])
|
||||
m = int(mod[4:6])
|
||||
d = int(mod[6:8])
|
||||
hh = int(mod[8:10])
|
||||
mm = int(mod[10:12])
|
||||
ss = int(mod[12:14])
|
||||
import datetime as _dt
|
||||
return int(_dt.datetime(y, m, d, hh, mm, ss, tzinfo=_dt.timezone.utc).timestamp())
|
||||
except Exception:
|
||||
return 0
|
||||
return 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class _Range:
|
||||
start: int
|
||||
end: Optional[int] # inclusive
|
||||
|
||||
|
||||
class FTPAdapter:
|
||||
def __init__(self, record: StorageAdapter):
|
||||
self.record = record
|
||||
cfg = record.config
|
||||
self.host: str = cfg.get("host")
|
||||
self.port: int = int(cfg.get("port", 21))
|
||||
self.username: Optional[str] = cfg.get("username")
|
||||
self.password: Optional[str] = cfg.get("password")
|
||||
self.passive: bool = bool(cfg.get("passive", True))
|
||||
self.timeout: int = int(cfg.get("timeout", 15))
|
||||
self.root_path: str = cfg.get("root", "/") or "/"
|
||||
|
||||
if not self.host:
|
||||
raise ValueError("FTP adapter requires 'host'")
|
||||
|
||||
def get_effective_root(self, sub_path: str | None) -> str:
|
||||
base = self.root_path.rstrip("/") or "/"
|
||||
if sub_path:
|
||||
return _join_remote(base, sub_path)
|
||||
return base
|
||||
|
||||
def _connect(self) -> FTP:
|
||||
ftp = FTP()
|
||||
ftp.connect(self.host, self.port, timeout=self.timeout)
|
||||
if self.username:
|
||||
ftp.login(self.username, self.password or "")
|
||||
else:
|
||||
ftp.login()
|
||||
ftp.set_pasv(self.passive)
|
||||
return ftp
|
||||
|
||||
async def list_dir(self, root: str, rel: str, page_num: int = 1, page_size: int = 50, sort_by: str = "name", sort_order: str = "asc") -> Tuple[List[Dict], int]:
|
||||
path = _join_remote(root, rel.strip('/'))
|
||||
|
||||
def _do_list() -> List[Dict]:
|
||||
ftp = self._connect()
|
||||
try:
|
||||
ftp.cwd(path)
|
||||
except error_perm as e:
|
||||
# path may be file
|
||||
ftp.quit()
|
||||
raise NotADirectoryError(rel) from e
|
||||
|
||||
entries: List[Dict] = []
|
||||
# Try MLSD first
|
||||
try:
|
||||
for name, facts in ftp.mlsd():
|
||||
if name in (".", ".."):
|
||||
continue
|
||||
is_dir = (facts.get("type") == "dir")
|
||||
size = int(facts.get("size") or 0)
|
||||
mtime = _parse_modify_to_epoch(facts.get("modify") or "")
|
||||
entries.append({
|
||||
"name": name,
|
||||
"is_dir": is_dir,
|
||||
"size": 0 if is_dir else size,
|
||||
"mtime": mtime,
|
||||
"type": "dir" if is_dir else "file",
|
||||
})
|
||||
ftp.quit()
|
||||
return entries
|
||||
except Exception:
|
||||
# Fallback to NLST + probing
|
||||
pass
|
||||
|
||||
names = []
|
||||
try:
|
||||
names = ftp.nlst()
|
||||
except Exception:
|
||||
ftp.quit()
|
||||
return []
|
||||
|
||||
for name in names:
|
||||
if name in (".", ".."):
|
||||
continue
|
||||
is_dir = False
|
||||
size = 0
|
||||
mtime = 0
|
||||
try:
|
||||
# If we can CWD, it's a directory
|
||||
ftp.cwd(_join_remote(path, name))
|
||||
ftp.cwd(path)
|
||||
is_dir = True
|
||||
except Exception:
|
||||
is_dir = False
|
||||
try:
|
||||
size = ftp.size(_join_remote(path, name)) or 0
|
||||
except Exception:
|
||||
size = 0
|
||||
try:
|
||||
mdtm = ftp.sendcmd("MDTM " + _join_remote(path, name))
|
||||
# Example: '213 20241012XXXXXX'
|
||||
if mdtm.startswith("213 "):
|
||||
mtime = _parse_modify_to_epoch(mdtm.split(" ", 1)[1])
|
||||
except Exception:
|
||||
pass
|
||||
entries.append({
|
||||
"name": name,
|
||||
"is_dir": is_dir,
|
||||
"size": 0 if is_dir else int(size or 0),
|
||||
"mtime": int(mtime or 0),
|
||||
"type": "dir" if is_dir else "file",
|
||||
})
|
||||
ftp.quit()
|
||||
return entries
|
||||
|
||||
entries = await asyncio.to_thread(_do_list)
|
||||
|
||||
reverse = sort_order.lower() == "desc"
|
||||
|
||||
def get_sort_key(item):
|
||||
key = (not item["is_dir"],)
|
||||
f = sort_by.lower()
|
||||
if f == "name":
|
||||
key += (item["name"].lower(),)
|
||||
elif f == "size":
|
||||
key += (item.get("size", 0),)
|
||||
elif f == "mtime":
|
||||
key += (item.get("mtime", 0),)
|
||||
else:
|
||||
key += (item["name"].lower(),)
|
||||
return key
|
||||
|
||||
entries.sort(key=get_sort_key, reverse=reverse)
|
||||
total = len(entries)
|
||||
start = (page_num - 1) * page_size
|
||||
end = start + page_size
|
||||
return entries[start:end], total
|
||||
|
||||
async def read_file(self, root: str, rel: str) -> bytes:
|
||||
path = _join_remote(root, rel)
|
||||
|
||||
def _do_read() -> bytes:
|
||||
ftp = self._connect()
|
||||
try:
|
||||
chunks: List[bytes] = []
|
||||
ftp.retrbinary("RETR " + path, lambda b: chunks.append(b))
|
||||
return b"".join(chunks)
|
||||
except error_perm as e:
|
||||
if str(e).startswith("550"):
|
||||
raise FileNotFoundError(rel)
|
||||
raise
|
||||
finally:
|
||||
try:
|
||||
ftp.quit()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return await asyncio.to_thread(_do_read)
|
||||
|
||||
async def write_file(self, root: str, rel: str, data: bytes):
|
||||
path = _join_remote(root, rel)
|
||||
|
||||
def _ensure_dirs(ftp: FTP, dir_path: str):
|
||||
parts = [p for p in dir_path.strip("/").split("/") if p]
|
||||
cur = "/"
|
||||
for p in parts:
|
||||
cur = _join_remote(cur, p)
|
||||
try:
|
||||
ftp.mkd(cur)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _do_write():
|
||||
ftp = self._connect()
|
||||
try:
|
||||
parent = "/" if "/" not in path.strip("/") else path.rsplit("/", 1)[0]
|
||||
_ensure_dirs(ftp, parent)
|
||||
from io import BytesIO
|
||||
bio = BytesIO(data)
|
||||
ftp.storbinary("STOR " + path, bio)
|
||||
finally:
|
||||
try:
|
||||
ftp.quit()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await asyncio.to_thread(_do_write)
|
||||
|
||||
async def write_upload_file(self, root: str, rel: str, file_obj, filename: str | None, file_size: int | None = None, content_type: str | None = None):
|
||||
path = _join_remote(root, rel)
|
||||
|
||||
def _ensure_dirs(ftp: FTP, dir_path: str):
|
||||
parts = [p for p in dir_path.strip("/").split("/") if p]
|
||||
cur = "/"
|
||||
for p in parts:
|
||||
cur = _join_remote(cur, p)
|
||||
try:
|
||||
ftp.mkd(cur)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _do_upload():
|
||||
ftp = self._connect()
|
||||
try:
|
||||
parent = "/" if "/" not in path.strip("/") else path.rsplit("/", 1)[0]
|
||||
_ensure_dirs(ftp, parent)
|
||||
try:
|
||||
if callable(getattr(file_obj, "seek", None)):
|
||||
file_obj.seek(0)
|
||||
except Exception:
|
||||
pass
|
||||
ftp.storbinary("STOR " + path, file_obj)
|
||||
finally:
|
||||
try:
|
||||
ftp.quit()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await asyncio.to_thread(_do_upload)
|
||||
return {"size": file_size or 0}
|
||||
|
||||
async def write_file_stream(self, root: str, rel: str, data_iter: AsyncIterator[bytes]):
|
||||
# KISS: 聚合后一次性写入
|
||||
buf = bytearray()
|
||||
async for chunk in data_iter:
|
||||
if chunk:
|
||||
buf.extend(chunk)
|
||||
await self.write_file(root, rel, bytes(buf))
|
||||
return len(buf)
|
||||
|
||||
async def mkdir(self, root: str, rel: str):
|
||||
path = _join_remote(root, rel)
|
||||
|
||||
def _do_mkdir():
|
||||
ftp = self._connect()
|
||||
try:
|
||||
parts = [p for p in path.strip("/").split("/") if p]
|
||||
cur = "/"
|
||||
for p in parts:
|
||||
cur = _join_remote(cur, p)
|
||||
try:
|
||||
ftp.mkd(cur)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
try:
|
||||
ftp.quit()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await asyncio.to_thread(_do_mkdir)
|
||||
|
||||
async def delete(self, root: str, rel: str):
|
||||
path = _join_remote(root, rel)
|
||||
|
||||
def _do_delete():
|
||||
ftp = self._connect()
|
||||
try:
|
||||
# Try file delete
|
||||
try:
|
||||
ftp.delete(path)
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Recursively delete dir
|
||||
def _rm_tree(dir_path: str):
|
||||
try:
|
||||
ftp.cwd(dir_path)
|
||||
except Exception:
|
||||
return
|
||||
items = []
|
||||
try:
|
||||
for name, facts in ftp.mlsd():
|
||||
if name in (".", ".."):
|
||||
continue
|
||||
items.append((name, facts.get("type") == "dir"))
|
||||
except Exception:
|
||||
try:
|
||||
names = ftp.nlst()
|
||||
except Exception:
|
||||
names = []
|
||||
for n in names:
|
||||
if n in (".", ".."):
|
||||
continue
|
||||
# Best-effort dir check
|
||||
try:
|
||||
ftp.cwd(_join_remote(dir_path, n))
|
||||
ftp.cwd(dir_path)
|
||||
items.append((n, True))
|
||||
except Exception:
|
||||
items.append((n, False))
|
||||
for n, is_dir in items:
|
||||
child = _join_remote(dir_path, n)
|
||||
if is_dir:
|
||||
_rm_tree(child)
|
||||
else:
|
||||
try:
|
||||
ftp.delete(child)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
ftp.rmd(dir_path)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
_rm_tree(path)
|
||||
finally:
|
||||
try:
|
||||
ftp.quit()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await asyncio.to_thread(_do_delete)
|
||||
|
||||
async def move(self, root: str, src_rel: str, dst_rel: str):
|
||||
src = _join_remote(root, src_rel)
|
||||
dst = _join_remote(root, dst_rel)
|
||||
|
||||
def _do_move():
|
||||
ftp = self._connect()
|
||||
try:
|
||||
# Ensure dst parent exists
|
||||
parent = "/" if "/" not in dst.strip("/") else dst.rsplit("/", 1)[0]
|
||||
parts = [p for p in parent.strip("/").split("/") if p]
|
||||
cur = "/"
|
||||
for p in parts:
|
||||
cur = _join_remote(cur, p)
|
||||
try:
|
||||
ftp.mkd(cur)
|
||||
except Exception:
|
||||
pass
|
||||
ftp.rename(src, dst)
|
||||
finally:
|
||||
try:
|
||||
ftp.quit()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await asyncio.to_thread(_do_move)
|
||||
|
||||
async def rename(self, root: str, src_rel: str, dst_rel: str):
|
||||
await self.move(root, src_rel, dst_rel)
|
||||
|
||||
async def copy(self, root: str, src_rel: str, dst_rel: str, overwrite: bool = False):
|
||||
src = _join_remote(root, src_rel)
|
||||
dst = _join_remote(root, dst_rel)
|
||||
|
||||
# naive implementation: download then upload; recursively for dirs
|
||||
async def _is_dir(path: str) -> bool:
|
||||
def _probe() -> bool:
|
||||
ftp = self._connect()
|
||||
try:
|
||||
try:
|
||||
ftp.cwd(path)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
finally:
|
||||
try:
|
||||
ftp.quit()
|
||||
except Exception:
|
||||
pass
|
||||
return await asyncio.to_thread(_probe)
|
||||
|
||||
if await _is_dir(src):
|
||||
# list children, create dst dir, copy recursively
|
||||
await self.mkdir(root, dst_rel)
|
||||
|
||||
children, _ = await self.list_dir(root, src_rel, page_num=1, page_size=10_000)
|
||||
for ent in children:
|
||||
child_src = f"{src_rel.rstrip('/')}/{ent['name']}"
|
||||
child_dst = f"{dst_rel.rstrip('/')}/{ent['name']}"
|
||||
await self.copy(root, child_src, child_dst, overwrite)
|
||||
return
|
||||
|
||||
# file
|
||||
data = await self.read_file(root, src_rel)
|
||||
if not overwrite:
|
||||
# best-effort existence check
|
||||
try:
|
||||
await self.stat_file(root, dst_rel)
|
||||
raise FileExistsError(dst_rel)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
await self.write_file(root, dst_rel, data)
|
||||
|
||||
async def stat_file(self, root: str, rel: str):
|
||||
path = _join_remote(root, rel)
|
||||
|
||||
def _do_stat():
|
||||
ftp = self._connect()
|
||||
try:
|
||||
# Try MLST
|
||||
try:
|
||||
resp: List[str] = []
|
||||
ftp.retrlines("MLST " + path, resp.append)
|
||||
# The last line usually contains facts
|
||||
facts = {}
|
||||
if resp:
|
||||
facts = _parse_mlst_line(resp[-1])
|
||||
name = rel.split("/")[-1]
|
||||
t = facts.get("type") or "file"
|
||||
is_dir = t == "dir"
|
||||
size = int(facts.get("size") or 0)
|
||||
mtime = _parse_modify_to_epoch(facts.get("modify") or "")
|
||||
return {
|
||||
"name": name,
|
||||
"is_dir": is_dir,
|
||||
"size": 0 if is_dir else size,
|
||||
"mtime": mtime,
|
||||
"type": "dir" if is_dir else "file",
|
||||
"path": path,
|
||||
}
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Probe directory
|
||||
try:
|
||||
ftp.cwd(path)
|
||||
return {
|
||||
"name": rel.split("/")[-1],
|
||||
"is_dir": True,
|
||||
"size": 0,
|
||||
"mtime": 0,
|
||||
"type": "dir",
|
||||
"path": path,
|
||||
}
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Treat as file
|
||||
try:
|
||||
size = ftp.size(path) or 0
|
||||
except Exception:
|
||||
size = 0
|
||||
try:
|
||||
mdtm = ftp.sendcmd("MDTM " + path)
|
||||
mtime = _parse_modify_to_epoch(mdtm.split(" ", 1)[1]) if mdtm.startswith("213 ") else 0
|
||||
except Exception:
|
||||
mtime = 0
|
||||
return {
|
||||
"name": rel.split("/")[-1],
|
||||
"is_dir": False,
|
||||
"size": int(size or 0),
|
||||
"mtime": int(mtime or 0),
|
||||
"type": "file",
|
||||
"path": path,
|
||||
}
|
||||
except error_perm as e:
|
||||
if str(e).startswith("550"):
|
||||
raise FileNotFoundError(rel)
|
||||
raise
|
||||
finally:
|
||||
try:
|
||||
ftp.quit()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return await asyncio.to_thread(_do_stat)
|
||||
|
||||
async def stream_file(self, root: str, rel: str, range_header: str | None):
|
||||
path = _join_remote(root, rel)
|
||||
# Get size (best-effort)
|
||||
def _get_size() -> Optional[int]:
|
||||
ftp = self._connect()
|
||||
try:
|
||||
try:
|
||||
return int(ftp.size(path) or 0)
|
||||
except Exception:
|
||||
return None
|
||||
finally:
|
||||
try:
|
||||
ftp.quit()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
total_size = await asyncio.to_thread(_get_size)
|
||||
mime, _ = mimetypes.guess_type(rel)
|
||||
content_type = mime or "application/octet-stream"
|
||||
|
||||
rng: Optional[_Range] = None
|
||||
status = 200
|
||||
headers = {"Accept-Ranges": "bytes", "Content-Type": content_type}
|
||||
if range_header and range_header.startswith("bytes=") and total_size is not None:
|
||||
try:
|
||||
s, e = (range_header.removeprefix("bytes=").split("-", 1))
|
||||
start = int(s) if s.strip() else 0
|
||||
end = int(e) if e.strip() else (total_size - 1)
|
||||
if start >= total_size:
|
||||
raise HTTPException(416, detail="Requested Range Not Satisfiable")
|
||||
if end >= total_size:
|
||||
end = total_size - 1
|
||||
rng = _Range(start, end)
|
||||
status = 206
|
||||
headers["Content-Range"] = f"bytes {start}-{end}/{total_size}"
|
||||
headers["Content-Length"] = str(end - start + 1)
|
||||
except ValueError:
|
||||
raise HTTPException(400, detail="Invalid Range header")
|
||||
elif total_size is not None:
|
||||
headers["Content-Length"] = str(total_size)
|
||||
|
||||
queue: asyncio.Queue[Optional[bytes]] = asyncio.Queue(maxsize=8)
|
||||
|
||||
class _Stop(Exception):
|
||||
pass
|
||||
|
||||
def _worker():
|
||||
ftp = self._connect()
|
||||
remaining = None
|
||||
if rng is not None:
|
||||
remaining = (rng.end - rng.start + 1) if rng.end is not None else None
|
||||
|
||||
def _cb(chunk: bytes):
|
||||
nonlocal remaining
|
||||
if not chunk:
|
||||
return
|
||||
try:
|
||||
if remaining is not None:
|
||||
if len(chunk) > remaining:
|
||||
part = chunk[:remaining]
|
||||
queue.put_nowait(part)
|
||||
remaining = 0
|
||||
raise _Stop()
|
||||
else:
|
||||
queue.put_nowait(chunk)
|
||||
remaining -= len(chunk)
|
||||
if remaining <= 0:
|
||||
raise _Stop()
|
||||
else:
|
||||
queue.put_nowait(chunk)
|
||||
except _Stop:
|
||||
raise
|
||||
except Exception:
|
||||
# queue full or event loop closed
|
||||
raise _Stop()
|
||||
|
||||
try:
|
||||
if rng is not None:
|
||||
ftp.retrbinary("RETR " + path, _cb, rest=rng.start)
|
||||
else:
|
||||
ftp.retrbinary("RETR " + path, _cb)
|
||||
queue.put_nowait(None)
|
||||
except _Stop:
|
||||
try:
|
||||
queue.put_nowait(None)
|
||||
except Exception:
|
||||
pass
|
||||
except error_perm as e:
|
||||
try:
|
||||
queue.put_nowait(None)
|
||||
except Exception:
|
||||
pass
|
||||
if str(e).startswith("550"):
|
||||
pass
|
||||
finally:
|
||||
try:
|
||||
ftp.quit()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def agen():
|
||||
worker_fut = asyncio.to_thread(_worker)
|
||||
try:
|
||||
while True:
|
||||
chunk = await queue.get()
|
||||
if chunk is None:
|
||||
break
|
||||
yield chunk
|
||||
finally:
|
||||
try:
|
||||
await worker_fut
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return StreamingResponse(agen(), status_code=status, headers=headers, media_type=content_type)
|
||||
|
||||
|
||||
ADAPTER_TYPE = "ftp"
|
||||
|
||||
CONFIG_SCHEMA = [
|
||||
{"key": "host", "label": "主机", "type": "string", "required": True, "placeholder": "ftp.example.com"},
|
||||
{"key": "port", "label": "端口", "type": "number", "required": False, "default": 21},
|
||||
{"key": "username", "label": "用户名", "type": "string", "required": False},
|
||||
{"key": "password", "label": "密码", "type": "password", "required": False},
|
||||
{"key": "passive", "label": "被动模式", "type": "boolean", "required": False, "default": True},
|
||||
{"key": "timeout", "label": "超时(秒)", "type": "number", "required": False, "default": 15},
|
||||
{"key": "root", "label": "根路径", "type": "string", "required": False, "default": "/"},
|
||||
]
|
||||
|
||||
|
||||
def ADAPTER_FACTORY(rec: StorageAdapter):
|
||||
return FTPAdapter(rec)
|
||||
559
domain/adapters/providers/googledrive.py
Normal file
559
domain/adapters/providers/googledrive.py
Normal file
@@ -0,0 +1,559 @@
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import List, Dict, Tuple, AsyncIterator
|
||||
import httpx
|
||||
from fastapi.responses import StreamingResponse, Response
|
||||
from fastapi import HTTPException
|
||||
from models import StorageAdapter
|
||||
|
||||
GOOGLE_OAUTH_URL = "https://oauth2.googleapis.com/token"
|
||||
GOOGLE_DRIVE_API_URL = "https://www.googleapis.com/drive/v3"
|
||||
|
||||
|
||||
class GoogleDriveAdapter:
|
||||
"""Google Drive 存储适配器"""
|
||||
|
||||
def __init__(self, record: StorageAdapter):
|
||||
self.record = record
|
||||
cfg = record.config
|
||||
self.client_id = cfg.get("client_id")
|
||||
self.client_secret = cfg.get("client_secret")
|
||||
self.refresh_token = cfg.get("refresh_token")
|
||||
self.root_folder_id = cfg.get("root_folder_id", "root")
|
||||
self.enable_redirect_307 = bool(cfg.get("enable_direct_download_307"))
|
||||
|
||||
if not all([self.client_id, self.client_secret, self.refresh_token]):
|
||||
raise ValueError(
|
||||
"Google Drive 适配器需要 client_id, client_secret, 和 refresh_token")
|
||||
|
||||
self._access_token: str | None = None
|
||||
self._token_expiry: datetime | None = None
|
||||
|
||||
def get_effective_root(self, sub_path: str | None) -> str:
|
||||
"""
|
||||
获取有效根路径。
|
||||
:param sub_path: 子路径。
|
||||
:return: 完整的有效路径。
|
||||
"""
|
||||
if sub_path:
|
||||
return f"{sub_path.strip('/')}".strip()
|
||||
return ""
|
||||
|
||||
async def _get_access_token(self) -> str:
|
||||
"""
|
||||
获取或刷新 access token。
|
||||
:return: access token。
|
||||
"""
|
||||
if self._access_token and self._token_expiry and datetime.now(timezone.utc) < self._token_expiry:
|
||||
return self._access_token
|
||||
|
||||
data = {
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
"refresh_token": self.refresh_token,
|
||||
"grant_type": "refresh_token",
|
||||
}
|
||||
async with httpx.AsyncClient(timeout=20.0) as client:
|
||||
resp = await client.post(GOOGLE_OAUTH_URL, data=data)
|
||||
resp.raise_for_status()
|
||||
token_data = resp.json()
|
||||
self._access_token = token_data["access_token"]
|
||||
self._token_expiry = datetime.now(
|
||||
timezone.utc) + timedelta(seconds=token_data["expires_in"] - 300)
|
||||
return self._access_token
|
||||
|
||||
async def _request(self, method: str, endpoint: str, **kwargs):
|
||||
"""
|
||||
向 Google Drive API 发送请求。
|
||||
:param method: HTTP 方法。
|
||||
:param endpoint: API 端点。
|
||||
:param kwargs: 其他请求参数。
|
||||
:return: 响应对象。
|
||||
"""
|
||||
token = await self._get_access_token()
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
if "headers" in kwargs:
|
||||
headers.update(kwargs.pop("headers"))
|
||||
|
||||
url = f"{GOOGLE_DRIVE_API_URL}{endpoint}"
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
resp = await client.request(method, url, headers=headers, **kwargs)
|
||||
if resp.status_code == 401:
|
||||
self._access_token = None
|
||||
token = await self._get_access_token()
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
resp = await client.request(method, url, headers=headers, **kwargs)
|
||||
return resp
|
||||
|
||||
async def _get_folder_id_by_path(self, path: str) -> str:
|
||||
"""
|
||||
通过路径获取文件夹 ID。
|
||||
:param path: 路径。
|
||||
:return: 文件夹 ID。
|
||||
"""
|
||||
if not path or path == "/":
|
||||
return self.root_folder_id
|
||||
|
||||
parts = [p for p in path.strip("/").split("/") if p]
|
||||
current_id = self.root_folder_id
|
||||
|
||||
for part in parts:
|
||||
query = f"name='{part}' and '{current_id}' in parents and mimeType='application/vnd.google-apps.folder' and trashed=false"
|
||||
params = {"q": query, "fields": "files(id, name)"}
|
||||
resp = await self._request("GET", "/files", params=params)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
files = data.get("files", [])
|
||||
if not files:
|
||||
raise FileNotFoundError(f"文件夹不存在: {part}")
|
||||
current_id = files[0]["id"]
|
||||
|
||||
return current_id
|
||||
|
||||
async def _get_file_id_by_path(self, path: str) -> str | None:
|
||||
"""
|
||||
通过路径获取文件 ID。
|
||||
:param path: 文件路径。
|
||||
:return: 文件 ID 或 None。
|
||||
"""
|
||||
if not path or path == "/":
|
||||
return self.root_folder_id
|
||||
|
||||
parts = [p for p in path.strip("/").split("/") if p]
|
||||
parent_id = self.root_folder_id
|
||||
|
||||
for i, part in enumerate(parts):
|
||||
is_last = i == len(parts) - 1
|
||||
mime_filter = "" if is_last else "and mimeType='application/vnd.google-apps.folder'"
|
||||
query = f"name='{part}' and '{parent_id}' in parents {mime_filter} and trashed=false"
|
||||
params = {"q": query, "fields": "files(id, name)"}
|
||||
resp = await self._request("GET", "/files", params=params)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
files = data.get("files", [])
|
||||
if not files:
|
||||
return None
|
||||
parent_id = files[0]["id"]
|
||||
|
||||
return parent_id
|
||||
|
||||
def _format_item(self, item: Dict) -> Dict:
|
||||
"""
|
||||
将 Google Drive API 返回的 item 格式化为统一的格式。
|
||||
:param item: Google Drive API 返回的 item 字典。
|
||||
:return: 格式化后的字典。
|
||||
"""
|
||||
is_dir = item["mimeType"] == "application/vnd.google-apps.folder"
|
||||
mtime_str = item.get("modifiedTime", item.get("createdTime", ""))
|
||||
try:
|
||||
mtime = int(datetime.fromisoformat(mtime_str.replace("Z", "+00:00")).timestamp())
|
||||
except:
|
||||
mtime = 0
|
||||
|
||||
return {
|
||||
"name": item["name"],
|
||||
"is_dir": is_dir,
|
||||
"size": 0 if is_dir else int(item.get("size", 0)),
|
||||
"mtime": mtime,
|
||||
"type": "dir" if is_dir else "file",
|
||||
}
|
||||
|
||||
async def list_dir(self, root: str, rel: str, page_num: int = 1, page_size: int = 50, sort_by: str = "name", sort_order: str = "asc") -> Tuple[List[Dict], int]:
|
||||
"""
|
||||
列出目录内容。
|
||||
:param root: 根路径。
|
||||
:param rel: 相对路径。
|
||||
:param page_num: 页码。
|
||||
:param page_size: 每页大小。
|
||||
:param sort_by: 排序字段
|
||||
:param sort_order: 排序顺序
|
||||
:return: 文件/目录列表和总数。
|
||||
"""
|
||||
try:
|
||||
folder_id = await self._get_folder_id_by_path(rel)
|
||||
except FileNotFoundError:
|
||||
return [], 0
|
||||
|
||||
query = f"'{folder_id}' in parents and trashed=false"
|
||||
params = {
|
||||
"q": query,
|
||||
"fields": "files(id, name, mimeType, size, modifiedTime, createdTime)",
|
||||
"pageSize": 1000,
|
||||
}
|
||||
|
||||
all_items = []
|
||||
page_token = None
|
||||
|
||||
while True:
|
||||
if page_token:
|
||||
params["pageToken"] = page_token
|
||||
|
||||
resp = await self._request("GET", "/files", params=params)
|
||||
if resp.status_code == 404:
|
||||
return [], 0
|
||||
resp.raise_for_status()
|
||||
|
||||
data = resp.json()
|
||||
all_items.extend(data.get("files", []))
|
||||
page_token = data.get("nextPageToken")
|
||||
|
||||
if not page_token:
|
||||
break
|
||||
|
||||
formatted_items = [self._format_item(item) for item in all_items]
|
||||
|
||||
# 排序
|
||||
reverse = sort_order.lower() == "desc"
|
||||
def get_sort_key(item):
|
||||
key = (not item["is_dir"],)
|
||||
sort_field = sort_by.lower()
|
||||
if sort_field == "name":
|
||||
key += (item["name"].lower(),)
|
||||
elif sort_field == "size":
|
||||
key += (item["size"],)
|
||||
elif sort_field == "mtime":
|
||||
key += (item["mtime"],)
|
||||
else:
|
||||
key += (item["name"].lower(),)
|
||||
return key
|
||||
formatted_items.sort(key=get_sort_key, reverse=reverse)
|
||||
|
||||
total_count = len(formatted_items)
|
||||
start_idx = (page_num - 1) * page_size
|
||||
end_idx = start_idx + page_size
|
||||
|
||||
return formatted_items[start_idx:end_idx], total_count
|
||||
|
||||
async def read_file(self, root: str, rel: str) -> bytes:
|
||||
"""
|
||||
读取文件内容。
|
||||
:param root: 根路径。
|
||||
:param rel: 相对路径。
|
||||
:return: 文件内容的字节流。
|
||||
"""
|
||||
file_id = await self._get_file_id_by_path(rel)
|
||||
if not file_id:
|
||||
raise FileNotFoundError(rel)
|
||||
|
||||
resp = await self._request("GET", f"/files/{file_id}", params={"alt": "media"})
|
||||
if resp.status_code == 404:
|
||||
raise FileNotFoundError(rel)
|
||||
resp.raise_for_status()
|
||||
return resp.content
|
||||
|
||||
async def write_file(self, root: str, rel: str, data: bytes):
|
||||
"""
|
||||
写入文件。
|
||||
:param root: 根路径。
|
||||
:param rel: 相对路径。
|
||||
:param data: 文件内容的字节流。
|
||||
"""
|
||||
parent_path = "/".join(rel.strip("/").split("/")[:-1])
|
||||
file_name = rel.strip("/").split("/")[-1]
|
||||
parent_id = await self._get_folder_id_by_path(parent_path)
|
||||
|
||||
# 检查文件是否已存在
|
||||
existing_id = await self._get_file_id_by_path(rel)
|
||||
|
||||
if existing_id:
|
||||
# 更新现有文件
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
token = await self._get_access_token()
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
url = f"https://www.googleapis.com/upload/drive/v3/files/{existing_id}?uploadType=media"
|
||||
resp = await client.patch(url, headers=headers, content=data)
|
||||
resp.raise_for_status()
|
||||
else:
|
||||
# 创建新文件
|
||||
metadata = {
|
||||
"name": file_name,
|
||||
"parents": [parent_id]
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
token = await self._get_access_token()
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
|
||||
# 使用 multipart 上传
|
||||
import json
|
||||
boundary = "===============boundary==============="
|
||||
headers["Content-Type"] = f"multipart/related; boundary={boundary}"
|
||||
|
||||
body = (
|
||||
f"--{boundary}\r\n"
|
||||
f"Content-Type: application/json; charset=UTF-8\r\n\r\n"
|
||||
f"{json.dumps(metadata)}\r\n"
|
||||
f"--{boundary}\r\n"
|
||||
f"Content-Type: application/octet-stream\r\n\r\n"
|
||||
).encode() + data + f"\r\n--{boundary}--".encode()
|
||||
|
||||
url = "https://www.googleapis.com/upload/drive/v3/files?uploadType=multipart"
|
||||
resp = await client.post(url, headers=headers, content=body)
|
||||
resp.raise_for_status()
|
||||
|
||||
async def write_file_stream(self, root: str, rel: str, data_iter: AsyncIterator[bytes]):
|
||||
"""
|
||||
以流式方式写入文件。
|
||||
:param root: 根路径。
|
||||
:param rel: 相对路径。
|
||||
:param data_iter: 文件内容的异步迭代器。
|
||||
:return: 文件大小。
|
||||
"""
|
||||
# 先收集所有数据
|
||||
chunks = []
|
||||
total_size = 0
|
||||
async for chunk in data_iter:
|
||||
chunks.append(chunk)
|
||||
total_size += len(chunk)
|
||||
|
||||
data = b"".join(chunks)
|
||||
await self.write_file(root, rel, data)
|
||||
return total_size
|
||||
|
||||
async def mkdir(self, root: str, rel: str):
|
||||
"""
|
||||
创建目录。
|
||||
:param root: 根路径。
|
||||
:param rel: 相对路径。
|
||||
"""
|
||||
parent_path = "/".join(rel.strip("/").split("/")[:-1])
|
||||
folder_name = rel.strip("/").split("/")[-1]
|
||||
parent_id = await self._get_folder_id_by_path(parent_path)
|
||||
|
||||
metadata = {
|
||||
"name": folder_name,
|
||||
"mimeType": "application/vnd.google-apps.folder",
|
||||
"parents": [parent_id]
|
||||
}
|
||||
|
||||
resp = await self._request("POST", "/files", json=metadata)
|
||||
resp.raise_for_status()
|
||||
|
||||
async def delete(self, root: str, rel: str):
|
||||
"""
|
||||
删除文件或目录。
|
||||
:param root: 根路径。
|
||||
:param rel: 相对路径。
|
||||
"""
|
||||
file_id = await self._get_file_id_by_path(rel)
|
||||
if not file_id:
|
||||
return
|
||||
|
||||
resp = await self._request("DELETE", f"/files/{file_id}")
|
||||
if resp.status_code not in (204, 404):
|
||||
resp.raise_for_status()
|
||||
|
||||
async def move(self, root: str, src_rel: str, dst_rel: str):
|
||||
"""
|
||||
移动或重命名文件/目录。
|
||||
:param root: 根路径。
|
||||
:param src_rel: 源相对路径。
|
||||
:param dst_rel: 目标相对路径。
|
||||
"""
|
||||
file_id = await self._get_file_id_by_path(src_rel)
|
||||
if not file_id:
|
||||
raise FileNotFoundError(src_rel)
|
||||
|
||||
# 获取当前父文件夹
|
||||
resp = await self._request("GET", f"/files/{file_id}", params={"fields": "parents"})
|
||||
resp.raise_for_status()
|
||||
current_parents = resp.json().get("parents", [])
|
||||
|
||||
# 获取目标父文件夹和新名称
|
||||
dst_parent_path = "/".join(dst_rel.strip("/").split("/")[:-1])
|
||||
dst_name = dst_rel.strip("/").split("/")[-1]
|
||||
dst_parent_id = await self._get_folder_id_by_path(dst_parent_path)
|
||||
|
||||
# 更新文件
|
||||
params = {
|
||||
"addParents": dst_parent_id,
|
||||
"removeParents": ",".join(current_parents) if current_parents else None,
|
||||
}
|
||||
metadata = {"name": dst_name}
|
||||
|
||||
resp = await self._request("PATCH", f"/files/{file_id}", params=params, json=metadata)
|
||||
resp.raise_for_status()
|
||||
|
||||
async def rename(self, root: str, src_rel: str, dst_rel: str):
|
||||
"""
|
||||
重命名文件或目录。
|
||||
"""
|
||||
await self.move(root, src_rel, dst_rel)
|
||||
|
||||
async def copy(self, root: str, src_rel: str, dst_rel: str, overwrite: bool = False):
|
||||
"""
|
||||
复制文件或目录。
|
||||
:param root: 根路径。
|
||||
:param src_rel: 源相对路径。
|
||||
:param dst_rel: 目标相对路径。
|
||||
:param overwrite: 是否覆盖。
|
||||
"""
|
||||
file_id = await self._get_file_id_by_path(src_rel)
|
||||
if not file_id:
|
||||
raise FileNotFoundError(src_rel)
|
||||
|
||||
dst_parent_path = "/".join(dst_rel.strip("/").split("/")[:-1])
|
||||
dst_name = dst_rel.strip("/").split("/")[-1]
|
||||
dst_parent_id = await self._get_folder_id_by_path(dst_parent_path)
|
||||
|
||||
metadata = {
|
||||
"name": dst_name,
|
||||
"parents": [dst_parent_id]
|
||||
}
|
||||
|
||||
resp = await self._request("POST", f"/files/{file_id}/copy", json=metadata)
|
||||
resp.raise_for_status()
|
||||
|
||||
async def stream_file(self, root: str, rel: str, range_header: str | None):
|
||||
"""
|
||||
流式传输文件(支持范围请求)。
|
||||
:param root: 根路径。
|
||||
:param rel: 相对路径。
|
||||
:param range_header: HTTP Range 头。
|
||||
:return: FastAPI StreamingResponse 对象。
|
||||
"""
|
||||
file_id = await self._get_file_id_by_path(rel)
|
||||
if not file_id:
|
||||
raise FileNotFoundError(rel)
|
||||
|
||||
# 获取文件元数据
|
||||
resp = await self._request("GET", f"/files/{file_id}", params={"fields": "name, size, mimeType"})
|
||||
if resp.status_code == 404:
|
||||
raise FileNotFoundError(rel)
|
||||
resp.raise_for_status()
|
||||
item_data = resp.json()
|
||||
|
||||
file_size = int(item_data.get("size", 0))
|
||||
content_type = item_data.get("mimeType", "application/octet-stream")
|
||||
|
||||
start = 0
|
||||
end = file_size - 1
|
||||
status = 200
|
||||
headers = {
|
||||
"Accept-Ranges": "bytes",
|
||||
"Content-Type": content_type,
|
||||
"Content-Disposition": f"inline; filename=\"{item_data.get('name')}\""
|
||||
}
|
||||
|
||||
if range_header and range_header.startswith("bytes="):
|
||||
try:
|
||||
part = range_header.removeprefix("bytes=")
|
||||
s, e = part.split("-", 1)
|
||||
if s.strip():
|
||||
start = int(s)
|
||||
if e.strip():
|
||||
end = int(e)
|
||||
if start >= file_size:
|
||||
raise HTTPException(416, "Requested Range Not Satisfiable")
|
||||
if end >= file_size:
|
||||
end = file_size - 1
|
||||
status = 206
|
||||
except ValueError:
|
||||
raise HTTPException(400, "Invalid Range header")
|
||||
|
||||
headers["Content-Range"] = f"bytes {start}-{end}/{file_size}"
|
||||
headers["Content-Length"] = str(end - start + 1)
|
||||
else:
|
||||
headers["Content-Length"] = str(file_size)
|
||||
|
||||
async def file_iterator():
|
||||
nonlocal start, end
|
||||
token = await self._get_access_token()
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
req_headers = {
|
||||
'Authorization': f'Bearer {token}',
|
||||
'Range': f'bytes={start}-{end}'
|
||||
}
|
||||
url = f"{GOOGLE_DRIVE_API_URL}/files/{file_id}?alt=media"
|
||||
async with client.stream("GET", url, headers=req_headers) as stream_resp:
|
||||
stream_resp.raise_for_status()
|
||||
async for chunk in stream_resp.aiter_bytes():
|
||||
yield chunk
|
||||
|
||||
return StreamingResponse(file_iterator(), status_code=status, headers=headers, media_type=content_type)
|
||||
|
||||
async def stat_file(self, root: str, rel: str):
|
||||
"""
|
||||
获取文件或目录的元数据。
|
||||
:param root: 根路径。
|
||||
:param rel: 相对路径。
|
||||
:return: 格式化后的文件/目录信息。
|
||||
"""
|
||||
file_id = await self._get_file_id_by_path(rel)
|
||||
if not file_id:
|
||||
raise FileNotFoundError(rel)
|
||||
|
||||
resp = await self._request("GET", f"/files/{file_id}", params={"fields": "id, name, mimeType, size, modifiedTime, createdTime"})
|
||||
if resp.status_code == 404:
|
||||
raise FileNotFoundError(rel)
|
||||
resp.raise_for_status()
|
||||
return self._format_item(resp.json())
|
||||
|
||||
async def get_direct_download_response(self, root: str, rel: str):
|
||||
"""
|
||||
获取直接下载响应 (307 重定向)。
|
||||
:param root: 根路径。
|
||||
:param rel: 相对路径。
|
||||
:return: 307 重定向响应或 None。
|
||||
"""
|
||||
if not self.enable_redirect_307:
|
||||
return None
|
||||
|
||||
file_id = await self._get_file_id_by_path(rel)
|
||||
if not file_id:
|
||||
raise FileNotFoundError(rel)
|
||||
|
||||
# 获取文件的下载链接
|
||||
resp = await self._request("GET", f"/files/{file_id}", params={"fields": "webContentLink"})
|
||||
if resp.status_code == 404:
|
||||
raise FileNotFoundError(rel)
|
||||
resp.raise_for_status()
|
||||
|
||||
item_data = resp.json()
|
||||
download_url = item_data.get("webContentLink")
|
||||
if not download_url:
|
||||
return None
|
||||
|
||||
return Response(status_code=307, headers={"Location": download_url})
|
||||
|
||||
async def get_thumbnail(self, root: str, rel: str, size: str = "medium"):
|
||||
"""
|
||||
获取文件的缩略图。
|
||||
:param root: 根路径。
|
||||
:param rel: 相对路径。
|
||||
:param size: 缩略图大小 (暂未使用,Google Drive 自动决定)。
|
||||
:return: 缩略图内容的字节流,或在不支持时返回 None。
|
||||
"""
|
||||
file_id = await self._get_file_id_by_path(rel)
|
||||
if not file_id:
|
||||
return None
|
||||
|
||||
try:
|
||||
resp = await self._request("GET", f"/files/{file_id}", params={"fields": "thumbnailLink"})
|
||||
if resp.status_code == 200:
|
||||
item_data = resp.json()
|
||||
thumbnail_link = item_data.get("thumbnailLink")
|
||||
if thumbnail_link:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
thumb_resp = await client.get(thumbnail_link)
|
||||
thumb_resp.raise_for_status()
|
||||
return thumb_resp.content
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
ADAPTER_TYPE = "googledrive"
|
||||
|
||||
CONFIG_SCHEMA = [
|
||||
{"key": "client_id", "label": "Client ID", "type": "string", "required": True},
|
||||
{"key": "client_secret", "label": "Client Secret",
|
||||
"type": "password", "required": True},
|
||||
{"key": "refresh_token", "label": "Refresh Token", "type": "password",
|
||||
"required": True, "help_text": "可以通过 Google OAuth 2.0 Playground 获取"},
|
||||
{"key": "root_folder_id", "label": "根文件夹 ID (Root Folder ID)", "type": "string",
|
||||
"required": False, "placeholder": "默认为根目录 (root)", "default": "root"},
|
||||
{"key": "enable_direct_download_307", "label": "Enable 307 redirect download", "type": "boolean", "default": False},
|
||||
]
|
||||
|
||||
|
||||
def ADAPTER_FACTORY(rec): return GoogleDriveAdapter(rec)
|
||||
337
domain/adapters/providers/local.py
Normal file
337
domain/adapters/providers/local.py
Normal file
@@ -0,0 +1,337 @@
|
||||
import os
|
||||
import shutil
|
||||
import stat
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Tuple, AsyncIterator
|
||||
import asyncio
|
||||
import mimetypes
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import StreamingResponse, Response
|
||||
from models import StorageAdapter
|
||||
|
||||
|
||||
def _safe_join(root: str, rel: str) -> Path:
|
||||
root_path = Path(root).resolve()
|
||||
full = (root_path / rel).resolve()
|
||||
if not str(full).startswith(str(root_path)):
|
||||
raise ValueError("Path escape detected")
|
||||
return full
|
||||
|
||||
|
||||
DEFAULT_FILE_MODE = 0o666
|
||||
DEFAULT_DIR_MODE = 0o777
|
||||
|
||||
|
||||
def _apply_mode(path: Path, mode: int):
|
||||
try:
|
||||
os.chmod(path, mode)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
class LocalAdapter:
|
||||
def __init__(self, record: StorageAdapter):
|
||||
self.record = record
|
||||
self.root = self.record.config.get("root")
|
||||
if not self.root:
|
||||
raise ValueError("Local adapter config requires 'root'")
|
||||
Path(self.root).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def get_effective_root(self, sub_path: str | None) -> str:
|
||||
root = self.record.config.get("root")
|
||||
if sub_path:
|
||||
return str(Path(root) / sub_path)
|
||||
return root
|
||||
|
||||
async def list_dir(self, root: str, rel: str, page_num: int = 1, page_size: int = 50, sort_by: str = "name", sort_order: str = "asc") -> Tuple[List[Dict], int]:
|
||||
rel = rel.strip('/')
|
||||
base = _safe_join(root, rel) if rel else Path(root)
|
||||
if not base.exists():
|
||||
return [], 0
|
||||
if not base.is_dir():
|
||||
raise NotADirectoryError(rel)
|
||||
|
||||
all_names = await asyncio.to_thread(os.listdir, base)
|
||||
|
||||
entries = []
|
||||
for name in all_names:
|
||||
fp = base / name
|
||||
try:
|
||||
st = await asyncio.to_thread(fp.stat)
|
||||
except FileNotFoundError:
|
||||
continue
|
||||
is_dir = fp.is_dir()
|
||||
entries.append({
|
||||
"name": name,
|
||||
"is_dir": is_dir,
|
||||
"size": 0 if is_dir else st.st_size,
|
||||
"mtime": int(st.st_mtime),
|
||||
"mode": stat.S_IMODE(st.st_mode),
|
||||
"type": "dir" if is_dir else "file",
|
||||
})
|
||||
|
||||
# 排序
|
||||
reverse = sort_order.lower() == "desc"
|
||||
|
||||
def get_sort_key(item):
|
||||
# 基础排序键,目录优先
|
||||
key = (not item["is_dir"],)
|
||||
sort_field = sort_by.lower()
|
||||
|
||||
if sort_field == "name":
|
||||
key += (item["name"].lower(),)
|
||||
elif sort_field == "size":
|
||||
key += (item["size"],)
|
||||
elif sort_field == "mtime":
|
||||
key += (item["mtime"],)
|
||||
else: # 默认按名称
|
||||
key += (item["name"].lower(),)
|
||||
return key
|
||||
|
||||
entries.sort(key=get_sort_key, reverse=reverse)
|
||||
|
||||
total_count = len(entries)
|
||||
|
||||
# 分页
|
||||
start_idx = (page_num - 1) * page_size
|
||||
end_idx = start_idx + page_size
|
||||
page_entries = entries[start_idx:end_idx]
|
||||
|
||||
return page_entries, total_count
|
||||
|
||||
async def read_file(self, root: str, rel: str) -> bytes:
|
||||
fp = _safe_join(root, rel)
|
||||
if not fp.exists() or not fp.is_file():
|
||||
raise FileNotFoundError(rel)
|
||||
return await asyncio.to_thread(fp.read_bytes)
|
||||
|
||||
async def write_file(self, root: str, rel: str, data: bytes):
|
||||
fp = _safe_join(root, rel)
|
||||
pre_exists = fp.exists()
|
||||
await asyncio.to_thread(os.makedirs, fp.parent, mode=DEFAULT_DIR_MODE, exist_ok=True)
|
||||
await asyncio.to_thread(fp.write_bytes, data)
|
||||
if not pre_exists:
|
||||
await asyncio.to_thread(_apply_mode, fp, DEFAULT_FILE_MODE)
|
||||
|
||||
async def write_upload_file(self, root: str, rel: str, file_obj, filename: str | None, file_size: int | None = None, content_type: str | None = None):
|
||||
fp = _safe_join(root, rel)
|
||||
pre_exists = fp.exists()
|
||||
await asyncio.to_thread(os.makedirs, fp.parent, mode=DEFAULT_DIR_MODE, exist_ok=True)
|
||||
|
||||
def _copy():
|
||||
try:
|
||||
if callable(getattr(file_obj, "seek", None)):
|
||||
file_obj.seek(0)
|
||||
except Exception:
|
||||
pass
|
||||
with open(fp, "wb") as f:
|
||||
shutil.copyfileobj(file_obj, f)
|
||||
|
||||
await asyncio.to_thread(_copy)
|
||||
if not pre_exists:
|
||||
await asyncio.to_thread(_apply_mode, fp, DEFAULT_FILE_MODE)
|
||||
|
||||
size = file_size
|
||||
if size is None:
|
||||
try:
|
||||
size = fp.stat().st_size
|
||||
except Exception:
|
||||
size = 0
|
||||
return {"size": int(size or 0)}
|
||||
|
||||
async def write_file_stream(self, root: str, rel: str, data_iter: AsyncIterator[bytes]):
|
||||
fp = _safe_join(root, rel)
|
||||
pre_exists = fp.exists()
|
||||
await asyncio.to_thread(os.makedirs, fp.parent, mode=DEFAULT_DIR_MODE, exist_ok=True)
|
||||
# 流式写入,避免一次性读入内存
|
||||
def _open():
|
||||
return open(fp, "wb")
|
||||
f = await asyncio.to_thread(_open)
|
||||
size = 0
|
||||
try:
|
||||
async for chunk in data_iter:
|
||||
if not chunk:
|
||||
continue
|
||||
size += len(chunk)
|
||||
await asyncio.to_thread(f.write, chunk)
|
||||
finally:
|
||||
await asyncio.to_thread(f.close)
|
||||
if not pre_exists:
|
||||
await asyncio.to_thread(_apply_mode, fp, DEFAULT_FILE_MODE)
|
||||
return size
|
||||
|
||||
async def mkdir(self, root: str, rel: str):
|
||||
fp = _safe_join(root, rel)
|
||||
await asyncio.to_thread(os.makedirs, fp, mode=DEFAULT_DIR_MODE, exist_ok=True)
|
||||
|
||||
async def delete(self, root: str, rel: str):
|
||||
fp = _safe_join(root, rel)
|
||||
if not fp.exists():
|
||||
return
|
||||
if fp.is_dir():
|
||||
await asyncio.to_thread(shutil.rmtree, fp)
|
||||
else:
|
||||
await asyncio.to_thread(fp.unlink)
|
||||
|
||||
async def stat_path(self, root: str, rel: str):
|
||||
"""新增: 返回路径状态调试信息"""
|
||||
fp = _safe_join(root, rel)
|
||||
def _stat():
|
||||
if not fp.exists():
|
||||
return {"exists": False, "is_dir": None, "path": str(fp)}
|
||||
return {
|
||||
"exists": True,
|
||||
"is_dir": fp.is_dir(),
|
||||
"path": str(fp)
|
||||
}
|
||||
return await asyncio.to_thread(_stat)
|
||||
|
||||
async def exists(self, root: str, rel: str) -> bool:
|
||||
"""新增: 判断路径是否存在"""
|
||||
fp = _safe_join(root, rel)
|
||||
return await asyncio.to_thread(fp.exists)
|
||||
|
||||
async def move(self, root: str, src_rel: str, dst_rel: str):
|
||||
src = _safe_join(root, src_rel)
|
||||
dst = _safe_join(root, dst_rel)
|
||||
if str(src) == str(dst):
|
||||
return
|
||||
if not src.exists():
|
||||
raise FileNotFoundError(src_rel)
|
||||
await asyncio.to_thread(dst.parent.mkdir, parents=True, exist_ok=True)
|
||||
|
||||
def _do_move():
|
||||
try:
|
||||
os.replace(src, dst)
|
||||
except OSError:
|
||||
shutil.move(str(src), str(dst))
|
||||
await asyncio.to_thread(_do_move)
|
||||
|
||||
async def rename(self, root: str, src_rel: str, dst_rel: str):
|
||||
src = _safe_join(root, src_rel)
|
||||
dst = _safe_join(root, dst_rel)
|
||||
if str(src) == str(dst):
|
||||
return
|
||||
if not src.exists():
|
||||
raise FileNotFoundError(src_rel)
|
||||
await asyncio.to_thread(dst.parent.mkdir, parents=True, exist_ok=True)
|
||||
def _do_rename():
|
||||
try:
|
||||
os.rename(src, dst)
|
||||
except OSError:
|
||||
os.replace(src, dst)
|
||||
await asyncio.to_thread(_do_rename)
|
||||
|
||||
async def copy(self, root: str, src_rel: str, dst_rel: str, overwrite: bool = False):
|
||||
src = _safe_join(root, src_rel)
|
||||
dst = _safe_join(root, dst_rel)
|
||||
if not src.exists():
|
||||
raise FileNotFoundError(src_rel)
|
||||
if str(src) == str(dst):
|
||||
return
|
||||
await asyncio.to_thread(dst.parent.mkdir, parents=True, exist_ok=True)
|
||||
def _do():
|
||||
if dst.exists():
|
||||
if not overwrite:
|
||||
raise FileExistsError(dst_rel)
|
||||
if dst.is_dir():
|
||||
shutil.rmtree(dst)
|
||||
else:
|
||||
dst.unlink()
|
||||
if src.is_dir():
|
||||
shutil.copytree(src, dst)
|
||||
else:
|
||||
shutil.copy2(src, dst)
|
||||
await asyncio.to_thread(_do)
|
||||
|
||||
async def stream_file(self, root: str, rel: str, range_header: str | None):
|
||||
fp = _safe_join(root, rel)
|
||||
if not fp.exists() or not fp.is_file():
|
||||
raise HTTPException(404, detail="File not found")
|
||||
mime, _ = mimetypes.guess_type(rel)
|
||||
content_type = mime or "application/octet-stream"
|
||||
file_size = (await asyncio.to_thread(fp.stat)).st_size
|
||||
start = 0
|
||||
end = file_size - 1
|
||||
status = 200
|
||||
headers = {
|
||||
"Accept-Ranges": "bytes",
|
||||
"Content-Type": content_type,
|
||||
}
|
||||
if range_header and range_header.startswith("bytes="):
|
||||
try:
|
||||
part = range_header.removeprefix("bytes=")
|
||||
s, e = part.split("-", 1)
|
||||
if s.strip():
|
||||
start = int(s)
|
||||
if e.strip():
|
||||
end = int(e)
|
||||
if start >= file_size:
|
||||
raise HTTPException(416, detail="Requested Range Not Satisfiable")
|
||||
if end >= file_size:
|
||||
end = file_size - 1
|
||||
status = 206
|
||||
except ValueError:
|
||||
raise HTTPException(400, detail="Invalid Range header")
|
||||
headers["Content-Range"] = f"bytes {start}-{end}/{file_size}"
|
||||
headers["Content-Length"] = str(end - start + 1)
|
||||
else:
|
||||
headers["Content-Length"] = str(file_size)
|
||||
|
||||
async def iterator():
|
||||
# 使用线程池避免阻塞
|
||||
def _read_segment(offset: int, length: int):
|
||||
with open(fp, "rb") as f:
|
||||
f.seek(offset)
|
||||
return f.read(length)
|
||||
chunk_size = 256 * 1024
|
||||
remaining = end - start + 1
|
||||
offset = start
|
||||
while remaining > 0:
|
||||
size = min(chunk_size, remaining)
|
||||
data = await asyncio.to_thread(_read_segment, offset, size)
|
||||
if not data:
|
||||
break
|
||||
yield data
|
||||
remaining -= len(data)
|
||||
offset += len(data)
|
||||
|
||||
return StreamingResponse(iterator(), status_code=status, headers=headers, media_type=content_type)
|
||||
|
||||
async def stat_file(self, root: str, rel: str):
|
||||
fp = _safe_join(root, rel)
|
||||
if not fp.exists():
|
||||
raise FileNotFoundError(rel)
|
||||
st = await asyncio.to_thread(fp.stat)
|
||||
info = {
|
||||
"name": fp.name,
|
||||
"is_dir": fp.is_dir(),
|
||||
"size": st.st_size,
|
||||
"mtime": int(st.st_mtime),
|
||||
"mode": stat.S_IMODE(st.st_mode),
|
||||
"type": "dir" if fp.is_dir() else "file",
|
||||
"path": str(fp),
|
||||
}
|
||||
# exif信息
|
||||
exif = None
|
||||
if not fp.is_dir():
|
||||
mime, _ = mimetypes.guess_type(fp.name)
|
||||
if mime and mime.startswith("image/"):
|
||||
try:
|
||||
from PIL import Image
|
||||
img = await asyncio.to_thread(Image.open, fp)
|
||||
exif_data = img._getexif()
|
||||
if exif_data:
|
||||
exif = {str(k): str(v) for k, v in exif_data.items()}
|
||||
except Exception:
|
||||
exif = None
|
||||
info["exif"] = exif
|
||||
return info
|
||||
|
||||
|
||||
ADAPTER_TYPE = "local"
|
||||
CONFIG_SCHEMA = [
|
||||
{"key": "root", "label": "根目录", "type": "string", "required": True, "placeholder": "/data/storage"},
|
||||
]
|
||||
ADAPTER_FACTORY = lambda rec: LocalAdapter(rec)
|
||||
461
domain/adapters/providers/onedrive.py
Normal file
461
domain/adapters/providers/onedrive.py
Normal file
@@ -0,0 +1,461 @@
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import List, Dict, Tuple, AsyncIterator
|
||||
import httpx
|
||||
from fastapi.responses import StreamingResponse, Response
|
||||
from fastapi import HTTPException
|
||||
from models import StorageAdapter
|
||||
|
||||
MS_GRAPH_URL = "https://graph.microsoft.com/v1.0"
|
||||
MS_OAUTH_URL = "https://login.microsoftonline.com/common/oauth2/v2.0/token"
|
||||
|
||||
|
||||
class OneDriveAdapter:
|
||||
"""OneDrive 存储适配器"""
|
||||
|
||||
def __init__(self, record: StorageAdapter):
|
||||
self.record = record
|
||||
cfg = record.config
|
||||
self.client_id = cfg.get("client_id")
|
||||
self.client_secret = cfg.get("client_secret")
|
||||
self.refresh_token = cfg.get("refresh_token")
|
||||
self.root = cfg.get("root", "/").strip("/")
|
||||
self.enable_redirect_307 = bool(cfg.get("enable_direct_download_307"))
|
||||
|
||||
if not all([self.client_id, self.client_secret, self.refresh_token]):
|
||||
raise ValueError(
|
||||
"OneDrive 适配器需要 client_id, client_secret, 和 refresh_token")
|
||||
|
||||
self._access_token: str | None = None
|
||||
self._token_expiry: datetime | None = None
|
||||
|
||||
def get_effective_root(self, sub_path: str | None) -> str:
|
||||
"""
|
||||
获取有效根路径。
|
||||
:param sub_path: 子路径。
|
||||
:return: 完整的有效路径。
|
||||
"""
|
||||
if sub_path:
|
||||
return f"/{self.root.strip('/')}/{sub_path.strip('/')}".strip()
|
||||
return f"/{self.root.strip('/')}".strip()
|
||||
|
||||
def _get_api_path(self, rel_path: str) -> str:
|
||||
"""
|
||||
将用户可见的相对路径转换为 Graph API 路径段。
|
||||
:param rel_path: 相对路径。
|
||||
:return: Graph API 路径段。
|
||||
"""
|
||||
full_path = self.get_effective_root(rel_path).strip('/')
|
||||
if not full_path:
|
||||
return ""
|
||||
return f":/{full_path}"
|
||||
|
||||
async def _get_access_token(self) -> str:
|
||||
"""
|
||||
获取或刷新 access token。
|
||||
:return: access token。
|
||||
"""
|
||||
if self._access_token and self._token_expiry and datetime.now(timezone.utc) < self._token_expiry:
|
||||
return self._access_token
|
||||
|
||||
data = {
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
"refresh_token": self.refresh_token,
|
||||
"grant_type": "refresh_token",
|
||||
}
|
||||
async with httpx.AsyncClient(timeout=20.0) as client:
|
||||
resp = await client.post(MS_OAUTH_URL, data=data)
|
||||
resp.raise_for_status()
|
||||
token_data = resp.json()
|
||||
self._access_token = token_data["access_token"]
|
||||
self._token_expiry = datetime.now(
|
||||
timezone.utc) + timedelta(seconds=token_data["expires_in"] - 300)
|
||||
return self._access_token
|
||||
|
||||
async def _request(self, method: str, api_path_segment: str | None = None, *, full_url: str | None = None, **kwargs):
|
||||
"""
|
||||
向 Microsoft Graph API 发送请求。
|
||||
:param method: HTTP 方法。
|
||||
:param api_path_segment: API 路径段 (与 full_url 互斥)。
|
||||
:param full_url: 完整的请求 URL (与 api_path_segment 互斥)。
|
||||
:param kwargs: 其他请求参数。
|
||||
:return: 响应对象。
|
||||
"""
|
||||
if not ((api_path_segment is not None) ^ (full_url is not None)):
|
||||
raise ValueError("必须提供 api_path_segment 或 full_url 中的一个,且仅一个")
|
||||
|
||||
token = await self._get_access_token()
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
if "headers" in kwargs:
|
||||
headers.update(kwargs.pop("headers"))
|
||||
|
||||
url = full_url if full_url else f"{MS_GRAPH_URL}/me/drive/root{api_path_segment}"
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
resp = await client.request(method, url, headers=headers, **kwargs)
|
||||
if resp.status_code == 401:
|
||||
self._access_token = None
|
||||
token = await self._get_access_token()
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
resp = await client.request(method, url, headers=headers, **kwargs)
|
||||
return resp
|
||||
|
||||
def _format_item(self, item: Dict) -> Dict:
|
||||
"""
|
||||
将 Graph API 返回的 item 格式化为统一的格式。
|
||||
:param item: Graph API 返回的 item 字典。
|
||||
:return: 格式化后的字典。
|
||||
"""
|
||||
is_dir = "folder" in item
|
||||
return {
|
||||
"name": item["name"],
|
||||
"is_dir": is_dir,
|
||||
"size": 0 if is_dir else item.get("size", 0),
|
||||
"mtime": int(datetime.fromisoformat(item["lastModifiedDateTime"].replace("Z", "+00:00")).timestamp()),
|
||||
"type": "dir" if is_dir else "file",
|
||||
}
|
||||
|
||||
async def list_dir(self, root: str, rel: str, page_num: int = 1, page_size: int = 50, sort_by: str = "name", sort_order: str = "asc") -> Tuple[List[Dict], int]:
|
||||
"""
|
||||
列出目录内容。
|
||||
由于 Graph API 不支持基于偏移($skip)的分页,此方法将获取所有项目,
|
||||
:param root: 根路径 (在此适配器中未使用,通过配置的 root 确定)。
|
||||
:param rel: 相对路径。
|
||||
:param page_num: 页码。
|
||||
:param page_size: 每页大小。
|
||||
:param sort_by: 排序字段
|
||||
:param sort_order: 排序顺序
|
||||
:return: 文件/目录列表和总数。
|
||||
"""
|
||||
api_path = self._get_api_path(rel)
|
||||
children_path = f"{api_path}:/children" if api_path else "/children"
|
||||
all_items = []
|
||||
params = {"$top": 999}
|
||||
resp = await self._request("GET", api_path_segment=children_path, params=params)
|
||||
|
||||
while True:
|
||||
if resp.status_code == 404 and not all_items:
|
||||
return [], 0
|
||||
resp.raise_for_status()
|
||||
|
||||
try:
|
||||
data = resp.json()
|
||||
except Exception as e:
|
||||
raise IOError(f"解析 Graph API 响应失败: {e}") from e
|
||||
|
||||
all_items.extend(data.get("value", []))
|
||||
next_link = data.get("@odata.nextLink")
|
||||
|
||||
if not next_link:
|
||||
break
|
||||
|
||||
resp = await self._request("GET", full_url=next_link)
|
||||
|
||||
formatted_items = [self._format_item(item) for item in all_items]
|
||||
|
||||
# 排序
|
||||
reverse = sort_order.lower() == "desc"
|
||||
def get_sort_key(item):
|
||||
key = (not item["is_dir"],)
|
||||
sort_field = sort_by.lower()
|
||||
if sort_field == "name":
|
||||
key += (item["name"].lower(),)
|
||||
elif sort_field == "size":
|
||||
key += (item["size"],)
|
||||
elif sort_field == "mtime":
|
||||
key += (item["mtime"],)
|
||||
else:
|
||||
key += (item["name"].lower(),)
|
||||
return key
|
||||
formatted_items.sort(key=get_sort_key, reverse=reverse)
|
||||
|
||||
total_count = len(formatted_items)
|
||||
start_idx = (page_num - 1) * page_size
|
||||
end_idx = start_idx + page_size
|
||||
|
||||
return formatted_items[start_idx:end_idx], total_count
|
||||
|
||||
async def read_file(self, root: str, rel: str) -> bytes:
|
||||
"""
|
||||
读取文件内容。
|
||||
:param root: 根路径。
|
||||
:param rel: 相对路径。
|
||||
:return: 文件内容的字节流。
|
||||
"""
|
||||
api_path = self._get_api_path(rel)
|
||||
if not api_path:
|
||||
raise IsADirectoryError("不能将根目录作为文件读取")
|
||||
|
||||
resp = await self._request("GET", api_path_segment=f"{api_path}:/content")
|
||||
if resp.status_code == 404:
|
||||
raise FileNotFoundError(rel)
|
||||
resp.raise_for_status()
|
||||
return resp.content
|
||||
|
||||
async def write_file(self, root: str, rel: str, data: bytes):
|
||||
"""
|
||||
写入文件。
|
||||
:param root: 根路径。
|
||||
:param rel: 相对路径。
|
||||
:param data: 文件内容的字节流。
|
||||
"""
|
||||
api_path = self._get_api_path(rel)
|
||||
if not api_path:
|
||||
raise ValueError("不能直接写入根路径")
|
||||
resp = await self._request("PUT", api_path_segment=f"{api_path}:/content", content=data)
|
||||
resp.raise_for_status()
|
||||
|
||||
async def write_file_stream(self, root: str, rel: str, data_iter: AsyncIterator[bytes]):
|
||||
"""
|
||||
以流式方式写入文件。
|
||||
:param root: 根路径。
|
||||
:param rel: 相对路径。
|
||||
:param data_iter: 文件内容的异步迭代器。
|
||||
:return: 文件大小。
|
||||
"""
|
||||
api_path = self._get_api_path(rel)
|
||||
if not api_path:
|
||||
raise ValueError("不能直接写入根路径")
|
||||
|
||||
resp = await self._request("PUT", api_path_segment=f"{api_path}:/content", content=data_iter)
|
||||
resp.raise_for_status()
|
||||
return resp.json().get("size", 0)
|
||||
|
||||
async def mkdir(self, root: str, rel: str):
|
||||
"""
|
||||
创建目录。
|
||||
:param root: 根路径。
|
||||
:param rel: 相对路径。
|
||||
"""
|
||||
parent_path_str, new_dir_name = rel.rstrip(
|
||||
'/').rsplit('/', 1) if '/' in rel.rstrip('/') else ('', rel)
|
||||
parent_api_path = self._get_api_path(parent_path_str)
|
||||
|
||||
children_path = f"{parent_api_path}:/children" if parent_api_path else "/children"
|
||||
|
||||
payload = {
|
||||
"name": new_dir_name,
|
||||
"folder": {},
|
||||
"@microsoft.graph.conflictBehavior": "fail" # 如果已存在则失败
|
||||
}
|
||||
resp = await self._request("POST", api_path_segment=children_path, json=payload)
|
||||
resp.raise_for_status()
|
||||
|
||||
async def delete(self, root: str, rel: str):
|
||||
"""
|
||||
删除文件或目录。
|
||||
:param root: 根路径。
|
||||
:param rel: 相对路径。
|
||||
"""
|
||||
api_path = self._get_api_path(rel)
|
||||
if not api_path:
|
||||
raise ValueError("不能删除根目录")
|
||||
|
||||
resp = await self._request("DELETE", api_path_segment=api_path)
|
||||
if resp.status_code not in (204, 404):
|
||||
resp.raise_for_status()
|
||||
|
||||
async def move(self, root: str, src_rel: str, dst_rel: str):
|
||||
"""
|
||||
移动或重命名文件/目录。
|
||||
:param root: 根路径。
|
||||
:param src_rel: 源相对路径。
|
||||
:param dst_rel: 目标相对路径。
|
||||
"""
|
||||
src_api_path = self._get_api_path(src_rel)
|
||||
if not src_api_path:
|
||||
raise ValueError("不能移动根目录")
|
||||
|
||||
dst_parent_rel, dst_name = dst_rel.rstrip(
|
||||
'/').rsplit('/', 1) if '/' in dst_rel.rstrip('/') else ('', dst_rel)
|
||||
dst_parent_api_path = self._get_api_path(dst_parent_rel)
|
||||
|
||||
# 获取父项目的 ID
|
||||
parent_resp = await self._request("GET", api_path_segment=dst_parent_api_path)
|
||||
parent_resp.raise_for_status()
|
||||
parent_id = parent_resp.json()["id"]
|
||||
|
||||
payload = {
|
||||
"parentReference": {"id": parent_id},
|
||||
"name": dst_name
|
||||
}
|
||||
resp = await self._request("PATCH", api_path_segment=src_api_path, json=payload)
|
||||
resp.raise_for_status()
|
||||
|
||||
async def rename(self, root: str, src_rel: str, dst_rel: str):
|
||||
"""
|
||||
重命名文件或目录。
|
||||
在 Graph API 中,移动和重命名是同一个 PATCH 操作。
|
||||
"""
|
||||
await self.move(root, src_rel, dst_rel)
|
||||
|
||||
async def copy(self, root: str, src_rel: str, dst_rel: str, overwrite: bool = False):
|
||||
"""
|
||||
复制文件或目录。
|
||||
:param root: 根路径。
|
||||
:param src_rel: 源相对路径。
|
||||
:param dst_rel: 目标相对路径。
|
||||
:param overwrite: 是否覆盖 (在此 API 中未直接使用)。
|
||||
"""
|
||||
src_api_path = self._get_api_path(src_rel)
|
||||
if not src_api_path:
|
||||
raise ValueError("不能复制根目录")
|
||||
|
||||
dst_parent_rel, dst_name = dst_rel.rstrip(
|
||||
'/').rsplit('/', 1) if '/' in dst_rel.rstrip('/') else ('', dst_rel)
|
||||
dst_parent_api_path = self._get_api_path(dst_parent_rel)
|
||||
|
||||
parent_resp = await self._request("GET", api_path_segment=dst_parent_api_path)
|
||||
parent_resp.raise_for_status()
|
||||
parent_id = parent_resp.json()["id"]
|
||||
|
||||
payload = {"parentReference": {"id": parent_id}, "name": dst_name}
|
||||
copy_path = f"{src_api_path}:/copy"
|
||||
resp = await self._request("POST", api_path_segment=copy_path, json=payload)
|
||||
resp.raise_for_status()
|
||||
|
||||
async def stream_file(self, root: str, rel: str, range_header: str | None):
|
||||
"""
|
||||
流式传输文件(支持范围请求)。
|
||||
:param root: 根路径。
|
||||
:param rel: 相对路径。
|
||||
:param range_header: HTTP Range 头。
|
||||
:return: FastAPI StreamingResponse 对象。
|
||||
"""
|
||||
api_path = self._get_api_path(rel)
|
||||
if not api_path:
|
||||
raise IsADirectoryError("不能对目录进行流式传输")
|
||||
|
||||
resp = await self._request("GET", api_path_segment=api_path)
|
||||
if resp.status_code == 404:
|
||||
raise FileNotFoundError(rel)
|
||||
resp.raise_for_status()
|
||||
item_data = resp.json()
|
||||
|
||||
download_url = item_data.get("@microsoft.graph.downloadUrl")
|
||||
if not download_url:
|
||||
raise Exception("无法获取下载 URL")
|
||||
|
||||
file_size = item_data.get("size", 0)
|
||||
content_type = item_data.get("file", {}).get(
|
||||
"mimeType", "application/octet-stream")
|
||||
|
||||
start = 0
|
||||
end = file_size - 1
|
||||
status = 200
|
||||
headers = {
|
||||
"Accept-Ranges": "bytes",
|
||||
"Content-Type": content_type,
|
||||
"Content-Disposition": f"inline; filename=\"{item_data.get('name')}\""
|
||||
}
|
||||
|
||||
if range_header and range_header.startswith("bytes="):
|
||||
try:
|
||||
part = range_header.removeprefix("bytes=")
|
||||
s, e = part.split("-", 1)
|
||||
if s.strip():
|
||||
start = int(s)
|
||||
if e.strip():
|
||||
end = int(e)
|
||||
if start >= file_size:
|
||||
raise HTTPException(416, "Requested Range Not Satisfiable")
|
||||
if end >= file_size:
|
||||
end = file_size - 1
|
||||
status = 206
|
||||
except ValueError:
|
||||
raise HTTPException(400, "Invalid Range header")
|
||||
|
||||
headers["Content-Range"] = f"bytes {start}-{end}/{file_size}"
|
||||
headers["Content-Length"] = str(end - start + 1)
|
||||
else:
|
||||
headers["Content-Length"] = str(file_size)
|
||||
|
||||
async def file_iterator():
|
||||
nonlocal start, end
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
req_headers = {'Range': f'bytes={start}-{end}'}
|
||||
async with client.stream("GET", download_url, headers=req_headers) as stream_resp:
|
||||
stream_resp.raise_for_status()
|
||||
async for chunk in stream_resp.aiter_bytes():
|
||||
yield chunk
|
||||
|
||||
return StreamingResponse(file_iterator(), status_code=status, headers=headers, media_type=content_type)
|
||||
|
||||
async def get_direct_download_response(self, root: str, rel: str):
|
||||
if not self.enable_redirect_307:
|
||||
return None
|
||||
|
||||
api_path = self._get_api_path(rel)
|
||||
if not api_path:
|
||||
raise IsADirectoryError("不能对目录进行直链重定向")
|
||||
|
||||
resp = await self._request("GET", api_path_segment=api_path)
|
||||
if resp.status_code == 404:
|
||||
raise FileNotFoundError(rel)
|
||||
resp.raise_for_status()
|
||||
|
||||
item_data = resp.json()
|
||||
download_url = item_data.get("@microsoft.graph.downloadUrl")
|
||||
if not download_url:
|
||||
return None
|
||||
|
||||
return Response(status_code=307, headers={"Location": download_url})
|
||||
|
||||
async def get_thumbnail(self, root: str, rel: str, size: str = "medium"):
|
||||
"""
|
||||
获取文件的缩略图。
|
||||
:param root: 根路径。
|
||||
:param rel: 相对路径。
|
||||
:param size: 缩略图大小 (large, medium, small)。
|
||||
:return: 缩略图内容的字节流,或在不支持时返回 None。
|
||||
"""
|
||||
api_path = self._get_api_path(rel)
|
||||
if not api_path:
|
||||
return None
|
||||
|
||||
thumb_path = f"{api_path}:/thumbnails/0/{size}"
|
||||
|
||||
try:
|
||||
resp = await self._request("GET", api_path_segment=thumb_path)
|
||||
if resp.status_code == 200:
|
||||
thumb_data = resp.json()
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
thumb_resp = await client.get(thumb_data['url'])
|
||||
thumb_resp.raise_for_status()
|
||||
return thumb_resp.content
|
||||
elif resp.status_code == 404:
|
||||
return None
|
||||
else:
|
||||
resp.raise_for_status()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def stat_file(self, root: str, rel: str):
|
||||
"""
|
||||
获取文件或目录的元数据。
|
||||
:param root: 根路径。
|
||||
:param rel: 相对路径。
|
||||
:return: 格式化后的文件/目录信息。
|
||||
"""
|
||||
api_path = self._get_api_path(rel)
|
||||
resp = await self._request("GET", api_path_segment=api_path)
|
||||
if resp.status_code == 404:
|
||||
raise FileNotFoundError(rel)
|
||||
resp.raise_for_status()
|
||||
return self._format_item(resp.json())
|
||||
|
||||
|
||||
ADAPTER_TYPE = "onedrive"
|
||||
|
||||
CONFIG_SCHEMA = [
|
||||
{"key": "client_id", "label": "Client ID", "type": "string", "required": True},
|
||||
{"key": "client_secret", "label": "Client Secret",
|
||||
"type": "password", "required": True},
|
||||
{"key": "refresh_token", "label": "Refresh Token", "type": "password",
|
||||
"required": True, "help_text": "可以通过运行 'python -m domain.adapters.providers.onedrive' 获取"},
|
||||
{"key": "root", "label": "根目录 (Root Path)", "type": "string",
|
||||
"required": False, "placeholder": "默认为根目录 /"},
|
||||
{"key": "enable_direct_download_307", "label": "Enable 307 redirect download", "type": "boolean", "default": False},
|
||||
]
|
||||
|
||||
|
||||
def ADAPTER_FACTORY(rec): return OneDriveAdapter(rec)
|
||||
911
domain/adapters/providers/quark.py
Normal file
911
domain/adapters/providers/quark.py
Normal file
@@ -0,0 +1,911 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import mimetypes
|
||||
import os
|
||||
import time
|
||||
from typing import Dict, List, Tuple, Optional, AsyncIterator, Any
|
||||
|
||||
import httpx
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from models import StorageAdapter
|
||||
from .base import BaseAdapter
|
||||
|
||||
|
||||
# Quark 普通(UC)接口
|
||||
API_BASE = "https://drive.quark.cn/1/clouddrive"
|
||||
REFERER = "https://pan.quark.cn"
|
||||
PR = "ucpro"
|
||||
|
||||
|
||||
class QuarkAdapter:
|
||||
"""夸克网盘(Cookie 模式)
|
||||
|
||||
- 使用浏览器导出的 Cookie 进行鉴权
|
||||
- 通过 Quark/UC 的 clouddrive 接口实现:列目录、读写、分片上传、基础操作
|
||||
- 根 FID 固定为 "0";路径解析通过名称遍历
|
||||
"""
|
||||
|
||||
def __init__(self, record: StorageAdapter):
|
||||
self.record = record
|
||||
cfg = record.config or {}
|
||||
self.cookie: str = cfg.get("cookie") or cfg.get("Cookie")
|
||||
self.root_fid: str = cfg.get("root_fid", "0")
|
||||
def _as_bool(value: Any) -> bool:
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
return value.strip().lower() in {"1", "true", "yes", "on"}
|
||||
return bool(value)
|
||||
|
||||
self.use_transcoding_address: bool = _as_bool(cfg.get("use_transcoding_address", False))
|
||||
self.only_list_video_file: bool = _as_bool(cfg.get("only_list_video_file", False))
|
||||
|
||||
if not self.cookie:
|
||||
raise ValueError("Quark 适配器需要 cookie 配置")
|
||||
|
||||
# 运行期缓存
|
||||
self._dir_fid_cache: Dict[str, str] = {f"{self.root_fid}:": self.root_fid}
|
||||
self._children_cache: Dict[str, List[Dict[str, Any]]] = {}
|
||||
|
||||
# UA 与超时
|
||||
self._ua = (
|
||||
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 "
|
||||
"(KHTML, like Gecko) quark-cloud-drive/2.5.20 Chrome/100.0.4896.160 "
|
||||
"Electron/18.3.5.4-b478491100 Safari/537.36 Channel/pckk_other_ch"
|
||||
)
|
||||
self._timeout = 30.0
|
||||
|
||||
# -----------------
|
||||
# 工具与通用请求
|
||||
# -----------------
|
||||
def get_effective_root(self, sub_path: str | None) -> str:
|
||||
return self.root_fid
|
||||
|
||||
async def _request(
|
||||
self,
|
||||
method: str,
|
||||
pathname: str,
|
||||
*,
|
||||
json: Any | None = None,
|
||||
params: Dict[str, str] | None = None,
|
||||
) -> Any:
|
||||
headers = {
|
||||
"Cookie": self._safe_cookie(self.cookie),
|
||||
"Accept": "application/json, text/plain, */*",
|
||||
"Referer": REFERER,
|
||||
"User-Agent": self._ua,
|
||||
}
|
||||
query = {"pr": PR, "fr": "pc"}
|
||||
if params:
|
||||
query.update(params)
|
||||
url = f"{API_BASE}{pathname}"
|
||||
|
||||
async with httpx.AsyncClient(timeout=self._timeout) as client:
|
||||
resp = await client.request(method, url, headers=headers, params=query, json=json)
|
||||
# 更新运行期 cookie(若返回 __puus/__pus)
|
||||
try:
|
||||
for key in ("__puus", "__pus"):
|
||||
v = resp.cookies.get(key)
|
||||
if v:
|
||||
# 简单替换/追加到 self.cookie
|
||||
self._set_cookie_kv(key, v)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 解析业务状态
|
||||
data = None
|
||||
try:
|
||||
data = resp.json()
|
||||
except Exception:
|
||||
resp.raise_for_status()
|
||||
return resp
|
||||
status = data.get("status")
|
||||
code = data.get("code")
|
||||
msg = data.get("message") or ""
|
||||
if (status is not None and status >= 400) or (code is not None and code != 0):
|
||||
raise HTTPException(502, detail=f"Quark error status={status} code={code} msg={msg}")
|
||||
return data
|
||||
|
||||
def _set_cookie_kv(self, key: str, value: str):
|
||||
# 将指定键值写入 self.cookie(粗略字符串处理)
|
||||
parts = [p.strip() for p in (self.cookie or "").replace("\r", "").replace("\n", "").split(";") if p.strip()]
|
||||
found = False
|
||||
for i, p in enumerate(parts):
|
||||
if p.startswith(key + "="):
|
||||
parts[i] = f"{key}={value}"
|
||||
found = True
|
||||
break
|
||||
if not found:
|
||||
parts.append(f"{key}={value}")
|
||||
self.cookie = "; ".join(parts)
|
||||
|
||||
def _sanitize_cookie(self, cookie: str) -> str:
|
||||
if not cookie:
|
||||
return ""
|
||||
# 去除换行与前后空白
|
||||
cookie = cookie.replace("\r", "").replace("\n", "").strip()
|
||||
# 统一分号分隔并去除多余空格/空段
|
||||
parts = [p.strip() for p in cookie.split(";") if p.strip()]
|
||||
return "; ".join(parts)
|
||||
|
||||
def _safe_cookie(self, cookie: str) -> str:
|
||||
s = self._sanitize_cookie(cookie)
|
||||
# 仅保留可见 ASCII (0x20-0x7E)
|
||||
s = "".join(ch for ch in s if 32 <= ord(ch) <= 126)
|
||||
return s
|
||||
|
||||
# -----------------
|
||||
# 列表与路径解析
|
||||
# -----------------
|
||||
def _map_file_item(self, it: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# Quark/UC 列表项:file=true 表示文件;false 表示目录
|
||||
is_dir = not bool(it.get("file", False))
|
||||
updated_at_ms = int(it.get("updated_at", 0) or 0)
|
||||
name = it.get("file_name") or it.get("filename") or it.get("name")
|
||||
return {
|
||||
"fid": it.get("fid"),
|
||||
"name": name,
|
||||
"is_dir": is_dir,
|
||||
"size": 0 if is_dir else int(it.get("size", 0) or 0),
|
||||
"mtime": updated_at_ms // 1000 if updated_at_ms else 0,
|
||||
"type": "dir" if is_dir else "file",
|
||||
}
|
||||
|
||||
async def _list_children(self, parent_fid: str) -> List[Dict[str, Any]]:
|
||||
if parent_fid in self._children_cache:
|
||||
return self._children_cache[parent_fid]
|
||||
|
||||
files: List[Dict[str, Any]] = []
|
||||
page = 1
|
||||
size = 100
|
||||
total = None
|
||||
while True:
|
||||
qp = {"pdir_fid": parent_fid, "_size": str(size), "_page": str(page), "_fetch_total": "1"}
|
||||
data = await self._request("GET", "/file/sort", params=qp)
|
||||
d = (data or {}).get("data", {})
|
||||
meta = (data or {}).get("metadata", {})
|
||||
page_files = d.get("list", [])
|
||||
files.extend(page_files)
|
||||
if total is None:
|
||||
total = meta.get("_total") or meta.get("total") or 0
|
||||
if page * size >= int(total):
|
||||
break
|
||||
page += 1
|
||||
|
||||
mapped = [self._map_file_item(x) for x in files if (not self.only_list_video_file) or (not x.get("file")) or (x.get("category") == 1)]
|
||||
self._children_cache[parent_fid] = mapped
|
||||
return mapped
|
||||
|
||||
def _dir_cache_key(self, base_fid: str, rel: str) -> str:
|
||||
return f"{base_fid}:{rel.strip('/')}"
|
||||
|
||||
async def _resolve_dir_fid_from(self, base_fid: str, rel: str) -> str:
|
||||
key = rel.strip("/")
|
||||
cache_key = self._dir_cache_key(base_fid, key)
|
||||
if cache_key in self._dir_fid_cache:
|
||||
return self._dir_fid_cache[cache_key]
|
||||
if key == "":
|
||||
self._dir_fid_cache[cache_key] = base_fid
|
||||
return base_fid
|
||||
|
||||
parent_fid = base_fid
|
||||
path_so_far = []
|
||||
for seg in key.split("/"):
|
||||
if seg == "":
|
||||
continue
|
||||
path_so_far.append(seg)
|
||||
cache_key = self._dir_cache_key(base_fid, "/".join(path_so_far))
|
||||
cached = self._dir_fid_cache.get(cache_key)
|
||||
if cached:
|
||||
parent_fid = cached
|
||||
continue
|
||||
children = await self._list_children(parent_fid)
|
||||
found = next((c for c in children if c["is_dir"] and c["name"] == seg), None)
|
||||
if not found:
|
||||
raise FileNotFoundError(f"Directory not found: {seg}")
|
||||
parent_fid = found["fid"]
|
||||
self._dir_fid_cache[cache_key] = parent_fid
|
||||
|
||||
return parent_fid
|
||||
|
||||
async def _find_child(self, parent_fid: str, name: str) -> Optional[Dict[str, Any]]:
|
||||
children = await self._list_children(parent_fid)
|
||||
for it in children:
|
||||
if it["name"] == name:
|
||||
return it
|
||||
return None
|
||||
|
||||
def _invalidate_children_cache(self, parent_fid: str):
|
||||
if parent_fid in self._children_cache:
|
||||
try:
|
||||
del self._children_cache[parent_fid]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# -----------------
|
||||
# 目录与文件列表
|
||||
# -----------------
|
||||
async def list_dir(
|
||||
self,
|
||||
root: str,
|
||||
rel: str,
|
||||
page_num: int = 1,
|
||||
page_size: int = 50,
|
||||
sort_by: str = "name",
|
||||
sort_order: str = "asc",
|
||||
) -> Tuple[List[Dict], int]:
|
||||
base_fid = root or self.root_fid
|
||||
fid = await self._resolve_dir_fid_from(base_fid, rel)
|
||||
items = await self._list_children(fid)
|
||||
|
||||
# 排序,目录优先
|
||||
reverse = sort_order.lower() == "desc"
|
||||
|
||||
def get_sort_key(item):
|
||||
key = (not item["is_dir"],)
|
||||
sf = sort_by.lower()
|
||||
if sf == "name":
|
||||
key += (item["name"].lower(),)
|
||||
elif sf == "size":
|
||||
key += (item["size"],)
|
||||
elif sf == "mtime":
|
||||
key += (item["mtime"],)
|
||||
else:
|
||||
key += (item["name"].lower(),)
|
||||
return key
|
||||
|
||||
items.sort(key=get_sort_key, reverse=reverse)
|
||||
total = len(items)
|
||||
start = (page_num - 1) * page_size
|
||||
end = start + page_size
|
||||
return items[start:end], total
|
||||
|
||||
# -----------------
|
||||
# 下载与流式下载
|
||||
# -----------------
|
||||
async def _get_download_url(self, fid: str) -> str:
|
||||
data = await self._request("POST", "/file/download", json={"fids": [fid]})
|
||||
arr = (data or {}).get("data", [])
|
||||
if not arr:
|
||||
raise HTTPException(502, detail="No download data returned by Quark")
|
||||
url = arr[0].get("download_url") or arr[0].get("DownloadUrl")
|
||||
if not url:
|
||||
raise HTTPException(502, detail="No download_url returned by Quark")
|
||||
return url
|
||||
|
||||
async def _get_transcoding_url(self, fid: str) -> Optional[str]:
|
||||
try:
|
||||
payload = {"fid": fid, "resolutions": "low,normal,high,super,2k,4k", "supports": "fmp4_av,m3u8,dolby_vision"}
|
||||
data = await self._request("POST", "/file/v2/play/project", json=payload)
|
||||
lst = (data or {}).get("data", {}).get("video_list", [])
|
||||
for item in lst:
|
||||
vi = item.get("video_info") or {}
|
||||
url = vi.get("url")
|
||||
if url:
|
||||
return url
|
||||
except Exception:
|
||||
return None
|
||||
return None
|
||||
|
||||
async def get_video_transcoding_url(self, fid: str) -> Optional[str]:
|
||||
if not self.use_transcoding_address:
|
||||
return None
|
||||
return await self._get_transcoding_url(fid)
|
||||
|
||||
def _is_video_name(self, name: str) -> bool:
|
||||
mime, _ = mimetypes.guess_type(name)
|
||||
return bool(mime and mime.startswith("video/"))
|
||||
|
||||
def _download_headers(self) -> Dict[str, str]:
|
||||
return {"Cookie": self._safe_cookie(self.cookie), "User-Agent": self._ua, "Referer": REFERER}
|
||||
|
||||
async def read_file(self, root: str, rel: str) -> bytes:
|
||||
if not rel or rel.endswith("/"):
|
||||
raise IsADirectoryError("Path is a directory")
|
||||
parent = rel.rsplit("/", 1)[0] if "/" in rel else ""
|
||||
name = rel.rsplit("/", 1)[-1]
|
||||
base_fid = root or self.root_fid
|
||||
parent_fid = await self._resolve_dir_fid_from(base_fid, parent)
|
||||
it = await self._find_child(parent_fid, name)
|
||||
if not it or it["is_dir"]:
|
||||
raise FileNotFoundError(rel)
|
||||
url = await self._get_download_url(it["fid"])
|
||||
headers = self._download_headers()
|
||||
async with httpx.AsyncClient(timeout=None, follow_redirects=True) as client:
|
||||
resp = await client.get(url, headers=headers)
|
||||
if resp.status_code == 404:
|
||||
raise FileNotFoundError(rel)
|
||||
resp.raise_for_status()
|
||||
return resp.content
|
||||
|
||||
async def read_file_range(self, root: str, rel: str, start: int, end: Optional[int] = None) -> bytes:
|
||||
if not rel or rel.endswith("/"):
|
||||
raise IsADirectoryError("Path is a directory")
|
||||
parent = rel.rsplit("/", 1)[0] if "/" in rel else ""
|
||||
name = rel.rsplit("/", 1)[-1]
|
||||
base_fid = root or self.root_fid
|
||||
parent_fid = await self._resolve_dir_fid_from(base_fid, parent)
|
||||
it = await self._find_child(parent_fid, name)
|
||||
if not it or it["is_dir"]:
|
||||
raise FileNotFoundError(rel)
|
||||
|
||||
url = await self._get_download_url(it["fid"])
|
||||
headers = dict(self._download_headers())
|
||||
headers["Range"] = f"bytes={start}-" if end is None else f"bytes={start}-{end}"
|
||||
async with httpx.AsyncClient(timeout=self._timeout, follow_redirects=True) as client:
|
||||
resp = await client.get(url, headers=headers)
|
||||
if resp.status_code == 404:
|
||||
raise FileNotFoundError(rel)
|
||||
if resp.status_code == 416:
|
||||
raise HTTPException(416, detail="Requested Range Not Satisfiable")
|
||||
resp.raise_for_status()
|
||||
return resp.content
|
||||
|
||||
async def stream_file(self, root: str, rel: str, range_header: str | None):
|
||||
if not rel or rel.endswith("/"):
|
||||
raise IsADirectoryError("Path is a directory")
|
||||
parent = rel.rsplit("/", 1)[0] if "/" in rel else ""
|
||||
name = rel.rsplit("/", 1)[-1]
|
||||
base_fid = root or self.root_fid
|
||||
parent_fid = await self._resolve_dir_fid_from(base_fid, parent)
|
||||
it = await self._find_child(parent_fid, name)
|
||||
if not it or it["is_dir"]:
|
||||
raise FileNotFoundError(rel)
|
||||
url = await self._get_download_url(it["fid"])
|
||||
if self.use_transcoding_address and self._is_video_name(name):
|
||||
tr = await self._get_transcoding_url(it["fid"])
|
||||
if tr:
|
||||
url = tr
|
||||
dl_headers = self._download_headers()
|
||||
|
||||
# 预获取大小/是否支持范围
|
||||
total_size: Optional[int] = None
|
||||
async with httpx.AsyncClient(timeout=self._timeout, follow_redirects=True) as client:
|
||||
try:
|
||||
head_resp = await client.head(url, headers=dl_headers)
|
||||
if head_resp.status_code == 200:
|
||||
cl = head_resp.headers.get("Content-Length")
|
||||
if cl and cl.isdigit():
|
||||
total_size = int(cl)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
mime, _ = mimetypes.guess_type(rel)
|
||||
content_type = mime or "application/octet-stream"
|
||||
|
||||
# 解析 Range
|
||||
start = 0
|
||||
end: Optional[int] = None
|
||||
status_code = 200
|
||||
if range_header and range_header.startswith("bytes="):
|
||||
status_code = 206
|
||||
part = range_header.split("=", 1)[1]
|
||||
s, e = part.split("-", 1)
|
||||
if s.strip():
|
||||
start = int(s)
|
||||
if e.strip():
|
||||
end = int(e)
|
||||
|
||||
if total_size is not None and end is None and status_code == 206:
|
||||
end = total_size - 1
|
||||
if end is not None and total_size is not None and end >= total_size:
|
||||
end = total_size - 1
|
||||
if total_size is not None and start >= total_size:
|
||||
raise HTTPException(416, detail="Requested Range Not Satisfiable")
|
||||
|
||||
resp_headers: Dict[str, str] = {"Accept-Ranges": "bytes", "Content-Type": content_type}
|
||||
if status_code == 206 and total_size is not None and end is not None:
|
||||
resp_headers["Content-Range"] = f"bytes {start}-{end}/{total_size}"
|
||||
resp_headers["Content-Length"] = str(end - start + 1)
|
||||
elif total_size is not None:
|
||||
resp_headers["Content-Length"] = str(total_size)
|
||||
|
||||
async def iterator():
|
||||
headers = dict(dl_headers)
|
||||
if status_code == 206 and end is not None:
|
||||
headers["Range"] = f"bytes={start}-{end}"
|
||||
async with httpx.AsyncClient(timeout=None, follow_redirects=True) as client:
|
||||
async with client.stream("GET", url, headers=headers) as resp:
|
||||
if resp.status_code in (404, 416):
|
||||
await resp.aclose()
|
||||
raise HTTPException(resp.status_code, detail="Upstream not available")
|
||||
async for chunk in resp.aiter_bytes():
|
||||
if chunk:
|
||||
yield chunk
|
||||
|
||||
return StreamingResponse(iterator(), status_code=status_code, headers=resp_headers, media_type=content_type)
|
||||
|
||||
# -----------------
|
||||
# 上传(大文件分片)
|
||||
# -----------------
|
||||
@staticmethod
|
||||
def _md5_hex(b: bytes) -> str:
|
||||
return hashlib.md5(b).hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def _sha1_hex(b: bytes) -> str:
|
||||
return hashlib.sha1(b).hexdigest()
|
||||
|
||||
def _guess_mime(self, name: str) -> str:
|
||||
mime, _ = mimetypes.guess_type(name)
|
||||
return mime or "application/octet-stream"
|
||||
|
||||
async def _upload_pre(self, filename: str, size: int, parent_fid: str) -> Dict[str, Any]:
|
||||
now_ms = int(time.time() * 1000)
|
||||
body = {
|
||||
"ccp_hash_update": True,
|
||||
"dir_name": "",
|
||||
"file_name": filename,
|
||||
"format_type": self._guess_mime(filename),
|
||||
"l_created_at": now_ms,
|
||||
"l_updated_at": now_ms,
|
||||
"pdir_fid": parent_fid,
|
||||
"size": size,
|
||||
}
|
||||
data = await self._request("POST", "/file/upload/pre", json=body)
|
||||
return data
|
||||
|
||||
async def write_file(self, root: str, rel: str, data: bytes):
|
||||
async def gen():
|
||||
yield data
|
||||
return await self.write_file_stream(root, rel, gen())
|
||||
|
||||
async def write_upload_file(self, root: str, rel: str, file_obj, filename: str | None, file_size: int | None = None, content_type: str | None = None):
|
||||
if not rel or rel.endswith("/"):
|
||||
raise HTTPException(400, detail="Invalid file path")
|
||||
|
||||
parent = rel.rsplit("/", 1)[0] if "/" in rel else ""
|
||||
name = filename or rel.rsplit("/", 1)[-1]
|
||||
base_fid = root or self.root_fid
|
||||
parent_fid = await self._resolve_dir_fid_from(base_fid, parent)
|
||||
|
||||
md5 = hashlib.md5()
|
||||
sha1 = hashlib.sha1()
|
||||
total = 0
|
||||
try:
|
||||
if callable(getattr(file_obj, "seek", None)):
|
||||
file_obj.seek(0)
|
||||
except Exception:
|
||||
pass
|
||||
while True:
|
||||
chunk = file_obj.read(1024 * 1024)
|
||||
if not chunk:
|
||||
break
|
||||
total += len(chunk)
|
||||
md5.update(chunk)
|
||||
sha1.update(chunk)
|
||||
|
||||
md5_hex = md5.hexdigest()
|
||||
sha1_hex = sha1.hexdigest()
|
||||
|
||||
# 预上传,拿到上传信息
|
||||
pre_resp = await self._upload_pre(name, total, parent_fid)
|
||||
pre_data = pre_resp.get("data", {})
|
||||
|
||||
# hash 秒传
|
||||
hash_body = {"md5": md5_hex, "sha1": sha1_hex, "task_id": pre_data.get("task_id")}
|
||||
hash_resp = await self._request("POST", "/file/update/hash", json=hash_body)
|
||||
if (hash_resp.get("data") or {}).get("finish") is True:
|
||||
self._invalidate_children_cache(parent_fid)
|
||||
return {"size": total}
|
||||
|
||||
# 分片上传
|
||||
part_size = int((pre_resp.get("metadata") or {}).get("part_size") or 0)
|
||||
if part_size <= 0:
|
||||
raise HTTPException(502, detail="Invalid part_size from Quark")
|
||||
|
||||
bucket = pre_data.get("bucket")
|
||||
obj_key = pre_data.get("obj_key")
|
||||
upload_id = pre_data.get("upload_id")
|
||||
upload_url = pre_data.get("upload_url")
|
||||
if not (bucket and obj_key and upload_id and upload_url):
|
||||
raise HTTPException(502, detail="Upload pre missing fields")
|
||||
|
||||
try:
|
||||
upload_host = upload_url.split("://", 1)[1]
|
||||
except Exception:
|
||||
upload_host = upload_url
|
||||
base_url = f"https://{bucket}.{upload_host}/{obj_key}"
|
||||
|
||||
try:
|
||||
if callable(getattr(file_obj, "seek", None)):
|
||||
file_obj.seek(0)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
etags: List[str] = []
|
||||
oss_ua = "aliyun-sdk-js/6.6.1 Chrome 98.0.4758.80 on Windows 10 64-bit"
|
||||
async with httpx.AsyncClient(timeout=None, follow_redirects=True) as client:
|
||||
part_number = 1
|
||||
left = total
|
||||
while left > 0:
|
||||
sz = min(part_size, left)
|
||||
data_bytes = file_obj.read(sz)
|
||||
if len(data_bytes) != sz:
|
||||
raise IOError("Failed to read part bytes")
|
||||
now_str = time.strftime("%a, %d %b %Y %H:%M:%S GMT", time.gmtime())
|
||||
auth_meta = (
|
||||
"PUT\n\n"
|
||||
f"{self._guess_mime(name)}\n"
|
||||
f"{now_str}\n"
|
||||
f"x-oss-date:{now_str}\n"
|
||||
f"x-oss-user-agent:{oss_ua}\n"
|
||||
f"/{bucket}/{obj_key}?partNumber={part_number}&uploadId={upload_id}"
|
||||
)
|
||||
auth_req_body = {"auth_info": pre_data.get("auth_info"), "auth_meta": auth_meta, "task_id": pre_data.get("task_id")}
|
||||
auth_resp = await self._request("POST", "/file/upload/auth", json=auth_req_body)
|
||||
auth_key = (auth_resp.get("data") or {}).get("auth_key")
|
||||
if not auth_key:
|
||||
raise HTTPException(502, detail="upload/auth missing auth_key")
|
||||
|
||||
put_headers = {
|
||||
"Authorization": auth_key,
|
||||
"Content-Type": self._guess_mime(name),
|
||||
"Referer": REFERER + "/",
|
||||
"x-oss-date": now_str,
|
||||
"x-oss-user-agent": oss_ua,
|
||||
}
|
||||
put_url = f"{base_url}?partNumber={part_number}&uploadId={upload_id}"
|
||||
put_resp = await client.put(put_url, headers=put_headers, content=data_bytes)
|
||||
if put_resp.status_code != 200:
|
||||
raise HTTPException(502, detail=f"Upload part failed status={put_resp.status_code} text={put_resp.text}")
|
||||
etag = put_resp.headers.get("Etag", "")
|
||||
etags.append(etag)
|
||||
left -= sz
|
||||
part_number += 1
|
||||
|
||||
parts_xml = [f"<Part>\n<PartNumber>{i+1}</PartNumber>\n<ETag>{etags[i]}</ETag>\n</Part>\n" for i in range(len(etags))]
|
||||
body_xml = "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<CompleteMultipartUpload>\n" + "".join(parts_xml) + "</CompleteMultipartUpload>"
|
||||
content_md5 = base64.b64encode(hashlib.md5(body_xml.encode("utf-8")).digest()).decode("ascii")
|
||||
callback = pre_data.get("callback") or {}
|
||||
try:
|
||||
import json as _json
|
||||
callback_b64 = base64.b64encode(_json.dumps(callback).encode("utf-8")).decode("ascii")
|
||||
except Exception:
|
||||
callback_b64 = ""
|
||||
|
||||
now_str = time.strftime("%a, %d %b %Y %H:%M:%S GMT", time.gmtime())
|
||||
auth_meta_commit = (
|
||||
"POST\n"
|
||||
f"{content_md5}\n"
|
||||
"application/xml\n"
|
||||
f"{now_str}\n"
|
||||
f"x-oss-callback:{callback_b64}\n"
|
||||
f"x-oss-date:{now_str}\n"
|
||||
f"x-oss-user-agent:{oss_ua}\n"
|
||||
f"/{bucket}/{obj_key}?uploadId={upload_id}"
|
||||
)
|
||||
auth_commit_resp = await self._request("POST", "/file/upload/auth", json={"auth_info": pre_data.get("auth_info"), "auth_meta": auth_meta_commit, "task_id": pre_data.get("task_id")})
|
||||
auth_key_commit = (auth_commit_resp.get("data") or {}).get("auth_key")
|
||||
if not auth_key_commit:
|
||||
raise HTTPException(502, detail="upload/auth(commit) missing auth_key")
|
||||
|
||||
async with httpx.AsyncClient(timeout=None, follow_redirects=True) as client:
|
||||
commit_headers = {
|
||||
"Authorization": auth_key_commit,
|
||||
"Content-MD5": content_md5,
|
||||
"Content-Type": "application/xml",
|
||||
"Referer": REFERER + "/",
|
||||
"x-oss-callback": callback_b64,
|
||||
"x-oss-date": now_str,
|
||||
"x-oss-user-agent": oss_ua,
|
||||
}
|
||||
commit_url = f"{base_url}?uploadId={upload_id}"
|
||||
r = await client.post(commit_url, headers=commit_headers, content=body_xml.encode("utf-8"))
|
||||
if r.status_code != 200:
|
||||
raise HTTPException(502, detail=f"Upload commit failed status={r.status_code} text={r.text}")
|
||||
|
||||
await self._request("POST", "/file/upload/finish", json={"obj_key": obj_key, "task_id": pre_data.get("task_id")})
|
||||
try:
|
||||
await asyncio.sleep(1.0)
|
||||
except Exception:
|
||||
pass
|
||||
self._invalidate_children_cache(parent_fid)
|
||||
return {"size": total}
|
||||
|
||||
async def write_file_stream(self, root: str, rel: str, data_iter: AsyncIterator[bytes]):
|
||||
if not rel or rel.endswith("/"):
|
||||
raise HTTPException(400, detail="Invalid file path")
|
||||
|
||||
parent = rel.rsplit("/", 1)[0] if "/" in rel else ""
|
||||
name = rel.rsplit("/", 1)[-1]
|
||||
base_fid = root or self.root_fid
|
||||
parent_fid = await self._resolve_dir_fid_from(base_fid, parent)
|
||||
|
||||
# 将数据落盘到临时文件,同时计算 MD5/SHA1
|
||||
import tempfile
|
||||
|
||||
md5 = hashlib.md5()
|
||||
sha1 = hashlib.sha1()
|
||||
total = 0
|
||||
with tempfile.NamedTemporaryFile(delete=False) as tf:
|
||||
tmp_path = tf.name
|
||||
try:
|
||||
async for chunk in data_iter:
|
||||
if not chunk:
|
||||
continue
|
||||
total += len(chunk)
|
||||
md5.update(chunk)
|
||||
sha1.update(chunk)
|
||||
tf.write(chunk)
|
||||
finally:
|
||||
tf.flush()
|
||||
|
||||
md5_hex = md5.hexdigest()
|
||||
sha1_hex = sha1.hexdigest()
|
||||
|
||||
# 预上传,拿到上传信息
|
||||
pre_resp = await self._upload_pre(name, total, parent_fid)
|
||||
pre_data = pre_resp.get("data", {})
|
||||
|
||||
# hash 秒传
|
||||
hash_body = {"md5": md5_hex, "sha1": sha1_hex, "task_id": pre_data.get("task_id")}
|
||||
hash_resp = await self._request("POST", "/file/update/hash", json=hash_body)
|
||||
if (hash_resp.get("data") or {}).get("finish") is True:
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except Exception:
|
||||
pass
|
||||
# 刷新父目录缓存
|
||||
self._invalidate_children_cache(parent_fid)
|
||||
return total
|
||||
|
||||
# 分片上传
|
||||
part_size = int((pre_resp.get("metadata") or {}).get("part_size") or 0)
|
||||
if part_size <= 0:
|
||||
raise HTTPException(502, detail="Invalid part_size from Quark")
|
||||
|
||||
bucket = pre_data.get("bucket")
|
||||
obj_key = pre_data.get("obj_key")
|
||||
upload_id = pre_data.get("upload_id")
|
||||
upload_url = pre_data.get("upload_url")
|
||||
if not (bucket and obj_key and upload_id and upload_url):
|
||||
raise HTTPException(502, detail="Upload pre missing fields")
|
||||
|
||||
# 计算 host 与基础 URL
|
||||
try:
|
||||
upload_host = upload_url.split("://", 1)[1]
|
||||
except Exception:
|
||||
upload_host = upload_url
|
||||
base_url = f"https://{bucket}.{upload_host}/{obj_key}"
|
||||
|
||||
# 分片循环
|
||||
etags: List[str] = []
|
||||
oss_ua = "aliyun-sdk-js/6.6.1 Chrome 98.0.4758.80 on Windows 10 64-bit"
|
||||
async with httpx.AsyncClient(timeout=None, follow_redirects=True) as client:
|
||||
with open(tmp_path, "rb") as rf:
|
||||
part_number = 1
|
||||
left = total
|
||||
while left > 0:
|
||||
sz = min(part_size, left)
|
||||
data_bytes = rf.read(sz)
|
||||
if len(data_bytes) != sz:
|
||||
raise IOError("Failed to read part bytes")
|
||||
now_str = time.strftime("%a, %d %b %Y %H:%M:%S GMT", time.gmtime())
|
||||
# 申请签名
|
||||
auth_meta = (
|
||||
"PUT\n\n"
|
||||
f"{self._guess_mime(name)}\n"
|
||||
f"{now_str}\n"
|
||||
f"x-oss-date:{now_str}\n"
|
||||
f"x-oss-user-agent:{oss_ua}\n"
|
||||
f"/{bucket}/{obj_key}?partNumber={part_number}&uploadId={upload_id}"
|
||||
)
|
||||
auth_req_body = {"auth_info": pre_data.get("auth_info"), "auth_meta": auth_meta, "task_id": pre_data.get("task_id")}
|
||||
auth_resp = await self._request("POST", "/file/upload/auth", json=auth_req_body)
|
||||
auth_key = (auth_resp.get("data") or {}).get("auth_key")
|
||||
if not auth_key:
|
||||
raise HTTPException(502, detail="upload/auth missing auth_key")
|
||||
|
||||
put_headers = {
|
||||
"Authorization": auth_key,
|
||||
"Content-Type": self._guess_mime(name),
|
||||
"Referer": REFERER + "/",
|
||||
"x-oss-date": now_str,
|
||||
"x-oss-user-agent": oss_ua,
|
||||
}
|
||||
put_url = f"{base_url}?partNumber={part_number}&uploadId={upload_id}"
|
||||
put_resp = await client.put(put_url, headers=put_headers, content=data_bytes)
|
||||
if put_resp.status_code != 200:
|
||||
raise HTTPException(502, detail=f"Upload part failed status={put_resp.status_code} text={put_resp.text}")
|
||||
etag = put_resp.headers.get("Etag", "")
|
||||
etags.append(etag)
|
||||
left -= sz
|
||||
part_number += 1
|
||||
|
||||
# 组合 commit xml
|
||||
parts_xml = [f"<Part>\n<PartNumber>{i+1}</PartNumber>\n<ETag>{etags[i]}</ETag>\n</Part>\n" for i in range(len(etags))]
|
||||
body_xml = "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<CompleteMultipartUpload>\n" + "".join(parts_xml) + "</CompleteMultipartUpload>"
|
||||
content_md5 = base64.b64encode(hashlib.md5(body_xml.encode("utf-8")).digest()).decode("ascii")
|
||||
callback = pre_data.get("callback") or {}
|
||||
try:
|
||||
import json as _json
|
||||
callback_b64 = base64.b64encode(_json.dumps(callback).encode("utf-8")).decode("ascii")
|
||||
except Exception:
|
||||
callback_b64 = ""
|
||||
|
||||
now_str = time.strftime("%a, %d %b %Y %H:%M:%S GMT", time.gmtime())
|
||||
auth_meta_commit = (
|
||||
"POST\n"
|
||||
f"{content_md5}\n"
|
||||
"application/xml\n"
|
||||
f"{now_str}\n"
|
||||
f"x-oss-callback:{callback_b64}\n"
|
||||
f"x-oss-date:{now_str}\n"
|
||||
f"x-oss-user-agent:{oss_ua}\n"
|
||||
f"/{bucket}/{obj_key}?uploadId={upload_id}"
|
||||
)
|
||||
auth_commit_resp = await self._request("POST", "/file/upload/auth", json={"auth_info": pre_data.get("auth_info"), "auth_meta": auth_meta_commit, "task_id": pre_data.get("task_id")})
|
||||
auth_key_commit = (auth_commit_resp.get("data") or {}).get("auth_key")
|
||||
if not auth_key_commit:
|
||||
raise HTTPException(502, detail="upload/auth(commit) missing auth_key")
|
||||
|
||||
async with httpx.AsyncClient(timeout=None, follow_redirects=True) as client:
|
||||
commit_headers = {
|
||||
"Authorization": auth_key_commit,
|
||||
"Content-MD5": content_md5,
|
||||
"Content-Type": "application/xml",
|
||||
"Referer": REFERER + "/",
|
||||
"x-oss-callback": callback_b64,
|
||||
"x-oss-date": now_str,
|
||||
"x-oss-user-agent": oss_ua,
|
||||
}
|
||||
commit_url = f"{base_url}?uploadId={upload_id}"
|
||||
r = await client.post(commit_url, headers=commit_headers, content=body_xml.encode("utf-8"))
|
||||
if r.status_code != 200:
|
||||
raise HTTPException(502, detail=f"Upload commit failed status={r.status_code} text={r.text}")
|
||||
|
||||
# finish
|
||||
await self._request("POST", "/file/upload/finish", json={"obj_key": obj_key, "task_id": pre_data.get("task_id")})
|
||||
# 端合并存在轻微延迟,等待再刷新缓存
|
||||
try:
|
||||
await asyncio.sleep(1.0)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except Exception:
|
||||
pass
|
||||
# 失效父目录缓存,确保后续列表可见
|
||||
self._invalidate_children_cache(parent_fid)
|
||||
return total
|
||||
|
||||
# -----------------
|
||||
# 基本文件操作
|
||||
# -----------------
|
||||
async def mkdir(self, root: str, rel: str):
|
||||
if not rel or rel == "/":
|
||||
raise HTTPException(400, detail="Cannot create root")
|
||||
parent = rel.rstrip("/")
|
||||
parent_rel, name = (parent.rsplit("/", 1) if "/" in parent else ("", parent))
|
||||
if not name:
|
||||
raise HTTPException(400, detail="Invalid directory name")
|
||||
pdir = await self._resolve_dir_fid_from(root or self.root_fid, parent_rel)
|
||||
await self._request("POST", "/file", json={"dir_init_lock": False, "dir_path": "", "file_name": name, "pdir_fid": pdir})
|
||||
self._invalidate_children_cache(pdir)
|
||||
|
||||
async def delete(self, root: str, rel: str):
|
||||
# 解析对象 fid + 父目录,用于失效缓存
|
||||
base_fid = root or self.root_fid
|
||||
if rel == "" or rel.endswith("/"):
|
||||
parent_rel = rel.rstrip("/")
|
||||
target_fid = await self._resolve_dir_fid_from(base_fid, parent_rel)
|
||||
parent_of_target = await self._resolve_dir_fid_from(base_fid, (parent_rel.rsplit("/", 1)[0] if "/" in parent_rel else ""))
|
||||
else:
|
||||
parent_rel, name = (rel.rsplit("/", 1) if "/" in rel else ("", rel))
|
||||
parent_of_target = await self._resolve_dir_fid_from(base_fid, parent_rel)
|
||||
it = await self._find_child(parent_of_target, name)
|
||||
if not it:
|
||||
return
|
||||
target_fid = it["fid"]
|
||||
await self._request("POST", "/file/delete", json={"action_type": 1, "exclude_fids": [], "filelist": [target_fid]})
|
||||
self._invalidate_children_cache(parent_of_target)
|
||||
|
||||
async def move(self, root: str, src_rel: str, dst_rel: str):
|
||||
# 支持跨目录与重命名:先移动到父目录,后重命名(若需要)
|
||||
src_parent_rel, src_name = (src_rel.rsplit("/", 1) if "/" in src_rel else ("", src_rel))
|
||||
dst_parent_rel, dst_name = (dst_rel.rsplit("/", 1) if "/" in dst_rel else ("", dst_rel))
|
||||
|
||||
base_fid = root or self.root_fid
|
||||
src_parent_fid = await self._resolve_dir_fid_from(base_fid, src_parent_rel)
|
||||
obj = await self._find_child(src_parent_fid, src_name)
|
||||
if not obj:
|
||||
raise FileNotFoundError(src_rel)
|
||||
dst_parent_fid = await self._resolve_dir_fid_from(base_fid, dst_parent_rel)
|
||||
|
||||
if src_parent_fid != dst_parent_fid:
|
||||
await self._request("POST", "/file/move", json={"action_type": 1, "exclude_fids": [], "filelist": [obj["fid"]], "to_pdir_fid": dst_parent_fid})
|
||||
self._invalidate_children_cache(src_parent_fid)
|
||||
self._invalidate_children_cache(dst_parent_fid)
|
||||
|
||||
if obj["name"] != dst_name:
|
||||
await self._request("POST", "/file/rename", json={"fid": obj["fid"], "file_name": dst_name})
|
||||
self._invalidate_children_cache(dst_parent_fid)
|
||||
|
||||
async def rename(self, root: str, src_rel: str, dst_rel: str):
|
||||
src_parent_rel, src_name = (src_rel.rsplit("/", 1) if "/" in src_rel else ("", src_rel))
|
||||
base_fid = root or self.root_fid
|
||||
src_parent_fid = await self._resolve_dir_fid_from(base_fid, src_parent_rel)
|
||||
obj = await self._find_child(src_parent_fid, src_name)
|
||||
if not obj:
|
||||
raise FileNotFoundError(src_rel)
|
||||
dst_name = dst_rel.rsplit("/", 1)[-1]
|
||||
await self._request("POST", "/file/rename", json={"fid": obj["fid"], "file_name": dst_name})
|
||||
self._invalidate_children_cache(src_parent_fid)
|
||||
|
||||
async def copy(self, root: str, src_rel: str, dst_rel: str, overwrite: bool = False):
|
||||
raise NotImplementedError("QuarkOpen does not support copy via open API")
|
||||
|
||||
# -----------------
|
||||
# STAT / EXISTS / 辅助
|
||||
# -----------------
|
||||
async def stat_file(self, root: str, rel: str):
|
||||
# 通过父目录列表获取元数据
|
||||
base_fid = root or self.root_fid
|
||||
if rel == "" or rel.endswith("/"):
|
||||
# 目录
|
||||
fid = await self._resolve_dir_fid_from(base_fid, rel.rstrip("/"))
|
||||
return {"name": rel.rstrip("/").split("/")[-1] if rel else "", "is_dir": True, "size": 0, "mtime": 0, "type": "dir", "fid": fid}
|
||||
parent_rel, name = (rel.rsplit("/", 1) if "/" in rel else ("", rel))
|
||||
parent_fid = await self._resolve_dir_fid_from(base_fid, parent_rel)
|
||||
it = await self._find_child(parent_fid, name)
|
||||
if not it:
|
||||
raise FileNotFoundError(rel)
|
||||
return it
|
||||
|
||||
async def exists(self, root: str, rel: str) -> bool:
|
||||
try:
|
||||
base_fid = root or self.root_fid
|
||||
if rel == "" or rel.endswith("/"):
|
||||
await self._resolve_dir_fid_from(base_fid, rel.rstrip("/"))
|
||||
return True
|
||||
parent_rel, name = (rel.rsplit("/", 1) if "/" in rel else ("", rel))
|
||||
parent_fid = await self._resolve_dir_fid_from(base_fid, parent_rel)
|
||||
it = await self._find_child(parent_fid, name)
|
||||
return it is not None
|
||||
except FileNotFoundError:
|
||||
return False
|
||||
|
||||
async def stat_path(self, root: str, rel: str):
|
||||
# 用于 move/copy 前的预检查调试
|
||||
try:
|
||||
base_fid = root or self.root_fid
|
||||
if rel == "" or rel.endswith("/"):
|
||||
fid = await self._resolve_dir_fid_from(base_fid, rel.rstrip("/"))
|
||||
return {"exists": True, "is_dir": True, "path": rel, "fid": fid}
|
||||
parent_rel, name = (rel.rsplit("/", 1) if "/" in rel else ("", rel))
|
||||
parent_fid = await self._resolve_dir_fid_from(base_fid, parent_rel)
|
||||
it = await self._find_child(parent_fid, name)
|
||||
if it:
|
||||
return {"exists": True, "is_dir": it["is_dir"], "path": rel, "fid": it["fid"]}
|
||||
return {"exists": False, "is_dir": None, "path": rel}
|
||||
except FileNotFoundError:
|
||||
return {"exists": False, "is_dir": None, "path": rel}
|
||||
|
||||
async def _resolve_target_fid(self, rel: str, *, base_fid: Optional[str] = None) -> str:
|
||||
base = base_fid or self.root_fid
|
||||
if rel == "" or rel.endswith("/"):
|
||||
return await self._resolve_dir_fid_from(base, rel.rstrip("/"))
|
||||
parent_rel, name = (rel.rsplit("/", 1) if "/" in rel else ("", rel))
|
||||
parent_fid = await self._resolve_dir_fid_from(base, parent_rel)
|
||||
it = await self._find_child(parent_fid, name)
|
||||
if not it:
|
||||
raise FileNotFoundError(rel)
|
||||
return it["fid"]
|
||||
|
||||
|
||||
ADAPTER_TYPE = "quark"
|
||||
|
||||
CONFIG_SCHEMA = [
|
||||
{"key": "cookie", "label": "Cookie", "type": "password", "required": True, "placeholder": "从 pan.quark.cn 复制"},
|
||||
{"key": "root_fid", "label": "根 FID", "type": "string", "required": False, "default": "0"},
|
||||
{"key": "use_transcoding_address", "label": "视频转码直链", "type": "boolean", "required": False, "default": False},
|
||||
{"key": "only_list_video_file", "label": "仅列出视频文件", "type": "boolean", "required": False, "default": False},
|
||||
]
|
||||
|
||||
def ADAPTER_FACTORY(rec: StorageAdapter) -> BaseAdapter:
|
||||
return QuarkAdapter(rec)
|
||||
347
domain/adapters/providers/s3.py
Normal file
347
domain/adapters/providers/s3.py
Normal file
@@ -0,0 +1,347 @@
|
||||
import asyncio
|
||||
import mimetypes
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Tuple, AsyncIterator
|
||||
from urllib.parse import quote
|
||||
|
||||
import aioboto3
|
||||
from botocore.exceptions import ClientError
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from models import StorageAdapter
|
||||
|
||||
|
||||
class S3Adapter:
|
||||
"""S3 兼容对象存储适配器"""
|
||||
|
||||
def __init__(self, record: StorageAdapter):
|
||||
self.record = record
|
||||
cfg = record.config
|
||||
self.bucket_name = cfg.get("bucket_name")
|
||||
self.aws_access_key_id = cfg.get("access_key_id")
|
||||
self.aws_secret_access_key = cfg.get("secret_access_key")
|
||||
self.region_name = cfg.get("region_name")
|
||||
self.endpoint_url = cfg.get("endpoint_url")
|
||||
self.root = cfg.get("root", "").strip("/")
|
||||
|
||||
if not all([self.bucket_name, self.aws_access_key_id, self.aws_secret_access_key]):
|
||||
raise ValueError(
|
||||
"S3 适配器需要 bucket_name, access_key_id, 和 secret_access_key")
|
||||
|
||||
self.session = aioboto3.Session(
|
||||
aws_access_key_id=self.aws_access_key_id,
|
||||
aws_secret_access_key=self.aws_secret_access_key,
|
||||
region_name=self.region_name,
|
||||
)
|
||||
|
||||
def get_effective_root(self, sub_path: str | None) -> str:
|
||||
"""获取 S3 中的有效根路径 (key prefix)"""
|
||||
if sub_path:
|
||||
return f"{self.root}/{sub_path.strip('/')}".strip("/")
|
||||
return self.root
|
||||
|
||||
def _get_s3_key(self, rel_path: str) -> str:
|
||||
"""将相对路径转换为 S3 key"""
|
||||
rel_path = rel_path.strip("/")
|
||||
if self.root:
|
||||
return f"{self.root}/{rel_path}"
|
||||
return rel_path
|
||||
|
||||
def _get_client(self):
|
||||
return self.session.client("s3", endpoint_url=self.endpoint_url)
|
||||
|
||||
async def list_dir(self, root: str, rel: str, page_num: int = 1, page_size: int = 50, sort_by: str = "name", sort_order: str = "asc") -> Tuple[List[Dict], int]:
|
||||
prefix = self._get_s3_key(rel)
|
||||
if prefix and not prefix.endswith("/"):
|
||||
prefix += "/"
|
||||
|
||||
all_items = []
|
||||
|
||||
async with self._get_client() as s3:
|
||||
paginator = s3.get_paginator("list_objects_v2")
|
||||
async for result in paginator.paginate(Bucket=self.bucket_name, Prefix=prefix, Delimiter="/"):
|
||||
# 添加子目录
|
||||
for common_prefix in result.get("CommonPrefixes", []):
|
||||
dir_name = common_prefix.get(
|
||||
"Prefix").removeprefix(prefix).strip("/")
|
||||
if dir_name:
|
||||
all_items.append({
|
||||
"name": dir_name,
|
||||
"is_dir": True,
|
||||
"size": 0,
|
||||
"mtime": 0,
|
||||
"type": "dir",
|
||||
})
|
||||
|
||||
# 添加文件
|
||||
for content in result.get("Contents", []):
|
||||
file_key = content.get("Key")
|
||||
if file_key == prefix: # 忽略目录本身
|
||||
continue
|
||||
file_name = file_key.removeprefix(prefix)
|
||||
if file_name:
|
||||
all_items.append({
|
||||
"name": file_name,
|
||||
"is_dir": False,
|
||||
"size": content.get("Size", 0),
|
||||
"mtime": int(content.get("LastModified", datetime.now()).timestamp()),
|
||||
"type": "file",
|
||||
})
|
||||
|
||||
# 在内存中排序和分页
|
||||
reverse = sort_order.lower() == "desc"
|
||||
def get_sort_key(item):
|
||||
key = (not item["is_dir"],)
|
||||
sort_field = sort_by.lower()
|
||||
if sort_field == "name":
|
||||
key += (item["name"].lower(),)
|
||||
elif sort_field == "size":
|
||||
key += (item["size"],)
|
||||
elif sort_field == "mtime":
|
||||
key += (item["mtime"],)
|
||||
else:
|
||||
key += (item["name"].lower(),)
|
||||
return key
|
||||
all_items.sort(key=get_sort_key, reverse=reverse)
|
||||
|
||||
total_count = len(all_items)
|
||||
start_idx = (page_num - 1) * page_size
|
||||
end_idx = start_idx + page_size
|
||||
|
||||
return all_items[start_idx:end_idx], total_count
|
||||
|
||||
async def read_file(self, root: str, rel: str) -> bytes:
|
||||
key = self._get_s3_key(rel)
|
||||
async with self._get_client() as s3:
|
||||
try:
|
||||
resp = await s3.get_object(Bucket=self.bucket_name, Key=key)
|
||||
return await resp["Body"].read()
|
||||
except ClientError as e:
|
||||
if e.response["Error"]["Code"] == "NoSuchKey":
|
||||
raise FileNotFoundError(rel)
|
||||
raise
|
||||
|
||||
async def write_file(self, root: str, rel: str, data: bytes):
|
||||
key = self._get_s3_key(rel)
|
||||
async with self._get_client() as s3:
|
||||
await s3.put_object(Bucket=self.bucket_name, Key=key, Body=data)
|
||||
|
||||
async def write_file_stream(self, root: str, rel: str, data_iter: AsyncIterator[bytes]):
|
||||
key = self._get_s3_key(rel)
|
||||
MIN_PART_SIZE = 5 * 1024 * 1024
|
||||
|
||||
async with self._get_client() as s3:
|
||||
mpu = await s3.create_multipart_upload(Bucket=self.bucket_name, Key=key)
|
||||
upload_id = mpu['UploadId']
|
||||
|
||||
parts = []
|
||||
part_number = 1
|
||||
total_size = 0
|
||||
buffer = bytearray()
|
||||
|
||||
try:
|
||||
async for chunk in data_iter:
|
||||
if not chunk:
|
||||
continue
|
||||
buffer.extend(chunk)
|
||||
|
||||
while len(buffer) >= MIN_PART_SIZE:
|
||||
part_data = buffer[:MIN_PART_SIZE]
|
||||
del buffer[:MIN_PART_SIZE]
|
||||
|
||||
part = await s3.upload_part(
|
||||
Bucket=self.bucket_name,
|
||||
Key=key,
|
||||
PartNumber=part_number,
|
||||
UploadId=upload_id,
|
||||
Body=part_data
|
||||
)
|
||||
|
||||
parts.append({'PartNumber': part_number, 'ETag': part['ETag']})
|
||||
total_size += len(part_data)
|
||||
part_number += 1
|
||||
|
||||
if buffer:
|
||||
part = await s3.upload_part(
|
||||
Bucket=self.bucket_name,
|
||||
Key=key,
|
||||
PartNumber=part_number,
|
||||
UploadId=upload_id,
|
||||
Body=bytes(buffer)
|
||||
)
|
||||
parts.append({'PartNumber': part_number, 'ETag': part['ETag']})
|
||||
total_size += len(buffer)
|
||||
|
||||
await s3.complete_multipart_upload(
|
||||
Bucket=self.bucket_name,
|
||||
Key=key,
|
||||
UploadId=upload_id,
|
||||
MultipartUpload={'Parts': parts}
|
||||
)
|
||||
except Exception as e:
|
||||
await s3.abort_multipart_upload(
|
||||
Bucket=self.bucket_name,
|
||||
Key=key,
|
||||
UploadId=upload_id
|
||||
)
|
||||
raise IOError(f"S3 stream upload failed: {e}") from e
|
||||
|
||||
return total_size
|
||||
|
||||
async def mkdir(self, root: str, rel: str):
|
||||
key = self._get_s3_key(rel)
|
||||
if not key.endswith("/"):
|
||||
key += "/"
|
||||
async with self._get_client() as s3:
|
||||
await s3.put_object(Bucket=self.bucket_name, Key=key, Body=b"")
|
||||
|
||||
async def delete(self, root: str, rel: str):
|
||||
key = self._get_s3_key(rel)
|
||||
async with self._get_client() as s3:
|
||||
is_dir_like = False
|
||||
try:
|
||||
head = await s3.head_object(Bucket=self.bucket_name, Key=key)
|
||||
if head['ContentLength'] == 0 and key.endswith('/'):
|
||||
is_dir_like = True
|
||||
except ClientError as e:
|
||||
if e.response['Error']['Code'] != '404':
|
||||
raise
|
||||
|
||||
# 如果是目录,删除目录下的所有对象
|
||||
if is_dir_like or not await self.stat_file(root, rel):
|
||||
dir_key = key if key.endswith('/') else key + '/'
|
||||
paginator = s3.get_paginator("list_objects_v2")
|
||||
objects_to_delete = []
|
||||
async for result in paginator.paginate(Bucket=self.bucket_name, Prefix=dir_key):
|
||||
for content in result.get("Contents", []):
|
||||
objects_to_delete.append({"Key": content["Key"]})
|
||||
if objects_to_delete:
|
||||
await s3.delete_objects(Bucket=self.bucket_name, Delete={"Objects": objects_to_delete})
|
||||
# 如果是文件,直接删除
|
||||
else:
|
||||
await s3.delete_object(Bucket=self.bucket_name, Key=key)
|
||||
|
||||
async def move(self, root: str, src_rel: str, dst_rel: str):
|
||||
await self.copy(root, src_rel, dst_rel, overwrite=True)
|
||||
await self.delete(root, src_rel)
|
||||
|
||||
async def rename(self, root: str, src_rel: str, dst_rel: str):
|
||||
await self.move(root, src_rel, dst_rel)
|
||||
|
||||
async def copy(self, root: str, src_rel: str, dst_rel: str, overwrite: bool = False):
|
||||
src_key = self._get_s3_key(src_rel)
|
||||
dst_key = self._get_s3_key(dst_rel)
|
||||
|
||||
async with self._get_client() as s3:
|
||||
if not overwrite:
|
||||
try:
|
||||
await s3.head_object(Bucket=self.bucket_name, Key=dst_key)
|
||||
raise FileExistsError(dst_rel)
|
||||
except ClientError as e:
|
||||
if e.response["Error"]["Code"] != "404":
|
||||
raise
|
||||
|
||||
copy_source = {"Bucket": self.bucket_name, "Key": src_key}
|
||||
await s3.copy_object(CopySource=copy_source, Bucket=self.bucket_name, Key=dst_key)
|
||||
|
||||
async def stat_file(self, root: str, rel: str):
|
||||
key = self._get_s3_key(rel)
|
||||
async with self._get_client() as s3:
|
||||
try:
|
||||
head = await s3.head_object(Bucket=self.bucket_name, Key=key)
|
||||
return {
|
||||
"name": rel.split("/")[-1],
|
||||
"is_dir": False,
|
||||
"size": head["ContentLength"],
|
||||
"mtime": int(head["LastModified"].timestamp()),
|
||||
"type": "file",
|
||||
}
|
||||
except ClientError as e:
|
||||
if e.response["Error"]["Code"] == "404":
|
||||
# 检查是否为一个 "目录"
|
||||
dir_key = key if key.endswith('/') else key + '/'
|
||||
resp = await s3.list_objects_v2(Bucket=self.bucket_name, Prefix=dir_key, MaxKeys=1)
|
||||
if resp.get('KeyCount', 0) > 0:
|
||||
return {
|
||||
"name": rel.split("/")[-1],
|
||||
"is_dir": True,
|
||||
"size": 0,
|
||||
"mtime": 0,
|
||||
"type": "dir",
|
||||
}
|
||||
raise FileNotFoundError(rel)
|
||||
raise
|
||||
|
||||
async def stream_file(self, root: str, rel: str, range_header: str | None):
|
||||
key = self._get_s3_key(rel)
|
||||
async with self._get_client() as s3:
|
||||
try:
|
||||
head = await s3.head_object(Bucket=self.bucket_name, Key=key)
|
||||
file_size = head["ContentLength"]
|
||||
content_type = head.get("ContentType", mimetypes.guess_type(key)[
|
||||
0] or "application/octet-stream")
|
||||
except ClientError as e:
|
||||
if e.response["Error"]["Code"] == "404":
|
||||
raise HTTPException(
|
||||
status_code=404, detail="File not found")
|
||||
raise
|
||||
|
||||
start = 0
|
||||
end = file_size - 1
|
||||
status = 200
|
||||
headers = {
|
||||
"Accept-Ranges": "bytes",
|
||||
"Content-Type": content_type,
|
||||
"Content-Length": str(file_size),
|
||||
"Content-Disposition": f"inline; filename=\"{quote(rel.split('/')[-1])}\""
|
||||
}
|
||||
|
||||
if range_header:
|
||||
range_val = range_header.strip().partition("=")[2]
|
||||
s, _, e = range_val.partition("-")
|
||||
try:
|
||||
start = int(s) if s else 0
|
||||
end = int(e) if e else file_size - 1
|
||||
if start >= file_size or end >= file_size or start > end:
|
||||
raise HTTPException(
|
||||
status_code=416, detail="Requested Range Not Satisfiable")
|
||||
status = 206
|
||||
headers["Content-Length"] = str(end - start + 1)
|
||||
headers["Content-Range"] = f"bytes {start}-{end}/{file_size}"
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Invalid Range header")
|
||||
|
||||
range_arg = f"bytes={start}-{end}"
|
||||
|
||||
async def iterator():
|
||||
try:
|
||||
resp = await s3.get_object(Bucket=self.bucket_name, Key=key, Range=range_arg)
|
||||
body = resp["Body"]
|
||||
while chunk := await body.read(65536):
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
raise
|
||||
|
||||
return StreamingResponse(iterator(), status_code=status, headers=headers, media_type=content_type)
|
||||
|
||||
|
||||
ADAPTER_TYPE = "s3"
|
||||
|
||||
CONFIG_SCHEMA = [
|
||||
{"key": "bucket_name", "label": "Bucket 名称",
|
||||
"type": "string", "required": True},
|
||||
{"key": "access_key_id", "label": "Access Key ID",
|
||||
"type": "string", "required": True},
|
||||
{"key": "secret_access_key", "label": "Secret Access Key",
|
||||
"type": "password", "required": True},
|
||||
{"key": "region_name", "label": "区域 (Region)", "type": "string",
|
||||
"required": False, "placeholder": "例如 us-east-1"},
|
||||
{"key": "endpoint_url", "label": "Endpoint URL", "type": "string",
|
||||
"required": False, "placeholder": "对于 S3 兼容存储, 例如 https://minio.example.com"},
|
||||
{"key": "root", "label": "根路径 (Root Path)", "type": "string",
|
||||
"required": False, "placeholder": "在 bucket 内的路径前缀"},
|
||||
]
|
||||
|
||||
|
||||
def ADAPTER_FACTORY(rec): return S3Adapter(rec)
|
||||
473
domain/adapters/providers/sftp.py
Normal file
473
domain/adapters/providers/sftp.py
Normal file
@@ -0,0 +1,473 @@
|
||||
import asyncio
|
||||
import mimetypes
|
||||
import stat as statmod
|
||||
from typing import List, Dict, Tuple, AsyncIterator, Optional
|
||||
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
import paramiko
|
||||
|
||||
from models import StorageAdapter
|
||||
|
||||
|
||||
def _join_remote(root: str, rel: str) -> str:
|
||||
root = (root or "/").rstrip("/") or "/"
|
||||
rel = (rel or "").lstrip("/")
|
||||
if not rel:
|
||||
return root
|
||||
return f"{root}/{rel}"
|
||||
|
||||
|
||||
class SFTPAdapter:
|
||||
def __init__(self, record: StorageAdapter):
|
||||
self.record = record
|
||||
cfg = record.config
|
||||
self.host: str = cfg.get("host")
|
||||
self.port: int = int(cfg.get("port", 22))
|
||||
self.username: str | None = cfg.get("username")
|
||||
self.password: str | None = cfg.get("password")
|
||||
self.timeout: int = int(cfg.get("timeout", 15))
|
||||
self.root_path: str = cfg.get("root") # 必填
|
||||
self.allow_unknown_host: bool = bool(cfg.get("allow_unknown_host", True))
|
||||
|
||||
if not self.host:
|
||||
raise ValueError("SFTP adapter requires 'host'")
|
||||
if not self.username or not self.password:
|
||||
raise ValueError("SFTP adapter requires 'username' and 'password'")
|
||||
if not self.root_path:
|
||||
raise ValueError("SFTP adapter requires 'root'")
|
||||
|
||||
def get_effective_root(self, sub_path: str | None) -> str:
|
||||
base = self.root_path.rstrip("/") or "/"
|
||||
if sub_path:
|
||||
return _join_remote(base, sub_path)
|
||||
return base
|
||||
|
||||
def _connect(self) -> paramiko.SFTPClient:
|
||||
ssh = paramiko.SSHClient()
|
||||
if self.allow_unknown_host:
|
||||
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
||||
ssh.connect(
|
||||
hostname=self.host,
|
||||
port=self.port,
|
||||
username=self.username,
|
||||
password=self.password,
|
||||
timeout=self.timeout,
|
||||
allow_agent=False,
|
||||
look_for_keys=False,
|
||||
)
|
||||
return ssh.open_sftp()
|
||||
|
||||
async def list_dir(self, root: str, rel: str, page_num: int = 1, page_size: int = 50, sort_by: str = "name", sort_order: str = "asc") -> Tuple[List[Dict], int]:
|
||||
path = _join_remote(root, rel)
|
||||
|
||||
def _do_list() -> List[Dict]:
|
||||
sftp = self._connect()
|
||||
try:
|
||||
attrs = sftp.listdir_attr(path)
|
||||
entries: List[Dict] = []
|
||||
for a in attrs:
|
||||
name = a.filename
|
||||
is_dir = statmod.S_ISDIR(a.st_mode)
|
||||
entries.append({
|
||||
"name": name,
|
||||
"is_dir": is_dir,
|
||||
"size": 0 if is_dir else int(a.st_size or 0),
|
||||
"mtime": int(a.st_mtime or 0),
|
||||
"type": "dir" if is_dir else "file",
|
||||
})
|
||||
return entries
|
||||
finally:
|
||||
try:
|
||||
sftp.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
entries = await asyncio.to_thread(_do_list)
|
||||
|
||||
reverse = sort_order.lower() == "desc"
|
||||
|
||||
def get_sort_key(item):
|
||||
key = (not item["is_dir"],)
|
||||
f = sort_by.lower()
|
||||
if f == "name":
|
||||
key += (item["name"].lower(),)
|
||||
elif f == "size":
|
||||
key += (item.get("size", 0),)
|
||||
elif f == "mtime":
|
||||
key += (item.get("mtime", 0),)
|
||||
else:
|
||||
key += (item["name"].lower(),)
|
||||
return key
|
||||
|
||||
entries.sort(key=get_sort_key, reverse=reverse)
|
||||
total = len(entries)
|
||||
start = (page_num - 1) * page_size
|
||||
end = start + page_size
|
||||
return entries[start:end], total
|
||||
|
||||
async def read_file(self, root: str, rel: str) -> bytes:
|
||||
path = _join_remote(root, rel)
|
||||
|
||||
def _do_read() -> bytes:
|
||||
sftp = self._connect()
|
||||
try:
|
||||
with sftp.open(path, "rb") as f:
|
||||
return f.read()
|
||||
except FileNotFoundError:
|
||||
raise
|
||||
except IOError as e:
|
||||
if getattr(e, "errno", None) == 2:
|
||||
raise FileNotFoundError(rel)
|
||||
raise
|
||||
finally:
|
||||
try:
|
||||
sftp.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return await asyncio.to_thread(_do_read)
|
||||
|
||||
async def write_file(self, root: str, rel: str, data: bytes):
|
||||
path = _join_remote(root, rel)
|
||||
|
||||
def _ensure_dirs(sftp: paramiko.SFTPClient, dir_path: str):
|
||||
parts = [p for p in dir_path.strip("/").split("/") if p]
|
||||
cur = "/"
|
||||
for p in parts:
|
||||
cur = _join_remote(cur, p)
|
||||
try:
|
||||
sftp.mkdir(cur)
|
||||
except IOError:
|
||||
# likely exists
|
||||
pass
|
||||
|
||||
def _do_write():
|
||||
sftp = self._connect()
|
||||
try:
|
||||
parent = "/" if "/" not in path.strip("/") else path.rsplit("/", 1)[0]
|
||||
_ensure_dirs(sftp, parent)
|
||||
with sftp.open(path, "wb") as f:
|
||||
f.write(data)
|
||||
finally:
|
||||
try:
|
||||
sftp.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await asyncio.to_thread(_do_write)
|
||||
|
||||
async def write_upload_file(self, root: str, rel: str, file_obj, filename: str | None, file_size: int | None = None, content_type: str | None = None):
|
||||
path = _join_remote(root, rel)
|
||||
|
||||
def _ensure_dirs(sftp: paramiko.SFTPClient, dir_path: str):
|
||||
parts = [p for p in dir_path.strip("/").split("/") if p]
|
||||
cur = "/"
|
||||
for p in parts:
|
||||
cur = _join_remote(cur, p)
|
||||
try:
|
||||
sftp.mkdir(cur)
|
||||
except IOError:
|
||||
pass
|
||||
|
||||
def _do_upload():
|
||||
sftp = self._connect()
|
||||
try:
|
||||
parent = "/" if "/" not in path.strip("/") else path.rsplit("/", 1)[0]
|
||||
_ensure_dirs(sftp, parent)
|
||||
try:
|
||||
if callable(getattr(file_obj, "seek", None)):
|
||||
file_obj.seek(0)
|
||||
except Exception:
|
||||
pass
|
||||
with sftp.open(path, "wb") as f:
|
||||
import shutil
|
||||
shutil.copyfileobj(file_obj, f)
|
||||
finally:
|
||||
try:
|
||||
sftp.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await asyncio.to_thread(_do_upload)
|
||||
return {"size": file_size or 0}
|
||||
|
||||
async def write_file_stream(self, root: str, rel: str, data_iter: AsyncIterator[bytes]):
|
||||
buf = bytearray()
|
||||
async for chunk in data_iter:
|
||||
if chunk:
|
||||
buf.extend(chunk)
|
||||
await self.write_file(root, rel, bytes(buf))
|
||||
return len(buf)
|
||||
|
||||
async def mkdir(self, root: str, rel: str):
|
||||
path = _join_remote(root, rel)
|
||||
|
||||
def _do_mkdir():
|
||||
sftp = self._connect()
|
||||
try:
|
||||
parts = [p for p in path.strip("/").split("/") if p]
|
||||
cur = "/"
|
||||
for p in parts:
|
||||
cur = _join_remote(cur, p)
|
||||
try:
|
||||
sftp.mkdir(cur)
|
||||
except IOError:
|
||||
pass
|
||||
finally:
|
||||
try:
|
||||
sftp.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await asyncio.to_thread(_do_mkdir)
|
||||
|
||||
async def delete(self, root: str, rel: str):
|
||||
path = _join_remote(root, rel)
|
||||
|
||||
def _do_delete():
|
||||
sftp = self._connect()
|
||||
try:
|
||||
# Try file remove first
|
||||
try:
|
||||
sftp.remove(path)
|
||||
return
|
||||
except IOError:
|
||||
pass
|
||||
|
||||
def _rm_tree(dp: str):
|
||||
try:
|
||||
for a in sftp.listdir_attr(dp):
|
||||
child = _join_remote(dp, a.filename)
|
||||
if statmod.S_ISDIR(a.st_mode):
|
||||
_rm_tree(child)
|
||||
else:
|
||||
try:
|
||||
sftp.remove(child)
|
||||
except Exception:
|
||||
pass
|
||||
sftp.rmdir(dp)
|
||||
except IOError:
|
||||
pass
|
||||
|
||||
_rm_tree(path)
|
||||
finally:
|
||||
try:
|
||||
sftp.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await asyncio.to_thread(_do_delete)
|
||||
|
||||
async def move(self, root: str, src_rel: str, dst_rel: str):
|
||||
src = _join_remote(root, src_rel)
|
||||
dst = _join_remote(root, dst_rel)
|
||||
|
||||
def _do_move():
|
||||
sftp = self._connect()
|
||||
try:
|
||||
# ensure dst parent exists
|
||||
parent = "/" if "/" not in dst.strip("/") else dst.rsplit("/", 1)[0]
|
||||
parts = [p for p in parent.strip("/").split("/") if p]
|
||||
cur = "/"
|
||||
for p in parts:
|
||||
cur = _join_remote(cur, p)
|
||||
try:
|
||||
sftp.mkdir(cur)
|
||||
except IOError:
|
||||
pass
|
||||
sftp.rename(src, dst)
|
||||
finally:
|
||||
try:
|
||||
sftp.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await asyncio.to_thread(_do_move)
|
||||
|
||||
async def rename(self, root: str, src_rel: str, dst_rel: str):
|
||||
await self.move(root, src_rel, dst_rel)
|
||||
|
||||
async def copy(self, root: str, src_rel: str, dst_rel: str, overwrite: bool = False):
|
||||
src = _join_remote(root, src_rel)
|
||||
dst = _join_remote(root, dst_rel)
|
||||
|
||||
def _is_dir() -> bool:
|
||||
sftp = self._connect()
|
||||
try:
|
||||
st = sftp.stat(src)
|
||||
return statmod.S_ISDIR(st.st_mode)
|
||||
finally:
|
||||
try:
|
||||
sftp.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if await asyncio.to_thread(_is_dir):
|
||||
await self.mkdir(root, dst_rel)
|
||||
|
||||
children, _ = await self.list_dir(root, src_rel, page_num=1, page_size=10_000)
|
||||
for ent in children:
|
||||
child_src = f"{src_rel.rstrip('/')}/{ent['name']}"
|
||||
child_dst = f"{dst_rel.rstrip('/')}/{ent['name']}"
|
||||
await self.copy(root, child_src, child_dst, overwrite)
|
||||
return
|
||||
|
||||
# file copy
|
||||
data = await self.read_file(root, src_rel)
|
||||
if not overwrite:
|
||||
try:
|
||||
await self.stat_file(root, dst_rel)
|
||||
raise FileExistsError(dst_rel)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
await self.write_file(root, dst_rel, data)
|
||||
|
||||
async def stat_file(self, root: str, rel: str):
|
||||
path = _join_remote(root, rel)
|
||||
|
||||
def _do_stat():
|
||||
sftp = self._connect()
|
||||
try:
|
||||
st = sftp.stat(path)
|
||||
is_dir = statmod.S_ISDIR(st.st_mode)
|
||||
info = {
|
||||
"name": rel.split("/")[-1],
|
||||
"is_dir": is_dir,
|
||||
"size": 0 if is_dir else int(st.st_size or 0),
|
||||
"mtime": int(st.st_mtime or 0),
|
||||
"type": "dir" if is_dir else "file",
|
||||
"path": path,
|
||||
}
|
||||
return info
|
||||
except FileNotFoundError:
|
||||
raise
|
||||
except IOError as e:
|
||||
if getattr(e, "errno", None) == 2:
|
||||
raise FileNotFoundError(rel)
|
||||
raise
|
||||
finally:
|
||||
try:
|
||||
sftp.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return await asyncio.to_thread(_do_stat)
|
||||
|
||||
async def exists(self, root: str, rel: str) -> bool:
|
||||
try:
|
||||
await self.stat_file(root, rel)
|
||||
return True
|
||||
except FileNotFoundError:
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def stream_file(self, root: str, rel: str, range_header: str | None):
|
||||
path = _join_remote(root, rel)
|
||||
|
||||
def _get_stat():
|
||||
sftp = self._connect()
|
||||
try:
|
||||
st = sftp.stat(path)
|
||||
return int(st.st_size or 0)
|
||||
finally:
|
||||
try:
|
||||
sftp.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
file_size = await asyncio.to_thread(_get_stat)
|
||||
if file_size is None:
|
||||
raise HTTPException(404, detail="File not found")
|
||||
|
||||
mime, _ = mimetypes.guess_type(rel)
|
||||
content_type = mime or "application/octet-stream"
|
||||
|
||||
start = 0
|
||||
end = file_size - 1
|
||||
status = 200
|
||||
headers = {
|
||||
"Accept-Ranges": "bytes",
|
||||
"Content-Type": content_type,
|
||||
"Content-Length": str(file_size),
|
||||
}
|
||||
|
||||
if range_header and range_header.startswith("bytes="):
|
||||
try:
|
||||
s, e = (range_header.removeprefix("bytes=").split("-", 1))
|
||||
if s.strip():
|
||||
start = int(s)
|
||||
if e.strip():
|
||||
end = int(e)
|
||||
if start >= file_size:
|
||||
raise HTTPException(416, detail="Requested Range Not Satisfiable")
|
||||
if end >= file_size:
|
||||
end = file_size - 1
|
||||
status = 206
|
||||
headers["Content-Length"] = str(end - start + 1)
|
||||
headers["Content-Range"] = f"bytes {start}-{end}/{file_size}"
|
||||
except ValueError:
|
||||
raise HTTPException(400, detail="Invalid Range header")
|
||||
|
||||
queue: asyncio.Queue[Optional[bytes]] = asyncio.Queue(maxsize=8)
|
||||
|
||||
def _worker():
|
||||
sftp = self._connect()
|
||||
try:
|
||||
with sftp.open(path, "rb") as f:
|
||||
f.seek(start)
|
||||
remaining = end - start + 1
|
||||
chunk_size = 64 * 1024
|
||||
while remaining > 0:
|
||||
to_read = chunk_size if remaining > chunk_size else remaining
|
||||
data = f.read(to_read)
|
||||
if not data:
|
||||
break
|
||||
try:
|
||||
queue.put_nowait(data)
|
||||
except Exception:
|
||||
break
|
||||
remaining -= len(data)
|
||||
try:
|
||||
queue.put_nowait(None)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
try:
|
||||
sftp.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def agen():
|
||||
worker_fut = asyncio.to_thread(_worker)
|
||||
try:
|
||||
while True:
|
||||
chunk = await queue.get()
|
||||
if chunk is None:
|
||||
break
|
||||
yield chunk
|
||||
finally:
|
||||
try:
|
||||
await worker_fut
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return StreamingResponse(agen(), status_code=status, headers=headers, media_type=content_type)
|
||||
|
||||
|
||||
ADAPTER_TYPE = "sftp"
|
||||
|
||||
CONFIG_SCHEMA = [
|
||||
{"key": "host", "label": "主机", "type": "string", "required": True, "placeholder": "sftp.example.com"},
|
||||
{"key": "port", "label": "端口", "type": "number", "required": False, "default": 22},
|
||||
{"key": "username", "label": "用户名", "type": "string", "required": True},
|
||||
{"key": "password", "label": "密码", "type": "password", "required": True},
|
||||
{"key": "root", "label": "根路径", "type": "string", "required": True, "placeholder": "/data"},
|
||||
{"key": "timeout", "label": "超时(秒)", "type": "number", "required": False, "default": 15},
|
||||
{"key": "allow_unknown_host", "label": "允许未知主机指纹", "type": "boolean", "required": False, "default": True},
|
||||
]
|
||||
|
||||
|
||||
def ADAPTER_FACTORY(rec: StorageAdapter):
|
||||
return SFTPAdapter(rec)
|
||||
578
domain/adapters/providers/telegram.py
Normal file
578
domain/adapters/providers/telegram.py
Normal file
@@ -0,0 +1,578 @@
|
||||
from typing import List, Dict, Tuple, AsyncIterator
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import os
|
||||
import struct
|
||||
from models import StorageAdapter
|
||||
from telethon import TelegramClient
|
||||
from telethon.crypto import AuthKey
|
||||
from telethon.sessions import StringSession
|
||||
from telethon.tl import types
|
||||
import socks
|
||||
|
||||
_SESSION_LOCKS: Dict[str, asyncio.Lock] = {}
|
||||
|
||||
|
||||
def _get_session_lock(session_string: str) -> asyncio.Lock:
|
||||
lock = _SESSION_LOCKS.get(session_string)
|
||||
if lock is None:
|
||||
lock = asyncio.Lock()
|
||||
_SESSION_LOCKS[session_string] = lock
|
||||
return lock
|
||||
|
||||
|
||||
class _NamedFile:
|
||||
def __init__(self, file_obj, name: str):
|
||||
self._file = file_obj
|
||||
self.name = name
|
||||
|
||||
def read(self, *args, **kwargs):
|
||||
return self._file.read(*args, **kwargs)
|
||||
|
||||
def seek(self, *args, **kwargs):
|
||||
return self._file.seek(*args, **kwargs)
|
||||
|
||||
def tell(self):
|
||||
return self._file.tell()
|
||||
|
||||
def seekable(self):
|
||||
return self._file.seekable()
|
||||
|
||||
def close(self):
|
||||
return self._file.close()
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._file, name)
|
||||
|
||||
# 适配器类型标识
|
||||
ADAPTER_TYPE = "telegram"
|
||||
|
||||
# 适配器配置项定义
|
||||
CONFIG_SCHEMA = [
|
||||
{"key": "api_id", "label": "API ID", "type": "string", "required": True, "help_text": "从 my.telegram.org 获取"},
|
||||
{"key": "api_hash", "label": "API Hash", "type": "password", "required": True, "help_text": "从 my.telegram.org 获取"},
|
||||
{"key": "session_string", "label": "Session String", "type": "password", "required": True, "help_text": "通过 generate_session.py 生成"},
|
||||
{"key": "chat_id", "label": "Chat ID", "type": "string", "required": True, "placeholder": "频道/群组的ID或用户名, 例如: -100123456789 或 'channel_username'"},
|
||||
{"key": "proxy_protocol", "label": "代理协议", "type": "string", "required": False, "placeholder": "例如: socks5, http"},
|
||||
{"key": "proxy_host", "label": "代理主机", "type": "string", "required": False, "placeholder": "例如: 127.0.0.1"},
|
||||
{"key": "proxy_port", "label": "代理端口", "type": "number", "required": False, "placeholder": "例如: 1080"},
|
||||
]
|
||||
|
||||
class TelegramAdapter:
|
||||
"""Telegram 存储适配器 (使用用户 Session)"""
|
||||
|
||||
def __init__(self, record: StorageAdapter):
|
||||
self.record = record
|
||||
cfg = record.config
|
||||
self.api_id = int(cfg.get("api_id"))
|
||||
self.api_hash = cfg.get("api_hash")
|
||||
self.session_string = cfg.get("session_string")
|
||||
self.chat_id_str = cfg.get("chat_id")
|
||||
|
||||
# 代理设置
|
||||
self.proxy_protocol = cfg.get("proxy_protocol")
|
||||
self.proxy_host = cfg.get("proxy_host")
|
||||
self.proxy_port = cfg.get("proxy_port")
|
||||
|
||||
self.proxy = None
|
||||
if self.proxy_protocol and self.proxy_host and self.proxy_port:
|
||||
proto_map = {
|
||||
"socks5": socks.SOCKS5,
|
||||
"http": socks.HTTP,
|
||||
}
|
||||
proxy_type = proto_map.get(self.proxy_protocol.lower())
|
||||
if proxy_type:
|
||||
self.proxy = (proxy_type, self.proxy_host, int(self.proxy_port))
|
||||
|
||||
try:
|
||||
self.chat_id = int(self.chat_id_str)
|
||||
except (ValueError, TypeError):
|
||||
self.chat_id = self.chat_id_str
|
||||
|
||||
if not all([self.api_id, self.api_hash, self.session_string, self.chat_id]):
|
||||
raise ValueError("Telegram 适配器需要 api_id, api_hash, session_string 和 chat_id")
|
||||
|
||||
@staticmethod
|
||||
def _parse_legacy_session_string(value: str) -> StringSession:
|
||||
"""
|
||||
兼容旧版 session_string 格式:
|
||||
- version(1B char) + base64(data)
|
||||
- data: dc_id(1B) + ip_len(2B) + ip(ASCII, ip_len bytes) + port(2B) + auth_key(256B)
|
||||
"""
|
||||
s = (value or "").strip()
|
||||
if not s:
|
||||
raise ValueError("session_string 为空")
|
||||
|
||||
body = s[1:] if s.startswith("1") else s
|
||||
raw = base64.urlsafe_b64decode(body)
|
||||
if len(raw) < 1 + 2 + 2 + 256:
|
||||
raise ValueError("legacy session 数据长度不足")
|
||||
|
||||
dc_id = raw[0]
|
||||
ip_len = struct.unpack(">H", raw[1:3])[0]
|
||||
expected_len = 1 + 2 + ip_len + 2 + 256
|
||||
if len(raw) != expected_len:
|
||||
raise ValueError("legacy session 数据长度不匹配")
|
||||
|
||||
ip_start = 3
|
||||
ip_end = ip_start + ip_len
|
||||
ip = raw[ip_start:ip_end].decode("utf-8")
|
||||
port = struct.unpack(">H", raw[ip_end : ip_end + 2])[0]
|
||||
key = raw[ip_end + 2 : ip_end + 2 + 256]
|
||||
|
||||
sess = StringSession()
|
||||
sess.set_dc(dc_id, ip, port)
|
||||
sess.auth_key = AuthKey(key)
|
||||
return sess
|
||||
|
||||
@staticmethod
|
||||
def _pick_photo_thumb(thumbs: list | None):
|
||||
if not thumbs:
|
||||
return None
|
||||
|
||||
cached = []
|
||||
others = []
|
||||
for t in thumbs:
|
||||
if isinstance(t, (types.PhotoCachedSize, types.PhotoStrippedSize)):
|
||||
cached.append(t)
|
||||
elif isinstance(t, (types.PhotoSize, types.PhotoSizeProgressive)):
|
||||
if not isinstance(t, types.PhotoSizeEmpty):
|
||||
others.append(t)
|
||||
|
||||
if cached:
|
||||
cached.sort(key=lambda x: len(getattr(x, "bytes", b"") or b""))
|
||||
return cached[-1]
|
||||
|
||||
if others:
|
||||
def _sz(x):
|
||||
if isinstance(x, types.PhotoSizeProgressive):
|
||||
return max(x.sizes or [0])
|
||||
return int(getattr(x, "size", 0) or 0)
|
||||
|
||||
others.sort(key=_sz)
|
||||
return others[-1]
|
||||
|
||||
return None
|
||||
|
||||
def _build_session(self) -> StringSession:
|
||||
s = (self.session_string or "").strip()
|
||||
if not s:
|
||||
raise ValueError("Telegram 适配器 session_string 为空")
|
||||
|
||||
try:
|
||||
return StringSession(s)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 少数工具可能去掉了 version 前缀,这里做一次兼容
|
||||
if not s.startswith("1"):
|
||||
try:
|
||||
return StringSession("1" + s)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
return self._parse_legacy_session_string(s)
|
||||
except Exception as exc:
|
||||
raise ValueError("Telegram session_string 无效,请使用 Telethon StringSession 重新生成") from exc
|
||||
|
||||
def _get_client(self) -> TelegramClient:
|
||||
"""创建一个新的 TelegramClient 实例"""
|
||||
return TelegramClient(self._build_session(), self.api_id, self.api_hash, proxy=self.proxy)
|
||||
|
||||
def get_effective_root(self, sub_path: str | None) -> str:
|
||||
return ""
|
||||
|
||||
async def list_dir(self, root: str, rel: str, page_num: int = 1, page_size: int = 50, sort_by: str = "name", sort_order: str = "asc") -> Tuple[List[Dict], int]:
|
||||
if rel:
|
||||
return [], 0
|
||||
|
||||
client = self._get_client()
|
||||
entries = []
|
||||
try:
|
||||
await client.connect()
|
||||
messages = await client.get_messages(self.chat_id, limit=200)
|
||||
for message in messages:
|
||||
if not message:
|
||||
continue
|
||||
|
||||
media = message.document or message.video or message.photo
|
||||
if not media:
|
||||
continue
|
||||
|
||||
file_meta = message.file
|
||||
if not file_meta:
|
||||
continue
|
||||
|
||||
filename = file_meta.name
|
||||
if not filename:
|
||||
if message.text and '.' in message.text and len(message.text) < 256 and '\n' not in message.text:
|
||||
filename = message.text
|
||||
else:
|
||||
filename = f"unknown_{message.id}"
|
||||
|
||||
size = file_meta.size
|
||||
if size is None:
|
||||
# 兼容缺失 size 的情况
|
||||
if hasattr(media, "size") and media.size is not None:
|
||||
size = media.size
|
||||
elif message.photo and getattr(message.photo, "sizes", None):
|
||||
photo_size = message.photo.sizes[-1]
|
||||
size = getattr(photo_size, "size", 0) or 0
|
||||
else:
|
||||
size = 0
|
||||
|
||||
entries.append({
|
||||
"name": f"{message.id}_{filename}",
|
||||
"is_dir": False,
|
||||
"size": size,
|
||||
"mtime": int(message.date.timestamp()),
|
||||
"type": "file",
|
||||
})
|
||||
finally:
|
||||
if client.is_connected():
|
||||
await client.disconnect()
|
||||
|
||||
# 排序
|
||||
reverse = sort_order.lower() == "desc"
|
||||
def get_sort_key(item):
|
||||
key = (not item["is_dir"],)
|
||||
sort_field = sort_by.lower()
|
||||
if sort_field == "name":
|
||||
key += (item["name"].lower(),)
|
||||
elif sort_field == "size":
|
||||
key += (item["size"],)
|
||||
elif sort_field == "mtime":
|
||||
key += (item["mtime"],)
|
||||
else:
|
||||
key += (item["name"].lower(),)
|
||||
return key
|
||||
entries.sort(key=get_sort_key, reverse=reverse)
|
||||
|
||||
total_count = len(entries)
|
||||
|
||||
# 分页
|
||||
start_idx = (page_num - 1) * page_size
|
||||
end_idx = start_idx + page_size
|
||||
page_entries = entries[start_idx:end_idx]
|
||||
|
||||
return page_entries, total_count
|
||||
|
||||
async def read_file(self, root: str, rel: str) -> bytes:
|
||||
try:
|
||||
message_id_str, _ = rel.split('_', 1)
|
||||
message_id = int(message_id_str)
|
||||
except (ValueError, IndexError):
|
||||
raise FileNotFoundError(f"无效的文件路径格式: {rel}")
|
||||
|
||||
client = self._get_client()
|
||||
try:
|
||||
await client.connect()
|
||||
message = await client.get_messages(self.chat_id, ids=message_id)
|
||||
if not message or not (message.document or message.video or message.photo):
|
||||
raise FileNotFoundError(f"在频道 {self.chat_id} 中未找到消息ID为 {message_id} 的文件")
|
||||
|
||||
file_bytes = await client.download_media(message, file=bytes)
|
||||
return file_bytes
|
||||
finally:
|
||||
if client.is_connected():
|
||||
await client.disconnect()
|
||||
|
||||
async def write_file(self, root: str, rel: str, data: bytes):
|
||||
"""将字节数据作为文件上传"""
|
||||
client = self._get_client()
|
||||
file_like = io.BytesIO(data)
|
||||
file_like.name = os.path.basename(rel) or "file"
|
||||
|
||||
try:
|
||||
await client.connect()
|
||||
sent = await client.send_file(self.chat_id, file_like, caption=file_like.name)
|
||||
message = sent[0] if isinstance(sent, list) and sent else sent
|
||||
actual_rel = rel
|
||||
if message:
|
||||
stored_name = file_like.name
|
||||
file_meta = getattr(message, "file", None)
|
||||
if file_meta and getattr(file_meta, "name", None):
|
||||
stored_name = file_meta.name
|
||||
if getattr(message, "id", None) is not None:
|
||||
actual_rel = f"{message.id}_{stored_name}"
|
||||
return {"rel": actual_rel, "size": len(data)}
|
||||
finally:
|
||||
if client.is_connected():
|
||||
await client.disconnect()
|
||||
|
||||
async def write_upload_file(self, root: str, rel: str, file_obj, filename: str | None, file_size: int | None = None, content_type: str | None = None):
|
||||
client = self._get_client()
|
||||
name = filename or os.path.basename(rel) or "file"
|
||||
file_like = _NamedFile(file_obj, name)
|
||||
|
||||
try:
|
||||
await client.connect()
|
||||
sent = await client.send_file(
|
||||
self.chat_id,
|
||||
file_like,
|
||||
caption=file_like.name,
|
||||
file_size=file_size,
|
||||
mime_type=content_type,
|
||||
)
|
||||
message = sent[0] if isinstance(sent, list) and sent else sent
|
||||
actual_rel = rel
|
||||
size = file_size or 0
|
||||
if message:
|
||||
stored_name = file_like.name
|
||||
file_meta = getattr(message, "file", None)
|
||||
if file_meta and getattr(file_meta, "name", None):
|
||||
stored_name = file_meta.name
|
||||
if getattr(message, "id", None) is not None:
|
||||
actual_rel = f"{message.id}_{stored_name}"
|
||||
if file_meta and getattr(file_meta, "size", None):
|
||||
size = int(file_meta.size)
|
||||
return {"rel": actual_rel, "size": size}
|
||||
finally:
|
||||
if client.is_connected():
|
||||
await client.disconnect()
|
||||
|
||||
async def write_file_stream(self, root: str, rel: str, data_iter: AsyncIterator[bytes]):
|
||||
"""以流式方式上传文件"""
|
||||
client = self._get_client()
|
||||
filename = os.path.basename(rel) or "file"
|
||||
import tempfile
|
||||
suffix = os.path.splitext(filename)[1]
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tf:
|
||||
temp_path = tf.name
|
||||
|
||||
total_size = 0
|
||||
try:
|
||||
with open(temp_path, "wb") as f:
|
||||
async for chunk in data_iter:
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
total_size += len(chunk)
|
||||
|
||||
await client.connect()
|
||||
sent = await client.send_file(self.chat_id, temp_path, caption=filename)
|
||||
message = sent[0] if isinstance(sent, list) and sent else sent
|
||||
actual_rel = rel
|
||||
if message:
|
||||
stored_name = filename
|
||||
file_meta = getattr(message, "file", None)
|
||||
if file_meta and getattr(file_meta, "name", None):
|
||||
stored_name = file_meta.name
|
||||
if getattr(message, "id", None) is not None:
|
||||
actual_rel = f"{message.id}_{stored_name}"
|
||||
|
||||
finally:
|
||||
if os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
if client.is_connected():
|
||||
await client.disconnect()
|
||||
return {"rel": actual_rel, "size": total_size}
|
||||
|
||||
async def mkdir(self, root: str, rel: str):
|
||||
raise NotImplementedError("Telegram 适配器不支持创建目录。")
|
||||
|
||||
async def get_thumbnail(self, root: str, rel: str, size: str = "medium"):
|
||||
try:
|
||||
message_id_str, _ = rel.split('_', 1)
|
||||
message_id = int(message_id_str)
|
||||
except (ValueError, IndexError):
|
||||
return None
|
||||
|
||||
client = self._get_client()
|
||||
try:
|
||||
await client.connect()
|
||||
message = await client.get_messages(self.chat_id, ids=message_id)
|
||||
if not message:
|
||||
return None
|
||||
|
||||
doc = message.document or message.video
|
||||
thumbs = None
|
||||
if doc and getattr(doc, "thumbs", None):
|
||||
thumbs = list(doc.thumbs or [])
|
||||
elif message.photo and getattr(message.photo, "sizes", None):
|
||||
thumbs = list(message.photo.sizes or [])
|
||||
|
||||
thumb = self._pick_photo_thumb(thumbs)
|
||||
if not thumb:
|
||||
return None
|
||||
|
||||
result = await client.download_media(message, bytes, thumb=thumb)
|
||||
if isinstance(result, (bytes, bytearray)):
|
||||
return bytes(result)
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
finally:
|
||||
if client.is_connected():
|
||||
await client.disconnect()
|
||||
|
||||
async def delete(self, root: str, rel: str):
|
||||
"""删除一个文件 (即一条消息)"""
|
||||
try:
|
||||
message_id_str, _ = rel.split('_', 1)
|
||||
message_id = int(message_id_str)
|
||||
except (ValueError, IndexError):
|
||||
raise FileNotFoundError(f"无效的文件路径格式,无法解析消息ID: {rel}")
|
||||
|
||||
client = self._get_client()
|
||||
try:
|
||||
await client.connect()
|
||||
result = await client.delete_messages(self.chat_id, [message_id])
|
||||
if not result or not result[0].pts:
|
||||
raise FileNotFoundError(f"在 {self.chat_id} 中删除消息 {message_id} 失败,可能消息不存在或无权限")
|
||||
finally:
|
||||
if client.is_connected():
|
||||
await client.disconnect()
|
||||
|
||||
async def move(self, root: str, src_rel: str, dst_rel: str):
|
||||
raise NotImplementedError("Telegram 适配器不支持移动。")
|
||||
|
||||
async def rename(self, root: str, src_rel: str, dst_rel: str):
|
||||
raise NotImplementedError("Telegram 适配器不支持重命名。")
|
||||
|
||||
async def copy(self, root: str, src_rel: str, dst_rel: str, overwrite: bool = False):
|
||||
raise NotImplementedError("Telegram 适配器不支持复制。")
|
||||
|
||||
async def stream_file(self, root: str, rel: str, range_header: str | None):
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi import HTTPException
|
||||
|
||||
try:
|
||||
message_id_str, _ = rel.split('_', 1)
|
||||
message_id = int(message_id_str)
|
||||
except (ValueError, IndexError):
|
||||
raise HTTPException(status_code=400, detail=f"无效的文件路径格式: {rel}")
|
||||
|
||||
client = self._get_client()
|
||||
lock = _get_session_lock(self.session_string)
|
||||
await lock.acquire()
|
||||
|
||||
try:
|
||||
await client.connect()
|
||||
message = await client.get_messages(self.chat_id, ids=message_id)
|
||||
media = message.document or message.video or message.photo
|
||||
if not message or not media:
|
||||
raise FileNotFoundError(f"在频道 {self.chat_id} 中未找到消息ID为 {message_id} 的文件")
|
||||
|
||||
file_meta = message.file
|
||||
file_size = file_meta.size if file_meta and file_meta.size is not None else None
|
||||
if file_size is None:
|
||||
if hasattr(media, "size") and media.size is not None:
|
||||
file_size = media.size
|
||||
elif message.photo and getattr(message.photo, "sizes", None):
|
||||
photo_size = message.photo.sizes[-1]
|
||||
file_size = getattr(photo_size, "size", 0) or 0
|
||||
else:
|
||||
file_size = 0
|
||||
|
||||
mime_type = None
|
||||
if file_meta and getattr(file_meta, "mime_type", None):
|
||||
mime_type = file_meta.mime_type
|
||||
if not mime_type:
|
||||
if hasattr(media, "mime_type") and media.mime_type:
|
||||
mime_type = media.mime_type
|
||||
elif message.photo:
|
||||
mime_type = "image/jpeg"
|
||||
else:
|
||||
mime_type = "application/octet-stream"
|
||||
|
||||
start = 0
|
||||
end = file_size - 1
|
||||
status = 200
|
||||
|
||||
headers = {
|
||||
"Accept-Ranges": "bytes",
|
||||
"Content-Type": mime_type,
|
||||
}
|
||||
|
||||
if range_header:
|
||||
try:
|
||||
range_val = range_header.strip().partition("=")[2]
|
||||
s, _, e = range_val.partition("-")
|
||||
start = int(s) if s else 0
|
||||
end = int(e) if e else file_size - 1
|
||||
if start >= file_size or end >= file_size or start > end:
|
||||
raise HTTPException(status_code=416, detail="Requested Range Not Satisfiable")
|
||||
status = 206
|
||||
headers["Content-Range"] = f"bytes {start}-{end}/{file_size}"
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail="Invalid Range header")
|
||||
|
||||
async def iterator():
|
||||
try:
|
||||
limit = end - start + 1
|
||||
downloaded = 0
|
||||
|
||||
async for chunk in client.iter_download(media, offset=start):
|
||||
if downloaded + len(chunk) > limit:
|
||||
yield chunk[:limit - downloaded]
|
||||
break
|
||||
yield chunk
|
||||
downloaded += len(chunk)
|
||||
if downloaded >= limit:
|
||||
break
|
||||
finally:
|
||||
try:
|
||||
if client.is_connected():
|
||||
await client.disconnect()
|
||||
finally:
|
||||
lock.release()
|
||||
|
||||
return StreamingResponse(iterator(), status_code=status, headers=headers)
|
||||
|
||||
except HTTPException:
|
||||
if client.is_connected():
|
||||
await client.disconnect()
|
||||
lock.release()
|
||||
raise
|
||||
except FileNotFoundError as e:
|
||||
if client.is_connected():
|
||||
await client.disconnect()
|
||||
lock.release()
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
if client.is_connected():
|
||||
await client.disconnect()
|
||||
lock.release()
|
||||
raise HTTPException(status_code=500, detail=f"Streaming failed: {str(e)}")
|
||||
|
||||
async def stat_file(self, root: str, rel: str):
|
||||
try:
|
||||
message_id_str, filename = rel.split('_', 1)
|
||||
message_id = int(message_id_str)
|
||||
except (ValueError, IndexError):
|
||||
raise FileNotFoundError(f"无效的文件路径格式: {rel}")
|
||||
|
||||
client = self._get_client()
|
||||
try:
|
||||
await client.connect()
|
||||
message = await client.get_messages(self.chat_id, ids=message_id)
|
||||
media = message.document or message.video or message.photo
|
||||
if not message or not media:
|
||||
raise FileNotFoundError(f"在频道 {self.chat_id} 中未找到消息ID为 {message_id} 的文件")
|
||||
|
||||
file_meta = message.file
|
||||
size = file_meta.size if file_meta and file_meta.size is not None else None
|
||||
if size is None:
|
||||
if hasattr(media, "size") and media.size is not None:
|
||||
size = media.size
|
||||
elif message.photo and getattr(message.photo, "sizes", None):
|
||||
photo_size = message.photo.sizes[-1]
|
||||
size = getattr(photo_size, "size", 0) or 0
|
||||
else:
|
||||
size = 0
|
||||
|
||||
return {
|
||||
"name": rel,
|
||||
"is_dir": False,
|
||||
"size": size,
|
||||
"mtime": int(message.date.timestamp()),
|
||||
"type": "file",
|
||||
}
|
||||
finally:
|
||||
if client.is_connected():
|
||||
await client.disconnect()
|
||||
|
||||
def ADAPTER_FACTORY(rec: StorageAdapter) -> TelegramAdapter:
|
||||
return TelegramAdapter(rec)
|
||||
492
domain/adapters/providers/webdav.py
Normal file
492
domain/adapters/providers/webdav.py
Normal file
@@ -0,0 +1,492 @@
|
||||
from typing import List, Dict, Optional, Tuple, AsyncIterator
|
||||
import httpx
|
||||
from urllib.parse import urljoin, quote
|
||||
from urllib.parse import urlparse, unquote
|
||||
import xml.etree.ElementTree as ET
|
||||
from models import StorageAdapter
|
||||
import mimetypes
|
||||
import logging
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import StreamingResponse, Response
|
||||
|
||||
NS = {"d": "DAV:"}
|
||||
|
||||
|
||||
class WebDAVAdapter:
|
||||
def __init__(self, record: StorageAdapter):
|
||||
self.record = record
|
||||
cfg = record.config
|
||||
self.base_url: str = cfg.get("base_url", "").rstrip('/') + '/'
|
||||
if not self.base_url.startswith("http"):
|
||||
raise ValueError("webdav requires base_url http/https")
|
||||
self.username = cfg.get("username")
|
||||
self.password = cfg.get("password")
|
||||
self.timeout = cfg.get("timeout", 15)
|
||||
|
||||
def get_effective_root(self, sub_path: str | None) -> str:
|
||||
base_url = self.record.config.get("base_url", "").rstrip('/') + '/'
|
||||
if sub_path:
|
||||
return base_url + sub_path.strip('/') + '/'
|
||||
return base_url
|
||||
|
||||
def _client(self):
|
||||
auth = (self.username, self.password) if self.username else None
|
||||
return httpx.AsyncClient(auth=auth, timeout=self.timeout, follow_redirects=True)
|
||||
|
||||
def _build_url(self, rel: str):
|
||||
rel = rel.strip('/')
|
||||
return self.base_url if not rel else urljoin(self.base_url, quote(rel) + ('/' if rel.endswith('/') else ''))
|
||||
|
||||
async def list_dir(self, root: str, rel: str, page_num: int = 1, page_size: int = 50, sort_by: str = "name", sort_order: str = "asc") -> Tuple[List[Dict], int]:
|
||||
raw_url = self._build_url(rel)
|
||||
url = raw_url if raw_url.endswith('/') else raw_url + '/'
|
||||
depth = "1"
|
||||
body = """<?xml version="1.0" encoding="utf-8" ?>
|
||||
<d:propfind xmlns:d="DAV:">
|
||||
<d:prop>
|
||||
<d:displayname />
|
||||
<d:getcontentlength />
|
||||
<d:getlastmodified />
|
||||
<d:resourcetype />
|
||||
</d:prop>
|
||||
</d:propfind>"""
|
||||
async with self._client() as client:
|
||||
resp = await client.request("PROPFIND", url, data=body, headers={"Depth": depth})
|
||||
resp.raise_for_status()
|
||||
xml_text = resp.text
|
||||
root_el = ET.fromstring(xml_text)
|
||||
all_entries: List[Dict] = []
|
||||
parsed_req = urlparse(url)
|
||||
base_path = parsed_req.path
|
||||
if not base_path.endswith('/'):
|
||||
base_path += '/'
|
||||
seen = set()
|
||||
for resp_el in root_el.findall("d:response", NS):
|
||||
href_el = resp_el.find("d:href", NS)
|
||||
if href_el is None:
|
||||
continue
|
||||
href = (href_el.text or "")
|
||||
parsed_href = urlparse(href)
|
||||
href_path = parsed_href.path or ""
|
||||
if not href_path.startswith(base_path):
|
||||
continue
|
||||
rel_path = href_path[len(base_path):].strip('/')
|
||||
if rel_path == "":
|
||||
continue
|
||||
name = unquote(rel_path.split('/')[0]).rstrip('/')
|
||||
if not name or name in seen:
|
||||
continue
|
||||
seen.add(name)
|
||||
propstat = resp_el.find("d:propstat", NS)
|
||||
if propstat is None:
|
||||
continue
|
||||
prop = propstat.find("d:prop", NS)
|
||||
if prop is None:
|
||||
continue
|
||||
size_el = prop.find("d:getcontentlength", NS)
|
||||
lm_el = prop.find("d:getlastmodified", NS)
|
||||
rt_el = prop.find("d:resourcetype", NS)
|
||||
is_dir = rt_el.find(
|
||||
"d:collection", NS) is not None if rt_el is not None else href_path.endswith('/')
|
||||
size = int(
|
||||
size_el.text) if size_el is not None and size_el.text and size_el.text.isdigit() else 0
|
||||
|
||||
from email.utils import parsedate_to_datetime
|
||||
mtime = 0
|
||||
if lm_el is not None and lm_el.text:
|
||||
try:
|
||||
mtime = int(parsedate_to_datetime(lm_el.text).timestamp())
|
||||
except Exception:
|
||||
mtime = 0
|
||||
|
||||
all_entries.append({
|
||||
"name": name,
|
||||
"is_dir": is_dir,
|
||||
"size": 0 if is_dir else size,
|
||||
"mtime": mtime,
|
||||
"type": "dir" if is_dir else "file",
|
||||
})
|
||||
|
||||
# 排序所有条目
|
||||
reverse = sort_order.lower() == "desc"
|
||||
def get_sort_key(item):
|
||||
key = (not item["is_dir"],)
|
||||
sort_field = sort_by.lower()
|
||||
if sort_field == "name":
|
||||
key += (item["name"].lower(),)
|
||||
elif sort_field == "size":
|
||||
key += (item["size"],)
|
||||
elif sort_field == "mtime":
|
||||
key += (item["mtime"],)
|
||||
else:
|
||||
key += (item["name"].lower(),)
|
||||
return key
|
||||
all_entries.sort(key=get_sort_key, reverse=reverse)
|
||||
|
||||
total_count = len(all_entries)
|
||||
|
||||
# 应用分页
|
||||
start_idx = (page_num - 1) * page_size
|
||||
end_idx = start_idx + page_size
|
||||
page_entries = all_entries[start_idx:end_idx]
|
||||
|
||||
return page_entries, total_count
|
||||
|
||||
async def read_file(self, root: str, rel: str) -> bytes:
|
||||
url = self._build_url(rel)
|
||||
async with self._client() as client:
|
||||
resp = await client.get(url)
|
||||
if resp.status_code == 404:
|
||||
raise FileNotFoundError(rel)
|
||||
resp.raise_for_status()
|
||||
return resp.content
|
||||
|
||||
async def write_file(self, root: str, rel: str, data: bytes):
|
||||
url = self._build_url(rel)
|
||||
async with self._client() as client:
|
||||
resp = await client.put(url, content=data)
|
||||
resp.raise_for_status()
|
||||
|
||||
async def mkdir(self, root: str, rel: str):
|
||||
url = self._build_url(rel.rstrip('/') + '/')
|
||||
async with self._client() as client:
|
||||
resp = await client.request("MKCOL", url)
|
||||
if resp.status_code not in (201, 405):
|
||||
resp.raise_for_status()
|
||||
|
||||
async def delete(self, root: str, rel: str):
|
||||
url = self._build_url(rel)
|
||||
async with self._client() as client:
|
||||
resp = await client.delete(url)
|
||||
if resp.status_code not in (204, 200, 404):
|
||||
resp.raise_for_status()
|
||||
|
||||
async def move(self, root: str, src_rel: str, dst_rel: str):
|
||||
src_url = self._build_url(src_rel)
|
||||
dst_url = self._build_url(dst_rel)
|
||||
async with self._client() as client:
|
||||
resp = await client.request("MOVE", src_url, headers={"Destination": dst_url})
|
||||
resp.raise_for_status()
|
||||
|
||||
async def rename(self, root: str, src_rel: str, dst_rel: str):
|
||||
src_url = self._build_url(src_rel)
|
||||
dst_url = self._build_url(dst_rel)
|
||||
async with self._client() as client:
|
||||
resp = await client.request("MOVE", src_url, headers={"Destination": dst_url})
|
||||
resp.raise_for_status()
|
||||
|
||||
async def get_file_size(self, root: str, rel: str) -> int:
|
||||
"""获取文件大小"""
|
||||
url = self._build_url(rel)
|
||||
async with self._client() as client:
|
||||
# 使用HEAD请求获取文件信息
|
||||
resp = await client.head(url)
|
||||
if resp.status_code == 404:
|
||||
raise FileNotFoundError(rel)
|
||||
resp.raise_for_status()
|
||||
|
||||
content_length = resp.headers.get('content-length')
|
||||
if content_length:
|
||||
return int(content_length)
|
||||
|
||||
# 如果HEAD不返回content-length,尝试PROPFIND
|
||||
body = """<?xml version="1.0" encoding="utf-8" ?>
|
||||
<d:propfind xmlns:d="DAV:">
|
||||
<d:prop>
|
||||
<d:getcontentlength />
|
||||
</d:prop>
|
||||
</d:propfind>"""
|
||||
resp = await client.request("PROPFIND", url, data=body, headers={"Depth": "0"})
|
||||
resp.raise_for_status()
|
||||
|
||||
root_el = ET.fromstring(resp.text)
|
||||
for resp_el in root_el.findall("d:response", NS):
|
||||
propstat = resp_el.find("d:propstat", NS)
|
||||
if propstat is None:
|
||||
continue
|
||||
prop = propstat.find("d:prop", NS)
|
||||
if prop is None:
|
||||
continue
|
||||
size_el = prop.find("d:getcontentlength", NS)
|
||||
if size_el is not None and size_el.text and size_el.text.isdigit():
|
||||
return int(size_el.text)
|
||||
|
||||
return 0
|
||||
|
||||
async def read_file_range(self, root: str, rel: str, start: int, end: Optional[int] = None) -> bytes:
|
||||
"""读取文件的指定范围"""
|
||||
url = self._build_url(rel)
|
||||
|
||||
# 构建Range头
|
||||
if end is None:
|
||||
range_header = f"bytes={start}-"
|
||||
else:
|
||||
range_header = f"bytes={start}-{end}"
|
||||
|
||||
async with self._client() as client:
|
||||
resp = await client.get(url, headers={"Range": range_header})
|
||||
if resp.status_code == 404:
|
||||
raise FileNotFoundError(rel)
|
||||
if resp.status_code not in (200, 206): # 206是Partial Content
|
||||
resp.raise_for_status()
|
||||
return resp.content
|
||||
|
||||
async def stream_file(self, root: str, rel: str, range_header: str | None):
|
||||
url = self._build_url(rel)
|
||||
mime, _ = mimetypes.guess_type(rel)
|
||||
content_type = mime or "application/octet-stream"
|
||||
logger = logging.getLogger(__name__)
|
||||
timeout = self.timeout
|
||||
auth = (self.username, self.password) if self.username else None
|
||||
|
||||
client_start = 0
|
||||
client_end = None
|
||||
status_code = 200
|
||||
if range_header and range_header.startswith("bytes="):
|
||||
status_code = 206
|
||||
part = range_header.removeprefix("bytes=")
|
||||
s, e = part.split("-", 1)
|
||||
if s.strip():
|
||||
client_start = int(s)
|
||||
if e.strip():
|
||||
client_end = int(e)
|
||||
|
||||
total_size = None
|
||||
accept_ranges = False
|
||||
async with httpx.AsyncClient(timeout=timeout, auth=auth, follow_redirects=True) as client:
|
||||
try:
|
||||
head_resp = await client.head(url)
|
||||
if head_resp.status_code == 404:
|
||||
raise HTTPException(404, detail="File not found")
|
||||
if head_resp.status_code == 200:
|
||||
cl = head_resp.headers.get("Content-Length")
|
||||
if cl and cl.isdigit():
|
||||
total_size = int(cl)
|
||||
ar = head_resp.headers.get("Accept-Ranges", "").lower()
|
||||
accept_ranges = "bytes" in ar
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.debug("HEAD failed %s err=%s", url, e)
|
||||
if total_size is None and (client_end is None):
|
||||
try:
|
||||
probe_req = client.build_request("GET", url, headers={"Range": "bytes=0-0"})
|
||||
probe_resp = await client.send(probe_req, stream=True)
|
||||
if probe_resp.status_code in (200, 206):
|
||||
cr = probe_resp.headers.get("Content-Range")
|
||||
if cr and "/" in cr:
|
||||
try:
|
||||
total_size = int(cr.rsplit("/", 1)[1])
|
||||
except Exception:
|
||||
pass
|
||||
await probe_resp.aclose()
|
||||
except Exception as e:
|
||||
logger.debug("Probe 0-0 failed %s err=%s", url, e)
|
||||
|
||||
if total_size is not None and client_end is None:
|
||||
client_end = total_size - 1
|
||||
if client_end is not None and client_end < client_start:
|
||||
raise HTTPException(416, detail="Requested Range Not Satisfiable")
|
||||
|
||||
# 若客户端未请求范围且上游不支持 Range,直接透传
|
||||
if status_code == 200 and (range_header is None) and not accept_ranges:
|
||||
async with httpx.AsyncClient(timeout=timeout, auth=auth, follow_redirects=True) as client:
|
||||
req = client.build_request("GET", url)
|
||||
resp = await client.send(req, stream=True)
|
||||
if resp.status_code == 404:
|
||||
await resp.aclose()
|
||||
raise HTTPException(404, detail="File not found")
|
||||
upstream_ct = resp.headers.get("Content-Type", content_type)
|
||||
|
||||
async def passthrough():
|
||||
try:
|
||||
async for chunk in resp.aiter_bytes():
|
||||
if chunk:
|
||||
yield chunk
|
||||
finally:
|
||||
await resp.aclose()
|
||||
return StreamingResponse(passthrough(), status_code=resp.status_code,
|
||||
headers={"Accept-Ranges": "bytes",
|
||||
"X-VFS-Remote-Status": str(resp.status_code)},
|
||||
media_type=upstream_ct)
|
||||
|
||||
SEGMENT_SIZE = 5 * 1024 * 1024
|
||||
MAX_RETRY_PER_SEG = 3
|
||||
FIRST_BYTE_MAX_RETRY = 3
|
||||
|
||||
resp_headers = {
|
||||
"Accept-Ranges": "bytes",
|
||||
"Content-Type": content_type,
|
||||
"X-VFS-Segmented": "1",
|
||||
}
|
||||
if status_code == 206 and total_size is not None:
|
||||
resp_headers["Content-Range"] = f"bytes {client_start}-{client_end}/{total_size}"
|
||||
|
||||
async def segmented_body():
|
||||
current = client_start
|
||||
first_byte_sent = False
|
||||
while True:
|
||||
if client_end is not None and current > client_end:
|
||||
break
|
||||
seg_start = current
|
||||
seg_end = (min(seg_start + SEGMENT_SIZE - 1, client_end)
|
||||
if client_end is not None else seg_start + SEGMENT_SIZE - 1)
|
||||
attempt = 0
|
||||
ok = False
|
||||
while attempt < MAX_RETRY_PER_SEG and not ok:
|
||||
attempt += 1
|
||||
headers_req = {"Range": f"bytes={seg_start}-{seg_end}"}
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=timeout, auth=auth, follow_redirects=True) as cseg:
|
||||
req = cseg.build_request("GET", url, headers=headers_req)
|
||||
rseg = await cseg.send(req, stream=True)
|
||||
if rseg.status_code in (200, 206):
|
||||
async for chunk in rseg.aiter_bytes():
|
||||
if chunk:
|
||||
first_byte_sent = True
|
||||
yield chunk
|
||||
await rseg.aclose()
|
||||
ok = True
|
||||
elif rseg.status_code == 404:
|
||||
await rseg.aclose()
|
||||
if not first_byte_sent:
|
||||
raise HTTPException(404, detail="File not found")
|
||||
return
|
||||
else:
|
||||
await rseg.aclose()
|
||||
logger.warning("Segment unexpected status %s %s-%s %s", rel, seg_start, seg_end, rseg.status_code)
|
||||
if not ok:
|
||||
continue
|
||||
except (httpx.ReadError, httpx.HTTPError, httpx.StreamError) as e:
|
||||
if not first_byte_sent and attempt >= FIRST_BYTE_MAX_RETRY:
|
||||
raise HTTPException(502, detail=f"Upstream error before first byte err={e}")
|
||||
logger.warning("Segment error %s %s-%s attempt=%d err=%s", rel, seg_start, seg_end, attempt, e)
|
||||
except Exception as e:
|
||||
if not first_byte_sent:
|
||||
raise
|
||||
logger.error("Segment unexpected %s %s-%s attempt=%d err=%s", rel, seg_start, seg_end, attempt, e)
|
||||
if not ok:
|
||||
logger.error("Abort streaming %s at %s-%s", rel, seg_start, seg_end)
|
||||
break
|
||||
current = seg_end + 1
|
||||
if client_end is None:
|
||||
continue
|
||||
if current > client_end:
|
||||
break
|
||||
|
||||
return StreamingResponse(segmented_body(), status_code=status_code, headers=resp_headers, media_type=content_type)
|
||||
|
||||
async def stat_file(self, root: str, rel: str):
|
||||
url = self._build_url(rel)
|
||||
async with self._client() as client:
|
||||
# PROPFIND 获取属性
|
||||
body = """<?xml version="1.0" encoding="utf-8" ?>
|
||||
<d:propfind xmlns:d="DAV:">
|
||||
<d:prop>
|
||||
<d:getcontentlength />
|
||||
<d:getlastmodified />
|
||||
<d:resourcetype />
|
||||
</d:prop>
|
||||
</d:propfind>"""
|
||||
resp = await client.request("PROPFIND", url, data=body, headers={"Depth": "0"})
|
||||
if resp.status_code == 404:
|
||||
raise FileNotFoundError(rel)
|
||||
resp.raise_for_status()
|
||||
root_el = ET.fromstring(resp.text)
|
||||
info = {
|
||||
"name": rel.split("/")[-1],
|
||||
"is_dir": False,
|
||||
"size": None,
|
||||
"mtime": None,
|
||||
"type": "file",
|
||||
"path": url,
|
||||
}
|
||||
for resp_el in root_el.findall("d:response", NS):
|
||||
propstat = resp_el.find("d:propstat", NS)
|
||||
if propstat is None:
|
||||
continue
|
||||
prop = propstat.find("d:prop", NS)
|
||||
if prop is None:
|
||||
continue
|
||||
size_el = prop.find("d:getcontentlength", NS)
|
||||
lm_el = prop.find("d:getlastmodified", NS)
|
||||
rt_el = prop.find("d:resourcetype", NS)
|
||||
is_dir = rt_el.find("d:collection", NS) is not None if rt_el is not None else False
|
||||
info["is_dir"] = is_dir
|
||||
info["type"] = "dir" if is_dir else "file"
|
||||
if size_el is not None and size_el.text and size_el.text.isdigit():
|
||||
info["size"] = int(size_el.text)
|
||||
elif info["size"] is None:
|
||||
info["size"] = 0
|
||||
if lm_el is not None and lm_el.text:
|
||||
from email.utils import parsedate_to_datetime
|
||||
try:
|
||||
info["mtime"] = int(parsedate_to_datetime(lm_el.text).timestamp())
|
||||
except Exception:
|
||||
info["mtime"] = 0
|
||||
elif info["mtime"] is None:
|
||||
info["mtime"] = 0
|
||||
# exif信息
|
||||
exif = None
|
||||
if not info["is_dir"]:
|
||||
mime, _ = mimetypes.guess_type(info["name"])
|
||||
if mime and mime.startswith("image/"):
|
||||
try:
|
||||
resp_img = await client.get(url)
|
||||
if resp_img.status_code == 200:
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
img = Image.open(BytesIO(resp_img.content))
|
||||
exif_data = img._getexif()
|
||||
if exif_data:
|
||||
exif = {str(k): str(v) for k, v in exif_data.items()}
|
||||
except Exception:
|
||||
exif = None
|
||||
info["exif"] = exif
|
||||
return info
|
||||
|
||||
async def exists(self, root: str, rel: str) -> bool:
|
||||
url = self._build_url(rel)
|
||||
async with self._client() as client:
|
||||
try:
|
||||
r = await client.head(url)
|
||||
return r.status_code in (200, 204)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def write_file_stream(self, root: str, rel: str, data_iter: AsyncIterator[bytes]):
|
||||
url = self._build_url(rel)
|
||||
async def agen():
|
||||
async for chunk in data_iter:
|
||||
if chunk:
|
||||
yield chunk
|
||||
async with self._client() as client:
|
||||
resp = await client.put(url, content=agen())
|
||||
resp.raise_for_status()
|
||||
return True
|
||||
|
||||
async def copy(self, root: str, src_rel: str, dst_rel: str, overwrite: bool = False):
|
||||
src_url = self._build_url(src_rel)
|
||||
dst_url = self._build_url(dst_rel)
|
||||
headers = {
|
||||
"Destination": dst_url,
|
||||
"Overwrite": "T" if overwrite else "F"
|
||||
}
|
||||
async with self._client() as client:
|
||||
resp = await client.request("COPY", src_url, headers=headers)
|
||||
if resp.status_code == 412:
|
||||
raise FileExistsError(dst_rel)
|
||||
if resp.status_code == 404:
|
||||
raise FileNotFoundError(src_rel)
|
||||
resp.raise_for_status()
|
||||
|
||||
ADAPTER_TYPE = "webdav"
|
||||
CONFIG_SCHEMA = [
|
||||
{"key": "base_url", "label": "基础地址", "type": "string",
|
||||
"required": True, "placeholder": "https://example.com/dav/"},
|
||||
{"key": "username", "label": "用户名", "type": "string", "required": False},
|
||||
{"key": "password", "label": "密码", "type": "password", "required": False},
|
||||
{"key": "timeout",
|
||||
"label": "超时(秒)", "type": "number", "required": False, "default": 15},
|
||||
]
|
||||
def ADAPTER_FACTORY(rec): return WebDAVAdapter(rec)
|
||||
157
domain/adapters/registry.py
Normal file
157
domain/adapters/registry.py
Normal file
@@ -0,0 +1,157 @@
|
||||
import inspect
|
||||
import pkgutil
|
||||
from importlib import import_module
|
||||
from typing import Callable, Dict
|
||||
|
||||
from models import StorageAdapter
|
||||
from .providers.base import BaseAdapter
|
||||
|
||||
AdapterFactory = Callable[[StorageAdapter], BaseAdapter]
|
||||
|
||||
TYPE_MAP: Dict[str, AdapterFactory] = {}
|
||||
CONFIG_SCHEMAS: Dict[str, list] = {}
|
||||
|
||||
|
||||
def normalize_adapter_type(value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
normalized = str(value).strip().lower()
|
||||
return normalized or None
|
||||
|
||||
|
||||
def discover_adapters():
|
||||
"""扫描 domain.adapters.providers 包, 自动注册适配器类型、工厂与配置 schema。"""
|
||||
from . import providers as adapters_pkg
|
||||
|
||||
TYPE_MAP.clear()
|
||||
CONFIG_SCHEMAS.clear()
|
||||
for modinfo in pkgutil.iter_modules(adapters_pkg.__path__):
|
||||
if modinfo.name.startswith("_"):
|
||||
continue
|
||||
full_name = f"{adapters_pkg.__name__}.{modinfo.name}"
|
||||
try:
|
||||
module = import_module(full_name)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
adapter_types = getattr(module, "ADAPTER_TYPES", None)
|
||||
if isinstance(adapter_types, dict):
|
||||
default_schema = getattr(module, "CONFIG_SCHEMA", None)
|
||||
schema_map = getattr(module, "CONFIG_SCHEMA_MAP", None)
|
||||
if not isinstance(schema_map, dict):
|
||||
schema_map = None
|
||||
|
||||
for adapter_type, factory in adapter_types.items():
|
||||
normalized_type = normalize_adapter_type(adapter_type)
|
||||
if not normalized_type:
|
||||
continue
|
||||
if not callable(factory):
|
||||
continue
|
||||
TYPE_MAP[normalized_type] = factory
|
||||
|
||||
schema = schema_map.get(normalized_type) if schema_map else default_schema
|
||||
if isinstance(schema, list):
|
||||
CONFIG_SCHEMAS[normalized_type] = schema
|
||||
continue
|
||||
|
||||
adapter_type = normalize_adapter_type(getattr(module, "ADAPTER_TYPE", None))
|
||||
schema = getattr(module, "CONFIG_SCHEMA", None)
|
||||
factory = getattr(module, "ADAPTER_FACTORY", None)
|
||||
|
||||
if not adapter_type:
|
||||
continue
|
||||
|
||||
if factory is None:
|
||||
for attr in module.__dict__.values():
|
||||
if inspect.isclass(attr) and attr.__name__.endswith("Adapter"):
|
||||
def _mk(cls=attr):
|
||||
return lambda rec: cls(rec)
|
||||
factory = _mk()
|
||||
break
|
||||
if not callable(factory):
|
||||
continue
|
||||
|
||||
TYPE_MAP[adapter_type] = factory
|
||||
if isinstance(schema, list):
|
||||
CONFIG_SCHEMAS[adapter_type] = schema
|
||||
|
||||
|
||||
def get_config_schemas() -> Dict[str, list]:
|
||||
return CONFIG_SCHEMAS
|
||||
|
||||
|
||||
def get_config_schema(adapter_type: str):
|
||||
return CONFIG_SCHEMAS.get(adapter_type)
|
||||
|
||||
|
||||
class RuntimeRegistry:
|
||||
def __init__(self):
|
||||
self._instances: Dict[int, BaseAdapter] = {}
|
||||
|
||||
async def refresh(self):
|
||||
discover_adapters()
|
||||
self._instances.clear()
|
||||
adapters = await StorageAdapter.filter(enabled=True)
|
||||
for rec in adapters:
|
||||
normalized_type = normalize_adapter_type(rec.type)
|
||||
if not normalized_type:
|
||||
continue
|
||||
if normalized_type != rec.type:
|
||||
rec.type = normalized_type
|
||||
try:
|
||||
await rec.save(update_fields=["type"])
|
||||
except Exception:
|
||||
continue
|
||||
factory = TYPE_MAP.get(normalized_type)
|
||||
if not factory:
|
||||
continue
|
||||
try:
|
||||
self._instances[rec.id] = factory(rec)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
def get(self, adapter_id: int) -> BaseAdapter | None:
|
||||
return self._instances.get(adapter_id)
|
||||
|
||||
def snapshot(self) -> Dict[int, BaseAdapter]:
|
||||
return dict(self._instances)
|
||||
|
||||
def remove(self, adapter_id: int):
|
||||
"""从缓存中移除一个适配器实例"""
|
||||
if adapter_id in self._instances:
|
||||
del self._instances[adapter_id]
|
||||
|
||||
async def upsert(self, rec: StorageAdapter):
|
||||
"""新增或更新一个适配器实例"""
|
||||
if not rec.enabled:
|
||||
self.remove(rec.id)
|
||||
return
|
||||
|
||||
normalized_type = normalize_adapter_type(rec.type)
|
||||
if not normalized_type:
|
||||
self.remove(rec.id)
|
||||
return
|
||||
if normalized_type != rec.type:
|
||||
rec.type = normalized_type
|
||||
try:
|
||||
await rec.save(update_fields=["type"])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
factory = TYPE_MAP.get(normalized_type)
|
||||
if not factory:
|
||||
discover_adapters()
|
||||
factory = TYPE_MAP.get(normalized_type)
|
||||
if not factory:
|
||||
return
|
||||
|
||||
try:
|
||||
instance = factory(rec)
|
||||
self._instances[rec.id] = instance
|
||||
except Exception:
|
||||
self.remove(rec.id)
|
||||
pass
|
||||
|
||||
|
||||
runtime_registry = RuntimeRegistry()
|
||||
discover_adapters()
|
||||
116
domain/adapters/service.py
Normal file
116
domain/adapters/service.py
Normal file
@@ -0,0 +1,116 @@
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from domain.auth import User
|
||||
from .registry import (
|
||||
get_config_schemas,
|
||||
normalize_adapter_type,
|
||||
runtime_registry,
|
||||
)
|
||||
from .types import AdapterCreate, AdapterOut
|
||||
from models import StorageAdapter
|
||||
|
||||
|
||||
class AdapterService:
|
||||
@classmethod
|
||||
def _validate_and_normalize_config(cls, adapter_type: str, cfg):
|
||||
schemas = get_config_schemas()
|
||||
adapter_type = normalize_adapter_type(adapter_type)
|
||||
if not adapter_type:
|
||||
raise HTTPException(400, detail="不支持的适配器类型")
|
||||
if not isinstance(cfg, dict):
|
||||
raise HTTPException(400, detail="config 必须是对象")
|
||||
schema = schemas.get(adapter_type)
|
||||
if not schema:
|
||||
raise HTTPException(400, detail=f"不支持的适配器类型: {adapter_type}")
|
||||
out = {}
|
||||
missing = []
|
||||
for f in schema:
|
||||
k = f["key"]
|
||||
if k in cfg and cfg[k] not in (None, ""):
|
||||
out[k] = cfg[k]
|
||||
elif "default" in f:
|
||||
out[k] = f["default"]
|
||||
elif f.get("required"):
|
||||
missing.append(k)
|
||||
if missing:
|
||||
raise HTTPException(400, detail="缺少必填配置字段: " + ", ".join(missing))
|
||||
if adapter_type in ("alist", "openlist"):
|
||||
username = out.get("username")
|
||||
password = out.get("password")
|
||||
if (username and not password) or (password and not username):
|
||||
raise HTTPException(400, detail="用户名和密码必须同时填写或同时留空")
|
||||
return out
|
||||
|
||||
@classmethod
|
||||
async def create_adapter(cls, data: AdapterCreate, current_user: Optional[User]):
|
||||
norm_path = AdapterCreate.normalize_mount_path(data.path)
|
||||
exists = await StorageAdapter.get_or_none(path=norm_path)
|
||||
if exists:
|
||||
raise HTTPException(400, detail="Mount path already exists")
|
||||
|
||||
adapter_fields = {
|
||||
"name": data.name,
|
||||
"type": data.type,
|
||||
"config": cls._validate_and_normalize_config(data.type, data.config or {}),
|
||||
"enabled": data.enabled,
|
||||
"path": norm_path,
|
||||
"sub_path": data.sub_path,
|
||||
}
|
||||
|
||||
rec = await StorageAdapter.create(**adapter_fields)
|
||||
await runtime_registry.upsert(rec)
|
||||
return AdapterOut.model_validate(rec)
|
||||
|
||||
@classmethod
|
||||
async def list_adapters(cls):
|
||||
adapters = await StorageAdapter.all()
|
||||
return [AdapterOut.model_validate(a) for a in adapters]
|
||||
|
||||
@classmethod
|
||||
async def available_adapter_types(cls):
|
||||
data = []
|
||||
for adapter_type, fields in get_config_schemas().items():
|
||||
data.append({
|
||||
"type": adapter_type,
|
||||
"config_schema": fields,
|
||||
})
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
async def get_adapter(cls, adapter_id: int):
|
||||
rec = await StorageAdapter.get_or_none(id=adapter_id)
|
||||
if not rec:
|
||||
raise HTTPException(404, detail="Not found")
|
||||
return AdapterOut.model_validate(rec)
|
||||
|
||||
@classmethod
|
||||
async def update_adapter(cls, adapter_id: int, data: AdapterCreate, current_user: Optional[User]):
|
||||
rec = await StorageAdapter.get_or_none(id=adapter_id)
|
||||
if not rec:
|
||||
raise HTTPException(404, detail="Not found")
|
||||
|
||||
norm_path = AdapterCreate.normalize_mount_path(data.path)
|
||||
existing = await StorageAdapter.get_or_none(path=norm_path)
|
||||
if existing and existing.id != adapter_id:
|
||||
raise HTTPException(400, detail="Mount path already exists")
|
||||
|
||||
rec.name = data.name
|
||||
rec.type = data.type
|
||||
rec.config = cls._validate_and_normalize_config(data.type, data.config or {})
|
||||
rec.enabled = data.enabled
|
||||
rec.path = norm_path
|
||||
rec.sub_path = data.sub_path
|
||||
await rec.save()
|
||||
|
||||
await runtime_registry.upsert(rec)
|
||||
return AdapterOut.model_validate(rec)
|
||||
|
||||
@classmethod
|
||||
async def delete_adapter(cls, adapter_id: int, current_user: Optional[User]):
|
||||
deleted = await StorageAdapter.filter(id=adapter_id).delete()
|
||||
if not deleted:
|
||||
raise HTTPException(404, detail="Not found")
|
||||
runtime_registry.remove(adapter_id)
|
||||
return {"deleted": True}
|
||||
50
domain/adapters/types.py
Normal file
50
domain/adapters/types.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import re
|
||||
from typing import Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
class AdapterBase(BaseModel):
|
||||
name: str
|
||||
type: str = Field(pattern=r"^[a-z0-9_]+$")
|
||||
config: Dict = Field(default_factory=dict)
|
||||
enabled: bool = True
|
||||
path: str = None
|
||||
sub_path: Optional[str] = None
|
||||
|
||||
@field_validator("type", mode="before")
|
||||
@classmethod
|
||||
def _normalize_type(cls, v: str):
|
||||
if not isinstance(v, str):
|
||||
raise ValueError("type required")
|
||||
normalized = v.strip().lower()
|
||||
if not normalized:
|
||||
raise ValueError("type required")
|
||||
if not re.fullmatch(r"[a-z0-9_]+", normalized):
|
||||
raise ValueError("type must be lowercase alphanumeric or underscore")
|
||||
return normalized
|
||||
|
||||
|
||||
class AdapterCreate(AdapterBase):
|
||||
@staticmethod
|
||||
def normalize_mount_path(p: str) -> str:
|
||||
p = p.strip()
|
||||
if not p.startswith('/'):
|
||||
p = '/' + p
|
||||
p = p.rstrip('/')
|
||||
return p or '/'
|
||||
|
||||
@field_validator("path")
|
||||
def _v_mount(cls, v: str):
|
||||
if not v:
|
||||
raise ValueError("mount_path required")
|
||||
return cls.normalize_mount_path(v)
|
||||
|
||||
|
||||
class AdapterOut(AdapterBase):
|
||||
id: int
|
||||
path: str = None
|
||||
sub_path: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
9
domain/agent/__init__.py
Normal file
9
domain/agent/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from .service import AgentService
|
||||
from .types import AgentChatContext, AgentChatRequest, PendingToolCall
|
||||
|
||||
__all__ = [
|
||||
"AgentService",
|
||||
"AgentChatContext",
|
||||
"AgentChatRequest",
|
||||
"PendingToolCall",
|
||||
]
|
||||
38
domain/agent/api.py
Normal file
38
domain/agent/api.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from api.response import success
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.auth import User, get_current_active_user
|
||||
from .service import AgentService
|
||||
from .types import AgentChatRequest
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/agent", tags=["agent"])
|
||||
|
||||
|
||||
@router.post("/chat")
|
||||
@audit(action=AuditAction.CREATE, description="Agent 对话", body_fields=["auto_execute"])
|
||||
async def chat(
|
||||
request: Request,
|
||||
payload: AgentChatRequest,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
data = await AgentService.chat(payload, current_user)
|
||||
return success(data)
|
||||
|
||||
|
||||
@router.post("/chat/stream")
|
||||
@audit(action=AuditAction.CREATE, description="Agent 对话(SSE)", body_fields=["auto_execute"])
|
||||
async def chat_stream(
|
||||
request: Request,
|
||||
payload: AgentChatRequest,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
return StreamingResponse(
|
||||
AgentService.chat_stream(payload, current_user),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache"},
|
||||
)
|
||||
472
domain/agent/service.py
Normal file
472
domain/agent/service.py
Normal file
@@ -0,0 +1,472 @@
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import httpx
|
||||
from fastapi import HTTPException
|
||||
|
||||
from domain.ai import AIProviderService, MissingModelError, chat_completion, chat_completion_stream
|
||||
from domain.auth import User
|
||||
from .tools import get_tool, openai_tools, tool_result_to_content
|
||||
from .types import AgentChatRequest, PendingToolCall
|
||||
|
||||
|
||||
def _normalize_path(p: Optional[str]) -> Optional[str]:
|
||||
if not p:
|
||||
return None
|
||||
s = str(p).strip()
|
||||
if not s:
|
||||
return None
|
||||
s = s.replace("\\", "/")
|
||||
if not s.startswith("/"):
|
||||
s = "/" + s
|
||||
s = s.rstrip("/") or "/"
|
||||
return s
|
||||
|
||||
|
||||
def _build_system_prompt(current_path: Optional[str]) -> str:
|
||||
lines = [
|
||||
"你是 Foxel 的 AI 助手。",
|
||||
"你可以通过工具对文件/目录进行查询、读写、移动、复制、删除,以及运行处理器(processor)。",
|
||||
"",
|
||||
"可用工具:",
|
||||
"- time:获取服务器当前时间(精确到秒,英文星期),支持 year/month/day/hour/minute/second 偏移。",
|
||||
"- web_fetch:抓取网页(HTTP 请求),支持 GET/POST/PUT/PATCH/DELETE/HEAD/OPTIONS,返回状态/标题/正文/链接等。",
|
||||
"- vfs_list_dir:浏览目录(列出 entries + pagination)。",
|
||||
"- vfs_stat:查看文件/目录信息。",
|
||||
"- vfs_read_text:读取文本文件内容(不支持二进制)。",
|
||||
"- vfs_search:搜索文件(vector/filename)。",
|
||||
"- vfs_write_text:写入文本文件内容(覆盖)。",
|
||||
"- vfs_mkdir:创建目录。",
|
||||
"- vfs_delete:删除文件或目录。",
|
||||
"- vfs_move:移动路径。",
|
||||
"- vfs_copy:复制路径。",
|
||||
"- vfs_rename:重命名路径。",
|
||||
"- processors_list:获取可用处理器列表(含 type/name/config_schema/produces_file/supports_directory)。",
|
||||
"- processors_run:运行处理器处理文件或目录(会返回 task_id 或 task_ids)。",
|
||||
"",
|
||||
"规则:",
|
||||
"1) 读操作(web_fetch/vfs_list_dir/vfs_stat/vfs_read_text/vfs_search)可直接调用工具。",
|
||||
"2) 写/改/删操作(vfs_write_text/vfs_mkdir/vfs_delete/vfs_move/vfs_copy/vfs_rename/processors_run)默认需要用户确认;只有在开启自动执行时才应直接执行。",
|
||||
"3) 用户未给出明确路径时先追问;若提供了“当前文件管理目录”,可以基于它把相对描述补全为绝对路径(以 / 开头)。",
|
||||
"4) 修改文件内容:先读取(vfs_read_text)→给出改动点→确认后再写入(vfs_write_text)。",
|
||||
"5) processors_run 返回任务 id 后,说明任务已提交,可在任务队列查看进度。",
|
||||
"6) 回答语言跟随用户;用户用英文则用英文,用户用中文则用中文。回答尽量简洁。",
|
||||
]
|
||||
if current_path:
|
||||
lines.append("")
|
||||
lines.append(f"当前文件管理目录:{current_path}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _ensure_tool_call_ids(message: Dict[str, Any]) -> Dict[str, Any]:
|
||||
tool_calls = message.get("tool_calls")
|
||||
if not isinstance(tool_calls, list):
|
||||
return message
|
||||
|
||||
changed = False
|
||||
for idx, call in enumerate(tool_calls):
|
||||
if not isinstance(call, dict):
|
||||
continue
|
||||
call_id = call.get("id")
|
||||
if isinstance(call_id, str) and call_id.strip():
|
||||
continue
|
||||
call["id"] = f"call_{idx}"
|
||||
changed = True
|
||||
|
||||
if changed:
|
||||
message["tool_calls"] = tool_calls
|
||||
return message
|
||||
|
||||
|
||||
def _extract_pending(tool_call: Dict[str, Any], requires_confirmation: bool) -> PendingToolCall:
|
||||
call_id = str(tool_call.get("id") or "")
|
||||
fn = tool_call.get("function") or {}
|
||||
name = str((fn.get("name") if isinstance(fn, dict) else None) or "")
|
||||
raw_args = fn.get("arguments") if isinstance(fn, dict) else None
|
||||
arguments: Dict[str, Any] = {}
|
||||
if isinstance(raw_args, str) and raw_args.strip():
|
||||
try:
|
||||
parsed = json.loads(raw_args)
|
||||
if isinstance(parsed, dict):
|
||||
arguments = parsed
|
||||
except json.JSONDecodeError:
|
||||
arguments = {}
|
||||
return PendingToolCall(
|
||||
id=call_id,
|
||||
name=name,
|
||||
arguments=arguments,
|
||||
requires_confirmation=requires_confirmation,
|
||||
)
|
||||
|
||||
|
||||
def _find_last_assistant_tool_calls(messages: List[Dict[str, Any]]) -> Tuple[int, Dict[str, Any]]:
|
||||
for idx in range(len(messages) - 1, -1, -1):
|
||||
msg = messages[idx]
|
||||
if not isinstance(msg, dict):
|
||||
continue
|
||||
if msg.get("role") != "assistant":
|
||||
continue
|
||||
tool_calls = msg.get("tool_calls")
|
||||
if isinstance(tool_calls, list) and tool_calls:
|
||||
return idx, msg
|
||||
raise HTTPException(status_code=400, detail="没有可确认的待执行操作")
|
||||
|
||||
|
||||
def _existing_tool_result_ids(messages: List[Dict[str, Any]]) -> set[str]:
|
||||
ids: set[str] = set()
|
||||
for msg in messages:
|
||||
if not isinstance(msg, dict):
|
||||
continue
|
||||
if msg.get("role") != "tool":
|
||||
continue
|
||||
tool_call_id = msg.get("tool_call_id")
|
||||
if isinstance(tool_call_id, str) and tool_call_id.strip():
|
||||
ids.add(tool_call_id)
|
||||
return ids
|
||||
|
||||
|
||||
async def _choose_chat_ability() -> str:
|
||||
tools_model = await AIProviderService.get_default_model("tools")
|
||||
return "tools" if tools_model else "chat"
|
||||
|
||||
|
||||
def _sse(event: str, data: Any) -> bytes:
|
||||
payload = json.dumps(data, ensure_ascii=False, separators=(",", ":"))
|
||||
return f"event: {event}\ndata: {payload}\n\n".encode("utf-8")
|
||||
|
||||
|
||||
def _format_exc(exc: BaseException) -> str:
|
||||
text = str(exc)
|
||||
return text if text else exc.__class__.__name__
|
||||
|
||||
|
||||
class AgentService:
|
||||
@classmethod
|
||||
async def chat(cls, req: AgentChatRequest, user: Optional[User]) -> Dict[str, Any]:
|
||||
history: List[Dict[str, Any]] = list(req.messages or [])
|
||||
current_path = _normalize_path(req.context.current_path if req.context else None)
|
||||
|
||||
system_prompt = _build_system_prompt(current_path)
|
||||
internal_messages: List[Dict[str, Any]] = [{"role": "system", "content": system_prompt}] + history
|
||||
|
||||
new_messages: List[Dict[str, Any]] = []
|
||||
pending: List[PendingToolCall] = []
|
||||
|
||||
approved_ids = {i for i in (req.approved_tool_call_ids or []) if isinstance(i, str) and i.strip()}
|
||||
rejected_ids = {i for i in (req.rejected_tool_call_ids or []) if isinstance(i, str) and i.strip()}
|
||||
|
||||
if approved_ids or rejected_ids:
|
||||
_, last_call_msg = _find_last_assistant_tool_calls(internal_messages)
|
||||
last_call_msg = _ensure_tool_call_ids(last_call_msg)
|
||||
tool_calls = last_call_msg.get("tool_calls") or []
|
||||
call_map: Dict[str, Dict[str, Any]] = {
|
||||
str(c.get("id")): c
|
||||
for c in tool_calls
|
||||
if isinstance(c, dict) and isinstance(c.get("id"), str)
|
||||
}
|
||||
|
||||
existing_ids = _existing_tool_result_ids(internal_messages)
|
||||
for call_id in approved_ids | rejected_ids:
|
||||
if call_id in existing_ids:
|
||||
continue
|
||||
tool_call = call_map.get(call_id)
|
||||
if not tool_call:
|
||||
continue
|
||||
fn = tool_call.get("function") or {}
|
||||
name = fn.get("name") if isinstance(fn, dict) else None
|
||||
args_raw = fn.get("arguments") if isinstance(fn, dict) else None
|
||||
args: Dict[str, Any] = {}
|
||||
if isinstance(args_raw, str) and args_raw.strip():
|
||||
try:
|
||||
parsed = json.loads(args_raw)
|
||||
if isinstance(parsed, dict):
|
||||
args = parsed
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
|
||||
spec = get_tool(str(name or ""))
|
||||
if call_id in rejected_ids:
|
||||
content = tool_result_to_content({"canceled": True, "reason": "user_rejected"})
|
||||
tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content}
|
||||
internal_messages.append(tool_msg)
|
||||
new_messages.append(tool_msg)
|
||||
continue
|
||||
|
||||
if not spec:
|
||||
content = tool_result_to_content({"error": f"unknown_tool: {name}"})
|
||||
tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content}
|
||||
internal_messages.append(tool_msg)
|
||||
new_messages.append(tool_msg)
|
||||
continue
|
||||
|
||||
try:
|
||||
result = await spec.handler(args)
|
||||
content = tool_result_to_content(result)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
content = tool_result_to_content({"error": str(exc)})
|
||||
tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content}
|
||||
internal_messages.append(tool_msg)
|
||||
new_messages.append(tool_msg)
|
||||
|
||||
tools_schema = openai_tools()
|
||||
ability = await _choose_chat_ability()
|
||||
max_loops = 4
|
||||
|
||||
for _ in range(max_loops):
|
||||
try:
|
||||
assistant = await chat_completion(
|
||||
internal_messages,
|
||||
ability=ability,
|
||||
tools=tools_schema,
|
||||
tool_choice="auto",
|
||||
timeout=60.0,
|
||||
)
|
||||
except MissingModelError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
except httpx.HTTPStatusError as exc:
|
||||
raise HTTPException(status_code=502, detail=f"对话请求失败: {exc}") from exc
|
||||
except httpx.RequestError as exc:
|
||||
raise HTTPException(status_code=502, detail=f"对话请求异常: {exc}") from exc
|
||||
|
||||
assistant = _ensure_tool_call_ids(assistant)
|
||||
internal_messages.append(assistant)
|
||||
new_messages.append(assistant)
|
||||
|
||||
tool_calls = assistant.get("tool_calls")
|
||||
if not isinstance(tool_calls, list) or not tool_calls:
|
||||
break
|
||||
|
||||
pending = []
|
||||
for call in tool_calls:
|
||||
if not isinstance(call, dict):
|
||||
continue
|
||||
call_id = str(call.get("id") or "")
|
||||
fn = call.get("function") or {}
|
||||
name = fn.get("name") if isinstance(fn, dict) else None
|
||||
args_raw = fn.get("arguments") if isinstance(fn, dict) else None
|
||||
args: Dict[str, Any] = {}
|
||||
if isinstance(args_raw, str) and args_raw.strip():
|
||||
try:
|
||||
parsed = json.loads(args_raw)
|
||||
if isinstance(parsed, dict):
|
||||
args = parsed
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
|
||||
spec = get_tool(str(name or ""))
|
||||
if not spec:
|
||||
content = tool_result_to_content({"error": f"unknown_tool: {name}"})
|
||||
tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content}
|
||||
internal_messages.append(tool_msg)
|
||||
new_messages.append(tool_msg)
|
||||
continue
|
||||
|
||||
if spec.requires_confirmation and not req.auto_execute:
|
||||
pending.append(_extract_pending(call, True))
|
||||
continue
|
||||
|
||||
try:
|
||||
result = await spec.handler(args)
|
||||
content = tool_result_to_content(result)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
content = tool_result_to_content({"error": str(exc)})
|
||||
tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content}
|
||||
internal_messages.append(tool_msg)
|
||||
new_messages.append(tool_msg)
|
||||
|
||||
if pending:
|
||||
break
|
||||
|
||||
payload: Dict[str, Any] = {"messages": new_messages}
|
||||
if pending:
|
||||
payload["pending_tool_calls"] = [p.model_dump() for p in pending]
|
||||
return payload
|
||||
|
||||
@classmethod
|
||||
async def chat_stream(cls, req: AgentChatRequest, user: Optional[User]):
|
||||
history: List[Dict[str, Any]] = list(req.messages or [])
|
||||
current_path = _normalize_path(req.context.current_path if req.context else None)
|
||||
|
||||
system_prompt = _build_system_prompt(current_path)
|
||||
internal_messages: List[Dict[str, Any]] = [{"role": "system", "content": system_prompt}] + history
|
||||
|
||||
new_messages: List[Dict[str, Any]] = []
|
||||
pending: List[PendingToolCall] = []
|
||||
|
||||
approved_ids = {i for i in (req.approved_tool_call_ids or []) if isinstance(i, str) and i.strip()}
|
||||
rejected_ids = {i for i in (req.rejected_tool_call_ids or []) if isinstance(i, str) and i.strip()}
|
||||
|
||||
try:
|
||||
if approved_ids or rejected_ids:
|
||||
_, last_call_msg = _find_last_assistant_tool_calls(internal_messages)
|
||||
last_call_msg = _ensure_tool_call_ids(last_call_msg)
|
||||
tool_calls = last_call_msg.get("tool_calls") or []
|
||||
call_map: Dict[str, Dict[str, Any]] = {
|
||||
str(c.get("id")): c
|
||||
for c in tool_calls
|
||||
if isinstance(c, dict) and isinstance(c.get("id"), str)
|
||||
}
|
||||
|
||||
existing_ids = _existing_tool_result_ids(internal_messages)
|
||||
for call_id in approved_ids | rejected_ids:
|
||||
if call_id in existing_ids:
|
||||
continue
|
||||
tool_call = call_map.get(call_id)
|
||||
if not tool_call:
|
||||
continue
|
||||
fn = tool_call.get("function") or {}
|
||||
name = fn.get("name") if isinstance(fn, dict) else None
|
||||
args_raw = fn.get("arguments") if isinstance(fn, dict) else None
|
||||
args: Dict[str, Any] = {}
|
||||
if isinstance(args_raw, str) and args_raw.strip():
|
||||
try:
|
||||
parsed = json.loads(args_raw)
|
||||
if isinstance(parsed, dict):
|
||||
args = parsed
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
|
||||
spec = get_tool(str(name or ""))
|
||||
if call_id in rejected_ids:
|
||||
content = tool_result_to_content({"canceled": True, "reason": "user_rejected"})
|
||||
tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content}
|
||||
internal_messages.append(tool_msg)
|
||||
new_messages.append(tool_msg)
|
||||
yield _sse("tool_end", {"tool_call_id": call_id, "name": str(name or ""), "message": tool_msg})
|
||||
continue
|
||||
|
||||
if not spec:
|
||||
content = tool_result_to_content({"error": f"unknown_tool: {name}"})
|
||||
tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content}
|
||||
internal_messages.append(tool_msg)
|
||||
new_messages.append(tool_msg)
|
||||
yield _sse("tool_end", {"tool_call_id": call_id, "name": str(name or ""), "message": tool_msg})
|
||||
continue
|
||||
|
||||
yield _sse("tool_start", {"tool_call_id": call_id, "name": spec.name})
|
||||
try:
|
||||
result = await spec.handler(args)
|
||||
content = tool_result_to_content(result)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
content = tool_result_to_content({"error": str(exc)})
|
||||
tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content}
|
||||
internal_messages.append(tool_msg)
|
||||
new_messages.append(tool_msg)
|
||||
yield _sse("tool_end", {"tool_call_id": call_id, "name": spec.name, "message": tool_msg})
|
||||
|
||||
tools_schema = openai_tools()
|
||||
ability = await _choose_chat_ability()
|
||||
max_loops = 4
|
||||
|
||||
for _ in range(max_loops):
|
||||
assistant_event_id = uuid.uuid4().hex
|
||||
yield _sse("assistant_start", {"id": assistant_event_id})
|
||||
|
||||
assistant_message: Dict[str, Any] | None = None
|
||||
try:
|
||||
async for event in chat_completion_stream(
|
||||
internal_messages,
|
||||
ability=ability,
|
||||
tools=tools_schema,
|
||||
tool_choice="auto",
|
||||
timeout=60.0,
|
||||
):
|
||||
if event.get("type") == "delta":
|
||||
delta = event.get("delta")
|
||||
if isinstance(delta, str) and delta:
|
||||
yield _sse("assistant_delta", {"id": assistant_event_id, "delta": delta})
|
||||
elif event.get("type") == "message":
|
||||
msg = event.get("message")
|
||||
if isinstance(msg, dict):
|
||||
assistant_message = msg
|
||||
except MissingModelError as exc:
|
||||
raise HTTPException(status_code=400, detail=_format_exc(exc)) from exc
|
||||
except httpx.HTTPStatusError as exc:
|
||||
raise HTTPException(status_code=502, detail=f"对话请求失败: {_format_exc(exc)}") from exc
|
||||
except httpx.RequestError as exc:
|
||||
raise HTTPException(status_code=502, detail=f"对话请求异常: {_format_exc(exc)}") from exc
|
||||
|
||||
if not assistant_message:
|
||||
assistant_message = {"role": "assistant", "content": ""}
|
||||
|
||||
assistant_message = _ensure_tool_call_ids(assistant_message)
|
||||
internal_messages.append(assistant_message)
|
||||
new_messages.append(assistant_message)
|
||||
yield _sse("assistant_end", {"id": assistant_event_id, "message": assistant_message})
|
||||
|
||||
tool_calls = assistant_message.get("tool_calls")
|
||||
if not isinstance(tool_calls, list) or not tool_calls:
|
||||
break
|
||||
|
||||
pending = []
|
||||
for call in tool_calls:
|
||||
if not isinstance(call, dict):
|
||||
continue
|
||||
call_id = str(call.get("id") or "")
|
||||
fn = call.get("function") or {}
|
||||
name = fn.get("name") if isinstance(fn, dict) else None
|
||||
args_raw = fn.get("arguments") if isinstance(fn, dict) else None
|
||||
args: Dict[str, Any] = {}
|
||||
if isinstance(args_raw, str) and args_raw.strip():
|
||||
try:
|
||||
parsed = json.loads(args_raw)
|
||||
if isinstance(parsed, dict):
|
||||
args = parsed
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
|
||||
spec = get_tool(str(name or ""))
|
||||
if not spec:
|
||||
content = tool_result_to_content({"error": f"unknown_tool: {name}"})
|
||||
tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content}
|
||||
internal_messages.append(tool_msg)
|
||||
new_messages.append(tool_msg)
|
||||
yield _sse("tool_end", {"tool_call_id": call_id, "name": str(name or ""), "message": tool_msg})
|
||||
continue
|
||||
|
||||
if spec.requires_confirmation and not req.auto_execute:
|
||||
pending.append(_extract_pending(call, True))
|
||||
continue
|
||||
|
||||
yield _sse("tool_start", {"tool_call_id": call_id, "name": spec.name})
|
||||
try:
|
||||
result = await spec.handler(args)
|
||||
content = tool_result_to_content(result)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
content = tool_result_to_content({"error": str(exc)})
|
||||
tool_msg = {"role": "tool", "tool_call_id": call_id, "content": content}
|
||||
internal_messages.append(tool_msg)
|
||||
new_messages.append(tool_msg)
|
||||
yield _sse("tool_end", {"tool_call_id": call_id, "name": spec.name, "message": tool_msg})
|
||||
|
||||
if pending:
|
||||
yield _sse("pending", {"pending_tool_calls": [p.model_dump() for p in pending]})
|
||||
break
|
||||
|
||||
payload: Dict[str, Any] = {"messages": new_messages}
|
||||
if pending:
|
||||
payload["pending_tool_calls"] = [p.model_dump() for p in pending]
|
||||
yield _sse("done", payload)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except HTTPException as exc:
|
||||
detail = exc.detail
|
||||
content = detail if isinstance(detail, str) else str(detail)
|
||||
if not content.strip():
|
||||
content = f"请求失败({exc.status_code})"
|
||||
new_messages.append({"role": "assistant", "content": content})
|
||||
payload: Dict[str, Any] = {"messages": new_messages}
|
||||
if pending:
|
||||
payload["pending_tool_calls"] = [p.model_dump() for p in pending]
|
||||
yield _sse("done", payload)
|
||||
return
|
||||
except Exception as exc: # noqa: BLE001
|
||||
new_messages.append({"role": "assistant", "content": f"服务端异常: {_format_exc(exc)}"})
|
||||
payload: Dict[str, Any] = {"messages": new_messages}
|
||||
if pending:
|
||||
payload["pending_tool_calls"] = [p.model_dump() for p in pending]
|
||||
yield _sse("done", payload)
|
||||
return
|
||||
37
domain/agent/tools/__init__.py
Normal file
37
domain/agent/tools/__init__.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from .base import ToolSpec, tool_result_to_content
|
||||
from .processors import TOOLS as PROCESSOR_TOOLS
|
||||
from .time import TOOLS as TIME_TOOLS
|
||||
from .vfs import TOOLS as VFS_TOOLS
|
||||
from .web_fetch import TOOLS as WEB_FETCH_TOOLS
|
||||
|
||||
TOOLS: Dict[str, ToolSpec] = {}
|
||||
for group in (TIME_TOOLS, WEB_FETCH_TOOLS, PROCESSOR_TOOLS, VFS_TOOLS):
|
||||
TOOLS.update(group)
|
||||
|
||||
|
||||
def get_tool(name: str) -> Optional[ToolSpec]:
|
||||
return TOOLS.get(name)
|
||||
|
||||
|
||||
def openai_tools() -> List[Dict[str, Any]]:
|
||||
out: List[Dict[str, Any]] = []
|
||||
for spec in TOOLS.values():
|
||||
out.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": spec.name,
|
||||
"description": spec.description,
|
||||
"parameters": spec.parameters,
|
||||
},
|
||||
})
|
||||
return out
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ToolSpec",
|
||||
"get_tool",
|
||||
"openai_tools",
|
||||
"tool_result_to_content",
|
||||
]
|
||||
149
domain/agent/tools/base.py
Normal file
149
domain/agent/tools/base.py
Normal file
@@ -0,0 +1,149 @@
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ToolSpec:
|
||||
name: str
|
||||
description: str
|
||||
parameters: Dict[str, Any]
|
||||
requires_confirmation: bool
|
||||
handler: Callable[[Dict[str, Any]], Awaitable[Any]]
|
||||
|
||||
|
||||
def _stringify_value(value: Any) -> str:
|
||||
if value is None:
|
||||
return ""
|
||||
if isinstance(value, bool):
|
||||
return "true" if value else "false"
|
||||
if isinstance(value, (int, float)):
|
||||
return str(value)
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
try:
|
||||
return json.dumps(value, ensure_ascii=False)
|
||||
except TypeError:
|
||||
return str(value)
|
||||
|
||||
|
||||
def _list_to_view_items(items: List[Any]) -> List[Any]:
|
||||
normalized: List[Any] = []
|
||||
for item in items:
|
||||
if isinstance(item, dict):
|
||||
normalized.append({str(k): _stringify_value(v) for k, v in item.items()})
|
||||
else:
|
||||
normalized.append(_stringify_value(item))
|
||||
return normalized
|
||||
|
||||
|
||||
def _dict_to_kv_items(data: Dict[str, Any]) -> List[Dict[str, str]]:
|
||||
return [{"key": str(k), "value": _stringify_value(v)} for k, v in data.items()]
|
||||
|
||||
|
||||
def _first_list_field(data: Dict[str, Any]) -> tuple[Optional[str], Optional[List[Any]]]:
|
||||
for key, value in data.items():
|
||||
if isinstance(value, list):
|
||||
return str(key), value
|
||||
return None, None
|
||||
|
||||
|
||||
def _build_view(data: Any) -> Dict[str, Any]:
|
||||
if data is None:
|
||||
return {"type": "kv", "items": []}
|
||||
if isinstance(data, str):
|
||||
return {"type": "text", "text": data}
|
||||
if isinstance(data, list):
|
||||
return {"type": "list", "items": _list_to_view_items(data)}
|
||||
if isinstance(data, dict):
|
||||
content = data.get("content")
|
||||
if isinstance(content, str):
|
||||
meta = {k: _stringify_value(v) for k, v in data.items() if k != "content"}
|
||||
view: Dict[str, Any] = {"type": "text", "text": content}
|
||||
if meta:
|
||||
view["meta"] = meta
|
||||
return view
|
||||
list_key, list_val = _first_list_field(data)
|
||||
if list_key and isinstance(list_val, list):
|
||||
meta = {k: _stringify_value(v) for k, v in data.items() if k != list_key}
|
||||
view = {"type": "list", "title": list_key, "items": _list_to_view_items(list_val)}
|
||||
if meta:
|
||||
view["meta"] = meta
|
||||
return view
|
||||
return {"type": "kv", "items": _dict_to_kv_items(data)}
|
||||
return {"type": "text", "text": _stringify_value(data)}
|
||||
|
||||
|
||||
def _build_summary(view: Dict[str, Any]) -> str:
|
||||
view_type = str(view.get("type") or "")
|
||||
if view_type == "text":
|
||||
text = view.get("text")
|
||||
size = len(text) if isinstance(text, str) else 0
|
||||
return f"chars: {size}" if size else "text"
|
||||
if view_type == "list":
|
||||
items = view.get("items")
|
||||
count = len(items) if isinstance(items, list) else 0
|
||||
title = str(view.get("title") or "items")
|
||||
return f"{title}: {count}"
|
||||
if view_type == "kv":
|
||||
items = view.get("items")
|
||||
count = len(items) if isinstance(items, list) else 0
|
||||
return f"fields: {count}"
|
||||
if view_type == "error":
|
||||
return str(view.get("message") or "error")
|
||||
return ""
|
||||
|
||||
|
||||
def _build_error_payload(code: str, message: str, detail: Any = None) -> Dict[str, Any]:
|
||||
summary = "Canceled" if code == "canceled" else message or "error"
|
||||
view = {"type": "error", "message": summary}
|
||||
payload: Dict[str, Any] = {
|
||||
"ok": False,
|
||||
"summary": summary,
|
||||
"view": view,
|
||||
"error": {
|
||||
"code": code,
|
||||
"message": message,
|
||||
},
|
||||
}
|
||||
if detail is not None:
|
||||
payload["error"]["detail"] = detail
|
||||
return payload
|
||||
|
||||
|
||||
def _normalize_tool_result(result: Any) -> Dict[str, Any]:
|
||||
if isinstance(result, dict) and "ok" in result:
|
||||
payload = dict(result)
|
||||
if payload.get("ok") is False:
|
||||
error = payload.get("error")
|
||||
message = _stringify_value(error.get("message") if isinstance(error, dict) else error)
|
||||
payload.setdefault("summary", message or "error")
|
||||
payload.setdefault("view", {"type": "error", "message": payload["summary"]})
|
||||
return payload
|
||||
data = payload.get("data")
|
||||
if payload.get("view") is None:
|
||||
payload["view"] = _build_view(data)
|
||||
if not payload.get("summary"):
|
||||
payload["summary"] = _build_summary(payload["view"])
|
||||
return payload
|
||||
|
||||
if isinstance(result, dict) and result.get("canceled"):
|
||||
reason = _stringify_value(result.get("reason") or "canceled")
|
||||
return _build_error_payload("canceled", reason, detail=result)
|
||||
|
||||
if isinstance(result, dict) and "error" in result:
|
||||
error = result.get("error")
|
||||
message = _stringify_value(error.get("message") if isinstance(error, dict) else error)
|
||||
return _build_error_payload("error", message, detail=error)
|
||||
|
||||
view = _build_view(result)
|
||||
summary = _build_summary(view)
|
||||
return {"ok": True, "summary": summary, "view": view, "data": result}
|
||||
|
||||
|
||||
def tool_result_to_content(result: Any) -> str:
|
||||
payload = _normalize_tool_result(result)
|
||||
try:
|
||||
return json.dumps(payload, ensure_ascii=False, default=str)
|
||||
except TypeError:
|
||||
return json.dumps({"ok": False, "summary": "error", "view": {"type": "error", "message": "error"}}, ensure_ascii=False)
|
||||
96
domain/agent/tools/processors.py
Normal file
96
domain/agent/tools/processors.py
Normal file
@@ -0,0 +1,96 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from domain.processors import ProcessDirectoryRequest, ProcessRequest, ProcessorService
|
||||
from domain.virtual_fs import VirtualFSService
|
||||
|
||||
from .base import ToolSpec
|
||||
|
||||
|
||||
async def _processors_list(_: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return {"processors": ProcessorService.list_processors()}
|
||||
|
||||
|
||||
async def _processors_run(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
path = str(args.get("path") or "")
|
||||
processor_type = str(args.get("processor_type") or "")
|
||||
config = args.get("config")
|
||||
if not isinstance(config, dict):
|
||||
config = {}
|
||||
|
||||
save_to = args.get("save_to")
|
||||
save_to = str(save_to) if isinstance(save_to, str) and save_to.strip() else None
|
||||
|
||||
max_depth = args.get("max_depth")
|
||||
max_depth_value: Optional[int] = None
|
||||
if max_depth is not None:
|
||||
try:
|
||||
max_depth_value = int(max_depth)
|
||||
except (TypeError, ValueError):
|
||||
max_depth_value = None
|
||||
|
||||
suffix = args.get("suffix")
|
||||
suffix_value = str(suffix) if isinstance(suffix, str) and suffix.strip() else None
|
||||
|
||||
overwrite_value = args.get("overwrite")
|
||||
overwrite = bool(overwrite_value) if overwrite_value is not None else None
|
||||
|
||||
is_dir = await VirtualFSService.path_is_directory(path)
|
||||
if is_dir and (max_depth_value is not None or suffix_value is not None):
|
||||
req = ProcessDirectoryRequest(
|
||||
path=path,
|
||||
processor_type=processor_type,
|
||||
config=config,
|
||||
overwrite=True if overwrite is None else overwrite,
|
||||
max_depth=max_depth_value,
|
||||
suffix=suffix_value,
|
||||
)
|
||||
result = await ProcessorService.process_directory(req)
|
||||
return {"mode": "directory", **result}
|
||||
|
||||
req = ProcessRequest(
|
||||
path=path,
|
||||
processor_type=processor_type,
|
||||
config=config,
|
||||
save_to=save_to,
|
||||
overwrite=False if overwrite is None else overwrite,
|
||||
)
|
||||
result = await ProcessorService.process_file(req)
|
||||
return {"mode": "file", **result}
|
||||
|
||||
|
||||
TOOLS: Dict[str, ToolSpec] = {
|
||||
"processors_list": ToolSpec(
|
||||
name="processors_list",
|
||||
description="获取可用处理器列表(type/name/config_schema 等)。",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"additionalProperties": False,
|
||||
},
|
||||
requires_confirmation=False,
|
||||
handler=_processors_list,
|
||||
),
|
||||
"processors_run": ToolSpec(
|
||||
name="processors_run",
|
||||
description=(
|
||||
"运行处理器处理文件或目录。"
|
||||
" 对目录可选 max_depth/suffix;对文件可选 overwrite/save_to。"
|
||||
" 返回任务 id(去任务队列查看进度)。"
|
||||
),
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "文件或目录路径(绝对路径,如 /foo/bar)"},
|
||||
"processor_type": {"type": "string", "description": "处理器类型(例如 image_watermark)"},
|
||||
"config": {"type": "object", "description": "处理器配置,按 processors_list 返回的 config_schema 填写"},
|
||||
"overwrite": {"type": "boolean", "description": "是否覆盖原文件/目录内文件"},
|
||||
"save_to": {"type": "string", "description": "保存到指定路径(仅文件模式,且 overwrite=false 时使用)"},
|
||||
"max_depth": {"type": "integer", "description": "目录遍历深度(仅目录模式)"},
|
||||
"suffix": {"type": "string", "description": "目录批处理时的输出后缀(仅 produces_file 且 overwrite=false)"},
|
||||
},
|
||||
"required": ["path", "processor_type"],
|
||||
},
|
||||
requires_confirmation=True,
|
||||
handler=_processors_run,
|
||||
),
|
||||
}
|
||||
92
domain/agent/tools/time.py
Normal file
92
domain/agent/tools/time.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import calendar
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict
|
||||
|
||||
from .base import ToolSpec
|
||||
|
||||
|
||||
def _parse_offset(args: Dict[str, Any], key: str) -> int:
|
||||
value = args.get(key)
|
||||
if value is None:
|
||||
return 0
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return 0
|
||||
|
||||
|
||||
def _add_months(dt: datetime, months: int) -> datetime:
|
||||
if months == 0:
|
||||
return dt
|
||||
total = dt.year * 12 + (dt.month - 1) + months
|
||||
year = total // 12
|
||||
month = total % 12 + 1
|
||||
last_day = calendar.monthrange(year, month)[1]
|
||||
day = min(dt.day, last_day)
|
||||
return dt.replace(year=year, month=month, day=day)
|
||||
|
||||
|
||||
async def _time(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
now = datetime.now()
|
||||
year_offset = _parse_offset(args, "year")
|
||||
month_offset = _parse_offset(args, "month")
|
||||
day_offset = _parse_offset(args, "day")
|
||||
hour_offset = _parse_offset(args, "hour")
|
||||
minute_offset = _parse_offset(args, "minute")
|
||||
second_offset = _parse_offset(args, "second")
|
||||
|
||||
dt = _add_months(now, year_offset * 12 + month_offset)
|
||||
dt = dt + timedelta(days=day_offset, hours=hour_offset, minutes=minute_offset, seconds=second_offset)
|
||||
|
||||
weekday_names = [
|
||||
"Monday",
|
||||
"Tuesday",
|
||||
"Wednesday",
|
||||
"Thursday",
|
||||
"Friday",
|
||||
"Saturday",
|
||||
"Sunday",
|
||||
]
|
||||
weekday = weekday_names[dt.weekday()]
|
||||
dt_str = dt.strftime("%Y-%m-%d %H:%M:%S")
|
||||
return {
|
||||
"ok": True,
|
||||
"summary": f"{dt_str} · {weekday}",
|
||||
"data": {
|
||||
"datetime": dt_str,
|
||||
"weekday": weekday,
|
||||
"offset": {
|
||||
"year": year_offset,
|
||||
"month": month_offset,
|
||||
"day": day_offset,
|
||||
"hour": hour_offset,
|
||||
"minute": minute_offset,
|
||||
"second": second_offset,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
TOOLS: Dict[str, ToolSpec] = {
|
||||
"time": ToolSpec(
|
||||
name="time",
|
||||
description=(
|
||||
"获取服务器当前时间(精确到秒,含英文星期)。"
|
||||
" 支持 year/month/day/hour/minute/second 偏移(可为负数)。"
|
||||
),
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"year": {"type": "integer", "description": "年偏移(可为负数)"},
|
||||
"month": {"type": "integer", "description": "月偏移(可为负数)"},
|
||||
"day": {"type": "integer", "description": "日偏移(可为负数)"},
|
||||
"hour": {"type": "integer", "description": "时偏移(可为负数)"},
|
||||
"minute": {"type": "integer", "description": "分偏移(可为负数)"},
|
||||
"second": {"type": "integer", "description": "秒偏移(可为负数)"},
|
||||
},
|
||||
"additionalProperties": False,
|
||||
},
|
||||
requires_confirmation=False,
|
||||
handler=_time,
|
||||
),
|
||||
}
|
||||
287
domain/agent/tools/vfs.py
Normal file
287
domain/agent/tools/vfs.py
Normal file
@@ -0,0 +1,287 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from domain.virtual_fs import VirtualFSService
|
||||
from domain.virtual_fs.search import VirtualFSSearchService
|
||||
|
||||
from .base import ToolSpec
|
||||
|
||||
|
||||
def _normalize_vfs_path(value: Any) -> str:
|
||||
s = str(value or "").strip().replace("\\", "/")
|
||||
if not s:
|
||||
return ""
|
||||
if not s.startswith("/"):
|
||||
s = "/" + s
|
||||
s = s.rstrip("/") or "/"
|
||||
return s
|
||||
|
||||
|
||||
def _require_vfs_path(value: Any, field: str) -> str:
|
||||
path = _normalize_vfs_path(value)
|
||||
if not path:
|
||||
raise ValueError(f"missing_{field}")
|
||||
return path
|
||||
|
||||
|
||||
async def _vfs_list_dir(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
path = _normalize_vfs_path(args.get("path") or "/") or "/"
|
||||
page = int(args.get("page") or 1)
|
||||
page_size = int(args.get("page_size") or 50)
|
||||
sort_by = str(args.get("sort_by") or "name")
|
||||
sort_order = str(args.get("sort_order") or "asc")
|
||||
return await VirtualFSService.list_directory(path, page, page_size, sort_by, sort_order)
|
||||
|
||||
|
||||
async def _vfs_stat(args: Dict[str, Any]) -> Any:
|
||||
path = _require_vfs_path(args.get("path"), "path")
|
||||
return await VirtualFSService.stat(path)
|
||||
|
||||
|
||||
async def _vfs_read_text(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
path = _require_vfs_path(args.get("path"), "path")
|
||||
encoding = str(args.get("encoding") or "utf-8")
|
||||
max_chars = int(args.get("max_chars") or 8000)
|
||||
|
||||
data = await VirtualFSService.read_file(path)
|
||||
if isinstance(data, (bytes, bytearray)):
|
||||
try:
|
||||
text = bytes(data).decode(encoding)
|
||||
except UnicodeDecodeError:
|
||||
return {"error": "binary_or_invalid_text", "path": path}
|
||||
elif isinstance(data, str):
|
||||
text = data
|
||||
else:
|
||||
text = str(data)
|
||||
|
||||
original_len = len(text)
|
||||
truncated = original_len > max_chars
|
||||
if truncated:
|
||||
text = text[:max_chars]
|
||||
return {
|
||||
"path": path,
|
||||
"encoding": encoding,
|
||||
"content": text,
|
||||
"truncated": truncated,
|
||||
"length": original_len,
|
||||
}
|
||||
|
||||
|
||||
async def _vfs_write_text(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
path = _require_vfs_path(args.get("path"), "path")
|
||||
if path == "/":
|
||||
raise ValueError("invalid_path")
|
||||
encoding = str(args.get("encoding") or "utf-8")
|
||||
content = str(args.get("content") or "")
|
||||
data = content.encode(encoding)
|
||||
await VirtualFSService.write_file(path, data)
|
||||
return {"written": True, "path": path, "encoding": encoding, "bytes": len(data)}
|
||||
|
||||
|
||||
async def _vfs_mkdir(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
path = _require_vfs_path(args.get("path"), "path")
|
||||
return await VirtualFSService.mkdir(path)
|
||||
|
||||
|
||||
async def _vfs_delete(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
path = _require_vfs_path(args.get("path"), "path")
|
||||
return await VirtualFSService.delete(path)
|
||||
|
||||
|
||||
async def _vfs_move(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
src = _require_vfs_path(args.get("src"), "src")
|
||||
dst = _require_vfs_path(args.get("dst"), "dst")
|
||||
if src == "/" or dst == "/":
|
||||
raise ValueError("invalid_path")
|
||||
overwrite = bool(args.get("overwrite") or False)
|
||||
return await VirtualFSService.move(src, dst, overwrite)
|
||||
|
||||
|
||||
async def _vfs_copy(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
src = _require_vfs_path(args.get("src"), "src")
|
||||
dst = _require_vfs_path(args.get("dst"), "dst")
|
||||
if src == "/" or dst == "/":
|
||||
raise ValueError("invalid_path")
|
||||
overwrite = bool(args.get("overwrite") or False)
|
||||
return await VirtualFSService.copy(src, dst, overwrite)
|
||||
|
||||
|
||||
async def _vfs_rename(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
src = _require_vfs_path(args.get("src"), "src")
|
||||
dst = _require_vfs_path(args.get("dst"), "dst")
|
||||
if src == "/" or dst == "/":
|
||||
raise ValueError("invalid_path")
|
||||
overwrite = bool(args.get("overwrite") or False)
|
||||
return await VirtualFSService.rename(src, dst, overwrite)
|
||||
|
||||
|
||||
async def _vfs_search(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
q = str(args.get("q") or "").strip()
|
||||
if not q:
|
||||
raise ValueError("missing_q")
|
||||
mode = str(args.get("mode") or "vector")
|
||||
top_k = int(args.get("top_k") or 10)
|
||||
page = int(args.get("page") or 1)
|
||||
page_size = int(args.get("page_size") or 10)
|
||||
return await VirtualFSSearchService.search(q, top_k, mode, page, page_size)
|
||||
|
||||
|
||||
TOOLS: Dict[str, ToolSpec] = {
|
||||
"vfs_list_dir": ToolSpec(
|
||||
name="vfs_list_dir",
|
||||
description="浏览目录(列出 entries + pagination)。",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "目录路径(绝对路径,如 /foo/bar)"},
|
||||
"page": {"type": "integer", "description": "页码(从 1 开始)"},
|
||||
"page_size": {"type": "integer", "description": "每页条数"},
|
||||
"sort_by": {"type": "string", "description": "排序字段:name/size/mtime"},
|
||||
"sort_order": {"type": "string", "description": "排序顺序:asc/desc"},
|
||||
},
|
||||
"required": ["path"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
requires_confirmation=False,
|
||||
handler=_vfs_list_dir,
|
||||
),
|
||||
"vfs_stat": ToolSpec(
|
||||
name="vfs_stat",
|
||||
description="查看文件/目录信息(size/mtime/is_dir/has_thumbnail/vector_index 等)。",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "路径(绝对路径,如 /foo/bar.txt)"},
|
||||
},
|
||||
"required": ["path"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
requires_confirmation=False,
|
||||
handler=_vfs_stat,
|
||||
),
|
||||
"vfs_read_text": ToolSpec(
|
||||
name="vfs_read_text",
|
||||
description="读取文本文件内容(解码失败视为二进制,返回 error)。",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "文件路径(绝对路径,如 /foo/bar.md)"},
|
||||
"encoding": {"type": "string", "description": "文本编码(默认 utf-8)"},
|
||||
"max_chars": {"type": "integer", "description": "最多返回的字符数(默认 8000)"},
|
||||
},
|
||||
"required": ["path"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
requires_confirmation=False,
|
||||
handler=_vfs_read_text,
|
||||
),
|
||||
"vfs_write_text": ToolSpec(
|
||||
name="vfs_write_text",
|
||||
description="写入文本文件内容(会覆盖目标文件)。",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "文件路径(绝对路径,如 /foo/bar.md)"},
|
||||
"content": {"type": "string", "description": "要写入的文本内容"},
|
||||
"encoding": {"type": "string", "description": "文本编码(默认 utf-8)"},
|
||||
},
|
||||
"required": ["path", "content"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
requires_confirmation=True,
|
||||
handler=_vfs_write_text,
|
||||
),
|
||||
"vfs_mkdir": ToolSpec(
|
||||
name="vfs_mkdir",
|
||||
description="创建目录。",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "目录路径(绝对路径,如 /foo/bar)"},
|
||||
},
|
||||
"required": ["path"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
requires_confirmation=True,
|
||||
handler=_vfs_mkdir,
|
||||
),
|
||||
"vfs_delete": ToolSpec(
|
||||
name="vfs_delete",
|
||||
description="删除文件或目录(由底层适配器决定是否递归)。",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "路径(绝对路径,如 /foo/bar 或 /foo/bar.txt)"},
|
||||
},
|
||||
"required": ["path"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
requires_confirmation=True,
|
||||
handler=_vfs_delete,
|
||||
),
|
||||
"vfs_move": ToolSpec(
|
||||
name="vfs_move",
|
||||
description="移动路径(可能进入任务队列)。",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"src": {"type": "string", "description": "源路径(绝对路径)"},
|
||||
"dst": {"type": "string", "description": "目标路径(绝对路径)"},
|
||||
"overwrite": {"type": "boolean", "description": "是否允许覆盖已存在目标(默认 false)"},
|
||||
},
|
||||
"required": ["src", "dst"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
requires_confirmation=True,
|
||||
handler=_vfs_move,
|
||||
),
|
||||
"vfs_copy": ToolSpec(
|
||||
name="vfs_copy",
|
||||
description="复制路径(可能进入任务队列)。",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"src": {"type": "string", "description": "源路径(绝对路径)"},
|
||||
"dst": {"type": "string", "description": "目标路径(绝对路径)"},
|
||||
"overwrite": {"type": "boolean", "description": "是否覆盖已存在目标(默认 false)"},
|
||||
},
|
||||
"required": ["src", "dst"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
requires_confirmation=True,
|
||||
handler=_vfs_copy,
|
||||
),
|
||||
"vfs_rename": ToolSpec(
|
||||
name="vfs_rename",
|
||||
description="重命名路径(本质是同目录 move)。",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"src": {"type": "string", "description": "源路径(绝对路径)"},
|
||||
"dst": {"type": "string", "description": "目标路径(绝对路径)"},
|
||||
"overwrite": {"type": "boolean", "description": "是否允许覆盖已存在目标(默认 false)"},
|
||||
},
|
||||
"required": ["src", "dst"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
requires_confirmation=True,
|
||||
handler=_vfs_rename,
|
||||
),
|
||||
"vfs_search": ToolSpec(
|
||||
name="vfs_search",
|
||||
description="搜索文件(mode=vector 或 filename)。",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"q": {"type": "string", "description": "搜索关键词"},
|
||||
"mode": {"type": "string", "description": "搜索模式:vector/filename(默认 vector)"},
|
||||
"top_k": {"type": "integer", "description": "返回数量(vector 模式使用,默认 10)"},
|
||||
"page": {"type": "integer", "description": "页码(filename 模式使用,默认 1)"},
|
||||
"page_size": {"type": "integer", "description": "分页大小(filename 模式使用,默认 10)"},
|
||||
},
|
||||
"required": ["q"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
requires_confirmation=False,
|
||||
handler=_vfs_search,
|
||||
),
|
||||
}
|
||||
182
domain/agent/tools/web_fetch.py
Normal file
182
domain/agent/tools/web_fetch.py
Normal file
@@ -0,0 +1,182 @@
|
||||
from html.parser import HTMLParser
|
||||
from typing import Any, Dict, List
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import httpx
|
||||
|
||||
from .base import ToolSpec
|
||||
|
||||
|
||||
class _HtmlTextExtractor(HTMLParser):
|
||||
def __init__(self, base_url: str):
|
||||
super().__init__()
|
||||
self.base_url = base_url
|
||||
self.links: List[str] = []
|
||||
self._link_set: set[str] = set()
|
||||
self._title_parts: List[str] = []
|
||||
self._text_parts: List[str] = []
|
||||
self._in_title = False
|
||||
self._skip_text = False
|
||||
|
||||
def handle_starttag(self, tag: str, attrs: List[tuple[str, str | None]]):
|
||||
tag = tag.lower()
|
||||
if tag == "title":
|
||||
self._in_title = True
|
||||
if tag in ("script", "style", "noscript"):
|
||||
self._skip_text = True
|
||||
if tag != "a":
|
||||
return
|
||||
href = ""
|
||||
for key, value in attrs:
|
||||
if key.lower() == "href":
|
||||
href = str(value or "").strip()
|
||||
break
|
||||
if not href or href.startswith("#"):
|
||||
return
|
||||
lower = href.lower()
|
||||
if lower.startswith(("javascript:", "mailto:", "tel:", "data:")):
|
||||
return
|
||||
resolved = urljoin(self.base_url, href)
|
||||
if resolved in self._link_set:
|
||||
return
|
||||
self._link_set.add(resolved)
|
||||
self.links.append(resolved)
|
||||
|
||||
def handle_endtag(self, tag: str):
|
||||
tag = tag.lower()
|
||||
if tag == "title":
|
||||
self._in_title = False
|
||||
if tag in ("script", "style", "noscript"):
|
||||
self._skip_text = False
|
||||
|
||||
def handle_data(self, data: str):
|
||||
if not data:
|
||||
return
|
||||
if self._in_title:
|
||||
self._title_parts.append(data)
|
||||
if self._skip_text:
|
||||
return
|
||||
if data.strip():
|
||||
self._text_parts.append(data)
|
||||
|
||||
@property
|
||||
def title(self) -> str:
|
||||
return " ".join(part.strip() for part in self._title_parts if part and part.strip()).strip()
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
if not self._text_parts:
|
||||
return ""
|
||||
text = " ".join(part.strip() for part in self._text_parts if part and part.strip())
|
||||
return " ".join(text.split())
|
||||
|
||||
|
||||
async def _web_fetch(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
url = str(args.get("url") or "").strip()
|
||||
if not url:
|
||||
raise ValueError("missing_url")
|
||||
|
||||
method = str(args.get("method") or "GET").upper()
|
||||
allowed_methods = {"GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"}
|
||||
if method not in allowed_methods:
|
||||
raise ValueError("invalid_method")
|
||||
|
||||
headers_raw = args.get("headers")
|
||||
headers = {str(k): str(v) for k, v in headers_raw.items() if v is not None} if isinstance(headers_raw, dict) else None
|
||||
params_raw = args.get("params")
|
||||
params = {str(k): str(v) for k, v in params_raw.items() if v is not None} if isinstance(params_raw, dict) else None
|
||||
json_body = args.get("json") if "json" in args else None
|
||||
body = args.get("body")
|
||||
|
||||
request_kwargs: Dict[str, Any] = {}
|
||||
if headers:
|
||||
request_kwargs["headers"] = headers
|
||||
if params:
|
||||
request_kwargs["params"] = params
|
||||
if json_body is not None:
|
||||
request_kwargs["json"] = json_body
|
||||
elif body is not None:
|
||||
request_kwargs["content"] = str(body)
|
||||
|
||||
async with httpx.AsyncClient(timeout=20.0, follow_redirects=True) as client:
|
||||
resp = await client.request(method, url, **request_kwargs)
|
||||
|
||||
content_type = resp.headers.get("content-type") or ""
|
||||
text = resp.text or ""
|
||||
is_html = "html" in content_type.lower()
|
||||
if not is_html:
|
||||
probe = text.lstrip()[:200].lower()
|
||||
if "<html" in probe or "<!doctype html" in probe:
|
||||
is_html = True
|
||||
|
||||
html = ""
|
||||
title = ""
|
||||
links: List[str] = []
|
||||
extracted_text = text
|
||||
|
||||
if is_html and text:
|
||||
html = text
|
||||
parser = _HtmlTextExtractor(str(resp.url))
|
||||
parser.feed(text)
|
||||
title = parser.title
|
||||
links = parser.links
|
||||
extracted_text = parser.text
|
||||
|
||||
data = {
|
||||
"url": url,
|
||||
"method": method,
|
||||
"final_url": str(resp.url),
|
||||
"status_code": resp.status_code,
|
||||
"content_type": content_type,
|
||||
"title": title,
|
||||
"html": html,
|
||||
"text": extracted_text,
|
||||
"links": links,
|
||||
}
|
||||
|
||||
summary_parts = [method, str(resp.status_code)]
|
||||
if title:
|
||||
summary_parts.append(title)
|
||||
summary_parts.append(f"{len(links)} links")
|
||||
summary = " · ".join(summary_parts)
|
||||
|
||||
view = {
|
||||
"type": "text",
|
||||
"text": extracted_text,
|
||||
"meta": {
|
||||
"url": url,
|
||||
"final_url": str(resp.url),
|
||||
"status_code": resp.status_code,
|
||||
"content_type": content_type,
|
||||
"title": title,
|
||||
"method": method,
|
||||
"links": len(links),
|
||||
},
|
||||
}
|
||||
return {"ok": True, "summary": summary, "view": view, "data": data}
|
||||
|
||||
|
||||
TOOLS: Dict[str, ToolSpec] = {
|
||||
"web_fetch": ToolSpec(
|
||||
name="web_fetch",
|
||||
description=(
|
||||
"抓取网页内容,返回状态、标题、正文、HTML、链接等信息。"
|
||||
" 支持 GET/POST/PUT/PATCH/DELETE/HEAD/OPTIONS。"
|
||||
),
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {"type": "string", "description": "目标 URL"},
|
||||
"method": {"type": "string", "description": "请求方法(默认 GET)"},
|
||||
"headers": {"type": "object", "description": "请求头", "additionalProperties": {"type": "string"}},
|
||||
"params": {"type": "object", "description": "查询参数", "additionalProperties": {"type": "string"}},
|
||||
"json": {"type": "object", "description": "JSON 请求体"},
|
||||
"body": {"type": "string", "description": "原始请求体"},
|
||||
},
|
||||
"required": ["url"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
requires_confirmation=False,
|
||||
handler=_web_fetch,
|
||||
),
|
||||
}
|
||||
23
domain/agent/types.py
Normal file
23
domain/agent/types.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class AgentChatContext(BaseModel):
|
||||
current_path: Optional[str] = None
|
||||
|
||||
|
||||
class AgentChatRequest(BaseModel):
|
||||
messages: List[Dict[str, Any]] = Field(default_factory=list)
|
||||
auto_execute: bool = False
|
||||
approved_tool_call_ids: List[str] = Field(default_factory=list)
|
||||
rejected_tool_call_ids: List[str] = Field(default_factory=list)
|
||||
context: Optional[AgentChatContext] = None
|
||||
|
||||
|
||||
class PendingToolCall(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
arguments: Dict[str, Any] = Field(default_factory=dict)
|
||||
requires_confirmation: bool = True
|
||||
|
||||
67
domain/ai/__init__.py
Normal file
67
domain/ai/__init__.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from .inference import (
|
||||
MissingModelError,
|
||||
chat_completion,
|
||||
chat_completion_stream,
|
||||
describe_image_base64,
|
||||
get_text_embedding,
|
||||
provider_service,
|
||||
rerank_texts,
|
||||
)
|
||||
from .service import (
|
||||
AIProviderService,
|
||||
FILE_COLLECTION_NAME,
|
||||
VECTOR_COLLECTION_NAME,
|
||||
DEFAULT_VECTOR_DIMENSION,
|
||||
VectorDBConfigManager,
|
||||
VectorDBService,
|
||||
)
|
||||
from .types import (
|
||||
ABILITIES,
|
||||
AIDefaultsUpdate,
|
||||
AIModelCreate,
|
||||
AIModelUpdate,
|
||||
AIProviderCreate,
|
||||
AIProviderUpdate,
|
||||
VectorDBConfigPayload,
|
||||
normalize_capabilities,
|
||||
)
|
||||
from .vector_providers import (
|
||||
BaseVectorProvider,
|
||||
MilvusLiteProvider,
|
||||
MilvusServerProvider,
|
||||
QdrantProvider,
|
||||
get_provider_class,
|
||||
get_provider_entry,
|
||||
list_providers,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"MissingModelError",
|
||||
"chat_completion",
|
||||
"chat_completion_stream",
|
||||
"describe_image_base64",
|
||||
"get_text_embedding",
|
||||
"provider_service",
|
||||
"rerank_texts",
|
||||
"AIProviderService",
|
||||
"VectorDBService",
|
||||
"VectorDBConfigManager",
|
||||
"DEFAULT_VECTOR_DIMENSION",
|
||||
"VECTOR_COLLECTION_NAME",
|
||||
"FILE_COLLECTION_NAME",
|
||||
"BaseVectorProvider",
|
||||
"MilvusLiteProvider",
|
||||
"MilvusServerProvider",
|
||||
"QdrantProvider",
|
||||
"list_providers",
|
||||
"get_provider_entry",
|
||||
"get_provider_class",
|
||||
"ABILITIES",
|
||||
"normalize_capabilities",
|
||||
"AIDefaultsUpdate",
|
||||
"AIModelCreate",
|
||||
"AIModelUpdate",
|
||||
"AIProviderCreate",
|
||||
"AIProviderUpdate",
|
||||
"VectorDBConfigPayload",
|
||||
]
|
||||
304
domain/ai/api.py
Normal file
304
domain/ai/api.py
Normal file
@@ -0,0 +1,304 @@
|
||||
from typing import Annotated, Dict, Optional
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Depends, HTTPException, Path, Request
|
||||
|
||||
from api.response import success
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.auth import User, get_current_active_user
|
||||
from .service import AIProviderService, VectorDBConfigManager, VectorDBService
|
||||
from .types import (
|
||||
AIDefaultsUpdate,
|
||||
AIModelCreate,
|
||||
AIModelUpdate,
|
||||
AIProviderCreate,
|
||||
AIProviderUpdate,
|
||||
VectorDBConfigPayload,
|
||||
)
|
||||
from .vector_providers import get_provider_class, get_provider_entry, list_providers
|
||||
|
||||
router_ai = APIRouter(prefix="/api/ai", tags=["ai"])
|
||||
router_vector_db = APIRouter(prefix="/api/vector-db", tags=["vector-db"])
|
||||
|
||||
|
||||
@audit(action=AuditAction.READ, description="获取 AI 提供商列表")
|
||||
@router_ai.get("/providers")
|
||||
async def list_providers_endpoint(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
providers = await AIProviderService.list_providers()
|
||||
return success({"providers": providers})
|
||||
|
||||
|
||||
@audit(
|
||||
action=AuditAction.CREATE,
|
||||
description="创建 AI 提供商",
|
||||
body_fields=["name", "identifier", "provider_type", "api_format", "base_url", "logo_url"],
|
||||
redact_fields=["api_key"],
|
||||
)
|
||||
@router_ai.post("/providers")
|
||||
async def create_provider(
|
||||
request: Request,
|
||||
payload: AIProviderCreate,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
provider = await AIProviderService.create_provider(payload.dict())
|
||||
return success(provider)
|
||||
|
||||
|
||||
@audit(action=AuditAction.READ, description="获取 AI 提供商详情")
|
||||
@router_ai.get("/providers/{provider_id}")
|
||||
async def get_provider(
|
||||
request: Request,
|
||||
provider_id: Annotated[int, Path(..., gt=0)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
provider = await AIProviderService.get_provider(provider_id, with_models=True)
|
||||
return success(provider)
|
||||
|
||||
|
||||
@audit(
|
||||
action=AuditAction.UPDATE,
|
||||
description="更新 AI 提供商",
|
||||
body_fields=["name", "provider_type", "api_format", "base_url", "logo_url", "api_key"],
|
||||
redact_fields=["api_key"],
|
||||
)
|
||||
@router_ai.put("/providers/{provider_id}")
|
||||
async def update_provider(
|
||||
request: Request,
|
||||
provider_id: Annotated[int, Path(..., gt=0)],
|
||||
payload: AIProviderUpdate,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
data = {k: v for k, v in payload.dict().items() if v is not None}
|
||||
if not data:
|
||||
raise HTTPException(status_code=400, detail="No fields to update")
|
||||
provider = await AIProviderService.update_provider(provider_id, data)
|
||||
return success(provider)
|
||||
|
||||
|
||||
@audit(action=AuditAction.DELETE, description="删除 AI 提供商")
|
||||
@router_ai.delete("/providers/{provider_id}")
|
||||
async def delete_provider(
|
||||
request: Request,
|
||||
provider_id: Annotated[int, Path(..., gt=0)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
await AIProviderService.delete_provider(provider_id)
|
||||
return success({"id": provider_id})
|
||||
|
||||
|
||||
@audit(action=AuditAction.UPDATE, description="同步模型列表")
|
||||
@router_ai.post("/providers/{provider_id}/sync-models")
|
||||
async def sync_models(
|
||||
request: Request,
|
||||
provider_id: Annotated[int, Path(..., gt=0)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
try:
|
||||
result = await AIProviderService.sync_models(provider_id)
|
||||
except (httpx.RequestError, httpx.HTTPStatusError) as exc:
|
||||
raise HTTPException(status_code=502, detail=f"Failed to synchronize models: {exc}") from exc
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
return success(result)
|
||||
|
||||
|
||||
@audit(action=AuditAction.READ, description="获取远程模型列表")
|
||||
@router_ai.get("/providers/{provider_id}/remote-models")
|
||||
async def fetch_remote_models(
|
||||
request: Request,
|
||||
provider_id: Annotated[int, Path(..., gt=0)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
try:
|
||||
models = await AIProviderService.fetch_remote_models(provider_id)
|
||||
except (httpx.RequestError, httpx.HTTPStatusError) as exc:
|
||||
raise HTTPException(status_code=502, detail=f"Failed to pull models: {exc}") from exc
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
return success({"models": models})
|
||||
|
||||
|
||||
@audit(action=AuditAction.READ, description="获取模型列表")
|
||||
@router_ai.get("/providers/{provider_id}/models")
|
||||
async def list_models(
|
||||
request: Request,
|
||||
provider_id: Annotated[int, Path(..., gt=0)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
models = await AIProviderService.list_models(provider_id)
|
||||
return success({"models": models})
|
||||
|
||||
|
||||
@audit(
|
||||
action=AuditAction.CREATE,
|
||||
description="创建模型",
|
||||
body_fields=["name", "display_name", "capabilities", "context_window", "embedding_dimensions"],
|
||||
)
|
||||
@router_ai.post("/providers/{provider_id}/models")
|
||||
async def create_model(
|
||||
request: Request,
|
||||
provider_id: Annotated[int, Path(..., gt=0)],
|
||||
payload: AIModelCreate,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
model = await AIProviderService.create_model(provider_id, payload.dict())
|
||||
return success(model)
|
||||
|
||||
|
||||
@audit(
|
||||
action=AuditAction.UPDATE,
|
||||
description="更新模型",
|
||||
body_fields=["display_name", "description", "capabilities", "context_window", "embedding_dimensions"],
|
||||
)
|
||||
@router_ai.put("/models/{model_id}")
|
||||
async def update_model(
|
||||
request: Request,
|
||||
model_id: Annotated[int, Path(..., gt=0)],
|
||||
payload: AIModelUpdate,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
data = {k: v for k, v in payload.dict().items() if v is not None}
|
||||
if not data:
|
||||
raise HTTPException(status_code=400, detail="No fields to update")
|
||||
model = await AIProviderService.update_model(model_id, data)
|
||||
return success(model)
|
||||
|
||||
|
||||
@audit(action=AuditAction.DELETE, description="删除模型")
|
||||
@router_ai.delete("/models/{model_id}")
|
||||
async def delete_model(
|
||||
request: Request,
|
||||
model_id: Annotated[int, Path(..., gt=0)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
await AIProviderService.delete_model(model_id)
|
||||
return success({"id": model_id})
|
||||
|
||||
|
||||
def _get_embedding_dimension(entry: Optional[Dict]) -> Optional[int]:
|
||||
if not entry:
|
||||
return None
|
||||
value = entry.get("embedding_dimensions")
|
||||
return int(value) if value is not None else None
|
||||
|
||||
|
||||
@audit(action=AuditAction.READ, description="获取默认模型")
|
||||
@router_ai.get("/defaults")
|
||||
async def get_defaults(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
defaults = await AIProviderService.get_default_models()
|
||||
return success(defaults)
|
||||
|
||||
|
||||
@audit(
|
||||
action=AuditAction.UPDATE,
|
||||
description="更新默认模型",
|
||||
body_fields=["chat", "vision", "embedding", "rerank", "voice", "tools"],
|
||||
)
|
||||
@router_ai.put("/defaults")
|
||||
async def update_defaults(
|
||||
request: Request,
|
||||
payload: AIDefaultsUpdate,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
previous = await AIProviderService.get_default_models()
|
||||
try:
|
||||
updated = await AIProviderService.set_default_models(payload.as_mapping())
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
prev_dim = _get_embedding_dimension(previous.get("embedding"))
|
||||
next_dim = _get_embedding_dimension(updated.get("embedding"))
|
||||
|
||||
if prev_dim and next_dim and prev_dim != next_dim:
|
||||
try:
|
||||
await VectorDBService().clear_all_data()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
raise HTTPException(status_code=500, detail=f"Failed to clear vector database: {exc}") from exc
|
||||
|
||||
return success(updated)
|
||||
|
||||
|
||||
@audit(action=AuditAction.UPDATE, description="清空向量数据库")
|
||||
@router_vector_db.post("/clear-all", summary="清空向量数据库")
|
||||
async def clear_vector_db(request: Request, user: User = Depends(get_current_active_user)):
|
||||
try:
|
||||
service = VectorDBService()
|
||||
await service.clear_all_data()
|
||||
return success(msg="向量数据库已清空")
|
||||
except Exception as e: # noqa: BLE001
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@audit(action=AuditAction.READ, description="获取向量数据库统计")
|
||||
@router_vector_db.get("/stats", summary="获取向量数据库统计")
|
||||
async def get_vector_db_stats(request: Request, user: User = Depends(get_current_active_user)):
|
||||
try:
|
||||
service = VectorDBService()
|
||||
data = await service.get_all_stats()
|
||||
return success(data=data)
|
||||
except Exception as e: # noqa: BLE001
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@audit(action=AuditAction.READ, description="获取向量数据库提供者列表")
|
||||
@router_vector_db.get("/providers", summary="列出可用向量数据库提供者")
|
||||
async def list_vector_providers(request: Request):
|
||||
return success(list_providers())
|
||||
|
||||
|
||||
@audit(action=AuditAction.READ, description="获取向量数据库配置")
|
||||
@router_vector_db.get("/config", summary="获取当前向量数据库配置")
|
||||
async def get_vector_db_config(request: Request, user: User = Depends(get_current_active_user)):
|
||||
service = VectorDBService()
|
||||
data = await service.current_provider()
|
||||
return success(data)
|
||||
|
||||
|
||||
@audit(action=AuditAction.UPDATE, description="更新向量数据库配置", body_fields=["type"])
|
||||
@router_vector_db.post("/config", summary="更新向量数据库配置")
|
||||
async def update_vector_db_config(
|
||||
request: Request, payload: VectorDBConfigPayload, user: User = Depends(get_current_active_user)
|
||||
):
|
||||
entry = get_provider_entry(payload.type)
|
||||
if not entry:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"未知的向量数据库类型: {payload.type}")
|
||||
if not entry.get("enabled", True):
|
||||
raise HTTPException(status_code=400, detail="该向量数据库类型暂不可用")
|
||||
|
||||
provider_cls = get_provider_class(payload.type)
|
||||
if not provider_cls:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"未找到类型 {payload.type} 对应的实现")
|
||||
|
||||
test_provider = provider_cls(payload.config)
|
||||
try:
|
||||
await test_provider.initialize()
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc))
|
||||
finally:
|
||||
client = getattr(test_provider, "client", None)
|
||||
close_fn = getattr(client, "close", None)
|
||||
if callable(close_fn):
|
||||
try:
|
||||
close_fn()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await VectorDBConfigManager.save_config(payload.type, payload.config)
|
||||
service = VectorDBService()
|
||||
await service.reload()
|
||||
config_data = await service.current_provider()
|
||||
stats = await service.get_all_stats()
|
||||
return success({"config": config_data, "stats": stats})
|
||||
|
||||
|
||||
__all__ = ["router_ai", "router_vector_db"]
|
||||
1163
domain/ai/inference.py
Normal file
1163
domain/ai/inference.py
Normal file
File diff suppressed because it is too large
Load Diff
501
domain/ai/service.py
Normal file
501
domain/ai/service.py
Normal file
@@ -0,0 +1,501 @@
|
||||
import asyncio
|
||||
import json
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import httpx
|
||||
from tortoise.exceptions import DoesNotExist
|
||||
from tortoise.transactions import in_transaction
|
||||
|
||||
from domain.config import ConfigService
|
||||
from models.database import AIDefaultModel, AIModel, AIProvider
|
||||
|
||||
from .types import ABILITIES, normalize_capabilities
|
||||
from .vector_providers import (
|
||||
BaseVectorProvider,
|
||||
get_provider_class,
|
||||
get_provider_entry,
|
||||
list_providers,
|
||||
)
|
||||
|
||||
DEFAULT_VECTOR_DIMENSION = 4096
|
||||
VECTOR_COLLECTION_NAME = "vector_collection"
|
||||
FILE_COLLECTION_NAME = "file_collection"
|
||||
|
||||
OPENAI_EMBEDDING_DIMS = {
|
||||
"text-embedding-3-large": 3072,
|
||||
"text-embedding-3-small": 1536,
|
||||
"text-embedding-ada-002": 1536,
|
||||
}
|
||||
|
||||
|
||||
class VectorDBConfigManager:
|
||||
TYPE_KEY = "VECTOR_DB_TYPE"
|
||||
CONFIG_KEY = "VECTOR_DB_CONFIG"
|
||||
DEFAULT_TYPE = "milvus_lite"
|
||||
|
||||
@classmethod
|
||||
async def load_config(cls) -> Tuple[str, Dict[str, Any]]:
|
||||
raw_type = await ConfigService.get(cls.TYPE_KEY, cls.DEFAULT_TYPE)
|
||||
provider_type = str(raw_type or cls.DEFAULT_TYPE)
|
||||
|
||||
raw_config = await ConfigService.get(cls.CONFIG_KEY)
|
||||
config_dict: Dict[str, Any] = {}
|
||||
if isinstance(raw_config, str) and raw_config:
|
||||
try:
|
||||
config_dict = json.loads(raw_config)
|
||||
except json.JSONDecodeError:
|
||||
config_dict = {}
|
||||
elif isinstance(raw_config, dict):
|
||||
config_dict = raw_config
|
||||
return provider_type, config_dict
|
||||
|
||||
@classmethod
|
||||
async def save_config(cls, provider_type: str, config: Dict[str, Any]) -> None:
|
||||
await ConfigService.set(cls.TYPE_KEY, provider_type)
|
||||
await ConfigService.set(cls.CONFIG_KEY, json.dumps(config or {}))
|
||||
|
||||
@classmethod
|
||||
async def get_type(cls) -> str:
|
||||
provider_type, _ = await cls.load_config()
|
||||
return provider_type
|
||||
|
||||
@classmethod
|
||||
async def get_config(cls) -> Dict[str, Any]:
|
||||
_, config = await cls.load_config()
|
||||
return config
|
||||
|
||||
|
||||
def _normalize_embedding_dim(value: Any) -> Optional[int]:
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
casted = int(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
return casted if casted > 0 else None
|
||||
|
||||
|
||||
def _apply_embedding_dim_to_metadata(
|
||||
data: Dict[str, Any],
|
||||
embedding_dim: Optional[int],
|
||||
base_metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
source = base_metadata if isinstance(base_metadata, dict) else {}
|
||||
metadata: Dict[str, Any] = dict(source)
|
||||
override = data.get("metadata")
|
||||
if isinstance(override, dict) and override:
|
||||
metadata.update(override)
|
||||
if embedding_dim is None:
|
||||
metadata.pop("embedding_dimensions", None)
|
||||
else:
|
||||
metadata["embedding_dimensions"] = embedding_dim
|
||||
data["metadata"] = metadata or None
|
||||
return data
|
||||
|
||||
|
||||
def infer_openai_capabilities(model_id: str) -> Tuple[List[str], Optional[int]]:
|
||||
lower = model_id.lower()
|
||||
caps = set()
|
||||
|
||||
if any(keyword in lower for keyword in ["gpt", "chat", "turbo", "o1", "sonnet", "haiku", "thinking"]):
|
||||
caps.update({"chat", "tools"})
|
||||
|
||||
if any(keyword in lower for keyword in ["vision", "gpt-4o", "gpt-4.1", "o1", "vision-preview", "omni"]):
|
||||
caps.add("vision")
|
||||
|
||||
if any(keyword in lower for keyword in ["embed", "embedding"]):
|
||||
caps.add("embedding")
|
||||
|
||||
if "rerank" in lower or "re-rank" in lower:
|
||||
caps.add("rerank")
|
||||
|
||||
if any(keyword in lower for keyword in ["tts", "speech", "audio"]):
|
||||
caps.add("voice")
|
||||
|
||||
embedding_dim = OPENAI_EMBEDDING_DIMS.get(model_id)
|
||||
return normalize_capabilities(caps), embedding_dim
|
||||
|
||||
|
||||
def infer_gemini_capabilities(methods: Iterable[str]) -> List[str]:
|
||||
caps = set()
|
||||
for method in methods:
|
||||
m = method.lower()
|
||||
if m in {"generatecontent", "counttokens"}:
|
||||
caps.update({"chat", "tools", "vision"})
|
||||
if m == "embedcontent":
|
||||
caps.add("embedding")
|
||||
if m in {"generatespeech", "audiogeneration"}:
|
||||
caps.add("voice")
|
||||
if m == "rerank":
|
||||
caps.add("rerank")
|
||||
return normalize_capabilities(caps)
|
||||
|
||||
|
||||
def serialize_provider(provider: AIProvider) -> Dict[str, Any]:
|
||||
return {
|
||||
"id": provider.id,
|
||||
"name": provider.name,
|
||||
"identifier": provider.identifier,
|
||||
"provider_type": provider.provider_type,
|
||||
"api_format": provider.api_format,
|
||||
"base_url": provider.base_url,
|
||||
"has_api_key": bool(provider.api_key),
|
||||
"logo_url": provider.logo_url,
|
||||
"extra_config": provider.extra_config or {},
|
||||
"created_at": provider.created_at,
|
||||
"updated_at": provider.updated_at,
|
||||
}
|
||||
|
||||
|
||||
def model_to_dict(model: AIModel, provider: Optional[AIProvider] = None) -> Dict[str, Any]:
|
||||
provider_obj = provider or getattr(model, "provider", None)
|
||||
provider_data = serialize_provider(provider_obj) if provider_obj else None
|
||||
return {
|
||||
"id": model.id,
|
||||
"provider_id": model.provider_id,
|
||||
"name": model.name,
|
||||
"display_name": model.display_name,
|
||||
"description": model.description,
|
||||
"capabilities": normalize_capabilities(model.capabilities),
|
||||
"context_window": model.context_window,
|
||||
"embedding_dimensions": model.embedding_dimensions,
|
||||
"metadata": model.metadata or {},
|
||||
"created_at": model.created_at,
|
||||
"updated_at": model.updated_at,
|
||||
"provider": provider_data,
|
||||
}
|
||||
|
||||
|
||||
def provider_to_dict(provider: AIProvider, models: Optional[List[AIModel]] = None) -> Dict[str, Any]:
|
||||
data = serialize_provider(provider)
|
||||
if models is not None:
|
||||
data["models"] = [model_to_dict(m, provider=provider) for m in models]
|
||||
return data
|
||||
|
||||
|
||||
class AIProviderService:
|
||||
@classmethod
|
||||
async def list_providers(cls) -> List[Dict[str, Any]]:
|
||||
providers = await AIProvider.all().order_by("id").prefetch_related("models")
|
||||
return [provider_to_dict(p, models=list(p.models)) for p in providers]
|
||||
|
||||
@classmethod
|
||||
async def get_provider(cls, provider_id: int, with_models: bool = False) -> Dict[str, Any]:
|
||||
if with_models:
|
||||
provider = await AIProvider.get(id=provider_id)
|
||||
models = await provider.models.all()
|
||||
return provider_to_dict(provider, models=models)
|
||||
provider = await AIProvider.get(id=provider_id)
|
||||
return provider_to_dict(provider)
|
||||
|
||||
@classmethod
|
||||
async def create_provider(cls, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
data = payload.copy()
|
||||
data.setdefault("extra_config", {})
|
||||
provider = await AIProvider.create(**data)
|
||||
return provider_to_dict(provider)
|
||||
|
||||
@classmethod
|
||||
async def update_provider(cls, provider_id: int, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
provider = await AIProvider.get(id=provider_id)
|
||||
for field, value in payload.items():
|
||||
setattr(provider, field, value)
|
||||
await provider.save()
|
||||
return provider_to_dict(provider)
|
||||
|
||||
@classmethod
|
||||
async def delete_provider(cls, provider_id: int) -> None:
|
||||
await AIProvider.filter(id=provider_id).delete()
|
||||
|
||||
@classmethod
|
||||
async def list_models(cls, provider_id: int) -> List[Dict[str, Any]]:
|
||||
models = await AIModel.filter(provider_id=provider_id).order_by("id").prefetch_related("provider")
|
||||
return [model_to_dict(m) for m in models]
|
||||
|
||||
@classmethod
|
||||
async def create_model(cls, provider_id: int, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
data = payload.copy()
|
||||
data["provider_id"] = provider_id
|
||||
data["capabilities"] = normalize_capabilities(data.get("capabilities"))
|
||||
embedding_dim = _normalize_embedding_dim(data.pop("embedding_dimensions", None))
|
||||
data = _apply_embedding_dim_to_metadata(data, embedding_dim)
|
||||
model = await AIModel.create(**data)
|
||||
await model.fetch_related("provider")
|
||||
return model_to_dict(model)
|
||||
|
||||
@classmethod
|
||||
async def update_model(cls, model_id: int, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
model = await AIModel.get(id=model_id)
|
||||
data = payload.copy()
|
||||
if "capabilities" in data:
|
||||
data["capabilities"] = normalize_capabilities(data.get("capabilities"))
|
||||
embedding_dim = None
|
||||
if "embedding_dimensions" in data:
|
||||
embedding_dim = _normalize_embedding_dim(data.pop("embedding_dimensions", None))
|
||||
_apply_embedding_dim_to_metadata(data, embedding_dim, base_metadata=model.metadata)
|
||||
for field, value in data.items():
|
||||
setattr(model, field, value)
|
||||
if embedding_dim is not None or ("embedding_dimensions" in payload and embedding_dim is None):
|
||||
model.embedding_dimensions = embedding_dim
|
||||
await model.save()
|
||||
await model.fetch_related("provider")
|
||||
return model_to_dict(model)
|
||||
|
||||
@classmethod
|
||||
async def delete_model(cls, model_id: int) -> None:
|
||||
await AIModel.filter(id=model_id).delete()
|
||||
|
||||
@classmethod
|
||||
async def fetch_remote_models(cls, provider_id: int) -> List[Dict[str, Any]]:
|
||||
provider = await AIProvider.get(id=provider_id)
|
||||
return await cls._get_remote_models(provider)
|
||||
|
||||
@classmethod
|
||||
async def _get_remote_models(cls, provider: AIProvider) -> List[Dict[str, Any]]:
|
||||
if not provider.base_url:
|
||||
raise ValueError("Provider base_url is required for syncing models")
|
||||
|
||||
fmt = (provider.api_format or "").lower()
|
||||
if fmt not in {"openai", "gemini"}:
|
||||
raise ValueError(f"Unsupported api_format '{provider.api_format}' for syncing models")
|
||||
|
||||
if fmt == "openai":
|
||||
return await cls._fetch_openai_models(provider)
|
||||
return await cls._fetch_gemini_models(provider)
|
||||
|
||||
@classmethod
|
||||
async def sync_models(cls, provider_id: int) -> Dict[str, int]:
|
||||
provider = await AIProvider.get(id=provider_id)
|
||||
remote_models = await cls._get_remote_models(provider)
|
||||
|
||||
created = 0
|
||||
updated = 0
|
||||
for entry in remote_models:
|
||||
defaults = entry.copy()
|
||||
model_id = defaults.pop("name")
|
||||
defaults["capabilities"] = normalize_capabilities(defaults.get("capabilities"))
|
||||
embedding_dim = _normalize_embedding_dim(defaults.pop("embedding_dimensions", None))
|
||||
defaults = _apply_embedding_dim_to_metadata(defaults, embedding_dim)
|
||||
obj, is_created = await AIModel.get_or_create(
|
||||
provider_id=provider.id,
|
||||
name=model_id,
|
||||
defaults=defaults,
|
||||
)
|
||||
if is_created:
|
||||
created += 1
|
||||
continue
|
||||
for field, value in defaults.items():
|
||||
setattr(obj, field, value)
|
||||
if embedding_dim is not None or ("embedding_dimensions" in entry and embedding_dim is None):
|
||||
obj.embedding_dimensions = embedding_dim
|
||||
await obj.save()
|
||||
updated += 1
|
||||
|
||||
return {"created": created, "updated": updated}
|
||||
|
||||
@classmethod
|
||||
async def get_default_models(cls) -> Dict[str, Optional[Dict[str, Any]]]:
|
||||
defaults = await AIDefaultModel.all().prefetch_related("model__provider")
|
||||
result: Dict[str, Optional[Dict[str, Any]]] = {ability: None for ability in ABILITIES}
|
||||
for item in defaults:
|
||||
result[item.ability] = model_to_dict(item.model, provider=item.model.provider) # type: ignore[attr-defined]
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
async def set_default_models(cls, mapping: Dict[str, Optional[int]]) -> Dict[str, Optional[Dict[str, Any]]]:
|
||||
normalized = {ability: mapping.get(ability) for ability in ABILITIES}
|
||||
async with in_transaction() as connection:
|
||||
for ability, model_id in normalized.items():
|
||||
record = await AIDefaultModel.get_or_none(ability=ability)
|
||||
if model_id:
|
||||
try:
|
||||
model = await AIModel.get(id=model_id)
|
||||
except DoesNotExist:
|
||||
raise ValueError(f"Model {model_id} not found")
|
||||
if record:
|
||||
record.model_id = model_id
|
||||
await record.save(using_db=connection)
|
||||
else:
|
||||
await AIDefaultModel.create(ability=ability, model_id=model_id)
|
||||
elif record:
|
||||
await record.delete(using_db=connection)
|
||||
return await cls.get_default_models()
|
||||
|
||||
@classmethod
|
||||
async def get_default_model(cls, ability: str) -> Optional[AIModel]:
|
||||
ability_key = ability.lower()
|
||||
if ability_key not in ABILITIES:
|
||||
return None
|
||||
record = await AIDefaultModel.get_or_none(ability=ability_key)
|
||||
if not record:
|
||||
return None
|
||||
model = await AIModel.get_or_none(id=record.model_id)
|
||||
if model:
|
||||
await model.fetch_related("provider")
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
async def _fetch_openai_models(cls, provider: AIProvider) -> List[Dict[str, Any]]:
|
||||
base_url = provider.base_url.rstrip("/")
|
||||
url = f"{base_url}/models"
|
||||
headers = {}
|
||||
if provider.api_key:
|
||||
headers["Authorization"] = f"Bearer {provider.api_key}"
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
payload = response.json()
|
||||
|
||||
data = payload.get("data", [])
|
||||
entries: List[Dict[str, Any]] = []
|
||||
for item in data:
|
||||
model_id = item.get("id")
|
||||
if not model_id:
|
||||
continue
|
||||
capabilities, embedding_dim = infer_openai_capabilities(model_id)
|
||||
entries.append({
|
||||
"name": model_id,
|
||||
"display_name": item.get("display_name"),
|
||||
"description": item.get("description"),
|
||||
"capabilities": capabilities,
|
||||
"context_window": item.get("context_window"),
|
||||
"embedding_dimensions": embedding_dim,
|
||||
"metadata": item,
|
||||
})
|
||||
return entries
|
||||
|
||||
@classmethod
|
||||
async def _fetch_gemini_models(cls, provider: AIProvider) -> List[Dict[str, Any]]:
|
||||
base_url = provider.base_url.rstrip("/")
|
||||
suffix = "/models"
|
||||
if provider.api_key:
|
||||
suffix += f"?key={provider.api_key}"
|
||||
url = f"{base_url}{suffix}"
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
payload = response.json()
|
||||
|
||||
data = payload.get("models", [])
|
||||
entries: List[Dict[str, Any]] = []
|
||||
for item in data:
|
||||
model_id = item.get("name")
|
||||
if not model_id:
|
||||
continue
|
||||
methods = item.get("supportedGenerationMethods") or []
|
||||
capabilities = infer_gemini_capabilities(methods)
|
||||
entries.append({
|
||||
"name": model_id,
|
||||
"display_name": item.get("displayName"),
|
||||
"description": item.get("description"),
|
||||
"capabilities": capabilities,
|
||||
"context_window": item.get("inputTokenLimit"),
|
||||
"embedding_dimensions": item.get("embeddingDimensions"),
|
||||
"metadata": item,
|
||||
})
|
||||
return entries
|
||||
|
||||
|
||||
class VectorDBService:
|
||||
_instance: Optional["VectorDBService"] = None
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(self, "_provider"):
|
||||
self._provider: Optional[BaseVectorProvider] = None
|
||||
self._provider_type: Optional[str] = None
|
||||
self._provider_config: Dict[str, Any] | None = None
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def _ensure_provider(self) -> BaseVectorProvider:
|
||||
if self._provider is None:
|
||||
await self.reload()
|
||||
assert self._provider is not None
|
||||
return self._provider
|
||||
|
||||
async def reload(self) -> BaseVectorProvider:
|
||||
async with self._lock:
|
||||
provider_type, provider_config = await VectorDBConfigManager.load_config()
|
||||
normalized_config = dict(provider_config or {})
|
||||
if (
|
||||
self._provider
|
||||
and self._provider_type == provider_type
|
||||
and self._provider_config == normalized_config
|
||||
):
|
||||
return self._provider
|
||||
|
||||
entry = get_provider_entry(provider_type)
|
||||
if not entry:
|
||||
raise RuntimeError(f"Unknown vector database provider: {provider_type}")
|
||||
if not entry.get("enabled", True):
|
||||
raise RuntimeError(f"Vector database provider '{provider_type}' is disabled")
|
||||
|
||||
provider_cls = get_provider_class(provider_type)
|
||||
if not provider_cls:
|
||||
raise RuntimeError(f"Provider class not found for '{provider_type}'")
|
||||
|
||||
provider = provider_cls(provider_config)
|
||||
await provider.initialize()
|
||||
|
||||
self._provider = provider
|
||||
self._provider_type = provider_type
|
||||
self._provider_config = normalized_config
|
||||
return provider
|
||||
|
||||
async def ensure_collection(self, collection_name: str, vector: bool = True, dim: int = DEFAULT_VECTOR_DIMENSION) -> None:
|
||||
provider = await self._ensure_provider()
|
||||
provider.ensure_collection(collection_name, vector, dim)
|
||||
|
||||
async def upsert_vector(self, collection_name: str, data: Dict[str, Any]) -> None:
|
||||
provider = await self._ensure_provider()
|
||||
provider.upsert_vector(collection_name, data)
|
||||
|
||||
async def delete_vector(self, collection_name: str, path: str) -> None:
|
||||
provider = await self._ensure_provider()
|
||||
provider.delete_vector(collection_name, path)
|
||||
|
||||
async def search_vectors(self, collection_name: str, query_embedding, top_k: int = 5):
|
||||
provider = await self._ensure_provider()
|
||||
return provider.search_vectors(collection_name, query_embedding, top_k)
|
||||
|
||||
async def search_by_path(self, collection_name: str, query_path: str, top_k: int = 20):
|
||||
provider = await self._ensure_provider()
|
||||
return provider.search_by_path(collection_name, query_path, top_k)
|
||||
|
||||
async def get_all_stats(self) -> Dict[str, Any]:
|
||||
provider = await self._ensure_provider()
|
||||
return provider.get_all_stats()
|
||||
|
||||
async def clear_all_data(self) -> None:
|
||||
provider = await self._ensure_provider()
|
||||
provider.clear_all_data()
|
||||
|
||||
async def current_provider(self) -> Dict[str, Any]:
|
||||
provider_type, provider_config = await VectorDBConfigManager.load_config()
|
||||
entry = get_provider_entry(provider_type) or {}
|
||||
return {
|
||||
"type": provider_type,
|
||||
"config": provider_config,
|
||||
"label": entry.get("label"),
|
||||
"enabled": entry.get("enabled", True),
|
||||
}
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AIProviderService",
|
||||
"VectorDBService",
|
||||
"VectorDBConfigManager",
|
||||
"DEFAULT_VECTOR_DIMENSION",
|
||||
"list_providers",
|
||||
"get_provider_entry",
|
||||
"get_provider_class",
|
||||
"normalize_capabilities",
|
||||
"ABILITIES",
|
||||
]
|
||||
121
domain/ai/types.py
Normal file
121
domain/ai/types.py
Normal file
@@ -0,0 +1,121 @@
|
||||
from typing import Any, Dict, Iterable, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
ABILITIES = ["chat", "vision", "embedding", "rerank", "voice", "tools"]
|
||||
|
||||
|
||||
def normalize_capabilities(items: Optional[Iterable[str]]) -> List[str]:
|
||||
if not items:
|
||||
return []
|
||||
normalized: List[str] = []
|
||||
for cap in items:
|
||||
key = str(cap).strip().lower()
|
||||
if key in ABILITIES and key not in normalized:
|
||||
normalized.append(key)
|
||||
return normalized
|
||||
|
||||
|
||||
class AIProviderBase(BaseModel):
|
||||
name: str
|
||||
identifier: str = Field(..., pattern=r"^[a-z0-9_\-\.]+$")
|
||||
provider_type: Optional[str] = None
|
||||
api_format: str
|
||||
base_url: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
logo_url: Optional[str] = None
|
||||
extra_config: Optional[dict] = None
|
||||
|
||||
@field_validator("api_format")
|
||||
@classmethod
|
||||
def normalize_format(cls, value: str) -> str:
|
||||
fmt = value.lower()
|
||||
if fmt not in {"openai", "gemini", "anthropic", "ollama"}:
|
||||
raise ValueError("api_format must be 'openai', 'gemini', 'anthropic', or 'ollama'")
|
||||
return fmt
|
||||
|
||||
|
||||
class AIProviderCreate(AIProviderBase):
|
||||
pass
|
||||
|
||||
|
||||
class AIProviderUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
provider_type: Optional[str] = None
|
||||
api_format: Optional[str] = None
|
||||
base_url: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
logo_url: Optional[str] = None
|
||||
extra_config: Optional[dict] = None
|
||||
|
||||
@field_validator("api_format")
|
||||
@classmethod
|
||||
def normalize_format(cls, value: Optional[str]) -> Optional[str]:
|
||||
if value is None:
|
||||
return value
|
||||
fmt = value.lower()
|
||||
if fmt not in {"openai", "gemini", "anthropic", "ollama"}:
|
||||
raise ValueError("api_format must be 'openai', 'gemini', 'anthropic', or 'ollama'")
|
||||
return fmt
|
||||
|
||||
|
||||
class AIModelBase(BaseModel):
|
||||
name: str
|
||||
display_name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
capabilities: Optional[List[str]] = None
|
||||
context_window: Optional[int] = None
|
||||
embedding_dimensions: Optional[int] = None
|
||||
metadata: Optional[dict] = None
|
||||
|
||||
@field_validator("capabilities")
|
||||
@classmethod
|
||||
def validate_capabilities(cls, items: Optional[List[str]]) -> Optional[List[str]]:
|
||||
if items is None:
|
||||
return None
|
||||
normalized = normalize_capabilities(items)
|
||||
invalid = set(items) - set(normalized)
|
||||
if invalid:
|
||||
raise ValueError(f"Unsupported capabilities: {', '.join(invalid)}")
|
||||
return normalized
|
||||
|
||||
|
||||
class AIModelCreate(AIModelBase):
|
||||
pass
|
||||
|
||||
|
||||
class AIModelUpdate(BaseModel):
|
||||
display_name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
capabilities: Optional[List[str]] = None
|
||||
context_window: Optional[int] = None
|
||||
embedding_dimensions: Optional[int] = None
|
||||
metadata: Optional[dict] = None
|
||||
|
||||
@field_validator("capabilities")
|
||||
@classmethod
|
||||
def validate_capabilities(cls, items: Optional[List[str]]) -> Optional[List[str]]:
|
||||
if items is None:
|
||||
return None
|
||||
normalized = normalize_capabilities(items)
|
||||
invalid = set(items) - set(normalized)
|
||||
if invalid:
|
||||
raise ValueError(f"Unsupported capabilities: {', '.join(invalid)}")
|
||||
return normalized
|
||||
|
||||
|
||||
class AIDefaultsUpdate(BaseModel):
|
||||
chat: Optional[int] = None
|
||||
vision: Optional[int] = None
|
||||
embedding: Optional[int] = None
|
||||
rerank: Optional[int] = None
|
||||
voice: Optional[int] = None
|
||||
tools: Optional[int] = None
|
||||
|
||||
def as_mapping(self) -> Dict[str, Optional[int]]:
|
||||
return {ability: getattr(self, ability) for ability in ABILITIES}
|
||||
|
||||
|
||||
class VectorDBConfigPayload(BaseModel):
|
||||
type: str = Field(..., description="向量数据库提供者类型")
|
||||
config: Dict[str, Any] = Field(default_factory=dict, description="提供者配置参数")
|
||||
65
domain/ai/vector_providers/__init__.py
Normal file
65
domain/ai/vector_providers/__init__.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from typing import Dict, List, Type
|
||||
|
||||
from .base import BaseVectorProvider
|
||||
from .milvus_lite import MilvusLiteProvider
|
||||
from .milvus_server import MilvusServerProvider
|
||||
from .qdrant import QdrantProvider
|
||||
|
||||
_PROVIDER_REGISTRY: Dict[str, Dict[str, object]] = {
|
||||
MilvusLiteProvider.type: {
|
||||
"class": MilvusLiteProvider,
|
||||
"label": MilvusLiteProvider.label,
|
||||
"description": MilvusLiteProvider.description,
|
||||
"enabled": MilvusLiteProvider.enabled,
|
||||
"config_schema": MilvusLiteProvider.config_schema,
|
||||
},
|
||||
MilvusServerProvider.type: {
|
||||
"class": MilvusServerProvider,
|
||||
"label": MilvusServerProvider.label,
|
||||
"description": MilvusServerProvider.description,
|
||||
"enabled": MilvusServerProvider.enabled,
|
||||
"config_schema": MilvusServerProvider.config_schema,
|
||||
},
|
||||
QdrantProvider.type: {
|
||||
"class": QdrantProvider,
|
||||
"label": QdrantProvider.label,
|
||||
"description": QdrantProvider.description,
|
||||
"enabled": QdrantProvider.enabled,
|
||||
"config_schema": QdrantProvider.config_schema,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def list_providers() -> List[Dict[str, object]]:
|
||||
return [
|
||||
{
|
||||
"type": type_key,
|
||||
"label": meta["label"],
|
||||
"description": meta.get("description"),
|
||||
"enabled": meta.get("enabled", True),
|
||||
"config_schema": meta.get("config_schema", []),
|
||||
}
|
||||
for type_key, meta in _PROVIDER_REGISTRY.items()
|
||||
]
|
||||
|
||||
|
||||
def get_provider_entry(provider_type: str) -> Dict[str, object] | None:
|
||||
return _PROVIDER_REGISTRY.get(provider_type)
|
||||
|
||||
|
||||
def get_provider_class(provider_type: str) -> Type[BaseVectorProvider] | None:
|
||||
entry = get_provider_entry(provider_type)
|
||||
if not entry:
|
||||
return None
|
||||
return entry.get("class") # type: ignore[return-value]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BaseVectorProvider",
|
||||
"MilvusLiteProvider",
|
||||
"MilvusServerProvider",
|
||||
"QdrantProvider",
|
||||
"list_providers",
|
||||
"get_provider_entry",
|
||||
"get_provider_class",
|
||||
]
|
||||
39
domain/ai/vector_providers/base.py
Normal file
39
domain/ai/vector_providers/base.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from typing import Any, Dict, List
|
||||
|
||||
|
||||
class BaseVectorProvider:
|
||||
"""向量数据库提供者基础类,所有实际实现需继承该类"""
|
||||
|
||||
type: str = ""
|
||||
label: str = ""
|
||||
description: str | None = None
|
||||
enabled: bool = True
|
||||
config_schema: List[Dict[str, Any]] = []
|
||||
|
||||
def __init__(self, config: Dict[str, Any] | None = None):
|
||||
self.config = config or {}
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""执行初始化逻辑,例如建立连接"""
|
||||
raise NotImplementedError
|
||||
|
||||
def ensure_collection(self, collection_name: str, vector: bool, dim: int) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def upsert_vector(self, collection_name: str, data: Dict[str, Any]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def delete_vector(self, collection_name: str, path: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def search_vectors(self, collection_name: str, query_embedding, top_k: int):
|
||||
raise NotImplementedError
|
||||
|
||||
def search_by_path(self, collection_name: str, query_path: str, top_k: int):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_all_stats(self) -> Dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
def clear_all_data(self) -> None:
|
||||
raise NotImplementedError
|
||||
276
domain/ai/vector_providers/milvus_lite.py
Normal file
276
domain/ai/vector_providers/milvus_lite.py
Normal file
@@ -0,0 +1,276 @@
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pymilvus import CollectionSchema, DataType, FieldSchema, MilvusClient
|
||||
|
||||
from .base import BaseVectorProvider
|
||||
|
||||
|
||||
class MilvusLiteProvider(BaseVectorProvider):
|
||||
type = "milvus_lite"
|
||||
label = "Milvus Lite"
|
||||
description = "Embedded Milvus Lite (local file storage)."
|
||||
enabled = True
|
||||
config_schema: List[Dict[str, Any]] = [
|
||||
{
|
||||
"key": "db_path",
|
||||
"label": "Database file path",
|
||||
"type": "text",
|
||||
"default": "data/db/milvus.db",
|
||||
"required": False,
|
||||
}
|
||||
]
|
||||
|
||||
def __init__(self, config: Dict[str, Any] | None = None):
|
||||
super().__init__(config)
|
||||
self.db_path = Path(self.config.get("db_path") or "data/db/milvus.db")
|
||||
self.client: MilvusClient | None = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
try:
|
||||
self.client = MilvusClient(str(self.db_path))
|
||||
except Exception as exc: # pragma: no cover - depends on local environment
|
||||
raise RuntimeError(f"Failed to open Milvus Lite at {self.db_path}: {exc}") from exc
|
||||
|
||||
def _get_client(self) -> MilvusClient:
|
||||
if not self.client:
|
||||
raise RuntimeError("Milvus Lite client is not initialized")
|
||||
return self.client
|
||||
|
||||
@staticmethod
|
||||
def _extract_hit_payload(hit: Any) -> tuple[Any, Any, Dict[str, Any]]:
|
||||
hit_id = getattr(hit, "id", None)
|
||||
distance = getattr(hit, "distance", None)
|
||||
payload: Dict[str, Any] = {}
|
||||
|
||||
raw: Dict[str, Any] | None = None
|
||||
if hasattr(hit, "entity"):
|
||||
raw_entity = getattr(hit, "entity")
|
||||
if hasattr(raw_entity, "to_dict"):
|
||||
raw = dict(raw_entity.to_dict())
|
||||
else:
|
||||
raw = dict(raw_entity)
|
||||
elif isinstance(hit, dict):
|
||||
raw = dict(hit)
|
||||
|
||||
if raw:
|
||||
hit_id = hit_id or raw.get("id")
|
||||
distance = distance if distance is not None else raw.get("distance")
|
||||
inner = raw.get("entity")
|
||||
if isinstance(inner, dict):
|
||||
payload = dict(inner)
|
||||
else:
|
||||
payload = {k: v for k, v in raw.items() if k not in {"id", "distance", "entity"}}
|
||||
|
||||
payload.setdefault("path", payload.get("source_path"))
|
||||
payload.setdefault("source_path", payload.get("path"))
|
||||
return hit_id, distance, payload
|
||||
|
||||
@staticmethod
|
||||
def _to_int(value: Any) -> int:
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return 0
|
||||
|
||||
def ensure_collection(self, collection_name: str, vector: bool, dim: int) -> None:
|
||||
client = self._get_client()
|
||||
if client.has_collection(collection_name):
|
||||
return
|
||||
common_fields = [
|
||||
FieldSchema(name="path", dtype=DataType.VARCHAR, max_length=512, is_primary=True, auto_id=False),
|
||||
FieldSchema(name="source_path", dtype=DataType.VARCHAR, max_length=512, is_primary=False, auto_id=False),
|
||||
]
|
||||
|
||||
if vector:
|
||||
vector_dim = dim if isinstance(dim, int) and dim > 0 else 0
|
||||
if vector_dim <= 0:
|
||||
vector_dim = 4096
|
||||
fields = [
|
||||
*common_fields,
|
||||
FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=vector_dim),
|
||||
]
|
||||
schema = CollectionSchema(fields, description="Vector collection", enable_dynamic_field=True)
|
||||
client.create_collection(collection_name, schema=schema)
|
||||
index_params = MilvusClient.prepare_index_params()
|
||||
index_params.add_index(
|
||||
field_name="embedding",
|
||||
index_type="IVF_FLAT",
|
||||
index_name="vector_index",
|
||||
metric_type="COSINE",
|
||||
params={"nlist": 64},
|
||||
)
|
||||
client.create_index(collection_name, index_params=index_params)
|
||||
else:
|
||||
schema = CollectionSchema(common_fields, description="Simple file index", enable_dynamic_field=True)
|
||||
client.create_collection(collection_name, schema=schema)
|
||||
|
||||
def upsert_vector(self, collection_name: str, data: Dict[str, Any]) -> None:
|
||||
payload = dict(data)
|
||||
payload.setdefault("source_path", payload.get("path"))
|
||||
payload.setdefault("vector_id", payload.get("path"))
|
||||
self._get_client().upsert(collection_name, data=[payload])
|
||||
|
||||
def delete_vector(self, collection_name: str, path: str) -> None:
|
||||
client = self._get_client()
|
||||
escaped = path.replace('"', '\\"')
|
||||
client.delete(collection_name, filter=f'source_path == "{escaped}"')
|
||||
|
||||
def search_vectors(self, collection_name: str, query_embedding, top_k: int):
|
||||
search_params = {"metric_type": "COSINE"}
|
||||
output_fields = [
|
||||
"path",
|
||||
"source_path",
|
||||
"chunk_id",
|
||||
"mime",
|
||||
"text",
|
||||
"start_offset",
|
||||
"end_offset",
|
||||
"type",
|
||||
"name",
|
||||
]
|
||||
raw_results = self._get_client().search(
|
||||
collection_name,
|
||||
data=[query_embedding],
|
||||
anns_field="embedding",
|
||||
search_params=search_params,
|
||||
limit=top_k,
|
||||
output_fields=output_fields,
|
||||
)
|
||||
formatted: List[List[Dict[str, Any]]] = []
|
||||
for hits in raw_results:
|
||||
bucket: List[Dict[str, Any]] = []
|
||||
for hit in hits:
|
||||
hit_id, distance, entity = self._extract_hit_payload(hit)
|
||||
bucket.append({
|
||||
"id": hit_id,
|
||||
"distance": distance,
|
||||
"entity": entity,
|
||||
})
|
||||
formatted.append(bucket)
|
||||
return formatted
|
||||
|
||||
def search_by_path(self, collection_name: str, query_path: str, top_k: int):
|
||||
if query_path:
|
||||
escaped = query_path.replace('"', '\\"')
|
||||
filter_expr = f'source_path like \"%{escaped}%\"'
|
||||
else:
|
||||
filter_expr = "source_path like '%%'"
|
||||
results = self._get_client().query(
|
||||
collection_name,
|
||||
filter=filter_expr,
|
||||
limit=top_k,
|
||||
output_fields=[
|
||||
"path",
|
||||
"source_path",
|
||||
"chunk_id",
|
||||
"mime",
|
||||
"text",
|
||||
"start_offset",
|
||||
"end_offset",
|
||||
"type",
|
||||
"name",
|
||||
],
|
||||
)
|
||||
formatted = []
|
||||
for row in results:
|
||||
entity = dict(row)
|
||||
entity.setdefault("path", entity.get("source_path"))
|
||||
formatted.append({
|
||||
"id": entity.get("path"),
|
||||
"distance": 1.0,
|
||||
"entity": entity,
|
||||
})
|
||||
return [formatted]
|
||||
|
||||
def get_all_stats(self) -> Dict[str, Any]:
|
||||
client = self._get_client()
|
||||
try:
|
||||
collection_names = client.list_collections()
|
||||
except Exception as exc:
|
||||
raise RuntimeError(f"Failed to list collections: {exc}") from exc
|
||||
|
||||
collections: List[Dict[str, Any]] = []
|
||||
total_vectors = 0
|
||||
total_estimated_memory = 0
|
||||
|
||||
for name in collection_names:
|
||||
try:
|
||||
stats = client.get_collection_stats(name) or {}
|
||||
except Exception:
|
||||
stats = {}
|
||||
row_count = self._to_int(stats.get("row_count"))
|
||||
total_vectors += row_count
|
||||
|
||||
dimension: Optional[int] = None
|
||||
is_vector_collection = False
|
||||
try:
|
||||
description = client.describe_collection(name)
|
||||
except Exception:
|
||||
description = None
|
||||
|
||||
if description:
|
||||
for field in description.get("fields", []):
|
||||
if field.get("type") == DataType.FLOAT_VECTOR:
|
||||
params = field.get("params") or {}
|
||||
dimension = self._to_int(params.get("dim")) or 4096
|
||||
is_vector_collection = True
|
||||
break
|
||||
|
||||
estimated_memory = 0
|
||||
if is_vector_collection and dimension:
|
||||
estimated_memory = row_count * dimension * 4
|
||||
total_estimated_memory += estimated_memory
|
||||
|
||||
indexes: List[Dict[str, Any]] = []
|
||||
try:
|
||||
index_names = client.list_indexes(name) or []
|
||||
except Exception:
|
||||
index_names = []
|
||||
|
||||
for index_name in index_names:
|
||||
try:
|
||||
detail = client.describe_index(name) or {}
|
||||
except Exception:
|
||||
detail = {}
|
||||
indexes.append(
|
||||
{
|
||||
"index_name": index_name,
|
||||
"index_type": detail.get("index_type"),
|
||||
"metric_type": detail.get("metric_type"),
|
||||
"indexed_rows": self._to_int(detail.get("indexed_rows")),
|
||||
"pending_index_rows": self._to_int(detail.get("pending_index_rows")),
|
||||
"state": detail.get("state"),
|
||||
}
|
||||
)
|
||||
|
||||
collections.append(
|
||||
{
|
||||
"name": name,
|
||||
"row_count": row_count,
|
||||
"dimension": dimension if is_vector_collection else None,
|
||||
"estimated_memory_bytes": estimated_memory,
|
||||
"is_vector_collection": is_vector_collection,
|
||||
"indexes": indexes,
|
||||
}
|
||||
)
|
||||
|
||||
db_file_size = None
|
||||
try:
|
||||
if self.db_path.exists():
|
||||
db_file_size = self.db_path.stat().st_size
|
||||
except OSError:
|
||||
db_file_size = None
|
||||
|
||||
return {
|
||||
"collections": collections,
|
||||
"collection_count": len(collections),
|
||||
"total_vectors": total_vectors,
|
||||
"estimated_total_memory_bytes": total_estimated_memory,
|
||||
"db_file_size_bytes": db_file_size,
|
||||
}
|
||||
|
||||
def clear_all_data(self) -> None:
|
||||
client = self._get_client()
|
||||
for collection_name in client.list_collections():
|
||||
client.drop_collection(collection_name)
|
||||
276
domain/ai/vector_providers/milvus_server.py
Normal file
276
domain/ai/vector_providers/milvus_server.py
Normal file
@@ -0,0 +1,276 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pymilvus import CollectionSchema, DataType, FieldSchema, MilvusClient
|
||||
|
||||
from .base import BaseVectorProvider
|
||||
|
||||
|
||||
class MilvusServerProvider(BaseVectorProvider):
|
||||
type = "milvus_server"
|
||||
label = "Milvus Server"
|
||||
description = "Remote Milvus instance accessed via URI."
|
||||
enabled = True
|
||||
config_schema: List[Dict[str, Any]] = [
|
||||
{
|
||||
"key": "uri",
|
||||
"label": "Server URI",
|
||||
"type": "text",
|
||||
"required": True,
|
||||
"placeholder": "http://localhost:19530",
|
||||
},
|
||||
{
|
||||
"key": "token",
|
||||
"label": "Token",
|
||||
"type": "password",
|
||||
"required": False,
|
||||
"placeholder": "user:password",
|
||||
},
|
||||
]
|
||||
|
||||
def __init__(self, config: Dict[str, Any] | None = None):
|
||||
super().__init__(config)
|
||||
self.client: MilvusClient | None = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
uri = self.config.get("uri")
|
||||
if not uri:
|
||||
raise RuntimeError("Milvus Server URI is required")
|
||||
try:
|
||||
self.client = MilvusClient(uri=uri, token=self.config.get("token"))
|
||||
except Exception as exc: # pragma: no cover - depends on remote availability
|
||||
raise RuntimeError(f"Failed to connect to Milvus Server {uri}: {exc}") from exc
|
||||
|
||||
def _get_client(self) -> MilvusClient:
|
||||
if not self.client:
|
||||
raise RuntimeError("Milvus Server client is not initialized")
|
||||
return self.client
|
||||
|
||||
@staticmethod
|
||||
def _extract_hit_payload(hit: Any) -> tuple[Any, Any, Dict[str, Any]]:
|
||||
hit_id = getattr(hit, "id", None)
|
||||
distance = getattr(hit, "distance", None)
|
||||
payload: Dict[str, Any] = {}
|
||||
|
||||
raw: Dict[str, Any] | None = None
|
||||
if hasattr(hit, "entity"):
|
||||
raw_entity = getattr(hit, "entity")
|
||||
if hasattr(raw_entity, "to_dict"):
|
||||
raw = dict(raw_entity.to_dict())
|
||||
else:
|
||||
raw = dict(raw_entity)
|
||||
elif isinstance(hit, dict):
|
||||
raw = dict(hit)
|
||||
|
||||
if raw:
|
||||
hit_id = hit_id or raw.get("id")
|
||||
distance = distance if distance is not None else raw.get("distance")
|
||||
inner = raw.get("entity")
|
||||
if isinstance(inner, dict):
|
||||
payload = dict(inner)
|
||||
else:
|
||||
payload = {k: v for k, v in raw.items() if k not in {"id", "distance", "entity"}}
|
||||
|
||||
payload.setdefault("path", payload.get("source_path"))
|
||||
payload.setdefault("source_path", payload.get("path"))
|
||||
return hit_id, distance, payload
|
||||
|
||||
@staticmethod
|
||||
def _to_int(value: Any) -> int:
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return 0
|
||||
|
||||
def ensure_collection(self, collection_name: str, vector: bool, dim: int) -> None:
|
||||
client = self._get_client()
|
||||
if client.has_collection(collection_name):
|
||||
return
|
||||
common_fields = [
|
||||
FieldSchema(name="path", dtype=DataType.VARCHAR, max_length=512, is_primary=True, auto_id=False),
|
||||
FieldSchema(name="source_path", dtype=DataType.VARCHAR, max_length=512, is_primary=False, auto_id=False),
|
||||
]
|
||||
if vector:
|
||||
vector_dim = dim if isinstance(dim, int) and dim > 0 else 0
|
||||
if vector_dim <= 0:
|
||||
vector_dim = 4096
|
||||
fields = [
|
||||
*common_fields,
|
||||
FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=vector_dim),
|
||||
]
|
||||
schema = CollectionSchema(fields, description="Vector collection", enable_dynamic_field=True)
|
||||
client.create_collection(collection_name, schema=schema)
|
||||
index_params = MilvusClient.prepare_index_params()
|
||||
index_params.add_index(
|
||||
field_name="embedding",
|
||||
index_type="IVF_FLAT",
|
||||
index_name="vector_index",
|
||||
metric_type="COSINE",
|
||||
params={"nlist": 64},
|
||||
)
|
||||
client.create_index(collection_name, index_params=index_params)
|
||||
else:
|
||||
schema = CollectionSchema(common_fields, description="Simple file index", enable_dynamic_field=True)
|
||||
client.create_collection(collection_name, schema=schema)
|
||||
|
||||
def upsert_vector(self, collection_name: str, data: Dict[str, Any]) -> None:
|
||||
payload = dict(data)
|
||||
payload.setdefault("source_path", payload.get("path"))
|
||||
payload.setdefault("vector_id", payload.get("path"))
|
||||
self._get_client().upsert(collection_name, data=[payload])
|
||||
|
||||
def delete_vector(self, collection_name: str, path: str) -> None:
|
||||
client = self._get_client()
|
||||
escaped = path.replace('"', '\\"')
|
||||
client.delete(collection_name, filter=f'source_path == "{escaped}"')
|
||||
|
||||
def search_vectors(self, collection_name: str, query_embedding, top_k: int):
|
||||
search_params = {"metric_type": "COSINE"}
|
||||
output_fields = [
|
||||
"path",
|
||||
"source_path",
|
||||
"chunk_id",
|
||||
"mime",
|
||||
"text",
|
||||
"start_offset",
|
||||
"end_offset",
|
||||
"type",
|
||||
"name",
|
||||
]
|
||||
raw_results = self._get_client().search(
|
||||
collection_name,
|
||||
data=[query_embedding],
|
||||
anns_field="embedding",
|
||||
search_params=search_params,
|
||||
limit=top_k,
|
||||
output_fields=output_fields,
|
||||
)
|
||||
formatted: List[List[Dict[str, Any]]] = []
|
||||
for hits in raw_results:
|
||||
bucket: List[Dict[str, Any]] = []
|
||||
for hit in hits:
|
||||
hit_id, distance, entity = self._extract_hit_payload(hit)
|
||||
bucket.append({
|
||||
"id": hit_id,
|
||||
"distance": distance,
|
||||
"entity": entity,
|
||||
})
|
||||
formatted.append(bucket)
|
||||
return formatted
|
||||
|
||||
def search_by_path(self, collection_name: str, query_path: str, top_k: int):
|
||||
if query_path:
|
||||
escaped = query_path.replace('"', '\\"')
|
||||
filter_expr = f'source_path like \"%{escaped}%\"'
|
||||
else:
|
||||
filter_expr = "source_path like '%%'"
|
||||
results = self._get_client().query(
|
||||
collection_name,
|
||||
filter=filter_expr,
|
||||
limit=top_k,
|
||||
output_fields=[
|
||||
"path",
|
||||
"source_path",
|
||||
"chunk_id",
|
||||
"mime",
|
||||
"text",
|
||||
"start_offset",
|
||||
"end_offset",
|
||||
"type",
|
||||
"name",
|
||||
],
|
||||
)
|
||||
formatted = []
|
||||
for row in results:
|
||||
entity = dict(row)
|
||||
entity.setdefault("path", entity.get("source_path"))
|
||||
formatted.append({
|
||||
"id": entity.get("path"),
|
||||
"distance": 1.0,
|
||||
"entity": entity,
|
||||
})
|
||||
return [formatted]
|
||||
|
||||
def get_all_stats(self) -> Dict[str, Any]:
|
||||
client = self._get_client()
|
||||
try:
|
||||
collection_names = client.list_collections()
|
||||
except Exception as exc:
|
||||
raise RuntimeError(f"Failed to list collections: {exc}") from exc
|
||||
|
||||
collections: List[Dict[str, Any]] = []
|
||||
total_vectors = 0
|
||||
total_estimated_memory = 0
|
||||
|
||||
for name in collection_names:
|
||||
try:
|
||||
stats = client.get_collection_stats(name) or {}
|
||||
except Exception:
|
||||
stats = {}
|
||||
row_count = self._to_int(stats.get("row_count"))
|
||||
total_vectors += row_count
|
||||
|
||||
dimension: Optional[int] = None
|
||||
is_vector_collection = False
|
||||
try:
|
||||
description = client.describe_collection(name)
|
||||
except Exception:
|
||||
description = None
|
||||
|
||||
if description:
|
||||
for field in description.get("fields", []):
|
||||
if field.get("type") == DataType.FLOAT_VECTOR:
|
||||
params = field.get("params") or {}
|
||||
dimension = self._to_int(params.get("dim")) or 4096
|
||||
is_vector_collection = True
|
||||
break
|
||||
|
||||
estimated_memory = 0
|
||||
if is_vector_collection and dimension:
|
||||
estimated_memory = row_count * dimension * 4
|
||||
total_estimated_memory += estimated_memory
|
||||
|
||||
indexes: List[Dict[str, Any]] = []
|
||||
try:
|
||||
index_names = client.list_indexes(name) or []
|
||||
except Exception:
|
||||
index_names = []
|
||||
|
||||
for index_name in index_names:
|
||||
try:
|
||||
detail = client.describe_index(name) or {}
|
||||
except Exception:
|
||||
detail = {}
|
||||
indexes.append(
|
||||
{
|
||||
"index_name": index_name,
|
||||
"index_type": detail.get("index_type"),
|
||||
"metric_type": detail.get("metric_type"),
|
||||
"indexed_rows": self._to_int(detail.get("indexed_rows")),
|
||||
"pending_index_rows": self._to_int(detail.get("pending_index_rows")),
|
||||
"state": detail.get("state"),
|
||||
}
|
||||
)
|
||||
|
||||
collections.append(
|
||||
{
|
||||
"name": name,
|
||||
"row_count": row_count,
|
||||
"dimension": dimension if is_vector_collection else None,
|
||||
"estimated_memory_bytes": estimated_memory,
|
||||
"is_vector_collection": is_vector_collection,
|
||||
"indexes": indexes,
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"collections": collections,
|
||||
"collection_count": len(collections),
|
||||
"total_vectors": total_vectors,
|
||||
"estimated_total_memory_bytes": total_estimated_memory,
|
||||
"db_file_size_bytes": None,
|
||||
}
|
||||
|
||||
def clear_all_data(self) -> None:
|
||||
client = self._get_client()
|
||||
for collection_name in client.list_collections():
|
||||
client.drop_collection(collection_name)
|
||||
271
domain/ai/vector_providers/qdrant.py
Normal file
271
domain/ai/vector_providers/qdrant.py
Normal file
@@ -0,0 +1,271 @@
|
||||
from typing import Any, Dict, List, Optional, Sequence
|
||||
from uuid import NAMESPACE_URL, uuid5
|
||||
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.http import models as qmodels
|
||||
|
||||
from .base import BaseVectorProvider
|
||||
|
||||
|
||||
class QdrantProvider(BaseVectorProvider):
|
||||
type = "qdrant"
|
||||
label = "Qdrant"
|
||||
description = "Qdrant vector database (HTTP API)."
|
||||
enabled = True
|
||||
config_schema: List[Dict[str, Any]] = [
|
||||
{
|
||||
"key": "url",
|
||||
"label": "Server URL",
|
||||
"type": "text",
|
||||
"required": True,
|
||||
"placeholder": "http://localhost:6333",
|
||||
},
|
||||
{
|
||||
"key": "api_key",
|
||||
"label": "API Key",
|
||||
"type": "password",
|
||||
"required": False,
|
||||
},
|
||||
]
|
||||
|
||||
def __init__(self, config: Dict[str, Any] | None = None):
|
||||
super().__init__(config)
|
||||
self.client: Optional[QdrantClient] = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
url = (self.config.get("url") or "").strip()
|
||||
if not url:
|
||||
raise RuntimeError("Qdrant URL is required")
|
||||
|
||||
api_key = (self.config.get("api_key") or None) or None
|
||||
try:
|
||||
client = QdrantClient(url=url, api_key=api_key)
|
||||
client.get_collections()
|
||||
self.client = client
|
||||
except Exception as exc: # pragma: no cover - 依赖外部服务
|
||||
raise RuntimeError(f"Failed to connect to Qdrant at {url}: {exc}") from exc
|
||||
|
||||
def _get_client(self) -> QdrantClient:
|
||||
if not self.client:
|
||||
raise RuntimeError("Qdrant client is not initialized")
|
||||
return self.client
|
||||
|
||||
@staticmethod
|
||||
def _vector_params(vector: bool, dim: int) -> qmodels.VectorParams:
|
||||
size = dim if vector and isinstance(dim, int) and dim > 0 else 1
|
||||
return qmodels.VectorParams(size=size, distance=qmodels.Distance.COSINE)
|
||||
|
||||
def _ensure_payload_indexes(self, client: QdrantClient, collection_name: str) -> None:
|
||||
for field in ("path", "source_path"):
|
||||
try:
|
||||
client.create_payload_index(
|
||||
collection_name=collection_name,
|
||||
field_name=field,
|
||||
field_schema="keyword",
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - 依赖外部服务
|
||||
message = str(exc).lower()
|
||||
if "already exists" in message or "index exists" in message:
|
||||
continue
|
||||
raise
|
||||
|
||||
def ensure_collection(self, collection_name: str, vector: bool, dim: int) -> None:
|
||||
client = self._get_client()
|
||||
try:
|
||||
exists = client.collection_exists(collection_name)
|
||||
except Exception as exc: # pragma: no cover - 依赖外部服务
|
||||
raise RuntimeError(f"Failed to check Qdrant collection '{collection_name}': {exc}") from exc
|
||||
|
||||
if exists:
|
||||
try:
|
||||
self._ensure_payload_indexes(client, collection_name)
|
||||
except Exception:
|
||||
pass
|
||||
return
|
||||
|
||||
vectors_config = self._vector_params(vector, dim)
|
||||
try:
|
||||
client.create_collection(collection_name=collection_name, vectors_config=vectors_config)
|
||||
except Exception as exc: # pragma: no cover
|
||||
if "already exists" in str(exc).lower():
|
||||
try:
|
||||
self._ensure_payload_indexes(client, collection_name)
|
||||
except Exception:
|
||||
pass
|
||||
return
|
||||
raise RuntimeError(f"Failed to create Qdrant collection '{collection_name}': {exc}") from exc
|
||||
|
||||
try:
|
||||
self._ensure_payload_indexes(client, collection_name)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _point_id(uid: str) -> str:
|
||||
return str(uuid5(NAMESPACE_URL, uid))
|
||||
|
||||
def _prepare_point(self, data: Dict[str, Any]) -> qmodels.PointStruct:
|
||||
uid = data.get("path")
|
||||
if not uid:
|
||||
raise ValueError("Qdrant upsert requires 'path' in data")
|
||||
|
||||
embedding = data.get("embedding")
|
||||
if embedding is None:
|
||||
vector = [0.0]
|
||||
else:
|
||||
vector = [float(x) for x in embedding]
|
||||
|
||||
payload = {k: v for k, v in data.items() if k != "embedding"}
|
||||
payload.setdefault("vector_id", uid)
|
||||
source_path = payload.get("source_path") or payload.get("path")
|
||||
payload["path"] = source_path
|
||||
return qmodels.PointStruct(id=self._point_id(str(uid)), vector=vector, payload=payload)
|
||||
|
||||
def upsert_vector(self, collection_name: str, data: Dict[str, Any]) -> None:
|
||||
client = self._get_client()
|
||||
point = self._prepare_point(data)
|
||||
client.upsert(collection_name=collection_name, wait=True, points=[point])
|
||||
|
||||
def delete_vector(self, collection_name: str, path: str) -> None:
|
||||
client = self._get_client()
|
||||
condition = qmodels.FieldCondition(
|
||||
key="path",
|
||||
match=qmodels.MatchValue(value=path),
|
||||
)
|
||||
flt = qmodels.Filter(must=[condition])
|
||||
selector = qmodels.FilterSelector(filter=flt)
|
||||
client.delete(collection_name=collection_name, points_selector=selector, wait=True)
|
||||
|
||||
def _format_search_results(self, points: Sequence[qmodels.ScoredPoint]):
|
||||
return [
|
||||
{
|
||||
"id": point.id,
|
||||
"distance": point.score,
|
||||
"entity": point.payload or {},
|
||||
}
|
||||
for point in points
|
||||
]
|
||||
|
||||
def search_vectors(self, collection_name: str, query_embedding, top_k: int):
|
||||
client = self._get_client()
|
||||
vector = [float(x) for x in query_embedding]
|
||||
points = client.search(
|
||||
collection_name=collection_name,
|
||||
query_vector=vector,
|
||||
limit=top_k,
|
||||
with_payload=True,
|
||||
)
|
||||
return [self._format_search_results(points)]
|
||||
|
||||
def search_by_path(self, collection_name: str, query_path: str, top_k: int):
|
||||
client = self._get_client()
|
||||
results: List[Dict[str, Any]] = []
|
||||
offset: Optional[str | int] = None
|
||||
remaining = max(top_k, 1)
|
||||
|
||||
while len(results) < top_k:
|
||||
batch_size = min(max(remaining * 2, 10), 200)
|
||||
records, next_offset = client.scroll(
|
||||
collection_name=collection_name,
|
||||
limit=batch_size,
|
||||
offset=offset,
|
||||
with_payload=True,
|
||||
)
|
||||
if not records:
|
||||
break
|
||||
|
||||
for record in records:
|
||||
payload = record.payload or {}
|
||||
path = payload.get("path")
|
||||
if query_path and path and query_path not in path:
|
||||
continue
|
||||
results.append({"id": record.id, "distance": 1.0, "entity": payload})
|
||||
if len(results) >= top_k:
|
||||
break
|
||||
|
||||
if next_offset is None or len(results) >= top_k:
|
||||
break
|
||||
offset = next_offset
|
||||
remaining = top_k - len(results)
|
||||
|
||||
return [results]
|
||||
|
||||
def _extract_vector_config(self, vectors) -> Optional[qmodels.VectorParams]:
|
||||
if isinstance(vectors, qmodels.VectorParams):
|
||||
return vectors
|
||||
if isinstance(vectors, dict):
|
||||
for value in vectors.values():
|
||||
if isinstance(value, qmodels.VectorParams):
|
||||
return value
|
||||
return None
|
||||
|
||||
def get_all_stats(self) -> Dict[str, Any]:
|
||||
client = self._get_client()
|
||||
try:
|
||||
response = client.get_collections()
|
||||
except Exception as exc: # pragma: no cover
|
||||
raise RuntimeError(f"Failed to list Qdrant collections: {exc}") from exc
|
||||
|
||||
collections: List[Dict[str, Any]] = []
|
||||
total_vectors = 0
|
||||
total_estimated_memory = 0
|
||||
|
||||
for description in response.collections or []:
|
||||
name = description.name
|
||||
try:
|
||||
info = client.get_collection(name)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
row_count = int(info.points_count or 0)
|
||||
total_vectors += row_count
|
||||
|
||||
vector_params = self._extract_vector_config(info.config.params.vectors if info.config and info.config.params else None)
|
||||
dimension = int(vector_params.size) if vector_params and vector_params.size else None
|
||||
estimated_memory = row_count * dimension * 4 if dimension else 0
|
||||
total_estimated_memory += estimated_memory
|
||||
distance = str(vector_params.distance) if vector_params and vector_params.distance else None
|
||||
|
||||
indexed_rows = int(info.indexed_vectors_count or 0)
|
||||
pending_rows = max(row_count - indexed_rows, 0)
|
||||
|
||||
collections.append(
|
||||
{
|
||||
"name": name,
|
||||
"row_count": row_count,
|
||||
"dimension": dimension,
|
||||
"estimated_memory_bytes": estimated_memory,
|
||||
"is_vector_collection": dimension is not None and dimension > 1,
|
||||
"indexes": [
|
||||
{
|
||||
"index_name": "hnsw",
|
||||
"index_type": "HNSW",
|
||||
"metric_type": distance,
|
||||
"indexed_rows": indexed_rows,
|
||||
"pending_index_rows": pending_rows,
|
||||
"state": info.status,
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"collections": collections,
|
||||
"collection_count": len(collections),
|
||||
"total_vectors": total_vectors,
|
||||
"estimated_total_memory_bytes": total_estimated_memory,
|
||||
"db_file_size_bytes": None,
|
||||
}
|
||||
|
||||
def clear_all_data(self) -> None:
|
||||
client = self._get_client()
|
||||
try:
|
||||
response = client.get_collections()
|
||||
except Exception as exc: # pragma: no cover
|
||||
raise RuntimeError(f"Failed to list Qdrant collections: {exc}") from exc
|
||||
|
||||
for description in response.collections or []:
|
||||
try:
|
||||
client.delete_collection(description.name)
|
||||
except Exception:
|
||||
continue
|
||||
4
domain/audit/__init__.py
Normal file
4
domain/audit/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .decorator import audit
|
||||
from .types import AuditAction
|
||||
|
||||
__all__ = ["audit", "AuditAction"]
|
||||
69
domain/audit/api.py
Normal file
69
domain/audit/api.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from datetime import datetime, timezone
|
||||
from typing import Annotated, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
|
||||
from api import response
|
||||
from domain.auth import User, get_current_active_user
|
||||
from domain.permission import require_system_permission
|
||||
from domain.permission.types import SystemPermission
|
||||
from .service import AuditService
|
||||
from .types import AuditAction
|
||||
|
||||
CurrentUser = Annotated[User, Depends(get_current_active_user)]
|
||||
|
||||
router = APIRouter(prefix="/api/audit", tags=["Audit"])
|
||||
|
||||
|
||||
def _parse_iso(value: Optional[str], field: str):
|
||||
if not value:
|
||||
return None
|
||||
try:
|
||||
normalized = value.replace("Z", "+00:00")
|
||||
dt = datetime.fromisoformat(normalized)
|
||||
if dt.tzinfo:
|
||||
dt = dt.astimezone(timezone.utc).replace(tzinfo=None)
|
||||
return dt
|
||||
except ValueError as exc: # noqa: BLE001
|
||||
raise HTTPException(status_code=400, detail=f"invalid {field}") from exc
|
||||
|
||||
|
||||
@router.get("/logs")
|
||||
@require_system_permission(SystemPermission.AUDIT_VIEW)
|
||||
async def list_audit_logs(
|
||||
current_user: CurrentUser,
|
||||
page_num: int = Query(1, ge=1, alias="page", description="页码"),
|
||||
page_size: int = Query(20, ge=1, le=200, description="每页条数"),
|
||||
action: AuditAction | None = Query(None, description="操作类型"),
|
||||
success: bool | None = Query(None, description="是否成功"),
|
||||
username: str | None = Query(None, description="用户名模糊匹配"),
|
||||
path: str | None = Query(None, description="路径模糊匹配"),
|
||||
start_time: str | None = Query(None, description="开始时间 (ISO 8601)"),
|
||||
end_time: str | None = Query(None, description="结束时间 (ISO 8601)"),
|
||||
):
|
||||
start_dt = _parse_iso(start_time, "start_time")
|
||||
end_dt = _parse_iso(end_time, "end_time")
|
||||
items, total = await AuditService.list_logs(
|
||||
page=page_num,
|
||||
page_size=page_size,
|
||||
action=str(action) if action else None,
|
||||
success=success,
|
||||
username=username,
|
||||
path=path,
|
||||
start_time=start_dt,
|
||||
end_time=end_dt,
|
||||
)
|
||||
return response.success(response.page(items, total, page_num, page_size))
|
||||
|
||||
|
||||
@router.delete("/logs")
|
||||
@require_system_permission(SystemPermission.AUDIT_VIEW)
|
||||
async def clear_audit_logs(
|
||||
current_user: CurrentUser,
|
||||
start_time: str | None = Query(None, description="开始时间 (ISO 8601)"),
|
||||
end_time: str | None = Query(None, description="结束时间 (ISO 8601)"),
|
||||
):
|
||||
start_dt = _parse_iso(start_time, "start_time")
|
||||
end_dt = _parse_iso(end_time, "end_time")
|
||||
deleted_count = await AuditService.clear_logs(start_time=start_dt, end_time=end_dt)
|
||||
return response.success({"deleted_count": deleted_count})
|
||||
204
domain/audit/decorator.py
Normal file
204
domain/audit/decorator.py
Normal file
@@ -0,0 +1,204 @@
|
||||
import inspect
|
||||
import time
|
||||
from functools import wraps
|
||||
from typing import Any, Dict, Mapping, Optional
|
||||
|
||||
import jwt
|
||||
from fastapi import Request
|
||||
from jwt.exceptions import InvalidTokenError
|
||||
|
||||
from domain.auth import ALGORITHM
|
||||
from domain.config import ConfigService
|
||||
from models.database import UserAccount
|
||||
from .service import AuditService
|
||||
from .types import AuditAction
|
||||
|
||||
|
||||
def _extract_request(bound_args: Mapping[str, Any]) -> Request | None:
|
||||
for value in bound_args.values():
|
||||
if isinstance(value, Request):
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
async def _resolve_user(request: Request | None, user_obj: Any | None) -> tuple[Optional[int], Optional[str]]:
|
||||
user_id: int | None = None
|
||||
username: str | None = None
|
||||
|
||||
if request:
|
||||
auth_header = request.headers.get("authorization") or request.headers.get("Authorization")
|
||||
if auth_header and auth_header.lower().startswith("bearer "):
|
||||
token = auth_header.split(" ", 1)[1]
|
||||
try:
|
||||
payload = jwt.decode(token, await ConfigService.get_secret_key("SECRET_KEY"), algorithms=[ALGORITHM])
|
||||
username = payload.get("sub") or payload.get("username")
|
||||
if username:
|
||||
user = await UserAccount.get_or_none(username=username)
|
||||
user_id = user.id if user else None
|
||||
except (InvalidTokenError, Exception):
|
||||
pass
|
||||
|
||||
if user_id is None and username is None and user_obj is not None:
|
||||
user_id = getattr(user_obj, "id", None) or getattr(user_obj, "user_id", None)
|
||||
username = getattr(user_obj, "username", None) or getattr(user_obj, "name", None)
|
||||
if isinstance(user_obj, dict):
|
||||
user_id = user_obj.get("id", user_obj.get("user_id", user_id))
|
||||
username = user_obj.get("username", user_obj.get("name", username))
|
||||
|
||||
return user_id, username
|
||||
|
||||
|
||||
def _extract_body_fields(bound_args: Mapping[str, Any], body_fields: list[str] | None, redact_fields: list[str] | None):
|
||||
if not body_fields:
|
||||
return None
|
||||
body: Dict[str, Any] = {}
|
||||
redacts = set(redact_fields or [])
|
||||
for value in bound_args.values():
|
||||
data: Optional[Dict[str, Any]] = None
|
||||
if hasattr(value, "model_dump"):
|
||||
try:
|
||||
data = value.model_dump()
|
||||
except Exception:
|
||||
data = None
|
||||
elif hasattr(value, "dict"):
|
||||
try:
|
||||
data = value.dict()
|
||||
except Exception:
|
||||
data = None
|
||||
elif isinstance(value, dict):
|
||||
data = value
|
||||
elif hasattr(value, "__dict__"):
|
||||
data = dict(value.__dict__)
|
||||
if not isinstance(data, dict):
|
||||
continue
|
||||
for field in body_fields:
|
||||
if field in data and field not in body:
|
||||
body[field] = data[field]
|
||||
if not body:
|
||||
return None
|
||||
for field in redacts:
|
||||
if field in body:
|
||||
body[field] = "<redacted>"
|
||||
return body
|
||||
|
||||
|
||||
def _build_request_params(request: Request | None) -> Dict[str, Any] | None:
|
||||
if not request:
|
||||
return None
|
||||
params: Dict[str, Any] = {}
|
||||
query = dict(request.query_params)
|
||||
if query:
|
||||
params["query"] = query
|
||||
path_params = dict(request.path_params or {})
|
||||
if path_params:
|
||||
params["path"] = path_params
|
||||
return params or None
|
||||
|
||||
|
||||
def _get_client_ip(request: Request | None) -> str | None:
|
||||
if not request:
|
||||
return None
|
||||
cf_connecting_ip = request.headers.get("cf-connecting-ip") or request.headers.get("CF-Connecting-IP")
|
||||
if cf_connecting_ip:
|
||||
ip = cf_connecting_ip.strip()
|
||||
if ip:
|
||||
return ip
|
||||
x_real_ip = request.headers.get("x-real-ip") or request.headers.get("X-Real-IP")
|
||||
if x_real_ip:
|
||||
ip = x_real_ip.strip()
|
||||
if ip:
|
||||
return ip
|
||||
x_forwarded_for = request.headers.get("x-forwarded-for") or request.headers.get("X-Forwarded-For")
|
||||
if x_forwarded_for:
|
||||
for part in x_forwarded_for.split(","):
|
||||
ip = part.strip()
|
||||
if ip and ip.lower() != "unknown":
|
||||
return ip
|
||||
return request.client.host if request.client else None
|
||||
|
||||
|
||||
def _status_code_from_response(response: Any) -> int:
|
||||
if hasattr(response, "status_code"):
|
||||
try:
|
||||
return int(getattr(response, "status_code"))
|
||||
except Exception:
|
||||
pass
|
||||
return 200
|
||||
|
||||
|
||||
def audit(
|
||||
*,
|
||||
action: AuditAction,
|
||||
description: str | None = None,
|
||||
body_fields: list[str] | None = None,
|
||||
redact_fields: list[str] | None = None,
|
||||
user_kw: str = "current_user",
|
||||
):
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
bound = inspect.signature(func).bind_partial(*args, **kwargs)
|
||||
bound.apply_defaults()
|
||||
request = _extract_request(bound.arguments)
|
||||
start = time.perf_counter()
|
||||
user_info = bound.arguments.get(user_kw)
|
||||
user_id, username = await _resolve_user(request, user_info)
|
||||
request_params = _build_request_params(request)
|
||||
request_body = _extract_body_fields(bound.arguments, body_fields, redact_fields)
|
||||
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
if inspect.isawaitable(result):
|
||||
result = await result
|
||||
status_code = _status_code_from_response(result)
|
||||
success = True
|
||||
error = None
|
||||
except Exception as exc: # noqa: BLE001
|
||||
status_code = getattr(exc, "status_code", 500)
|
||||
success = False
|
||||
error = str(exc)
|
||||
duration_ms = round((time.perf_counter() - start) * 1000, 2)
|
||||
try:
|
||||
await AuditService.log(
|
||||
action=action,
|
||||
description=description,
|
||||
user_id=user_id,
|
||||
username=username,
|
||||
client_ip=_get_client_ip(request),
|
||||
method=request.method if request else "",
|
||||
path=request.url.path if request else func.__name__,
|
||||
status_code=status_code,
|
||||
duration_ms=duration_ms,
|
||||
success=success,
|
||||
request_params=request_params,
|
||||
request_body=request_body,
|
||||
error=error,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
raise
|
||||
|
||||
duration_ms = round((time.perf_counter() - start) * 1000, 2)
|
||||
try:
|
||||
await AuditService.log(
|
||||
action=action,
|
||||
description=description,
|
||||
user_id=user_id,
|
||||
username=username,
|
||||
client_ip=_get_client_ip(request),
|
||||
method=request.method if request else "",
|
||||
path=request.url.path if request else func.__name__,
|
||||
status_code=status_code,
|
||||
duration_ms=duration_ms,
|
||||
success=success,
|
||||
request_params=request_params,
|
||||
request_body=request_body,
|
||||
error=error,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
124
domain/audit/service.py
Normal file
124
domain/audit/service.py
Normal file
@@ -0,0 +1,124 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from models.database import AuditLog
|
||||
|
||||
from .types import AuditAction
|
||||
|
||||
|
||||
class AuditService:
|
||||
@classmethod
|
||||
async def log(
|
||||
cls,
|
||||
*,
|
||||
action: AuditAction | str,
|
||||
description: Optional[str],
|
||||
user_id: Optional[int],
|
||||
username: Optional[str],
|
||||
client_ip: Optional[str],
|
||||
method: str,
|
||||
path: str,
|
||||
status_code: int,
|
||||
duration_ms: Optional[float],
|
||||
success: bool,
|
||||
request_params: Optional[Dict[str, Any]] = None,
|
||||
request_body: Optional[Dict[str, Any]] = None,
|
||||
error: Optional[str] = None,
|
||||
) -> None:
|
||||
await AuditLog.create(
|
||||
action=str(action),
|
||||
description=description,
|
||||
user_id=user_id,
|
||||
username=username,
|
||||
client_ip=client_ip,
|
||||
method=method,
|
||||
path=path,
|
||||
status_code=status_code,
|
||||
duration_ms=duration_ms,
|
||||
success=success,
|
||||
request_params=request_params,
|
||||
request_body=request_body,
|
||||
error=error,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _serialize(cls, log: AuditLog) -> Dict[str, Any]:
|
||||
return {
|
||||
"id": log.id,
|
||||
"created_at": log.created_at.isoformat() if log.created_at else None,
|
||||
"action": log.action,
|
||||
"description": log.description,
|
||||
"user_id": log.user_id,
|
||||
"username": log.username,
|
||||
"client_ip": log.client_ip,
|
||||
"method": log.method,
|
||||
"path": log.path,
|
||||
"status_code": log.status_code,
|
||||
"duration_ms": log.duration_ms,
|
||||
"success": log.success,
|
||||
"request_params": log.request_params,
|
||||
"request_body": log.request_body,
|
||||
"error": log.error,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _apply_filters(
|
||||
cls,
|
||||
*,
|
||||
action: str | None = None,
|
||||
success: bool | None = None,
|
||||
username: str | None = None,
|
||||
path: str | None = None,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
):
|
||||
qs = AuditLog.all()
|
||||
if action:
|
||||
qs = qs.filter(action=action)
|
||||
if success is not None:
|
||||
qs = qs.filter(success=success)
|
||||
if username:
|
||||
qs = qs.filter(username__icontains=username)
|
||||
if path:
|
||||
qs = qs.filter(path__icontains=path)
|
||||
if start_time:
|
||||
qs = qs.filter(created_at__gte=start_time)
|
||||
if end_time:
|
||||
qs = qs.filter(created_at__lte=end_time)
|
||||
return qs
|
||||
|
||||
@classmethod
|
||||
async def list_logs(
|
||||
cls,
|
||||
*,
|
||||
page: int,
|
||||
page_size: int,
|
||||
action: str | None = None,
|
||||
success: bool | None = None,
|
||||
username: str | None = None,
|
||||
path: str | None = None,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
):
|
||||
qs = cls._apply_filters(
|
||||
action=action,
|
||||
success=success,
|
||||
username=username,
|
||||
path=path,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
total = await qs.count()
|
||||
offset = (page - 1) * page_size
|
||||
items = await qs.order_by("-created_at").offset(offset).limit(page_size)
|
||||
return [cls._serialize(log) for log in items], total
|
||||
|
||||
@classmethod
|
||||
async def clear_logs(
|
||||
cls,
|
||||
*,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
) -> int:
|
||||
qs = cls._apply_filters(start_time=start_time, end_time=end_time)
|
||||
deleted_count = await qs.delete()
|
||||
return deleted_count
|
||||
16
domain/audit/types.py
Normal file
16
domain/audit/types.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class AuditAction(StrEnum):
|
||||
LOGIN = "login"
|
||||
LOGOUT = "logout"
|
||||
REGISTER = "register"
|
||||
READ = "read"
|
||||
CREATE = "create"
|
||||
UPDATE = "update"
|
||||
DELETE = "delete"
|
||||
RESET_PASSWORD = "reset_password"
|
||||
SHARE = "share"
|
||||
DOWNLOAD = "download"
|
||||
UPLOAD = "upload"
|
||||
OTHER = "other"
|
||||
49
domain/auth/__init__.py
Normal file
49
domain/auth/__init__.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from .service import (
|
||||
ALGORITHM,
|
||||
AuthService,
|
||||
authenticate_user_db,
|
||||
create_access_token,
|
||||
get_current_active_user,
|
||||
get_current_user,
|
||||
get_password_hash,
|
||||
has_users,
|
||||
register_user,
|
||||
request_password_reset,
|
||||
reset_password_with_token,
|
||||
verify_password,
|
||||
verify_password_reset_token,
|
||||
)
|
||||
from .types import (
|
||||
PasswordResetConfirm,
|
||||
PasswordResetRequest,
|
||||
RegisterRequest,
|
||||
Token,
|
||||
TokenData,
|
||||
UpdateMeRequest,
|
||||
User,
|
||||
UserInDB,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ALGORITHM",
|
||||
"AuthService",
|
||||
"authenticate_user_db",
|
||||
"create_access_token",
|
||||
"get_current_active_user",
|
||||
"get_current_user",
|
||||
"get_password_hash",
|
||||
"has_users",
|
||||
"register_user",
|
||||
"request_password_reset",
|
||||
"reset_password_with_token",
|
||||
"verify_password",
|
||||
"verify_password_reset_token",
|
||||
"PasswordResetConfirm",
|
||||
"PasswordResetRequest",
|
||||
"RegisterRequest",
|
||||
"Token",
|
||||
"TokenData",
|
||||
"UpdateMeRequest",
|
||||
"User",
|
||||
"UserInDB",
|
||||
]
|
||||
90
domain/auth/api.py
Normal file
90
domain/auth/api.py
Normal file
@@ -0,0 +1,90 @@
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
|
||||
from api.response import success
|
||||
from domain.audit import AuditAction, audit
|
||||
from .service import AuthService, get_current_active_user
|
||||
from .types import (
|
||||
PasswordResetConfirm,
|
||||
PasswordResetRequest,
|
||||
RegisterRequest,
|
||||
Token,
|
||||
UpdateMeRequest,
|
||||
User,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/api/auth", tags=["auth"])
|
||||
|
||||
|
||||
@router.post("/register", summary="注册用户(首个用户为管理员)")
|
||||
@audit(
|
||||
action=AuditAction.REGISTER,
|
||||
description="注册用户",
|
||||
body_fields=["username", "email", "full_name"],
|
||||
redact_fields=["password"],
|
||||
)
|
||||
async def register(request: Request, data: RegisterRequest):
|
||||
user = await AuthService.register_user(data)
|
||||
return success({"username": user.username}, msg="注册成功")
|
||||
|
||||
|
||||
@router.post("/login")
|
||||
@audit(action=AuditAction.LOGIN, description="用户登录", body_fields=["username"], redact_fields=["password"])
|
||||
async def login_for_access_token(
|
||||
request: Request,
|
||||
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
||||
) -> Token:
|
||||
return await AuthService.login(form_data)
|
||||
|
||||
|
||||
@router.get("/me", summary="获取当前登录用户信息")
|
||||
@audit(action=AuditAction.READ, description="获取当前用户信息")
|
||||
async def get_me(
|
||||
request: Request, current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
profile = AuthService.get_profile(current_user)
|
||||
return success(profile)
|
||||
|
||||
|
||||
@router.put("/me", summary="更新当前登录用户信息")
|
||||
@audit(
|
||||
action=AuditAction.UPDATE,
|
||||
description="更新当前用户信息",
|
||||
body_fields=["email", "full_name"],
|
||||
redact_fields=["old_password", "new_password"],
|
||||
)
|
||||
async def update_me(
|
||||
request: Request,
|
||||
payload: UpdateMeRequest,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
profile = await AuthService.update_me(payload, current_user)
|
||||
return success(profile)
|
||||
|
||||
|
||||
@router.post("/password-reset/request", summary="请求密码重置邮件")
|
||||
@audit(action=AuditAction.RESET_PASSWORD, description="请求密码重置邮件", body_fields=["email"])
|
||||
async def password_reset_request_endpoint(request: Request, payload: PasswordResetRequest):
|
||||
await AuthService.request_password_reset(payload)
|
||||
return success(msg="如果邮箱存在,将发送重置邮件")
|
||||
|
||||
|
||||
@router.get("/password-reset/verify", summary="校验密码重置令牌")
|
||||
@audit(action=AuditAction.RESET_PASSWORD, description="校验密码重置令牌", redact_fields=["token"])
|
||||
async def password_reset_verify(request: Request, token: str):
|
||||
user = await AuthService.verify_password_reset_token(token)
|
||||
return success({"username": user.username, "email": user.email})
|
||||
|
||||
|
||||
@router.post("/password-reset/confirm", summary="使用令牌重置密码")
|
||||
@audit(
|
||||
action=AuditAction.RESET_PASSWORD,
|
||||
description="重置密码",
|
||||
body_fields=["token"],
|
||||
redact_fields=["token", "password"],
|
||||
)
|
||||
async def password_reset_confirm(request: Request, payload: PasswordResetConfirm):
|
||||
await AuthService.reset_password_with_token(payload)
|
||||
return success(msg="密码已重置")
|
||||
415
domain/auth/service.py
Normal file
415
domain/auth/service.py
Normal file
@@ -0,0 +1,415 @@
|
||||
import asyncio
|
||||
import hashlib
|
||||
import secrets
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Annotated
|
||||
|
||||
import bcrypt
|
||||
import jwt
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
||||
from jwt.exceptions import InvalidTokenError
|
||||
|
||||
from domain.config import ConfigService
|
||||
from models.database import Role, UserAccount, UserRole
|
||||
from .types import (
|
||||
PasswordResetConfirm,
|
||||
PasswordResetRequest,
|
||||
RegisterRequest,
|
||||
Token,
|
||||
TokenData,
|
||||
UpdateMeRequest,
|
||||
User,
|
||||
UserInDB,
|
||||
)
|
||||
|
||||
ALGORITHM = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 * 365
|
||||
PASSWORD_RESET_TOKEN_EXPIRE_MINUTES = 10
|
||||
|
||||
|
||||
def _now() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PasswordResetEntry:
|
||||
user_id: int
|
||||
email: str
|
||||
username: str
|
||||
expires_at: datetime
|
||||
used: bool = False
|
||||
|
||||
|
||||
class PasswordResetStore:
|
||||
_tokens: dict[str, PasswordResetEntry] = {}
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
@classmethod
|
||||
def _cleanup(cls):
|
||||
now = _now()
|
||||
for token, record in list(cls._tokens.items()):
|
||||
if record.used or record.expires_at < now:
|
||||
cls._tokens.pop(token, None)
|
||||
|
||||
@classmethod
|
||||
async def create(cls, user: UserAccount) -> str:
|
||||
async with cls._lock:
|
||||
cls._cleanup()
|
||||
for key, record in list(cls._tokens.items()):
|
||||
if record.user_id == user.id:
|
||||
cls._tokens.pop(key, None)
|
||||
token = secrets.token_urlsafe(32)
|
||||
expires_at = _now() + timedelta(minutes=PASSWORD_RESET_TOKEN_EXPIRE_MINUTES)
|
||||
cls._tokens[token] = PasswordResetEntry(
|
||||
user_id=user.id,
|
||||
email=user.email or "",
|
||||
username=user.username,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
return token
|
||||
|
||||
@classmethod
|
||||
async def get(cls, token: str) -> PasswordResetEntry | None:
|
||||
async with cls._lock:
|
||||
cls._cleanup()
|
||||
record = cls._tokens.get(token)
|
||||
if not record or record.used:
|
||||
return None
|
||||
return record
|
||||
|
||||
@classmethod
|
||||
async def mark_used(cls, token: str) -> None:
|
||||
async with cls._lock:
|
||||
record = cls._tokens.get(token)
|
||||
if record:
|
||||
record.used = True
|
||||
cls._cleanup()
|
||||
|
||||
@classmethod
|
||||
async def invalidate_user(cls, user_id: int, except_token: str | None = None) -> None:
|
||||
async with cls._lock:
|
||||
for key, record in list(cls._tokens.items()):
|
||||
if record.user_id == user_id and key != except_token:
|
||||
cls._tokens.pop(key, None)
|
||||
cls._cleanup()
|
||||
|
||||
|
||||
class AuthService:
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login")
|
||||
algorithm = ALGORITHM
|
||||
access_token_expire_minutes = ACCESS_TOKEN_EXPIRE_MINUTES
|
||||
password_reset_token_expire_minutes = PASSWORD_RESET_TOKEN_EXPIRE_MINUTES
|
||||
|
||||
@staticmethod
|
||||
def _to_bytes(value: str) -> bytes:
|
||||
return value.encode("utf-8")
|
||||
|
||||
@classmethod
|
||||
async def get_secret_key(cls) -> str:
|
||||
return await ConfigService.get_secret_key("SECRET_KEY", None)
|
||||
|
||||
@classmethod
|
||||
def _normalize_email(cls, email: str | None) -> str:
|
||||
return (email or "").strip().lower()
|
||||
|
||||
@classmethod
|
||||
def verify_password(cls, plain_password: str, hashed_password: str) -> bool:
|
||||
try:
|
||||
return bcrypt.checkpw(cls._to_bytes(plain_password), hashed_password.encode("utf-8"))
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_password_hash(cls, password: str) -> str:
|
||||
encoded = cls._to_bytes(password)
|
||||
if len(encoded) > 72:
|
||||
raise HTTPException(status_code=400, detail="密码过长")
|
||||
return bcrypt.hashpw(encoded, bcrypt.gensalt()).decode("utf-8")
|
||||
|
||||
@classmethod
|
||||
async def get_user_db(cls, username_or_email: str) -> UserInDB | None:
|
||||
user = await UserAccount.get_or_none(username=username_or_email)
|
||||
if not user:
|
||||
user = await UserAccount.get_or_none(email=username_or_email)
|
||||
if user:
|
||||
return UserInDB(
|
||||
id=user.id,
|
||||
username=user.username,
|
||||
email=user.email,
|
||||
full_name=user.full_name,
|
||||
disabled=user.disabled,
|
||||
is_admin=user.is_admin,
|
||||
hashed_password=user.hashed_password,
|
||||
)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
async def authenticate_user_db(cls, username_or_email: str, password: str) -> UserInDB | None:
|
||||
user = await cls.get_user_db(username_or_email)
|
||||
if not user:
|
||||
return None
|
||||
if not cls.verify_password(password, user.hashed_password):
|
||||
return None
|
||||
return user
|
||||
|
||||
@classmethod
|
||||
async def has_users(cls) -> bool:
|
||||
user_count = await UserAccount.all().count()
|
||||
return user_count > 0
|
||||
|
||||
@classmethod
|
||||
async def register_user(cls, payload: RegisterRequest):
|
||||
has_users = await cls.has_users()
|
||||
normalized_email = cls._normalize_email(payload.email)
|
||||
if not normalized_email:
|
||||
raise HTTPException(status_code=400, detail="邮箱不能为空")
|
||||
|
||||
if has_users:
|
||||
allow_register = str(await ConfigService.get("AUTH_ALLOW_REGISTER", "false") or "").strip().lower()
|
||||
if allow_register not in ("1", "true", "yes", "on"):
|
||||
raise HTTPException(status_code=403, detail="系统未开放注册")
|
||||
|
||||
default_role_id_raw = str(await ConfigService.get("AUTH_DEFAULT_REGISTER_ROLE_ID", "") or "").strip()
|
||||
if not default_role_id_raw:
|
||||
raise HTTPException(status_code=400, detail="未配置默认注册角色")
|
||||
try:
|
||||
default_role_id = int(default_role_id_raw)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail="默认注册角色配置错误") from exc
|
||||
|
||||
role = await Role.get_or_none(id=default_role_id)
|
||||
if not role:
|
||||
raise HTTPException(status_code=400, detail="默认注册角色不存在")
|
||||
|
||||
exists = await UserAccount.get_or_none(username=payload.username)
|
||||
if exists:
|
||||
raise HTTPException(status_code=400, detail="用户名已存在")
|
||||
|
||||
existing_email = await UserAccount.get_or_none(email=normalized_email)
|
||||
if existing_email:
|
||||
raise HTTPException(status_code=400, detail="邮箱已被使用")
|
||||
|
||||
hashed = cls.get_password_hash(payload.password)
|
||||
|
||||
# 第一个用户自动成为超级管理员(不受开放注册开关影响)
|
||||
if not has_users:
|
||||
user = await UserAccount.create(
|
||||
username=payload.username,
|
||||
email=normalized_email,
|
||||
full_name=payload.full_name,
|
||||
hashed_password=hashed,
|
||||
disabled=False,
|
||||
is_admin=True,
|
||||
)
|
||||
return user
|
||||
|
||||
# 系统已初始化:按默认角色创建普通用户
|
||||
user = await UserAccount.create(
|
||||
username=payload.username,
|
||||
email=normalized_email,
|
||||
full_name=payload.full_name,
|
||||
hashed_password=hashed,
|
||||
disabled=False,
|
||||
is_admin=False,
|
||||
)
|
||||
await UserRole.create(user_id=user.id, role_id=default_role_id)
|
||||
return user
|
||||
|
||||
@classmethod
|
||||
async def create_access_token(cls, data: dict, expires_delta: timedelta | None = None):
|
||||
to_encode = data.copy()
|
||||
if "sub" not in to_encode and "username" in to_encode:
|
||||
to_encode["sub"] = to_encode["username"]
|
||||
expire = _now() + (expires_delta or timedelta(minutes=15))
|
||||
to_encode.update({"exp": expire})
|
||||
secret_key = await cls.get_secret_key()
|
||||
encoded_jwt = jwt.encode(to_encode, secret_key, algorithm=cls.algorithm)
|
||||
return encoded_jwt
|
||||
|
||||
@classmethod
|
||||
async def login(cls, form: OAuth2PasswordRequestForm) -> Token:
|
||||
user = await cls.authenticate_user_db(form.username, form.password)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="用户名或密码错误",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
# 更新最后登录时间
|
||||
db_user = await UserAccount.get_or_none(id=user.id)
|
||||
if db_user:
|
||||
db_user.last_login = _now()
|
||||
await db_user.save(update_fields=["last_login"])
|
||||
|
||||
access_token_expires = timedelta(minutes=cls.access_token_expire_minutes)
|
||||
access_token = await cls.create_access_token(
|
||||
data={"sub": user.username}, expires_delta=access_token_expires
|
||||
)
|
||||
return Token(access_token=access_token, token_type="bearer")
|
||||
|
||||
@classmethod
|
||||
def _build_profile(cls, user: User | UserInDB | UserAccount) -> dict:
|
||||
email = cls._normalize_email(getattr(user, "email", None))
|
||||
md5_hash = hashlib.md5(email.encode("utf-8")).hexdigest()
|
||||
gravatar_url = f"https://cn.cravatar.com/avatar/{md5_hash}?s=64&d=identicon"
|
||||
return {
|
||||
"id": user.id,
|
||||
"username": user.username,
|
||||
"email": getattr(user, "email", None),
|
||||
"full_name": getattr(user, "full_name", None),
|
||||
"gravatar_url": gravatar_url,
|
||||
"is_admin": getattr(user, "is_admin", False),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_profile(cls, user: User | UserInDB | UserAccount) -> dict:
|
||||
return cls._build_profile(user)
|
||||
|
||||
@classmethod
|
||||
async def update_me(cls, payload: UpdateMeRequest, current_user: User) -> dict:
|
||||
db_user = await UserAccount.get_or_none(id=current_user.id)
|
||||
if not db_user:
|
||||
raise HTTPException(status_code=404, detail="用户不存在")
|
||||
|
||||
if payload.email is not None:
|
||||
exists = (
|
||||
await UserAccount.filter(email=payload.email)
|
||||
.exclude(id=db_user.id)
|
||||
.exists()
|
||||
)
|
||||
if exists:
|
||||
raise HTTPException(status_code=400, detail="邮箱已被占用")
|
||||
db_user.email = payload.email
|
||||
|
||||
if payload.full_name is not None:
|
||||
db_user.full_name = payload.full_name
|
||||
|
||||
if payload.new_password:
|
||||
if not payload.old_password:
|
||||
raise HTTPException(status_code=400, detail="请提供原密码")
|
||||
if not cls.verify_password(payload.old_password, db_user.hashed_password):
|
||||
raise HTTPException(status_code=400, detail="原密码错误")
|
||||
db_user.hashed_password = cls.get_password_hash(payload.new_password)
|
||||
|
||||
await db_user.save()
|
||||
return cls._build_profile(db_user)
|
||||
|
||||
@classmethod
|
||||
async def request_password_reset(cls, payload: PasswordResetRequest) -> bool:
|
||||
normalized = cls._normalize_email(payload.email)
|
||||
if not normalized:
|
||||
return False
|
||||
user = await UserAccount.get_or_none(email=normalized)
|
||||
if not user or not user.email:
|
||||
return False
|
||||
|
||||
token = await PasswordResetStore.create(user)
|
||||
try:
|
||||
await cls._send_password_reset_email(user, token)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
await PasswordResetStore.mark_used(token)
|
||||
await PasswordResetStore.invalidate_user(user.id)
|
||||
raise HTTPException(status_code=500, detail="邮件发送失败") from exc
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
async def verify_password_reset_token(cls, token: str) -> UserAccount:
|
||||
record = await PasswordResetStore.get(token)
|
||||
if not record:
|
||||
raise HTTPException(status_code=400, detail="重置链接无效")
|
||||
user = await UserAccount.get_or_none(id=record.user_id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=400, detail="重置链接无效")
|
||||
if record.expires_at < _now():
|
||||
await PasswordResetStore.mark_used(token)
|
||||
raise HTTPException(status_code=400, detail="重置链接已过期")
|
||||
return user
|
||||
|
||||
@classmethod
|
||||
async def reset_password_with_token(cls, payload: PasswordResetConfirm) -> None:
|
||||
record = await PasswordResetStore.get(payload.token)
|
||||
if not record:
|
||||
raise HTTPException(status_code=400, detail="重置链接无效")
|
||||
if record.expires_at < _now():
|
||||
await PasswordResetStore.mark_used(payload.token)
|
||||
raise HTTPException(status_code=400, detail="重置链接已过期")
|
||||
|
||||
user = await UserAccount.get_or_none(id=record.user_id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=400, detail="重置链接无效")
|
||||
user.hashed_password = cls.get_password_hash(payload.password)
|
||||
await user.save(update_fields=["hashed_password"])
|
||||
await PasswordResetStore.mark_used(payload.token)
|
||||
await PasswordResetStore.invalidate_user(user.id)
|
||||
|
||||
@classmethod
|
||||
async def get_current_user(cls, token: str):
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
try:
|
||||
secret_key = await cls.get_secret_key()
|
||||
payload = jwt.decode(token, secret_key, algorithms=[cls.algorithm])
|
||||
username = payload.get("sub")
|
||||
if username is None:
|
||||
raise credentials_exception
|
||||
token_data = TokenData(username=username)
|
||||
except InvalidTokenError:
|
||||
raise credentials_exception
|
||||
user = await cls.get_user_db(token_data.username)
|
||||
if user is None:
|
||||
raise credentials_exception
|
||||
return user
|
||||
|
||||
@classmethod
|
||||
async def get_current_active_user(cls, current_user: User):
|
||||
if current_user.disabled:
|
||||
raise HTTPException(status_code=400, detail="Inactive user")
|
||||
return current_user
|
||||
|
||||
@classmethod
|
||||
async def _send_password_reset_email(cls, user: UserAccount, token: str) -> None:
|
||||
from domain.email import EmailService
|
||||
|
||||
app_domain = await ConfigService.get("APP_DOMAIN", None)
|
||||
base_url = (app_domain or "http://localhost:5173").rstrip("/")
|
||||
reset_link = f"{base_url}/reset-password?token={token}"
|
||||
await EmailService.enqueue_email(
|
||||
recipients=[user.email],
|
||||
subject="Foxel 密码重置",
|
||||
template="password_reset",
|
||||
context={
|
||||
"username": user.username,
|
||||
"reset_link": reset_link,
|
||||
"expire_minutes": cls.password_reset_token_expire_minutes,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def _current_user_dep(token: Annotated[str, Depends(AuthService.oauth2_scheme)]):
|
||||
return await AuthService.get_current_user(token)
|
||||
|
||||
|
||||
async def _current_active_user_dep(
|
||||
current_user: Annotated[User, Depends(_current_user_dep)],
|
||||
):
|
||||
return await AuthService.get_current_active_user(current_user)
|
||||
|
||||
|
||||
# 方便依赖注入与外部使用
|
||||
get_current_user = _current_user_dep
|
||||
get_current_active_user = _current_active_user_dep
|
||||
authenticate_user_db = AuthService.authenticate_user_db
|
||||
create_access_token = AuthService.create_access_token
|
||||
register_user = AuthService.register_user
|
||||
request_password_reset = AuthService.request_password_reset
|
||||
verify_password_reset_token = AuthService.verify_password_reset_token
|
||||
reset_password_with_token = AuthService.reset_password_with_token
|
||||
has_users = AuthService.has_users
|
||||
verify_password = AuthService.verify_password
|
||||
get_password_hash = AuthService.get_password_hash
|
||||
46
domain/auth/types.py
Normal file
46
domain/auth/types.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Token(BaseModel):
|
||||
access_token: str
|
||||
token_type: str
|
||||
|
||||
|
||||
class TokenData(BaseModel):
|
||||
username: str | None = None
|
||||
|
||||
|
||||
class User(BaseModel):
|
||||
id: int
|
||||
username: str
|
||||
email: str | None = None
|
||||
full_name: str | None = None
|
||||
disabled: bool | None = None
|
||||
is_admin: bool = False
|
||||
|
||||
|
||||
class UserInDB(User):
|
||||
hashed_password: str
|
||||
|
||||
|
||||
class RegisterRequest(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
email: str
|
||||
full_name: str | None = None
|
||||
|
||||
|
||||
class UpdateMeRequest(BaseModel):
|
||||
email: str | None = None
|
||||
full_name: str | None = None
|
||||
old_password: str | None = None
|
||||
new_password: str | None = None
|
||||
|
||||
|
||||
class PasswordResetRequest(BaseModel):
|
||||
email: str
|
||||
|
||||
|
||||
class PasswordResetConfirm(BaseModel):
|
||||
token: str
|
||||
password: str
|
||||
7
domain/backup/__init__.py
Normal file
7
domain/backup/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from .service import BackupService
|
||||
from .types import BackupData
|
||||
|
||||
__all__ = [
|
||||
"BackupService",
|
||||
"BackupData",
|
||||
]
|
||||
44
domain/backup/api.py
Normal file
44
domain/backup/api.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import datetime
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, File, Form, Query, Request, UploadFile
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.auth import User, get_current_active_user
|
||||
from domain.permission import require_system_permission
|
||||
from domain.permission.types import SystemPermission
|
||||
from .service import BackupService
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api/backup",
|
||||
tags=["Backup & Restore"],
|
||||
dependencies=[Depends(get_current_active_user)],
|
||||
)
|
||||
|
||||
|
||||
@router.get("/export", summary="导出全站数据")
|
||||
@audit(action=AuditAction.DOWNLOAD, description="导出备份")
|
||||
@require_system_permission(SystemPermission.CONFIG_EDIT)
|
||||
async def export_backup(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
sections: list[str] | None = Query(default=None),
|
||||
):
|
||||
data = await BackupService.export_data(sections=sections)
|
||||
timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
headers = {"Content-Disposition": f"attachment; filename=foxel_backup_{timestamp}.json"}
|
||||
return JSONResponse(content=data.model_dump(), headers=headers)
|
||||
|
||||
|
||||
@router.post("/import", summary="导入数据")
|
||||
@audit(action=AuditAction.UPLOAD, description="导入备份")
|
||||
@require_system_permission(SystemPermission.CONFIG_EDIT)
|
||||
async def import_backup(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
file: UploadFile = File(...),
|
||||
mode: str = Form("replace"),
|
||||
):
|
||||
await BackupService.import_from_bytes(file.filename, await file.read(), mode=mode)
|
||||
return {"message": "数据导入成功。"}
|
||||
339
domain/backup/service.py
Normal file
339
domain/backup/service.py
Normal file
@@ -0,0 +1,339 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import HTTPException
|
||||
from tortoise.transactions import in_transaction
|
||||
|
||||
from domain.config import VERSION
|
||||
from .types import BackupData
|
||||
from models.database import (
|
||||
AIDefaultModel,
|
||||
AIModel,
|
||||
AIProvider,
|
||||
AutomationTask,
|
||||
Configuration,
|
||||
Plugin,
|
||||
ShareLink,
|
||||
StorageAdapter,
|
||||
UserAccount,
|
||||
)
|
||||
|
||||
|
||||
class BackupService:
|
||||
ALL_SECTIONS = (
|
||||
"storage_adapters",
|
||||
"user_accounts",
|
||||
"automation_tasks",
|
||||
"share_links",
|
||||
"configurations",
|
||||
"ai_providers",
|
||||
"ai_models",
|
||||
"ai_default_models",
|
||||
"plugins",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def export_data(cls, sections: list[str] | None = None) -> BackupData:
|
||||
sections = cls._normalize_sections(sections)
|
||||
section_set = set(sections)
|
||||
async with in_transaction():
|
||||
adapters = (
|
||||
await StorageAdapter.all().values()
|
||||
if "storage_adapters" in section_set
|
||||
else []
|
||||
)
|
||||
users = (
|
||||
await UserAccount.all().values()
|
||||
if "user_accounts" in section_set
|
||||
else []
|
||||
)
|
||||
tasks = (
|
||||
await AutomationTask.all().values()
|
||||
if "automation_tasks" in section_set
|
||||
else []
|
||||
)
|
||||
shares = (
|
||||
await ShareLink.all().values()
|
||||
if "share_links" in section_set
|
||||
else []
|
||||
)
|
||||
configs = (
|
||||
await Configuration.all().values()
|
||||
if "configurations" in section_set
|
||||
else []
|
||||
)
|
||||
providers = (
|
||||
await AIProvider.all().values()
|
||||
if "ai_providers" in section_set
|
||||
else []
|
||||
)
|
||||
models = (
|
||||
await AIModel.all().values() if "ai_models" in section_set else []
|
||||
)
|
||||
default_models = (
|
||||
await AIDefaultModel.all().values()
|
||||
if "ai_default_models" in section_set
|
||||
else []
|
||||
)
|
||||
plugins = (
|
||||
await Plugin.all().values() if "plugins" in section_set else []
|
||||
)
|
||||
|
||||
share_links = cls._serialize_datetime_fields(
|
||||
shares, ["created_at", "expires_at"]
|
||||
)
|
||||
ai_providers = cls._serialize_datetime_fields(
|
||||
providers, ["created_at", "updated_at"]
|
||||
)
|
||||
ai_models = cls._serialize_datetime_fields(
|
||||
models, ["created_at", "updated_at"]
|
||||
)
|
||||
ai_default_models = cls._serialize_datetime_fields(
|
||||
default_models, ["created_at", "updated_at"]
|
||||
)
|
||||
plugin_items = cls._serialize_datetime_fields(
|
||||
plugins, ["created_at", "updated_at"]
|
||||
)
|
||||
|
||||
return BackupData(
|
||||
version=VERSION,
|
||||
sections=sections,
|
||||
storage_adapters=list(adapters),
|
||||
user_accounts=list(users),
|
||||
automation_tasks=list(tasks),
|
||||
share_links=share_links,
|
||||
configurations=list(configs),
|
||||
ai_providers=ai_providers,
|
||||
ai_models=ai_models,
|
||||
ai_default_models=ai_default_models,
|
||||
plugins=plugin_items,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def import_from_bytes(
|
||||
cls, filename: str, content: bytes, mode: str = "replace"
|
||||
) -> None:
|
||||
if not filename.endswith(".json"):
|
||||
raise HTTPException(status_code=400, detail="无效的文件类型, 请上传 .json 文件")
|
||||
try:
|
||||
raw_data = json.loads(content)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=400, detail="无法解析JSON文件")
|
||||
await cls.import_data(BackupData(**raw_data), mode=mode)
|
||||
|
||||
@classmethod
|
||||
async def import_data(cls, payload: BackupData, mode: str = "replace") -> None:
|
||||
sections = cls._normalize_sections(payload.sections)
|
||||
if mode not in {"replace", "merge"}:
|
||||
raise HTTPException(status_code=400, detail="无效的导入模式")
|
||||
|
||||
share_links = (
|
||||
cls._parse_datetime_fields(payload.share_links, ["created_at", "expires_at"])
|
||||
if payload.share_links
|
||||
else []
|
||||
)
|
||||
ai_providers = (
|
||||
cls._parse_datetime_fields(payload.ai_providers, ["created_at", "updated_at"])
|
||||
if payload.ai_providers
|
||||
else []
|
||||
)
|
||||
ai_models = (
|
||||
cls._parse_datetime_fields(payload.ai_models, ["created_at", "updated_at"])
|
||||
if payload.ai_models
|
||||
else []
|
||||
)
|
||||
ai_default_models = (
|
||||
cls._parse_datetime_fields(
|
||||
payload.ai_default_models, ["created_at", "updated_at"]
|
||||
)
|
||||
if payload.ai_default_models
|
||||
else []
|
||||
)
|
||||
plugins = (
|
||||
cls._parse_datetime_fields(payload.plugins, ["created_at", "updated_at"])
|
||||
if payload.plugins
|
||||
else []
|
||||
)
|
||||
|
||||
async with in_transaction() as conn:
|
||||
if mode == "replace":
|
||||
if "share_links" in sections:
|
||||
await ShareLink.all().using_db(conn).delete()
|
||||
if "automation_tasks" in sections:
|
||||
await AutomationTask.all().using_db(conn).delete()
|
||||
if "storage_adapters" in sections:
|
||||
await StorageAdapter.all().using_db(conn).delete()
|
||||
if "user_accounts" in sections:
|
||||
await UserAccount.all().using_db(conn).delete()
|
||||
if "configurations" in sections:
|
||||
await Configuration.all().using_db(conn).delete()
|
||||
if "ai_default_models" in sections:
|
||||
await AIDefaultModel.all().using_db(conn).delete()
|
||||
if "ai_models" in sections:
|
||||
await AIModel.all().using_db(conn).delete()
|
||||
if "ai_providers" in sections:
|
||||
await AIProvider.all().using_db(conn).delete()
|
||||
if "plugins" in sections:
|
||||
await Plugin.all().using_db(conn).delete()
|
||||
|
||||
if "configurations" in sections and payload.configurations:
|
||||
if mode == "merge":
|
||||
await cls._merge_records(
|
||||
Configuration, payload.configurations, conn
|
||||
)
|
||||
else:
|
||||
await Configuration.bulk_create(
|
||||
[Configuration(**config) for config in payload.configurations],
|
||||
using_db=conn,
|
||||
)
|
||||
|
||||
if "user_accounts" in sections and payload.user_accounts:
|
||||
if mode == "merge":
|
||||
await cls._merge_records(UserAccount, payload.user_accounts, conn)
|
||||
else:
|
||||
await UserAccount.bulk_create(
|
||||
[UserAccount(**user) for user in payload.user_accounts],
|
||||
using_db=conn,
|
||||
)
|
||||
|
||||
if "storage_adapters" in sections and payload.storage_adapters:
|
||||
if mode == "merge":
|
||||
await cls._merge_records(
|
||||
StorageAdapter, payload.storage_adapters, conn
|
||||
)
|
||||
else:
|
||||
await StorageAdapter.bulk_create(
|
||||
[StorageAdapter(**adapter) for adapter in payload.storage_adapters],
|
||||
using_db=conn,
|
||||
)
|
||||
|
||||
if "automation_tasks" in sections and payload.automation_tasks:
|
||||
if mode == "merge":
|
||||
await cls._merge_records(
|
||||
AutomationTask, payload.automation_tasks, conn
|
||||
)
|
||||
else:
|
||||
await AutomationTask.bulk_create(
|
||||
[AutomationTask(**task) for task in payload.automation_tasks],
|
||||
using_db=conn,
|
||||
)
|
||||
|
||||
if "share_links" in sections and share_links:
|
||||
if mode == "merge":
|
||||
await cls._merge_records(ShareLink, share_links, conn)
|
||||
else:
|
||||
await ShareLink.bulk_create(
|
||||
[ShareLink(**share) for share in share_links],
|
||||
using_db=conn,
|
||||
)
|
||||
|
||||
if "ai_providers" in sections and ai_providers:
|
||||
if mode == "merge":
|
||||
await cls._merge_records(AIProvider, ai_providers, conn)
|
||||
else:
|
||||
await AIProvider.bulk_create(
|
||||
[AIProvider(**item) for item in ai_providers],
|
||||
using_db=conn,
|
||||
)
|
||||
|
||||
if "ai_models" in sections and ai_models:
|
||||
if mode == "merge":
|
||||
await cls._merge_records(AIModel, ai_models, conn)
|
||||
else:
|
||||
await AIModel.bulk_create(
|
||||
[AIModel(**item) for item in ai_models],
|
||||
using_db=conn,
|
||||
)
|
||||
|
||||
if "ai_default_models" in sections and ai_default_models:
|
||||
if mode == "merge":
|
||||
await cls._merge_records(
|
||||
AIDefaultModel, ai_default_models, conn
|
||||
)
|
||||
else:
|
||||
await AIDefaultModel.bulk_create(
|
||||
[AIDefaultModel(**item) for item in ai_default_models],
|
||||
using_db=conn,
|
||||
)
|
||||
|
||||
if "plugins" in sections and plugins:
|
||||
if mode == "merge":
|
||||
await cls._merge_records(Plugin, plugins, conn)
|
||||
else:
|
||||
await Plugin.bulk_create(
|
||||
[Plugin(**item) for item in plugins],
|
||||
using_db=conn,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _normalize_sections(cls, sections: list[str] | None) -> list[str]:
|
||||
if not sections:
|
||||
return list(cls.ALL_SECTIONS)
|
||||
normalized = [item for item in sections if item]
|
||||
invalid = [item for item in normalized if item not in cls.ALL_SECTIONS]
|
||||
if invalid:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"无效的备份分区: {', '.join(invalid)}"
|
||||
)
|
||||
result: list[str] = []
|
||||
seen = set()
|
||||
for item in normalized:
|
||||
if item in seen:
|
||||
continue
|
||||
seen.add(item)
|
||||
result.append(item)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
async def _merge_records(model, records: list[dict], using_db) -> None:
|
||||
for record in records:
|
||||
data = dict(record)
|
||||
record_id = data.pop("id", None)
|
||||
if record_id is None:
|
||||
await model.create(using_db=using_db, **data)
|
||||
continue
|
||||
updated = (
|
||||
await model.filter(id=record_id)
|
||||
.using_db(using_db)
|
||||
.update(**data)
|
||||
)
|
||||
if updated == 0:
|
||||
await model.create(using_db=using_db, id=record_id, **data)
|
||||
|
||||
@staticmethod
|
||||
def _serialize_datetime_fields(
|
||||
records: list[dict], fields: list[str]
|
||||
) -> list[dict]:
|
||||
serialized: list[dict] = []
|
||||
for record in records:
|
||||
item = dict(record)
|
||||
for field in fields:
|
||||
value = item.get(field)
|
||||
if isinstance(value, datetime):
|
||||
item[field] = value.isoformat()
|
||||
serialized.append(item)
|
||||
return serialized
|
||||
|
||||
@staticmethod
|
||||
def _parse_datetime_fields(
|
||||
records: list[dict], fields: list[str]
|
||||
) -> list[dict]:
|
||||
parsed: list[dict] = []
|
||||
for record in records:
|
||||
item = dict(record)
|
||||
for field in fields:
|
||||
value = item.get(field)
|
||||
if isinstance(value, str):
|
||||
item[field] = BackupService._from_iso(value)
|
||||
parsed.append(item)
|
||||
return parsed
|
||||
|
||||
@staticmethod
|
||||
def _from_iso(value: str) -> datetime | None:
|
||||
if not value:
|
||||
return None
|
||||
normalized = value.replace("Z", "+00:00")
|
||||
try:
|
||||
return datetime.fromisoformat(normalized)
|
||||
except ValueError as exc: # noqa: BLE001
|
||||
raise HTTPException(status_code=400, detail="无效的日期格式") from exc
|
||||
17
domain/backup/types.py
Normal file
17
domain/backup/types.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class BackupData(BaseModel):
|
||||
version: str | None = None
|
||||
sections: list[str] = Field(default_factory=list)
|
||||
storage_adapters: list[dict[str, Any]] = Field(default_factory=list)
|
||||
user_accounts: list[dict[str, Any]] = Field(default_factory=list)
|
||||
automation_tasks: list[dict[str, Any]] = Field(default_factory=list)
|
||||
share_links: list[dict[str, Any]] = Field(default_factory=list)
|
||||
configurations: list[dict[str, Any]] = Field(default_factory=list)
|
||||
ai_providers: list[dict[str, Any]] = Field(default_factory=list)
|
||||
ai_models: list[dict[str, Any]] = Field(default_factory=list)
|
||||
ai_default_models: list[dict[str, Any]] = Field(default_factory=list)
|
||||
plugins: list[dict[str, Any]] = Field(default_factory=list)
|
||||
10
domain/config/__init__.py
Normal file
10
domain/config/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from .service import ConfigService, VERSION
|
||||
from .types import ConfigItem, LatestVersionInfo, SystemStatus
|
||||
|
||||
__all__ = [
|
||||
"ConfigService",
|
||||
"VERSION",
|
||||
"ConfigItem",
|
||||
"LatestVersionInfo",
|
||||
"SystemStatus",
|
||||
]
|
||||
83
domain/config/api.py
Normal file
83
domain/config/api.py
Normal file
@@ -0,0 +1,83 @@
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Form, Request
|
||||
|
||||
from api.response import success
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.auth import User, get_current_active_user
|
||||
from domain.permission import require_system_permission
|
||||
from domain.permission.types import SystemPermission
|
||||
from .service import ConfigService
|
||||
from .types import ConfigItem
|
||||
|
||||
router = APIRouter(prefix="/api/config", tags=["config"])
|
||||
|
||||
PUBLIC_CONFIG_KEYS = [
|
||||
"THEME_MODE",
|
||||
"THEME_PRIMARY_COLOR",
|
||||
"THEME_BORDER_RADIUS",
|
||||
"THEME_CUSTOM_TOKENS",
|
||||
"THEME_CUSTOM_CSS",
|
||||
]
|
||||
|
||||
|
||||
@router.get("/")
|
||||
@audit(action=AuditAction.READ, description="获取配置")
|
||||
@require_system_permission(SystemPermission.CONFIG_EDIT)
|
||||
async def get_config(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
key: str,
|
||||
):
|
||||
value = await ConfigService.get(key)
|
||||
return success(ConfigItem(key=key, value=value).model_dump())
|
||||
|
||||
|
||||
@router.post("/")
|
||||
@audit(action=AuditAction.UPDATE, description="设置配置", body_fields=["key", "value"])
|
||||
@require_system_permission(SystemPermission.CONFIG_EDIT)
|
||||
async def set_config(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
key: str = Form(...),
|
||||
value: str = Form(""),
|
||||
):
|
||||
await ConfigService.set(key, value)
|
||||
return success(ConfigItem(key=key, value=value).model_dump())
|
||||
|
||||
|
||||
@router.get("/all")
|
||||
@audit(action=AuditAction.READ, description="获取全部配置")
|
||||
@require_system_permission(SystemPermission.CONFIG_EDIT)
|
||||
async def get_all_config(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
configs = await ConfigService.get_all()
|
||||
return success(configs)
|
||||
|
||||
@router.get("/public")
|
||||
@audit(action=AuditAction.READ, description="获取公开配置")
|
||||
async def get_public_config(
|
||||
request: Request,
|
||||
):
|
||||
data = {}
|
||||
for key in PUBLIC_CONFIG_KEYS:
|
||||
value = await ConfigService.get(key)
|
||||
if value is not None:
|
||||
data[key] = value
|
||||
return success(data)
|
||||
|
||||
|
||||
@router.get("/status")
|
||||
@audit(action=AuditAction.READ, description="获取系统状态")
|
||||
async def get_system_status(request: Request):
|
||||
status_data = await ConfigService.get_system_status()
|
||||
return success(status_data.model_dump())
|
||||
|
||||
|
||||
@router.get("/latest-version")
|
||||
@audit(action=AuditAction.READ, description="获取最新版本")
|
||||
async def get_latest_version(request: Request):
|
||||
info = await ConfigService.get_latest_version()
|
||||
return success(info.model_dump())
|
||||
111
domain/config/service.py
Normal file
111
domain/config/service.py
Normal file
@@ -0,0 +1,111 @@
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import httpx
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from .types import LatestVersionInfo, SystemStatus
|
||||
from models.database import Configuration, UserAccount
|
||||
|
||||
load_dotenv(dotenv_path=".env")
|
||||
|
||||
VERSION = "v1.7.4"
|
||||
|
||||
|
||||
class ConfigService:
|
||||
_cache: Dict[str, Any] = {}
|
||||
_latest_version_cache: Dict[str, Any] = {"timestamp": 0.0, "data": None}
|
||||
|
||||
@classmethod
|
||||
async def get(cls, key: str, default: Optional[Any] = None) -> Any:
|
||||
if key in cls._cache:
|
||||
return cls._cache[key]
|
||||
try:
|
||||
config = await Configuration.get_or_none(key=key)
|
||||
if config:
|
||||
cls._cache[key] = config.value
|
||||
return config.value
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
env_value = os.getenv(key)
|
||||
if env_value is not None:
|
||||
cls._cache[key] = env_value
|
||||
return env_value
|
||||
return default
|
||||
|
||||
@classmethod
|
||||
async def get_secret_key(cls, key: str, default: Optional[Any] = None) -> bytes:
|
||||
value = await cls.get(key, default)
|
||||
if isinstance(value, bytes):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
return value.encode("utf-8")
|
||||
if value is None:
|
||||
raise ValueError(f"Secret key '{key}' not found in config or environment.")
|
||||
return str(value).encode("utf-8")
|
||||
|
||||
@classmethod
|
||||
async def set(cls, key: str, value: Any):
|
||||
obj, _ = await Configuration.get_or_create(key=key, defaults={"value": value})
|
||||
obj.value = value
|
||||
await obj.save()
|
||||
cls._cache[key] = value
|
||||
|
||||
@classmethod
|
||||
async def get_all(cls) -> Dict[str, Any]:
|
||||
try:
|
||||
configs = await Configuration.all()
|
||||
result = {}
|
||||
for config in configs:
|
||||
result[config.key] = config.value
|
||||
cls._cache[config.key] = config.value
|
||||
return result
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def clear_cache(cls):
|
||||
cls._cache.clear()
|
||||
|
||||
@classmethod
|
||||
async def get_system_status(cls) -> SystemStatus:
|
||||
logo = await cls.get("APP_LOGO", "/logo.svg")
|
||||
favicon = await cls.get("APP_FAVICON", logo)
|
||||
user_count = await UserAccount.all().count()
|
||||
return SystemStatus(
|
||||
version=VERSION,
|
||||
title=await cls.get("APP_NAME", "Foxel"),
|
||||
logo=logo,
|
||||
favicon=favicon,
|
||||
is_initialized=user_count > 0,
|
||||
app_domain=await cls.get("APP_DOMAIN"),
|
||||
file_domain=await cls.get("FILE_DOMAIN"),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def get_latest_version(cls) -> LatestVersionInfo:
|
||||
current_time = time.time()
|
||||
cache = cls._latest_version_cache
|
||||
if current_time - cache["timestamp"] < 3600 and cache["data"]:
|
||||
return cache["data"]
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.get(
|
||||
"https://api.github.com/repos/DrizzleTime/Foxel/releases/latest",
|
||||
follow_redirects=True,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
version_info = LatestVersionInfo(
|
||||
latest_version=data.get("tag_name"),
|
||||
body=data.get("body"),
|
||||
)
|
||||
cache["timestamp"] = current_time
|
||||
cache["data"] = version_info
|
||||
return version_info
|
||||
except httpx.RequestError:
|
||||
if cache["data"]:
|
||||
return cache["data"]
|
||||
return LatestVersionInfo()
|
||||
23
domain/config/types.py
Normal file
23
domain/config/types.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ConfigItem(BaseModel):
|
||||
key: str
|
||||
value: Optional[Any] = None
|
||||
|
||||
|
||||
class SystemStatus(BaseModel):
|
||||
version: str
|
||||
title: str
|
||||
logo: str
|
||||
favicon: str
|
||||
is_initialized: bool
|
||||
app_domain: Optional[str] = None
|
||||
file_domain: Optional[str] = None
|
||||
|
||||
|
||||
class LatestVersionInfo(BaseModel):
|
||||
latest_version: Optional[str] = None
|
||||
body: Optional[str] = None
|
||||
20
domain/email/__init__.py
Normal file
20
domain/email/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from .service import EmailService, EmailTemplateRenderer
|
||||
from .types import (
|
||||
EmailConfig,
|
||||
EmailSecurity,
|
||||
EmailSendPayload,
|
||||
EmailTemplatePreviewPayload,
|
||||
EmailTemplateUpdate,
|
||||
EmailTestRequest,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"EmailService",
|
||||
"EmailTemplateRenderer",
|
||||
"EmailConfig",
|
||||
"EmailSecurity",
|
||||
"EmailSendPayload",
|
||||
"EmailTemplatePreviewPayload",
|
||||
"EmailTemplateUpdate",
|
||||
"EmailTestRequest",
|
||||
]
|
||||
91
domain/email/api.py
Normal file
91
domain/email/api.py
Normal file
@@ -0,0 +1,91 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
|
||||
from api.response import success
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.auth import User, get_current_active_user
|
||||
from .service import EmailService, EmailTemplateRenderer
|
||||
from .types import (
|
||||
EmailTemplatePreviewPayload,
|
||||
EmailTemplateUpdate,
|
||||
EmailTestRequest,
|
||||
)
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/email", tags=["email"])
|
||||
|
||||
|
||||
@router.post("/test")
|
||||
@audit(action=AuditAction.CREATE, description="发送测试邮件")
|
||||
async def trigger_test_email(
|
||||
request: Request,
|
||||
payload: EmailTestRequest,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
):
|
||||
try:
|
||||
task = await EmailService.enqueue_email(
|
||||
recipients=[str(payload.to)],
|
||||
subject=payload.subject,
|
||||
template=payload.template,
|
||||
context=payload.context,
|
||||
)
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc))
|
||||
return success({"task_id": task.id})
|
||||
|
||||
|
||||
@router.get("/templates")
|
||||
@audit(action=AuditAction.READ, description="获取邮件模板列表")
|
||||
async def list_email_templates(
|
||||
request: Request,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
):
|
||||
templates = await EmailTemplateRenderer.list_templates()
|
||||
return success({"templates": templates})
|
||||
|
||||
|
||||
@router.get("/templates/{name}")
|
||||
@audit(action=AuditAction.READ, description="查看邮件模板")
|
||||
async def get_email_template(
|
||||
request: Request,
|
||||
name: str,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
):
|
||||
try:
|
||||
content = await EmailTemplateRenderer.load(name)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc))
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="模板不存在")
|
||||
return success({"name": name, "content": content})
|
||||
|
||||
|
||||
@router.post("/templates/{name}")
|
||||
@audit(action=AuditAction.UPDATE, description="更新邮件模板")
|
||||
async def update_email_template(
|
||||
request: Request,
|
||||
name: str,
|
||||
payload: EmailTemplateUpdate,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
):
|
||||
try:
|
||||
await EmailTemplateRenderer.save(name, payload.content)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc))
|
||||
return success({"name": name})
|
||||
|
||||
|
||||
@router.post("/templates/{name}/preview")
|
||||
@audit(action=AuditAction.READ, description="预览邮件模板")
|
||||
async def preview_email_template(
|
||||
request: Request,
|
||||
name: str,
|
||||
payload: EmailTemplatePreviewPayload,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
):
|
||||
try:
|
||||
html = await EmailTemplateRenderer.render(name, payload.context)
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="模板不存在")
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc))
|
||||
return success({"html": html})
|
||||
151
domain/email/service.py
Normal file
151
domain/email/service.py
Normal file
@@ -0,0 +1,151 @@
|
||||
import asyncio
|
||||
import re
|
||||
import smtplib
|
||||
from email.message import EmailMessage
|
||||
from email.utils import formataddr
|
||||
from pathlib import Path
|
||||
from string import Template
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from domain.config import ConfigService
|
||||
from .types import EmailConfig, EmailSecurity, EmailSendPayload
|
||||
|
||||
|
||||
class EmailTemplateRenderer:
|
||||
ROOT = Path("templates/email")
|
||||
|
||||
@classmethod
|
||||
def _resolve_path(cls, template_name: str) -> Path:
|
||||
if not re.fullmatch(r"[A-Za-z0-9_\-]+", template_name):
|
||||
raise ValueError("Invalid template name")
|
||||
return cls.ROOT / f"{template_name}.html"
|
||||
|
||||
@classmethod
|
||||
async def list_templates(cls) -> list[str]:
|
||||
cls.ROOT.mkdir(parents=True, exist_ok=True)
|
||||
return sorted(
|
||||
path.stem for path in cls.ROOT.glob("*.html") if path.is_file()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def load(cls, template_name: str) -> str:
|
||||
path = cls._resolve_path(template_name)
|
||||
if not path.is_file():
|
||||
raise FileNotFoundError(f"Email template '{template_name}' not found")
|
||||
return await asyncio.to_thread(path.read_text, encoding="utf-8")
|
||||
|
||||
@classmethod
|
||||
async def save(cls, template_name: str, content: str) -> None:
|
||||
path = cls._resolve_path(template_name)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
await asyncio.to_thread(path.write_text, content, encoding="utf-8")
|
||||
|
||||
@classmethod
|
||||
async def render(cls, template_name: str, context: Dict[str, Any]) -> str:
|
||||
raw = await cls.load(template_name)
|
||||
context = {k: str(v) for k, v in (context or {}).items()}
|
||||
return Template(raw).safe_substitute(context)
|
||||
|
||||
|
||||
class EmailService:
|
||||
CONFIG_KEY = "EMAIL_CONFIG"
|
||||
|
||||
@classmethod
|
||||
async def _load_config(cls) -> EmailConfig:
|
||||
raw_config = await ConfigService.get(cls.CONFIG_KEY)
|
||||
return EmailConfig.parse_config(raw_config)
|
||||
|
||||
@staticmethod
|
||||
def _html_to_text(html: str) -> str:
|
||||
stripped = re.sub(r"<[^>]+>", " ", html)
|
||||
return " ".join(stripped.split())
|
||||
|
||||
@classmethod
|
||||
async def _deliver(cls, config: EmailConfig, payload: EmailSendPayload, html_body: str):
|
||||
message = EmailMessage()
|
||||
message["Subject"] = payload.subject
|
||||
message["From"] = formataddr(
|
||||
(config.sender_name or str(config.sender_email), str(config.sender_email))
|
||||
)
|
||||
message["To"] = ", ".join([str(addr) for addr in payload.recipients])
|
||||
|
||||
plain_body = cls._html_to_text(html_body)
|
||||
message.set_content(plain_body or html_body)
|
||||
message.add_alternative(html_body, subtype="html")
|
||||
|
||||
await asyncio.to_thread(cls._deliver_sync, config, message)
|
||||
|
||||
@staticmethod
|
||||
def _deliver_sync(config: EmailConfig, message: EmailMessage):
|
||||
if config.security == EmailSecurity.SSL:
|
||||
smtp: smtplib.SMTP = smtplib.SMTP_SSL(
|
||||
config.host, config.port, timeout=config.timeout
|
||||
)
|
||||
else:
|
||||
smtp = smtplib.SMTP(config.host, config.port, timeout=config.timeout)
|
||||
|
||||
try:
|
||||
if config.security == EmailSecurity.STARTTLS:
|
||||
smtp.starttls()
|
||||
if config.username and config.password:
|
||||
smtp.login(config.username, config.password)
|
||||
smtp.send_message(message)
|
||||
finally:
|
||||
try:
|
||||
smtp.quit()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
async def enqueue_email(
|
||||
cls,
|
||||
recipients: List[str],
|
||||
subject: str,
|
||||
template: str,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
from domain.tasks import TaskProgress, task_queue_service
|
||||
|
||||
payload = EmailSendPayload(
|
||||
recipients=recipients,
|
||||
subject=subject,
|
||||
template=template,
|
||||
context=context or {},
|
||||
)
|
||||
|
||||
task = await task_queue_service.add_task(
|
||||
"send_email",
|
||||
payload.model_dump(mode="json"),
|
||||
)
|
||||
|
||||
await task_queue_service.update_progress(
|
||||
task.id,
|
||||
TaskProgress(stage="queued", percent=0.0, detail="Waiting to send"),
|
||||
)
|
||||
return task
|
||||
|
||||
@classmethod
|
||||
async def send_from_task(cls, task_id: str, data: Dict[str, Any]):
|
||||
from domain.tasks import TaskProgress, task_queue_service
|
||||
|
||||
payload = EmailSendPayload(**data)
|
||||
|
||||
await task_queue_service.update_progress(
|
||||
task_id,
|
||||
TaskProgress(stage="preparing", percent=10.0, detail="Rendering template"),
|
||||
)
|
||||
|
||||
config = await cls._load_config()
|
||||
html_body = await EmailTemplateRenderer.render(payload.template, payload.context)
|
||||
|
||||
await task_queue_service.update_progress(
|
||||
task_id,
|
||||
TaskProgress(stage="sending", percent=60.0, detail="Sending message"),
|
||||
)
|
||||
|
||||
await cls._deliver(config, payload, html_body)
|
||||
|
||||
await task_queue_service.update_progress(
|
||||
task_id,
|
||||
TaskProgress(stage="completed", percent=100.0, detail="Email sent"),
|
||||
)
|
||||
63
domain/email/types.py
Normal file
63
domain/email/types.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import json
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, EmailStr, Field, ValidationError
|
||||
|
||||
|
||||
class EmailSecurity(str, Enum):
|
||||
NONE = "none"
|
||||
SSL = "ssl"
|
||||
STARTTLS = "starttls"
|
||||
|
||||
|
||||
class EmailConfig(BaseModel):
|
||||
host: str
|
||||
port: int = Field(..., gt=0)
|
||||
username: Optional[str] = None
|
||||
password: Optional[str] = None
|
||||
sender_email: EmailStr
|
||||
sender_name: Optional[str] = None
|
||||
security: EmailSecurity = EmailSecurity.NONE
|
||||
timeout: float = Field(default=30.0, gt=0.0)
|
||||
|
||||
@classmethod
|
||||
def parse_config(cls, raw_config: Any) -> "EmailConfig":
|
||||
"""接受字符串或 dict 配置并解析为 EmailConfig。"""
|
||||
if raw_config is None:
|
||||
raise ValueError("Email configuration not found")
|
||||
|
||||
if isinstance(raw_config, str):
|
||||
raw_config = raw_config.strip()
|
||||
data: Any = json.loads(raw_config) if raw_config else {}
|
||||
elif isinstance(raw_config, dict):
|
||||
data = raw_config
|
||||
else:
|
||||
raise ValueError("Invalid email configuration format")
|
||||
|
||||
try:
|
||||
return cls(**data)
|
||||
except ValidationError as exc:
|
||||
raise ValueError(f"Invalid email configuration: {exc}") from exc
|
||||
|
||||
|
||||
class EmailSendPayload(BaseModel):
|
||||
recipients: List[EmailStr] = Field(..., min_length=1)
|
||||
subject: str = Field(..., min_length=1)
|
||||
template: str = Field(..., min_length=1)
|
||||
context: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class EmailTestRequest(BaseModel):
|
||||
to: EmailStr
|
||||
subject: str = Field(..., min_length=1)
|
||||
template: str = Field(default="test", min_length=1)
|
||||
context: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class EmailTemplateUpdate(BaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
class EmailTemplatePreviewPayload(BaseModel):
|
||||
context: Dict[str, Any] = Field(default_factory=dict)
|
||||
7
domain/offline_downloads/__init__.py
Normal file
7
domain/offline_downloads/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from .service import OfflineDownloadService
|
||||
from .types import OfflineDownloadCreate
|
||||
|
||||
__all__ = [
|
||||
"OfflineDownloadService",
|
||||
"OfflineDownloadCreate",
|
||||
]
|
||||
44
domain/offline_downloads/api.py
Normal file
44
domain/offline_downloads/api.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
|
||||
from api.response import success
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.auth import User, get_current_active_user
|
||||
from domain.permission import require_path_permission
|
||||
from domain.permission.types import PathAction
|
||||
from .service import OfflineDownloadService
|
||||
from .types import OfflineDownloadCreate
|
||||
|
||||
CurrentUser = Annotated[User, Depends(get_current_active_user)]
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api/offline-downloads",
|
||||
tags=["OfflineDownloads"],
|
||||
)
|
||||
|
||||
|
||||
@router.post("/")
|
||||
@audit(
|
||||
action=AuditAction.CREATE,
|
||||
description="创建离线下载任务",
|
||||
body_fields=["url", "dest_dir", "filename"],
|
||||
)
|
||||
@require_path_permission(PathAction.WRITE, "payload.dest_dir")
|
||||
async def create_offline_download(request: Request, payload: OfflineDownloadCreate, current_user: CurrentUser):
|
||||
data = await OfflineDownloadService.create_download(payload, current_user)
|
||||
return success(data)
|
||||
|
||||
|
||||
@router.get("/")
|
||||
@audit(action=AuditAction.READ, description="获取离线下载列表")
|
||||
async def list_offline_downloads(request: Request, current_user: CurrentUser):
|
||||
data = OfflineDownloadService.list_downloads()
|
||||
return success(data)
|
||||
|
||||
|
||||
@router.get("/{task_id}")
|
||||
@audit(action=AuditAction.READ, description="获取离线下载详情")
|
||||
async def get_offline_download(task_id: str, request: Request, current_user: CurrentUser):
|
||||
data = OfflineDownloadService.get_download(task_id)
|
||||
return success(data)
|
||||
251
domain/offline_downloads/service.py
Normal file
251
domain/offline_downloads/service.py
Normal file
@@ -0,0 +1,251 @@
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Annotated, AsyncIterator
|
||||
|
||||
import aiofiles
|
||||
import aiohttp
|
||||
from fastapi import Depends, HTTPException
|
||||
|
||||
from domain.auth import User, get_current_active_user
|
||||
from domain.tasks import Task, TaskProgress, task_queue_service
|
||||
from domain.virtual_fs import VirtualFSService
|
||||
from .types import OfflineDownloadCreate
|
||||
|
||||
|
||||
class OfflineDownloadService:
|
||||
current_user_dep = Annotated[User, Depends(get_current_active_user)]
|
||||
temp_root = Path("data/tmp/offline_downloads")
|
||||
|
||||
@classmethod
|
||||
async def create_download(cls, payload: OfflineDownloadCreate, current_user: User) -> dict:
|
||||
await cls._ensure_destination(payload.dest_dir)
|
||||
task = await task_queue_service.add_task(
|
||||
"offline_http_download",
|
||||
{
|
||||
"url": str(payload.url),
|
||||
"dest_dir": payload.dest_dir,
|
||||
"filename": payload.filename,
|
||||
},
|
||||
)
|
||||
|
||||
await task_queue_service.update_progress(
|
||||
task.id,
|
||||
TaskProgress(
|
||||
stage="queued",
|
||||
percent=0.0,
|
||||
bytes_total=None,
|
||||
bytes_done=0,
|
||||
detail="Waiting to start",
|
||||
),
|
||||
)
|
||||
|
||||
return {"task_id": task.id}
|
||||
|
||||
@classmethod
|
||||
def list_downloads(cls) -> list[dict]:
|
||||
tasks = [t for t in task_queue_service.get_all_tasks() if t.name == "offline_http_download"]
|
||||
return [t.dict() for t in tasks]
|
||||
|
||||
@classmethod
|
||||
def get_download(cls, task_id: str) -> dict:
|
||||
task = task_queue_service.get_task(task_id)
|
||||
if not task or task.name != "offline_http_download":
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
return task.dict()
|
||||
|
||||
@classmethod
|
||||
async def run_http_download(cls, task: Task):
|
||||
params = task.task_info
|
||||
url = params.get("url")
|
||||
dest_dir = params.get("dest_dir")
|
||||
filename = params.get("filename")
|
||||
|
||||
if not url or not dest_dir or not filename:
|
||||
raise ValueError("Missing required parameters for offline download")
|
||||
|
||||
cls.temp_root.mkdir(parents=True, exist_ok=True)
|
||||
temp_dir = cls.temp_root / task.id
|
||||
temp_dir.mkdir(parents=True, exist_ok=True)
|
||||
temp_file = temp_dir / "payload"
|
||||
|
||||
bytes_total: int | None = None
|
||||
bytes_done = 0
|
||||
last_update = time.monotonic()
|
||||
|
||||
await task_queue_service.update_progress(
|
||||
task.id,
|
||||
TaskProgress(
|
||||
stage="downloading",
|
||||
percent=0.0,
|
||||
bytes_total=None,
|
||||
bytes_done=0,
|
||||
detail="HTTP downloading",
|
||||
),
|
||||
)
|
||||
|
||||
async def report_download(delta: int, total: int | None):
|
||||
nonlocal bytes_done, bytes_total, last_update
|
||||
if total is not None:
|
||||
bytes_total = total
|
||||
bytes_done += delta
|
||||
now = time.monotonic()
|
||||
if delta and now - last_update < 0.5:
|
||||
return
|
||||
last_update = now
|
||||
percent = None
|
||||
total_for_display = bytes_total if bytes_total is not None else None
|
||||
if bytes_total:
|
||||
percent = min(100.0, round(bytes_done / bytes_total * 100, 2))
|
||||
await task_queue_service.update_progress(
|
||||
task.id,
|
||||
TaskProgress(
|
||||
stage="downloading",
|
||||
percent=percent,
|
||||
bytes_total=total_for_display,
|
||||
bytes_done=bytes_done,
|
||||
detail="HTTP downloading",
|
||||
),
|
||||
)
|
||||
|
||||
timeout = aiohttp.ClientTimeout(total=None, connect=30)
|
||||
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
async with session.get(url) as resp:
|
||||
if resp.status != 200:
|
||||
raise ValueError(f"HTTP {resp.status} for {url}")
|
||||
content_length = resp.headers.get("Content-Length")
|
||||
total_size = int(content_length) if content_length else None
|
||||
bytes_done = 0
|
||||
async with aiofiles.open(temp_file, "wb") as f:
|
||||
async for chunk in resp.content.iter_chunked(512 * 1024):
|
||||
if not chunk:
|
||||
continue
|
||||
await f.write(chunk)
|
||||
await report_download(len(chunk), total_size)
|
||||
await report_download(0, total_size)
|
||||
|
||||
file_size = os.path.getsize(temp_file)
|
||||
bytes_done_transfer = 0
|
||||
|
||||
async def report_transfer(delta: int):
|
||||
nonlocal bytes_done_transfer
|
||||
bytes_done_transfer += delta
|
||||
percent = min(100.0, round(bytes_done_transfer / file_size * 100, 2)) if file_size else None
|
||||
await task_queue_service.update_progress(
|
||||
task.id,
|
||||
TaskProgress(
|
||||
stage="transferring",
|
||||
percent=percent,
|
||||
bytes_total=file_size or None,
|
||||
bytes_done=bytes_done_transfer,
|
||||
detail="Saving to storage",
|
||||
),
|
||||
)
|
||||
|
||||
async def chunk_iter() -> AsyncIterator[bytes]:
|
||||
async for chunk in cls._iter_file(temp_file, 512 * 1024, report_transfer):
|
||||
yield chunk
|
||||
|
||||
final_path, resolved_name = await cls._allocate_destination(dest_dir, filename)
|
||||
|
||||
await task_queue_service.update_progress(
|
||||
task.id,
|
||||
TaskProgress(
|
||||
stage="transferring",
|
||||
percent=0.0,
|
||||
bytes_total=file_size or None,
|
||||
bytes_done=0,
|
||||
detail="Saving to storage",
|
||||
),
|
||||
)
|
||||
|
||||
await VirtualFSService.write_file_stream(final_path, chunk_iter())
|
||||
|
||||
await task_queue_service.update_progress(
|
||||
task.id,
|
||||
TaskProgress(
|
||||
stage="completed",
|
||||
percent=100.0,
|
||||
bytes_total=file_size or None,
|
||||
bytes_done=file_size,
|
||||
detail="Completed",
|
||||
),
|
||||
)
|
||||
await task_queue_service.update_meta(task.id, {"final_path": final_path, "filename": resolved_name})
|
||||
|
||||
try:
|
||||
os.remove(temp_file)
|
||||
temp_dir.rmdir()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return final_path
|
||||
|
||||
@classmethod
|
||||
async def _ensure_destination(cls, dest_dir: str) -> None:
|
||||
try:
|
||||
is_dir = await VirtualFSService.path_is_directory(dest_dir)
|
||||
except HTTPException:
|
||||
is_dir = False
|
||||
if not is_dir:
|
||||
raise HTTPException(400, detail="Destination directory not found")
|
||||
|
||||
@staticmethod
|
||||
def _normalize_path(path: str) -> str:
|
||||
if not path:
|
||||
return "/"
|
||||
if not path.startswith("/"):
|
||||
path = "/" + path
|
||||
if len(path) > 1 and path.endswith("/"):
|
||||
path = path.rstrip("/")
|
||||
return path or "/"
|
||||
|
||||
@staticmethod
|
||||
async def _path_exists(full_path: str) -> bool:
|
||||
try:
|
||||
await VirtualFSService.stat_file(full_path)
|
||||
return True
|
||||
except FileNotFoundError:
|
||||
return False
|
||||
except HTTPException as exc: # noqa: PERF203
|
||||
if exc.status_code == 404:
|
||||
return False
|
||||
raise
|
||||
|
||||
@classmethod
|
||||
async def _allocate_destination(cls, dest_dir: str, filename: str) -> tuple[str, str]:
|
||||
dest_dir = cls._normalize_path(dest_dir)
|
||||
stem, suffix = cls._split_filename(filename)
|
||||
candidate = filename
|
||||
base = "" if dest_dir == "/" else dest_dir
|
||||
attempt = 0
|
||||
while await cls._path_exists(f"{base}/{candidate}" if base else f"/{candidate}"):
|
||||
attempt += 1
|
||||
if stem:
|
||||
candidate = f"{stem} ({attempt}){suffix}"
|
||||
else:
|
||||
candidate = f"file ({attempt}){suffix}" if suffix else f"file ({attempt})"
|
||||
full_path = f"{base}/{candidate}" if base else f"/{candidate}"
|
||||
return full_path, candidate
|
||||
|
||||
@staticmethod
|
||||
def _split_filename(filename: str) -> tuple[str, str]:
|
||||
if not filename:
|
||||
return "", ""
|
||||
if filename.startswith(".") and filename.count(".") == 1:
|
||||
return filename, ""
|
||||
if "." not in filename:
|
||||
return filename, ""
|
||||
stem, ext = filename.rsplit(".", 1)
|
||||
return stem, f".{ext}"
|
||||
|
||||
@staticmethod
|
||||
async def _iter_file(path: Path, chunk_size: int, report_cb) -> AsyncIterator[bytes]:
|
||||
async with aiofiles.open(path, "rb") as f:
|
||||
while True:
|
||||
chunk = await f.read(chunk_size)
|
||||
if not chunk:
|
||||
break
|
||||
await report_cb(len(chunk))
|
||||
yield chunk
|
||||
7
domain/offline_downloads/types.py
Normal file
7
domain/offline_downloads/types.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from pydantic import BaseModel, HttpUrl, Field
|
||||
|
||||
|
||||
class OfflineDownloadCreate(BaseModel):
|
||||
url: HttpUrl
|
||||
dest_dir: str = Field(..., min_length=1)
|
||||
filename: str = Field(..., min_length=1)
|
||||
10
domain/permission/__init__.py
Normal file
10
domain/permission/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from .service import PermissionService
|
||||
from .matcher import PathMatcher
|
||||
from .decorator import require_path_permission, require_system_permission
|
||||
|
||||
__all__ = [
|
||||
"PermissionService",
|
||||
"PathMatcher",
|
||||
"require_system_permission",
|
||||
"require_path_permission",
|
||||
]
|
||||
41
domain/permission/api.py
Normal file
41
domain/permission/api.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from typing import Annotated
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from domain.auth.service import get_current_active_user
|
||||
from domain.auth.types import User
|
||||
from .service import PermissionService
|
||||
from .types import (
|
||||
PathPermissionCheck,
|
||||
PathPermissionResult,
|
||||
UserPermissions,
|
||||
PermissionInfo,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/api", tags=["permissions"])
|
||||
|
||||
|
||||
@router.get("/permissions", response_model=list[PermissionInfo])
|
||||
async def get_all_permissions(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
) -> list[PermissionInfo]:
|
||||
"""获取所有权限定义"""
|
||||
return await PermissionService.get_all_permissions()
|
||||
|
||||
|
||||
@router.get("/me/permissions", response_model=UserPermissions)
|
||||
async def get_my_permissions(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
) -> UserPermissions:
|
||||
"""获取当前用户的有效权限"""
|
||||
return await PermissionService.get_user_permissions(current_user.id)
|
||||
|
||||
|
||||
@router.post("/me/check-path", response_model=PathPermissionResult)
|
||||
async def check_path_permission(
|
||||
data: PathPermissionCheck,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
) -> PathPermissionResult:
|
||||
"""检查当前用户对某路径的权限"""
|
||||
return await PermissionService.check_path_permission_detailed(
|
||||
current_user.id, data.path, data.action
|
||||
)
|
||||
103
domain/permission/decorator.py
Normal file
103
domain/permission/decorator.py
Normal file
@@ -0,0 +1,103 @@
|
||||
import inspect
|
||||
from functools import wraps
|
||||
from typing import Any, Iterable, Mapping
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from .service import PermissionService
|
||||
|
||||
|
||||
def _get_user_id(user: Any) -> int | None:
|
||||
if user is None:
|
||||
return None
|
||||
if isinstance(user, Mapping):
|
||||
raw = user.get("id") or user.get("user_id")
|
||||
return int(raw) if isinstance(raw, int) else None
|
||||
value = getattr(user, "id", None) or getattr(user, "user_id", None)
|
||||
return int(value) if isinstance(value, int) else None
|
||||
|
||||
|
||||
def _resolve_expr(bound_args: Mapping[str, Any], expr: str) -> Any:
|
||||
parts = [p for p in (expr or "").split(".") if p]
|
||||
if not parts:
|
||||
return None
|
||||
cur: Any = bound_args.get(parts[0])
|
||||
for part in parts[1:]:
|
||||
if cur is None:
|
||||
return None
|
||||
if isinstance(cur, Mapping):
|
||||
cur = cur.get(part)
|
||||
else:
|
||||
cur = getattr(cur, part, None)
|
||||
return cur
|
||||
|
||||
|
||||
def require_system_permission(permission_code: str, *, user_kw: str = "current_user"):
|
||||
"""
|
||||
在 endpoint 内部执行系统/适配器权限校验。
|
||||
|
||||
设计目标:
|
||||
- 保持和当前“在函数体内手写 require_*”一致的行为:失败会被外层 @audit 捕获记录
|
||||
- 不依赖 FastAPI dependencies(避免权限失败发生在 endpoint 之外)
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
bound = inspect.signature(func).bind_partial(*args, **kwargs)
|
||||
bound.apply_defaults()
|
||||
user_id = _get_user_id(bound.arguments.get(user_kw))
|
||||
if user_id is None:
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
await PermissionService.require_system_permission(user_id, permission_code)
|
||||
|
||||
result = func(*args, **kwargs)
|
||||
if inspect.isawaitable(result):
|
||||
result = await result
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def require_path_permission(action: str, path_expr: str, *, user_kw: str = "current_user"):
|
||||
"""
|
||||
在 endpoint 内部执行路径权限校验。
|
||||
|
||||
path_expr 支持:
|
||||
- "full_path"
|
||||
- "body.src" / "body.dst"
|
||||
- "payload.paths"(list[str] 会逐个检查)
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
bound = inspect.signature(func).bind_partial(*args, **kwargs)
|
||||
bound.apply_defaults()
|
||||
user_id = _get_user_id(bound.arguments.get(user_kw))
|
||||
if user_id is None:
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
|
||||
value = _resolve_expr(bound.arguments, path_expr)
|
||||
paths: Iterable[Any]
|
||||
if isinstance(value, (list, tuple, set)):
|
||||
paths = value
|
||||
else:
|
||||
paths = [value]
|
||||
|
||||
for path in paths:
|
||||
if path is None:
|
||||
raise HTTPException(status_code=400, detail="Missing path")
|
||||
await PermissionService.require_path_permission(user_id, str(path), action)
|
||||
|
||||
result = func(*args, **kwargs)
|
||||
if inspect.isawaitable(result):
|
||||
result = await result
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
158
domain/permission/matcher.py
Normal file
158
domain/permission/matcher.py
Normal file
@@ -0,0 +1,158 @@
|
||||
import re
|
||||
import fnmatch
|
||||
from functools import lru_cache
|
||||
|
||||
|
||||
class PathMatcher:
|
||||
"""路径匹配器,支持精确匹配、通配符匹配和正则匹配"""
|
||||
|
||||
@classmethod
|
||||
def normalize_path(cls, path: str) -> str:
|
||||
"""规范化路径"""
|
||||
if not path:
|
||||
return "/"
|
||||
# 确保以 / 开头
|
||||
if not path.startswith("/"):
|
||||
path = "/" + path
|
||||
# 移除末尾的 /(除了根路径)
|
||||
if path != "/" and path.endswith("/"):
|
||||
path = path.rstrip("/")
|
||||
return path
|
||||
|
||||
@classmethod
|
||||
def get_parent_path(cls, path: str) -> str | None:
|
||||
"""获取父目录路径"""
|
||||
path = cls.normalize_path(path)
|
||||
if path == "/":
|
||||
return None
|
||||
parent = "/".join(path.rsplit("/", 1)[:-1])
|
||||
return parent if parent else "/"
|
||||
|
||||
@classmethod
|
||||
def match_pattern(cls, path: str, pattern: str, is_regex: bool = False) -> bool:
|
||||
"""
|
||||
匹配路径和模式
|
||||
|
||||
Args:
|
||||
path: 要匹配的路径
|
||||
pattern: 匹配模式
|
||||
is_regex: 是否为正则表达式
|
||||
|
||||
Returns:
|
||||
是否匹配
|
||||
"""
|
||||
path = cls.normalize_path(path)
|
||||
pattern = cls.normalize_path(pattern)
|
||||
|
||||
if is_regex:
|
||||
return cls._match_regex(path, pattern)
|
||||
else:
|
||||
return cls._match_glob(path, pattern)
|
||||
|
||||
@classmethod
|
||||
def _match_regex(cls, path: str, pattern: str) -> bool:
|
||||
"""正则表达式匹配"""
|
||||
try:
|
||||
# 限制正则表达式的复杂度,防止 ReDoS 攻击
|
||||
if len(pattern) > 500:
|
||||
return False
|
||||
regex = re.compile(pattern)
|
||||
return bool(regex.match(path))
|
||||
except re.error:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def _match_glob(cls, path: str, pattern: str) -> bool:
|
||||
"""
|
||||
通配符匹配
|
||||
|
||||
支持的语法:
|
||||
- * : 匹配单层目录中的任意字符
|
||||
- ** : 匹配任意层级目录
|
||||
- ? : 匹配单个字符
|
||||
"""
|
||||
# 精确匹配
|
||||
if pattern == path:
|
||||
return True
|
||||
|
||||
# 处理 ** 通配符
|
||||
if "**" in pattern:
|
||||
return cls._match_double_star(path, pattern)
|
||||
|
||||
# 使用 fnmatch 进行标准通配符匹配
|
||||
return fnmatch.fnmatch(path, pattern)
|
||||
|
||||
@classmethod
|
||||
def _match_double_star(cls, path: str, pattern: str) -> bool:
|
||||
"""处理 ** 通配符匹配"""
|
||||
# 将 ** 替换为特殊标记
|
||||
parts = pattern.split("**")
|
||||
|
||||
if len(parts) == 2:
|
||||
prefix, suffix = parts
|
||||
# 移除 prefix 末尾的 / 和 suffix 开头的 /
|
||||
prefix = prefix.rstrip("/") if prefix else ""
|
||||
suffix = suffix.lstrip("/") if suffix else ""
|
||||
|
||||
# 检查前缀匹配
|
||||
if prefix and not path.startswith(prefix):
|
||||
return False
|
||||
|
||||
# 如果没有后缀,只需要前缀匹配
|
||||
if not suffix:
|
||||
return True
|
||||
|
||||
# 检查后缀匹配
|
||||
remaining = path[len(prefix):].lstrip("/") if prefix else path.lstrip("/")
|
||||
|
||||
# 后缀可以出现在任意位置
|
||||
if "*" in suffix or "?" in suffix:
|
||||
# 后缀包含通配符,逐层检查
|
||||
path_parts = remaining.split("/")
|
||||
suffix_parts = suffix.split("/")
|
||||
|
||||
# 简化处理:检查路径的最后几层是否与后缀匹配
|
||||
if len(path_parts) >= len(suffix_parts):
|
||||
tail = "/".join(path_parts[-len(suffix_parts):])
|
||||
return fnmatch.fnmatch(tail, suffix)
|
||||
return False
|
||||
else:
|
||||
# 后缀是精确字符串
|
||||
return remaining.endswith(suffix) or ("/" + suffix) in remaining or remaining == suffix
|
||||
|
||||
# 多个 ** 的情况,使用简化匹配
|
||||
regex_pattern = pattern.replace("**", ".*").replace("*", "[^/]*").replace("?", ".")
|
||||
try:
|
||||
return bool(re.match(f"^{regex_pattern}$", path))
|
||||
except re.error:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_pattern_specificity(cls, pattern: str, is_regex: bool = False) -> int:
|
||||
"""
|
||||
计算模式的具体程度(用于优先级排序)
|
||||
|
||||
返回值越大表示模式越具体
|
||||
"""
|
||||
pattern = cls.normalize_path(pattern)
|
||||
|
||||
if is_regex:
|
||||
# 正则表达式具体程度较低
|
||||
return len(pattern) // 2
|
||||
|
||||
# 精确路径最具体
|
||||
if "*" not in pattern and "?" not in pattern:
|
||||
return len(pattern) * 10
|
||||
|
||||
# 计算非通配符部分的长度
|
||||
specificity = 0
|
||||
parts = pattern.split("/")
|
||||
for part in parts:
|
||||
if part == "**":
|
||||
specificity += 1
|
||||
elif "*" in part or "?" in part:
|
||||
specificity += 5
|
||||
else:
|
||||
specificity += 10
|
||||
|
||||
return specificity
|
||||
340
domain/permission/service.py
Normal file
340
domain/permission/service.py
Normal file
@@ -0,0 +1,340 @@
|
||||
from typing import List, Optional
|
||||
from fastapi import HTTPException
|
||||
|
||||
from models.database import (
|
||||
UserAccount,
|
||||
UserRole,
|
||||
RolePermission,
|
||||
PathRule,
|
||||
)
|
||||
from .matcher import PathMatcher
|
||||
from .types import (
|
||||
PathAction,
|
||||
PathRuleInfo,
|
||||
PathPermissionResult,
|
||||
UserPermissions,
|
||||
PermissionInfo,
|
||||
PERMISSION_DEFINITIONS,
|
||||
)
|
||||
|
||||
|
||||
class PermissionService:
|
||||
"""权限检查服务"""
|
||||
|
||||
# 权限检查结果缓存(简单的内存缓存)
|
||||
_cache: dict[str, tuple[bool, float]] = {}
|
||||
_cache_ttl = 300 # 5分钟缓存
|
||||
|
||||
@classmethod
|
||||
async def check_path_permission(
|
||||
cls, user_id: int, path: str, action: str
|
||||
) -> bool:
|
||||
"""
|
||||
检查用户对路径的操作权限
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
path: 要检查的路径
|
||||
action: 操作类型 (read/write/delete/share)
|
||||
|
||||
Returns:
|
||||
是否有权限
|
||||
"""
|
||||
import time
|
||||
|
||||
# 检查缓存
|
||||
cache_key = f"{user_id}:{path}:{action}"
|
||||
if cache_key in cls._cache:
|
||||
result, timestamp = cls._cache[cache_key]
|
||||
if time.time() - timestamp < cls._cache_ttl:
|
||||
return result
|
||||
|
||||
# 获取用户
|
||||
user = await UserAccount.get_or_none(id=user_id)
|
||||
if not user:
|
||||
return False
|
||||
|
||||
# 超级管理员直接放行
|
||||
if user.is_admin:
|
||||
cls._cache[cache_key] = (True, time.time())
|
||||
return True
|
||||
|
||||
# 获取用户所有角色
|
||||
user_roles = await UserRole.filter(user_id=user_id).prefetch_related("role")
|
||||
role_ids = [ur.role_id for ur in user_roles]
|
||||
|
||||
if not role_ids:
|
||||
cls._cache[cache_key] = (False, time.time())
|
||||
return False
|
||||
|
||||
# 获取所有角色的路径规则
|
||||
path_rules = await PathRule.filter(role_id__in=role_ids).order_by("-priority")
|
||||
|
||||
# 规范化路径
|
||||
normalized_path = PathMatcher.normalize_path(path)
|
||||
|
||||
# 按优先级和具体程度匹配
|
||||
result = cls._match_path_rules(normalized_path, action, list(path_rules))
|
||||
|
||||
# 如果没有匹配到规则,检查父目录(继承)
|
||||
if result is None:
|
||||
parent_path = PathMatcher.get_parent_path(normalized_path)
|
||||
if parent_path:
|
||||
result = await cls.check_path_permission(user_id, parent_path, action)
|
||||
else:
|
||||
result = False # 默认拒绝
|
||||
|
||||
cls._cache[cache_key] = (result, time.time())
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def _match_path_rules(
|
||||
cls, path: str, action: str, rules: List[PathRule]
|
||||
) -> Optional[bool]:
|
||||
"""
|
||||
匹配路径规则
|
||||
|
||||
Returns:
|
||||
True/False 表示明确的权限结果,None 表示没有匹配到规则
|
||||
"""
|
||||
# 按优先级和具体程度排序
|
||||
sorted_rules = sorted(
|
||||
rules,
|
||||
key=lambda r: (
|
||||
r.priority,
|
||||
PathMatcher.get_pattern_specificity(r.path_pattern, r.is_regex),
|
||||
),
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
for rule in sorted_rules:
|
||||
if PathMatcher.match_pattern(path, rule.path_pattern, rule.is_regex):
|
||||
# 匹配到规则,检查具体操作权限
|
||||
if action == PathAction.READ:
|
||||
return rule.can_read
|
||||
elif action == PathAction.WRITE:
|
||||
return rule.can_write
|
||||
elif action == PathAction.DELETE:
|
||||
return rule.can_delete
|
||||
elif action == PathAction.SHARE:
|
||||
return rule.can_share
|
||||
else:
|
||||
return False
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
async def check_system_permission(cls, user_id: int, permission_code: str) -> bool:
|
||||
"""检查用户的系统/适配器权限"""
|
||||
# 获取用户
|
||||
user = await UserAccount.get_or_none(id=user_id)
|
||||
if not user:
|
||||
return False
|
||||
|
||||
# 超级管理员直接放行
|
||||
if user.is_admin:
|
||||
return True
|
||||
|
||||
# 获取用户所有角色
|
||||
user_roles = await UserRole.filter(user_id=user_id)
|
||||
role_ids = [ur.role_id for ur in user_roles]
|
||||
|
||||
if not role_ids:
|
||||
return False
|
||||
|
||||
role_permission = await RolePermission.filter(
|
||||
role_id__in=role_ids, permission_code=permission_code
|
||||
).first()
|
||||
|
||||
return role_permission is not None
|
||||
|
||||
@classmethod
|
||||
async def require_path_permission(
|
||||
cls, user_id: int, path: str, action: str
|
||||
) -> None:
|
||||
"""要求用户具有路径权限,否则抛出 403"""
|
||||
if not await cls.check_path_permission(user_id, path, action):
|
||||
raise HTTPException(403, detail=f"没有权限执行此操作: {action}")
|
||||
|
||||
@classmethod
|
||||
async def require_system_permission(
|
||||
cls, user_id: int, permission_code: str
|
||||
) -> None:
|
||||
"""要求用户具有系统权限,否则抛出 403"""
|
||||
if not await cls.check_system_permission(user_id, permission_code):
|
||||
raise HTTPException(403, detail=f"没有权限: {permission_code}")
|
||||
|
||||
@classmethod
|
||||
async def get_user_permissions(cls, user_id: int) -> UserPermissions:
|
||||
"""获取用户的所有权限"""
|
||||
user = await UserAccount.get_or_none(id=user_id)
|
||||
if not user:
|
||||
raise HTTPException(404, detail="用户不存在")
|
||||
|
||||
# 超级管理员拥有所有权限
|
||||
if user.is_admin:
|
||||
all_permission_codes = [item["code"] for item in PERMISSION_DEFINITIONS]
|
||||
all_path_rules = await PathRule.all()
|
||||
return UserPermissions(
|
||||
user_id=user_id,
|
||||
is_admin=True,
|
||||
permissions=all_permission_codes,
|
||||
path_rules=[
|
||||
PathRuleInfo(
|
||||
id=r.id,
|
||||
role_id=r.role_id,
|
||||
path_pattern=r.path_pattern,
|
||||
is_regex=r.is_regex,
|
||||
can_read=r.can_read,
|
||||
can_write=r.can_write,
|
||||
can_delete=r.can_delete,
|
||||
can_share=r.can_share,
|
||||
priority=r.priority,
|
||||
created_at=r.created_at,
|
||||
)
|
||||
for r in all_path_rules
|
||||
],
|
||||
)
|
||||
|
||||
# 获取用户角色
|
||||
user_roles = await UserRole.filter(user_id=user_id)
|
||||
role_ids = [ur.role_id for ur in user_roles]
|
||||
|
||||
# 获取权限
|
||||
permissions = []
|
||||
if role_ids:
|
||||
role_permissions = await RolePermission.filter(role_id__in=role_ids)
|
||||
permissions = sorted(set(rp.permission_code for rp in role_permissions))
|
||||
|
||||
# 获取路径规则
|
||||
path_rules = []
|
||||
if role_ids:
|
||||
rules = await PathRule.filter(role_id__in=role_ids)
|
||||
path_rules = [
|
||||
PathRuleInfo(
|
||||
id=r.id,
|
||||
role_id=r.role_id,
|
||||
path_pattern=r.path_pattern,
|
||||
is_regex=r.is_regex,
|
||||
can_read=r.can_read,
|
||||
can_write=r.can_write,
|
||||
can_delete=r.can_delete,
|
||||
can_share=r.can_share,
|
||||
priority=r.priority,
|
||||
created_at=r.created_at,
|
||||
)
|
||||
for r in rules
|
||||
]
|
||||
|
||||
return UserPermissions(
|
||||
user_id=user_id,
|
||||
is_admin=False,
|
||||
permissions=permissions,
|
||||
path_rules=path_rules,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def get_all_permissions(cls) -> List[PermissionInfo]:
|
||||
"""获取所有权限定义"""
|
||||
return [
|
||||
PermissionInfo(
|
||||
code=item["code"],
|
||||
name=item["name"],
|
||||
category=item["category"],
|
||||
description=item.get("description"),
|
||||
)
|
||||
for item in PERMISSION_DEFINITIONS
|
||||
]
|
||||
|
||||
@classmethod
|
||||
async def check_path_permission_detailed(
|
||||
cls, user_id: int, path: str, action: str
|
||||
) -> PathPermissionResult:
|
||||
"""检查路径权限并返回详细结果"""
|
||||
user = await UserAccount.get_or_none(id=user_id)
|
||||
if not user:
|
||||
return PathPermissionResult(path=path, action=action, allowed=False)
|
||||
|
||||
# 超级管理员
|
||||
if user.is_admin:
|
||||
return PathPermissionResult(path=path, action=action, allowed=True)
|
||||
|
||||
# 获取用户角色
|
||||
user_roles = await UserRole.filter(user_id=user_id)
|
||||
role_ids = [ur.role_id for ur in user_roles]
|
||||
|
||||
if not role_ids:
|
||||
return PathPermissionResult(path=path, action=action, allowed=False)
|
||||
|
||||
# 获取路径规则
|
||||
path_rules = await PathRule.filter(role_id__in=role_ids).order_by("-priority")
|
||||
normalized_path = PathMatcher.normalize_path(path)
|
||||
|
||||
# 查找匹配的规则
|
||||
matched_rule = None
|
||||
for rule in sorted(
|
||||
path_rules,
|
||||
key=lambda r: (
|
||||
r.priority,
|
||||
PathMatcher.get_pattern_specificity(r.path_pattern, r.is_regex),
|
||||
),
|
||||
reverse=True,
|
||||
):
|
||||
if PathMatcher.match_pattern(
|
||||
normalized_path, rule.path_pattern, rule.is_regex
|
||||
):
|
||||
matched_rule = rule
|
||||
break
|
||||
|
||||
# 检查权限
|
||||
allowed = False
|
||||
if matched_rule:
|
||||
if action == PathAction.READ:
|
||||
allowed = matched_rule.can_read
|
||||
elif action == PathAction.WRITE:
|
||||
allowed = matched_rule.can_write
|
||||
elif action == PathAction.DELETE:
|
||||
allowed = matched_rule.can_delete
|
||||
elif action == PathAction.SHARE:
|
||||
allowed = matched_rule.can_share
|
||||
|
||||
rule_info = None
|
||||
if matched_rule:
|
||||
rule_info = PathRuleInfo(
|
||||
id=matched_rule.id,
|
||||
role_id=matched_rule.role_id,
|
||||
path_pattern=matched_rule.path_pattern,
|
||||
is_regex=matched_rule.is_regex,
|
||||
can_read=matched_rule.can_read,
|
||||
can_write=matched_rule.can_write,
|
||||
can_delete=matched_rule.can_delete,
|
||||
can_share=matched_rule.can_share,
|
||||
priority=matched_rule.priority,
|
||||
created_at=matched_rule.created_at,
|
||||
)
|
||||
|
||||
return PathPermissionResult(
|
||||
path=path, action=action, allowed=allowed, matched_rule=rule_info
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def clear_cache(cls, user_id: int | None = None) -> None:
|
||||
"""清除权限缓存"""
|
||||
if user_id is None:
|
||||
cls._cache.clear()
|
||||
else:
|
||||
# 清除特定用户的缓存
|
||||
keys_to_delete = [k for k in cls._cache if k.startswith(f"{user_id}:")]
|
||||
for k in keys_to_delete:
|
||||
del cls._cache[k]
|
||||
|
||||
@classmethod
|
||||
async def filter_paths_by_permission(
|
||||
cls, user_id: int, paths: List[str], action: str
|
||||
) -> List[str]:
|
||||
"""过滤出用户有权限的路径列表"""
|
||||
result = []
|
||||
for path in paths:
|
||||
if await cls.check_path_permission(user_id, path, action):
|
||||
result.append(path)
|
||||
return result
|
||||
107
domain/permission/types.py
Normal file
107
domain/permission/types.py
Normal file
@@ -0,0 +1,107 @@
|
||||
from pydantic import BaseModel
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
# 权限操作类型
|
||||
class PathAction:
|
||||
READ = "read"
|
||||
WRITE = "write"
|
||||
DELETE = "delete"
|
||||
SHARE = "share"
|
||||
|
||||
|
||||
# 系统权限代码
|
||||
class SystemPermission:
|
||||
USER_CREATE = "system.user.create"
|
||||
USER_EDIT = "system.user.edit"
|
||||
USER_DELETE = "system.user.delete"
|
||||
USER_LIST = "system.user.list"
|
||||
ROLE_MANAGE = "system.role.manage"
|
||||
CONFIG_EDIT = "system.config.edit"
|
||||
AUDIT_VIEW = "system.audit.view"
|
||||
|
||||
|
||||
# 适配器权限代码
|
||||
class AdapterPermission:
|
||||
CREATE = "adapter.create"
|
||||
EDIT = "adapter.edit"
|
||||
DELETE = "adapter.delete"
|
||||
LIST = "adapter.list"
|
||||
|
||||
|
||||
# 所有权限定义
|
||||
PERMISSION_DEFINITIONS = [
|
||||
# 系统权限
|
||||
{"code": SystemPermission.USER_CREATE, "name": "创建用户", "category": "system", "description": "允许创建新用户"},
|
||||
{"code": SystemPermission.USER_EDIT, "name": "编辑用户", "category": "system", "description": "允许编辑用户信息"},
|
||||
{"code": SystemPermission.USER_DELETE, "name": "删除用户", "category": "system", "description": "允许删除用户"},
|
||||
{"code": SystemPermission.USER_LIST, "name": "查看用户列表", "category": "system", "description": "允许查看用户列表"},
|
||||
{"code": SystemPermission.ROLE_MANAGE, "name": "管理角色和权限", "category": "system", "description": "允许管理角色和权限配置"},
|
||||
{"code": SystemPermission.CONFIG_EDIT, "name": "修改系统配置", "category": "system", "description": "允许修改系统配置"},
|
||||
{"code": SystemPermission.AUDIT_VIEW, "name": "查看审计日志", "category": "system", "description": "允许查看审计日志"},
|
||||
# 适配器权限
|
||||
{"code": AdapterPermission.CREATE, "name": "创建存储适配器", "category": "adapter", "description": "允许创建存储适配器"},
|
||||
{"code": AdapterPermission.EDIT, "name": "编辑存储适配器", "category": "adapter", "description": "允许编辑存储适配器"},
|
||||
{"code": AdapterPermission.DELETE, "name": "删除存储适配器", "category": "adapter", "description": "允许删除存储适配器"},
|
||||
{"code": AdapterPermission.LIST, "name": "查看存储适配器列表", "category": "adapter", "description": "允许查看存储适配器列表"},
|
||||
]
|
||||
|
||||
|
||||
# Pydantic 模型
|
||||
class PermissionInfo(BaseModel):
|
||||
code: str
|
||||
name: str
|
||||
category: str
|
||||
description: str | None = None
|
||||
|
||||
|
||||
class PathRuleInfo(BaseModel):
|
||||
id: int
|
||||
role_id: int
|
||||
path_pattern: str
|
||||
is_regex: bool
|
||||
can_read: bool
|
||||
can_write: bool
|
||||
can_delete: bool
|
||||
can_share: bool
|
||||
priority: int
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class PathRuleCreate(BaseModel):
|
||||
path_pattern: str
|
||||
is_regex: bool = False
|
||||
can_read: bool = True
|
||||
can_write: bool = False
|
||||
can_delete: bool = False
|
||||
can_share: bool = False
|
||||
priority: int = 0
|
||||
|
||||
|
||||
class PathRuleUpdate(BaseModel):
|
||||
path_pattern: str | None = None
|
||||
is_regex: bool | None = None
|
||||
can_read: bool | None = None
|
||||
can_write: bool | None = None
|
||||
can_delete: bool | None = None
|
||||
can_share: bool | None = None
|
||||
priority: int | None = None
|
||||
|
||||
|
||||
class PathPermissionCheck(BaseModel):
|
||||
path: str
|
||||
action: str
|
||||
|
||||
|
||||
class PathPermissionResult(BaseModel):
|
||||
path: str
|
||||
action: str
|
||||
allowed: bool
|
||||
matched_rule: PathRuleInfo | None = None
|
||||
|
||||
|
||||
class UserPermissions(BaseModel):
|
||||
user_id: int
|
||||
is_admin: bool
|
||||
permissions: list[str] # 系统/适配器权限代码列表
|
||||
path_rules: list[PathRuleInfo] # 路径权限规则
|
||||
17
domain/plugins/__init__.py
Normal file
17
domain/plugins/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""
|
||||
Foxel 插件系统
|
||||
|
||||
提供 .foxpkg 插件包的安装、管理和运行时加载功能。
|
||||
"""
|
||||
|
||||
from .loader import PluginLoadError, PluginLoader
|
||||
from .service import PluginService
|
||||
from .startup import init_plugins, load_installed_plugins
|
||||
|
||||
__all__ = [
|
||||
"PluginLoader",
|
||||
"PluginLoadError",
|
||||
"PluginService",
|
||||
"init_plugins",
|
||||
"load_installed_plugins",
|
||||
]
|
||||
131
domain/plugins/api.py
Normal file
131
domain/plugins/api.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""
|
||||
插件管理 API 路由
|
||||
"""
|
||||
|
||||
from typing import Annotated, List
|
||||
|
||||
from fastapi import APIRouter, Depends, File, Request, UploadFile
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
from domain.audit import AuditAction, audit
|
||||
from domain.auth import User, get_current_active_user
|
||||
from domain.permission import require_system_permission
|
||||
from domain.permission.types import SystemPermission
|
||||
from .service import PluginService
|
||||
from .types import (
|
||||
PluginInstallResult,
|
||||
PluginOut,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/api/plugins", tags=["plugins"])
|
||||
|
||||
|
||||
# ========== 安装 ==========
|
||||
|
||||
|
||||
@router.post("/install", response_model=PluginInstallResult)
|
||||
@audit(action=AuditAction.CREATE, description="安装插件包")
|
||||
@require_system_permission(SystemPermission.ROLE_MANAGE)
|
||||
async def install_plugin(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
file: UploadFile = File(...),
|
||||
):
|
||||
"""
|
||||
安装 .foxpkg 插件包
|
||||
|
||||
上传 .foxpkg 文件进行安装。
|
||||
"""
|
||||
content = await file.read()
|
||||
return await PluginService.install_package(content, file.filename or "plugin.foxpkg")
|
||||
|
||||
|
||||
# ========== 插件列表和详情 ==========
|
||||
|
||||
|
||||
@router.get("", response_model=List[PluginOut])
|
||||
@audit(action=AuditAction.READ, description="获取插件列表")
|
||||
async def list_plugins(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
"""获取已安装的插件列表"""
|
||||
return await PluginService.list_plugins()
|
||||
|
||||
|
||||
@router.get("/{key_or_id}", response_model=PluginOut)
|
||||
@audit(action=AuditAction.READ, description="获取插件详情")
|
||||
async def get_plugin(
|
||||
request: Request,
|
||||
key_or_id: str,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
"""获取单个插件详情"""
|
||||
return await PluginService.get_plugin(key_or_id)
|
||||
|
||||
|
||||
# ========== 插件管理 ==========
|
||||
|
||||
|
||||
@router.delete("/{key_or_id}")
|
||||
@audit(action=AuditAction.DELETE, description="卸载插件")
|
||||
@require_system_permission(SystemPermission.ROLE_MANAGE)
|
||||
async def delete_plugin(
|
||||
request: Request,
|
||||
key_or_id: str,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
"""卸载插件"""
|
||||
await PluginService.delete(key_or_id)
|
||||
return {"code": 0, "msg": "ok"}
|
||||
|
||||
|
||||
# ========== 插件资源 ==========
|
||||
|
||||
|
||||
@router.get("/{key_or_id}/bundle.js")
|
||||
async def get_bundle(request: Request, key_or_id: str):
|
||||
"""获取插件前端 bundle"""
|
||||
path = await PluginService.get_bundle_path(key_or_id)
|
||||
v = (request.query_params.get("v") or "").strip()
|
||||
cache_control = "public, max-age=31536000, immutable" if v else "no-cache"
|
||||
return FileResponse(
|
||||
path,
|
||||
media_type="application/javascript",
|
||||
headers={"Cache-Control": cache_control},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{key}/assets/{asset_path:path}")
|
||||
async def get_asset(request: Request, key: str, asset_path: str):
|
||||
"""获取插件静态资源"""
|
||||
path = await PluginService.get_asset_path(key, asset_path)
|
||||
|
||||
# 根据扩展名确定 MIME 类型
|
||||
ext = path.suffix.lower()
|
||||
media_types = {
|
||||
".js": "application/javascript",
|
||||
".css": "text/css",
|
||||
".json": "application/json",
|
||||
".svg": "image/svg+xml",
|
||||
".png": "image/png",
|
||||
".jpg": "image/jpeg",
|
||||
".jpeg": "image/jpeg",
|
||||
".gif": "image/gif",
|
||||
".webp": "image/webp",
|
||||
".ico": "image/x-icon",
|
||||
".woff": "font/woff",
|
||||
".woff2": "font/woff2",
|
||||
".ttf": "font/ttf",
|
||||
".eot": "application/vnd.ms-fontobject",
|
||||
".html": "text/html",
|
||||
".txt": "text/plain",
|
||||
".md": "text/markdown",
|
||||
}
|
||||
media_type = media_types.get(ext, "application/octet-stream")
|
||||
|
||||
return FileResponse(
|
||||
path,
|
||||
media_type=media_type,
|
||||
headers={"Cache-Control": "public, max-age=3600"},
|
||||
)
|
||||
449
domain/plugins/loader.py
Normal file
449
domain/plugins/loader.py
Normal file
@@ -0,0 +1,449 @@
|
||||
"""
|
||||
插件加载器模块
|
||||
|
||||
负责:
|
||||
1. .foxpkg 解包和验证
|
||||
2. 插件文件部署
|
||||
3. 后端路由动态加载
|
||||
4. 处理器动态注册
|
||||
"""
|
||||
|
||||
import io
|
||||
import json
|
||||
import shutil
|
||||
import sys
|
||||
import zipfile
|
||||
from importlib.util import module_from_spec, spec_from_file_location
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from .types import (
|
||||
ManifestProcessorConfig,
|
||||
ManifestRouteConfig,
|
||||
PluginManifest,
|
||||
)
|
||||
|
||||
|
||||
class PluginLoadError(Exception):
|
||||
"""插件加载错误"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class PluginLoader:
|
||||
"""插件加载器"""
|
||||
|
||||
PLUGINS_ROOT = Path("data/plugins")
|
||||
|
||||
# 已加载的插件模块缓存
|
||||
_loaded_modules: Dict[str, ModuleType] = {}
|
||||
# 已挂载的路由追踪
|
||||
_mounted_routers: Dict[str, List[APIRouter]] = {}
|
||||
|
||||
@classmethod
|
||||
def get_plugin_dir(cls, plugin_key: str) -> Path:
|
||||
"""获取插件目录"""
|
||||
return cls.PLUGINS_ROOT / plugin_key
|
||||
|
||||
@classmethod
|
||||
def get_manifest_path(cls, plugin_key: str) -> Path:
|
||||
"""获取插件 manifest.json 路径"""
|
||||
return cls.get_plugin_dir(plugin_key) / "manifest.json"
|
||||
|
||||
@classmethod
|
||||
def get_frontend_bundle_path(cls, plugin_key: str, entry: Optional[str] = None) -> Path:
|
||||
"""获取前端 bundle 路径"""
|
||||
plugin_dir = cls.get_plugin_dir(plugin_key)
|
||||
if entry:
|
||||
return plugin_dir / entry
|
||||
# 默认位置
|
||||
return plugin_dir / "frontend" / "index.js"
|
||||
|
||||
@classmethod
|
||||
def get_asset_path(cls, plugin_key: str, asset_path: str) -> Path:
|
||||
"""获取静态资源路径"""
|
||||
return cls.get_plugin_dir(plugin_key) / asset_path
|
||||
|
||||
# ========== 解包和验证 ==========
|
||||
|
||||
@classmethod
|
||||
def validate_manifest(cls, manifest_data: Dict[str, Any]) -> Tuple[bool, List[str]]:
|
||||
"""验证 manifest 数据"""
|
||||
errors: List[str] = []
|
||||
|
||||
# 必需字段检查
|
||||
if not manifest_data.get("key"):
|
||||
errors.append("manifest 缺少必需字段: key")
|
||||
if not manifest_data.get("name"):
|
||||
errors.append("manifest 缺少必需字段: name")
|
||||
|
||||
# key 格式检查(Java 命名空间格式)
|
||||
key = manifest_data.get("key", "")
|
||||
if key:
|
||||
import re
|
||||
|
||||
# 格式: com.example.plugin (至少两级,每级以小写字母开头,可包含小写字母和数字)
|
||||
if not re.match(r"^[a-z][a-z0-9]*(\.[a-z][a-z0-9]*)+$", key):
|
||||
errors.append(
|
||||
"key 格式无效:必须使用命名空间格式(如 com.example.plugin),"
|
||||
"每个部分以小写字母开头,只能包含小写字母和数字,至少两级"
|
||||
)
|
||||
|
||||
# 版本格式检查(简单检查)
|
||||
version = manifest_data.get("version", "")
|
||||
if version and not isinstance(version, str):
|
||||
errors.append("version 必须是字符串")
|
||||
|
||||
# 验证 frontend 配置
|
||||
frontend = manifest_data.get("frontend")
|
||||
if frontend and isinstance(frontend, dict):
|
||||
if frontend.get("entry") and not isinstance(frontend["entry"], str):
|
||||
errors.append("frontend.entry 必须是字符串")
|
||||
if frontend.get("styles") is not None:
|
||||
if not isinstance(frontend["styles"], list) or not all(
|
||||
isinstance(x, str) for x in frontend["styles"]
|
||||
):
|
||||
errors.append("frontend.styles 必须是字符串数组")
|
||||
supported_exts = frontend.get("supportedExts") or frontend.get("supported_exts")
|
||||
if supported_exts and not isinstance(supported_exts, list):
|
||||
errors.append("frontend.supportedExts 必须是数组")
|
||||
use_system_window = frontend.get("useSystemWindow") or frontend.get("use_system_window")
|
||||
if use_system_window is not None and not isinstance(use_system_window, bool):
|
||||
errors.append("frontend.useSystemWindow 必须是布尔值")
|
||||
|
||||
# 验证 backend 配置
|
||||
backend = manifest_data.get("backend")
|
||||
if backend and isinstance(backend, dict):
|
||||
routes = backend.get("routes", [])
|
||||
if routes:
|
||||
for i, route in enumerate(routes):
|
||||
if not route.get("module"):
|
||||
errors.append(f"backend.routes[{i}] 缺少 module")
|
||||
if not route.get("prefix"):
|
||||
errors.append(f"backend.routes[{i}] 缺少 prefix")
|
||||
|
||||
processors = backend.get("processors", [])
|
||||
if processors:
|
||||
for i, proc in enumerate(processors):
|
||||
if not proc.get("module"):
|
||||
errors.append(f"backend.processors[{i}] 缺少 module")
|
||||
if not proc.get("type"):
|
||||
errors.append(f"backend.processors[{i}] 缺少 type")
|
||||
|
||||
return len(errors) == 0, errors
|
||||
|
||||
@classmethod
|
||||
def unpack_foxpkg(
|
||||
cls, file_content: bytes, target_key: Optional[str] = None
|
||||
) -> Tuple[PluginManifest, Path]:
|
||||
"""
|
||||
解包 .foxpkg 文件
|
||||
|
||||
Args:
|
||||
file_content: .foxpkg 文件内容
|
||||
target_key: 可选,指定安装的插件 key(覆盖 manifest 中的 key)
|
||||
|
||||
Returns:
|
||||
(manifest, plugin_dir) 元组
|
||||
|
||||
Raises:
|
||||
PluginLoadError: 解包或验证失败
|
||||
"""
|
||||
try:
|
||||
with zipfile.ZipFile(io.BytesIO(file_content)) as zf:
|
||||
# 读取 manifest.json
|
||||
try:
|
||||
manifest_bytes = zf.read("manifest.json")
|
||||
except KeyError:
|
||||
raise PluginLoadError("插件包缺少 manifest.json")
|
||||
|
||||
try:
|
||||
manifest_data = json.loads(manifest_bytes.decode("utf-8"))
|
||||
except json.JSONDecodeError as e:
|
||||
raise PluginLoadError(f"manifest.json 解析失败: {e}")
|
||||
|
||||
# 验证 manifest
|
||||
valid, errors = cls.validate_manifest(manifest_data)
|
||||
if not valid:
|
||||
raise PluginLoadError(f"manifest 验证失败: {'; '.join(errors)}")
|
||||
|
||||
# 解析 manifest
|
||||
try:
|
||||
manifest = PluginManifest.model_validate(manifest_data)
|
||||
except Exception as e:
|
||||
raise PluginLoadError(f"manifest 解析失败: {e}")
|
||||
|
||||
# 确定插件 key
|
||||
plugin_key = target_key or manifest.key
|
||||
|
||||
# 验证包内文件
|
||||
cls._validate_package_files(zf, manifest)
|
||||
|
||||
# 部署文件
|
||||
target_dir = cls.PLUGINS_ROOT / plugin_key
|
||||
if target_dir.exists():
|
||||
# 备份旧版本
|
||||
backup_dir = cls.PLUGINS_ROOT / f"{plugin_key}.backup"
|
||||
if backup_dir.exists():
|
||||
shutil.rmtree(backup_dir)
|
||||
shutil.move(str(target_dir), str(backup_dir))
|
||||
|
||||
target_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
zf.extractall(target_dir)
|
||||
except Exception as e:
|
||||
# 恢复备份
|
||||
if (cls.PLUGINS_ROOT / f"{plugin_key}.backup").exists():
|
||||
shutil.rmtree(target_dir, ignore_errors=True)
|
||||
shutil.move(str(cls.PLUGINS_ROOT / f"{plugin_key}.backup"), str(target_dir))
|
||||
raise PluginLoadError(f"文件解压失败: {e}")
|
||||
|
||||
# 清理备份
|
||||
backup_dir = cls.PLUGINS_ROOT / f"{plugin_key}.backup"
|
||||
if backup_dir.exists():
|
||||
shutil.rmtree(backup_dir, ignore_errors=True)
|
||||
|
||||
return manifest, target_dir
|
||||
|
||||
except zipfile.BadZipFile:
|
||||
raise PluginLoadError("无效的插件包格式(非 ZIP 文件)")
|
||||
|
||||
@classmethod
|
||||
def _validate_package_files(cls, zf: zipfile.ZipFile, manifest: PluginManifest) -> None:
|
||||
"""验证包内文件是否完整"""
|
||||
file_list = zf.namelist()
|
||||
|
||||
# 检查前端入口
|
||||
if manifest.frontend and manifest.frontend.entry:
|
||||
if manifest.frontend.entry not in file_list:
|
||||
raise PluginLoadError(f"前端入口文件不存在: {manifest.frontend.entry}")
|
||||
|
||||
# 检查后端模块
|
||||
if manifest.backend:
|
||||
if manifest.backend.routes:
|
||||
for route in manifest.backend.routes:
|
||||
if route.module not in file_list:
|
||||
raise PluginLoadError(f"路由模块不存在: {route.module}")
|
||||
|
||||
if manifest.backend.processors:
|
||||
for proc in manifest.backend.processors:
|
||||
if proc.module not in file_list:
|
||||
raise PluginLoadError(f"处理器模块不存在: {proc.module}")
|
||||
|
||||
# ========== 路由动态加载 ==========
|
||||
|
||||
@classmethod
|
||||
def load_route_module(cls, plugin_key: str, route_config: ManifestRouteConfig) -> APIRouter:
|
||||
"""
|
||||
动态加载插件路由模块
|
||||
|
||||
Args:
|
||||
plugin_key: 插件标识
|
||||
route_config: 路由配置
|
||||
|
||||
Returns:
|
||||
加载的 APIRouter
|
||||
"""
|
||||
module_path = cls.get_plugin_dir(plugin_key) / route_config.module
|
||||
|
||||
if not module_path.exists():
|
||||
raise PluginLoadError(f"路由模块不存在: {module_path}")
|
||||
|
||||
module_name = f"foxel_plugin_{plugin_key}_route_{module_path.stem}"
|
||||
|
||||
try:
|
||||
spec = spec_from_file_location(module_name, module_path)
|
||||
if spec is None or spec.loader is None:
|
||||
raise PluginLoadError(f"无法加载路由模块: {module_path}")
|
||||
|
||||
module = module_from_spec(spec)
|
||||
sys.modules[module_name] = module
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
# 缓存模块
|
||||
cls._loaded_modules[f"{plugin_key}:route:{route_config.module}"] = module
|
||||
|
||||
# 获取 router
|
||||
router = getattr(module, "router", None)
|
||||
if router is None:
|
||||
raise PluginLoadError(f"路由模块缺少 'router' 对象: {module_path}")
|
||||
|
||||
if not isinstance(router, APIRouter):
|
||||
raise PluginLoadError(f"'router' 不是有效的 APIRouter 实例: {module_path}")
|
||||
|
||||
# 创建包装路由器添加前缀
|
||||
wrapper = APIRouter(prefix=route_config.prefix, tags=route_config.tags or [])
|
||||
wrapper.include_router(router)
|
||||
|
||||
return wrapper
|
||||
|
||||
except PluginLoadError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise PluginLoadError(f"加载路由模块失败 [{module_path}]: {e}")
|
||||
|
||||
@classmethod
|
||||
def load_all_routes(cls, plugin_key: str, manifest: PluginManifest) -> List[APIRouter]:
|
||||
"""加载插件的所有路由"""
|
||||
routers: List[APIRouter] = []
|
||||
|
||||
if not manifest.backend or not manifest.backend.routes:
|
||||
return routers
|
||||
|
||||
for route_config in manifest.backend.routes:
|
||||
router = cls.load_route_module(plugin_key, route_config)
|
||||
routers.append(router)
|
||||
|
||||
cls._mounted_routers[plugin_key] = routers
|
||||
return routers
|
||||
|
||||
# ========== 处理器动态注册 ==========
|
||||
|
||||
@classmethod
|
||||
def load_processor_module(
|
||||
cls, plugin_key: str, processor_config: ManifestProcessorConfig
|
||||
) -> None:
|
||||
"""
|
||||
动态加载并注册处理器模块
|
||||
|
||||
Args:
|
||||
plugin_key: 插件标识
|
||||
processor_config: 处理器配置
|
||||
"""
|
||||
module_path = cls.get_plugin_dir(plugin_key) / processor_config.module
|
||||
|
||||
if not module_path.exists():
|
||||
raise PluginLoadError(f"处理器模块不存在: {module_path}")
|
||||
|
||||
module_name = f"foxel_plugin_{plugin_key}_processor_{module_path.stem}"
|
||||
|
||||
try:
|
||||
spec = spec_from_file_location(module_name, module_path)
|
||||
if spec is None or spec.loader is None:
|
||||
raise PluginLoadError(f"无法加载处理器模块: {module_path}")
|
||||
|
||||
module = module_from_spec(spec)
|
||||
sys.modules[module_name] = module
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
# 缓存模块
|
||||
cls._loaded_modules[f"{plugin_key}:processor:{processor_config.module}"] = module
|
||||
|
||||
# 获取处理器工厂
|
||||
factory = getattr(module, "PROCESSOR_FACTORY", None)
|
||||
if factory is None:
|
||||
raise PluginLoadError(f"处理器模块缺少 'PROCESSOR_FACTORY': {module_path}")
|
||||
|
||||
# 获取配置 schema
|
||||
config_schema = getattr(module, "CONFIG_SCHEMA", [])
|
||||
processor_name = getattr(module, "PROCESSOR_NAME", processor_config.name or processor_config.type)
|
||||
supported_exts = getattr(module, "SUPPORTED_EXTS", [])
|
||||
|
||||
# 注册到处理器注册表
|
||||
from domain.processors import CONFIG_SCHEMAS, TYPE_MAP
|
||||
|
||||
processor_type = processor_config.type
|
||||
TYPE_MAP[processor_type] = factory
|
||||
|
||||
# 获取实例以读取属性
|
||||
try:
|
||||
sample = factory()
|
||||
produces_file = getattr(sample, "produces_file", False)
|
||||
supports_directory = getattr(sample, "supports_directory", False)
|
||||
except Exception:
|
||||
produces_file = False
|
||||
supports_directory = False
|
||||
|
||||
CONFIG_SCHEMAS[processor_type] = {
|
||||
"type": processor_type,
|
||||
"name": processor_name,
|
||||
"supported_exts": supported_exts,
|
||||
"config_schema": config_schema,
|
||||
"produces_file": produces_file,
|
||||
"supports_directory": supports_directory,
|
||||
"plugin": plugin_key, # 标记来源插件
|
||||
"module_path": str(module_path),
|
||||
}
|
||||
|
||||
except PluginLoadError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise PluginLoadError(f"加载处理器模块失败 [{module_path}]: {e}")
|
||||
|
||||
@classmethod
|
||||
def load_all_processors(cls, plugin_key: str, manifest: PluginManifest) -> List[str]:
|
||||
"""加载插件的所有处理器,返回处理器类型列表"""
|
||||
processor_types: List[str] = []
|
||||
|
||||
if not manifest.backend or not manifest.backend.processors:
|
||||
return processor_types
|
||||
|
||||
for proc_config in manifest.backend.processors:
|
||||
cls.load_processor_module(plugin_key, proc_config)
|
||||
processor_types.append(proc_config.type)
|
||||
|
||||
return processor_types
|
||||
|
||||
# ========== 卸载 ==========
|
||||
|
||||
@classmethod
|
||||
def unload_plugin(cls, plugin_key: str, manifest: Optional[PluginManifest] = None) -> None:
|
||||
"""
|
||||
卸载插件的后端组件
|
||||
|
||||
Args:
|
||||
plugin_key: 插件标识
|
||||
manifest: 可选的 manifest,用于确定要卸载的组件
|
||||
"""
|
||||
# 卸载处理器
|
||||
if manifest and manifest.backend and manifest.backend.processors:
|
||||
from domain.processors import CONFIG_SCHEMAS, TYPE_MAP
|
||||
|
||||
for proc_config in manifest.backend.processors:
|
||||
proc_type = proc_config.type
|
||||
if proc_type in TYPE_MAP:
|
||||
del TYPE_MAP[proc_type]
|
||||
if proc_type in CONFIG_SCHEMAS:
|
||||
del CONFIG_SCHEMAS[proc_type]
|
||||
|
||||
# 清理缓存的模块
|
||||
keys_to_remove = [k for k in cls._loaded_modules if k.startswith(f"{plugin_key}:")]
|
||||
for key in keys_to_remove:
|
||||
module = cls._loaded_modules.pop(key, None)
|
||||
if module and module.__name__ in sys.modules:
|
||||
del sys.modules[module.__name__]
|
||||
|
||||
# 清理路由追踪(注意:FastAPI 不支持动态移除路由,需要重启应用)
|
||||
cls._mounted_routers.pop(plugin_key, None)
|
||||
|
||||
@classmethod
|
||||
def delete_plugin_files(cls, plugin_key: str) -> None:
|
||||
"""删除插件文件"""
|
||||
plugin_dir = cls.get_plugin_dir(plugin_key)
|
||||
if plugin_dir.exists():
|
||||
shutil.rmtree(plugin_dir)
|
||||
|
||||
# 同时删除备份
|
||||
backup_dir = cls.PLUGINS_ROOT / f"{plugin_key}.backup"
|
||||
if backup_dir.exists():
|
||||
shutil.rmtree(backup_dir)
|
||||
|
||||
# ========== 读取 manifest ==========
|
||||
|
||||
@classmethod
|
||||
def read_manifest(cls, plugin_key: str) -> Optional[PluginManifest]:
|
||||
"""从文件系统读取插件 manifest"""
|
||||
manifest_path = cls.get_manifest_path(plugin_key)
|
||||
if not manifest_path.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(manifest_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
return PluginManifest.model_validate(data)
|
||||
except Exception:
|
||||
return None
|
||||
273
domain/plugins/service.py
Normal file
273
domain/plugins/service.py
Normal file
@@ -0,0 +1,273 @@
|
||||
"""
|
||||
插件服务模块
|
||||
|
||||
负责插件的安装、卸载等管理操作
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from .loader import PluginLoadError, PluginLoader
|
||||
from .types import (
|
||||
PluginInstallResult,
|
||||
PluginManifest,
|
||||
PluginOut,
|
||||
)
|
||||
from models.database import Plugin
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PluginService:
|
||||
"""插件服务"""
|
||||
|
||||
_plugins_root = Path("data/plugins")
|
||||
|
||||
# ========== 工具方法 ==========
|
||||
|
||||
@classmethod
|
||||
def _get_plugin_dir(cls, plugin_key: str) -> Path:
|
||||
"""获取插件目录"""
|
||||
return cls._plugins_root / plugin_key
|
||||
|
||||
@classmethod
|
||||
def _get_bundle_path(cls, rec: Plugin) -> Path:
|
||||
"""获取前端 bundle 路径"""
|
||||
plugin_dir = cls._get_plugin_dir(rec.key)
|
||||
# 从 manifest 读取
|
||||
if rec.manifest:
|
||||
frontend = rec.manifest.get("frontend", {})
|
||||
entry = frontend.get("entry")
|
||||
if entry:
|
||||
return plugin_dir / entry
|
||||
# 默认位置
|
||||
return plugin_dir / "frontend" / "index.js"
|
||||
|
||||
@classmethod
|
||||
async def _get_by_key_or_404(cls, key: str) -> Plugin:
|
||||
"""通过 key 获取插件,不存在则返回 404"""
|
||||
rec = await Plugin.get_or_none(key=key)
|
||||
if not rec:
|
||||
raise HTTPException(status_code=404, detail="Plugin not found")
|
||||
return rec
|
||||
|
||||
@classmethod
|
||||
async def _get_by_key_or_id(cls, key_or_id: Union[str, int]) -> Plugin:
|
||||
"""通过 key 或 ID 获取插件"""
|
||||
# 尝试作为 ID
|
||||
if isinstance(key_or_id, int) or (isinstance(key_or_id, str) and key_or_id.isdigit()):
|
||||
plugin_id = int(key_or_id)
|
||||
rec = await Plugin.get_or_none(id=plugin_id)
|
||||
if rec:
|
||||
return rec
|
||||
# 尝试作为 key
|
||||
if isinstance(key_or_id, str):
|
||||
rec = await Plugin.get_or_none(key=key_or_id)
|
||||
if rec:
|
||||
return rec
|
||||
raise HTTPException(status_code=404, detail="Plugin not found")
|
||||
|
||||
# ========== 安装 ==========
|
||||
|
||||
@classmethod
|
||||
async def install_package(cls, file_content: bytes, filename: str) -> PluginInstallResult:
|
||||
"""
|
||||
安装 .foxpkg 插件包
|
||||
|
||||
Args:
|
||||
file_content: 插件包内容
|
||||
filename: 文件名
|
||||
|
||||
Returns:
|
||||
安装结果
|
||||
"""
|
||||
errors: List[str] = []
|
||||
|
||||
try:
|
||||
# 解包
|
||||
manifest, plugin_dir = PluginLoader.unpack_foxpkg(file_content)
|
||||
plugin_key = manifest.key
|
||||
|
||||
# 检查是否已存在
|
||||
existing = await Plugin.get_or_none(key=plugin_key)
|
||||
if existing:
|
||||
# 更新现有插件
|
||||
logger.info(f"更新插件: {plugin_key}")
|
||||
rec = existing
|
||||
else:
|
||||
# 创建新插件
|
||||
logger.info(f"安装新插件: {plugin_key}")
|
||||
rec = Plugin(key=plugin_key)
|
||||
|
||||
# 更新字段
|
||||
rec.name = manifest.name
|
||||
rec.version = manifest.version
|
||||
rec.description = manifest.description
|
||||
rec.author = manifest.author
|
||||
rec.website = manifest.website
|
||||
rec.github = manifest.github
|
||||
rec.license = manifest.license
|
||||
rec.manifest = manifest.model_dump(mode="json")
|
||||
|
||||
# 从 manifest.frontend 提取前端配置
|
||||
if manifest.frontend:
|
||||
rec.open_app = manifest.frontend.open_app or False
|
||||
rec.supported_exts = manifest.frontend.supported_exts
|
||||
rec.default_bounds = manifest.frontend.default_bounds
|
||||
rec.default_maximized = manifest.frontend.default_maximized
|
||||
rec.icon = manifest.frontend.icon
|
||||
|
||||
await rec.save()
|
||||
|
||||
# 加载后端组件(如果有)
|
||||
loaded_routes: List[str] = []
|
||||
loaded_processors: List[str] = []
|
||||
|
||||
if manifest.backend:
|
||||
# 加载路由
|
||||
if manifest.backend.routes:
|
||||
try:
|
||||
from main import app
|
||||
routers = PluginLoader.load_all_routes(plugin_key, manifest)
|
||||
for router in routers:
|
||||
app.include_router(router)
|
||||
loaded_routes.append(router.prefix)
|
||||
except PluginLoadError as e:
|
||||
errors.append(f"路由加载失败: {e}")
|
||||
logger.error(f"插件 {plugin_key} 路由加载失败: {e}")
|
||||
except Exception as e:
|
||||
errors.append(f"路由加载失败: {e}")
|
||||
logger.exception(f"插件 {plugin_key} 路由加载异常")
|
||||
|
||||
# 加载处理器
|
||||
if manifest.backend.processors:
|
||||
try:
|
||||
processor_types = PluginLoader.load_all_processors(plugin_key, manifest)
|
||||
loaded_processors = processor_types
|
||||
except PluginLoadError as e:
|
||||
errors.append(f"处理器加载失败: {e}")
|
||||
logger.error(f"插件 {plugin_key} 处理器加载失败: {e}")
|
||||
except Exception as e:
|
||||
errors.append(f"处理器加载失败: {e}")
|
||||
logger.exception(f"插件 {plugin_key} 处理器加载异常")
|
||||
|
||||
# 更新加载状态
|
||||
rec.loaded_routes = loaded_routes if loaded_routes else None
|
||||
rec.loaded_processors = loaded_processors if loaded_processors else None
|
||||
await rec.save()
|
||||
|
||||
return PluginInstallResult(
|
||||
success=True,
|
||||
plugin=PluginOut.model_validate(rec),
|
||||
message="安装成功" if not errors else "安装完成,但有部分组件加载失败",
|
||||
errors=errors if errors else None,
|
||||
)
|
||||
|
||||
except PluginLoadError as e:
|
||||
logger.error(f"插件安装失败: {e}")
|
||||
return PluginInstallResult(
|
||||
success=False,
|
||||
message=str(e),
|
||||
errors=[str(e)],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("插件安装异常")
|
||||
return PluginInstallResult(
|
||||
success=False,
|
||||
message=f"安装失败: {e}",
|
||||
errors=[str(e)],
|
||||
)
|
||||
|
||||
# ========== 查询 ==========
|
||||
|
||||
@classmethod
|
||||
async def list_plugins(cls) -> List[PluginOut]:
|
||||
"""获取所有插件列表"""
|
||||
rows = await Plugin.all().order_by("-id")
|
||||
for rec in rows:
|
||||
try:
|
||||
manifest = PluginLoader.read_manifest(rec.key)
|
||||
if manifest:
|
||||
rec.manifest = manifest.model_dump(mode="json")
|
||||
except Exception:
|
||||
continue
|
||||
return [PluginOut.model_validate(r) for r in rows]
|
||||
|
||||
@classmethod
|
||||
async def get_plugin(cls, key_or_id: Union[str, int]) -> PluginOut:
|
||||
"""获取单个插件详情"""
|
||||
rec = await cls._get_by_key_or_id(key_or_id)
|
||||
try:
|
||||
manifest = PluginLoader.read_manifest(rec.key)
|
||||
if manifest:
|
||||
rec.manifest = manifest.model_dump(mode="json")
|
||||
except Exception:
|
||||
pass
|
||||
return PluginOut.model_validate(rec)
|
||||
|
||||
@classmethod
|
||||
async def get_bundle_path(cls, key_or_id: Union[str, int]) -> Path:
|
||||
"""获取插件前端 bundle 路径"""
|
||||
rec = await cls._get_by_key_or_id(key_or_id)
|
||||
bundle_path = cls._get_bundle_path(rec)
|
||||
if not bundle_path.exists():
|
||||
raise HTTPException(status_code=404, detail="Plugin bundle not found")
|
||||
return bundle_path
|
||||
|
||||
@classmethod
|
||||
async def get_asset_path(cls, key: str, asset_path: str) -> Path:
|
||||
"""获取插件静态资源路径"""
|
||||
rec = await cls._get_by_key_or_404(key)
|
||||
plugin_dir = cls._get_plugin_dir(rec.key)
|
||||
|
||||
# 安全检查:防止路径遍历
|
||||
asset_path = asset_path.lstrip("/")
|
||||
if ".." in asset_path:
|
||||
raise HTTPException(status_code=400, detail="Invalid asset path")
|
||||
|
||||
full_path = plugin_dir / asset_path
|
||||
if not full_path.exists():
|
||||
raise HTTPException(status_code=404, detail="Asset not found")
|
||||
|
||||
# 确保路径在插件目录内
|
||||
try:
|
||||
full_path.resolve().relative_to(plugin_dir.resolve())
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail="Invalid asset path")
|
||||
|
||||
return full_path
|
||||
|
||||
# ========== 管理操作 ==========
|
||||
|
||||
@classmethod
|
||||
async def delete(cls, key_or_id: Union[str, int]) -> None:
|
||||
"""删除/卸载插件"""
|
||||
rec = await cls._get_by_key_or_id(key_or_id)
|
||||
|
||||
# 获取 manifest 用于卸载组件
|
||||
manifest: Optional[PluginManifest] = None
|
||||
if rec.manifest:
|
||||
try:
|
||||
manifest = PluginManifest.model_validate(rec.manifest)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 卸载后端组件
|
||||
if manifest:
|
||||
PluginLoader.unload_plugin(rec.key, manifest)
|
||||
|
||||
# 删除数据库记录
|
||||
await rec.delete()
|
||||
|
||||
# 删除文件
|
||||
with contextlib.suppress(Exception):
|
||||
plugin_dir = cls._get_plugin_dir(rec.key)
|
||||
if plugin_dir.exists():
|
||||
shutil.rmtree(plugin_dir)
|
||||
|
||||
logger.info(f"插件 {rec.key} 已卸载")
|
||||
115
domain/plugins/startup.py
Normal file
115
domain/plugins/startup.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""
|
||||
插件启动加载模块
|
||||
|
||||
负责在应用启动时加载所有已安装的插件
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, List, Tuple
|
||||
|
||||
from .loader import PluginLoadError, PluginLoader
|
||||
from .types import PluginManifest
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastapi import FastAPI
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def load_installed_plugins(app: "FastAPI") -> Tuple[int, List[str]]:
|
||||
"""
|
||||
加载所有已安装的插件
|
||||
|
||||
Args:
|
||||
app: FastAPI 应用实例
|
||||
|
||||
Returns:
|
||||
(成功加载数量, 错误列表)
|
||||
"""
|
||||
from models.database import Plugin
|
||||
|
||||
errors: List[str] = []
|
||||
loaded_count = 0
|
||||
|
||||
try:
|
||||
plugins = await Plugin.all()
|
||||
except Exception as e:
|
||||
logger.error(f"查询插件列表失败: {e}")
|
||||
return 0, [f"查询插件列表失败: {e}"]
|
||||
|
||||
for plugin in plugins:
|
||||
if not plugin.key:
|
||||
continue
|
||||
|
||||
try:
|
||||
# 获取 manifest
|
||||
manifest = None
|
||||
if plugin.manifest:
|
||||
try:
|
||||
manifest = PluginManifest.model_validate(plugin.manifest)
|
||||
except Exception:
|
||||
# 尝试从文件系统读取
|
||||
manifest = PluginLoader.read_manifest(plugin.key)
|
||||
else:
|
||||
manifest = PluginLoader.read_manifest(plugin.key)
|
||||
|
||||
if not manifest:
|
||||
logger.warning(f"插件 {plugin.key} 缺少 manifest,跳过加载")
|
||||
continue
|
||||
|
||||
# 加载后端路由
|
||||
loaded_routes: List[str] = []
|
||||
if manifest.backend and manifest.backend.routes:
|
||||
try:
|
||||
routers = PluginLoader.load_all_routes(plugin.key, manifest)
|
||||
for router in routers:
|
||||
app.include_router(router)
|
||||
loaded_routes.append(router.prefix)
|
||||
logger.info(f"插件 {plugin.key} 加载了 {len(routers)} 个路由")
|
||||
except PluginLoadError as e:
|
||||
errors.append(f"插件 {plugin.key} 路由加载失败: {e}")
|
||||
logger.error(f"插件 {plugin.key} 路由加载失败: {e}")
|
||||
|
||||
# 加载处理器
|
||||
loaded_processors: List[str] = []
|
||||
if manifest.backend and manifest.backend.processors:
|
||||
try:
|
||||
processor_types = PluginLoader.load_all_processors(plugin.key, manifest)
|
||||
loaded_processors = processor_types
|
||||
logger.info(f"插件 {plugin.key} 注册了 {len(processor_types)} 个处理器")
|
||||
except PluginLoadError as e:
|
||||
errors.append(f"插件 {plugin.key} 处理器加载失败: {e}")
|
||||
logger.error(f"插件 {plugin.key} 处理器加载失败: {e}")
|
||||
|
||||
# 更新数据库记录
|
||||
plugin.loaded_routes = loaded_routes if loaded_routes else None
|
||||
plugin.loaded_processors = loaded_processors if loaded_processors else None
|
||||
await plugin.save()
|
||||
|
||||
loaded_count += 1
|
||||
logger.info(f"插件 {plugin.key} 加载完成")
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"插件 {plugin.key} 加载异常: {e}"
|
||||
errors.append(error_msg)
|
||||
logger.exception(error_msg)
|
||||
|
||||
return loaded_count, errors
|
||||
|
||||
|
||||
async def init_plugins(app: "FastAPI") -> None:
|
||||
"""
|
||||
初始化插件系统
|
||||
|
||||
在应用启动时调用
|
||||
"""
|
||||
logger.info("开始加载已安装插件...")
|
||||
|
||||
loaded_count, errors = await load_installed_plugins(app)
|
||||
|
||||
if errors:
|
||||
logger.warning(f"插件加载完成,共 {loaded_count} 个成功,{len(errors)} 个错误")
|
||||
for error in errors:
|
||||
logger.warning(f" - {error}")
|
||||
else:
|
||||
logger.info(f"插件加载完成,共 {loaded_count} 个插件")
|
||||
143
domain/plugins/types.py
Normal file
143
domain/plugins/types.py
Normal file
@@ -0,0 +1,143 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
# ========== Manifest 相关类型 ==========
|
||||
|
||||
|
||||
class ManifestFrontend(BaseModel):
|
||||
"""manifest.json 中的 frontend 配置"""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True, extra="ignore")
|
||||
|
||||
entry: Optional[str] = Field(default=None, description="前端入口文件路径")
|
||||
styles: Optional[List[str]] = Field(default=None, description="前端样式文件路径列表(相对插件根目录)")
|
||||
open_app: Optional[bool] = Field(
|
||||
default=None,
|
||||
alias="openApp",
|
||||
description="是否支持独立打开",
|
||||
)
|
||||
supported_exts: Optional[List[str]] = Field(
|
||||
default=None,
|
||||
alias="supportedExts",
|
||||
description="支持的文件扩展名列表",
|
||||
)
|
||||
default_bounds: Optional[Dict[str, Any]] = Field(
|
||||
default=None,
|
||||
alias="defaultBounds",
|
||||
description="默认窗口尺寸",
|
||||
)
|
||||
default_maximized: Optional[bool] = Field(
|
||||
default=None,
|
||||
alias="defaultMaximized",
|
||||
description="是否默认最大化",
|
||||
)
|
||||
icon: Optional[str] = Field(default=None, description="图标路径")
|
||||
use_system_window: Optional[bool] = Field(
|
||||
default=None,
|
||||
alias="useSystemWindow",
|
||||
description="是否使用系统窗口",
|
||||
)
|
||||
|
||||
|
||||
class ManifestRouteConfig(BaseModel):
|
||||
"""manifest.json 中的路由配置"""
|
||||
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
|
||||
module: str = Field(..., description="路由模块路径")
|
||||
prefix: str = Field(..., description="路由前缀")
|
||||
tags: Optional[List[str]] = Field(default=None, description="API 标签")
|
||||
|
||||
|
||||
class ManifestProcessorConfig(BaseModel):
|
||||
"""manifest.json 中的处理器配置"""
|
||||
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
|
||||
module: str = Field(..., description="处理器模块路径")
|
||||
type: str = Field(..., description="处理器类型标识")
|
||||
name: Optional[str] = Field(default=None, description="处理器显示名称")
|
||||
|
||||
|
||||
class ManifestBackend(BaseModel):
|
||||
"""manifest.json 中的 backend 配置"""
|
||||
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
|
||||
routes: Optional[List[ManifestRouteConfig]] = Field(default=None, description="路由列表")
|
||||
processors: Optional[List[ManifestProcessorConfig]] = Field(
|
||||
default=None, description="处理器列表"
|
||||
)
|
||||
|
||||
|
||||
class ManifestDependencies(BaseModel):
|
||||
"""manifest.json 中的依赖配置"""
|
||||
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
|
||||
python: Optional[str] = Field(default=None, description="Python 版本要求")
|
||||
packages: Optional[List[str]] = Field(default=None, description="Python 包依赖列表")
|
||||
|
||||
|
||||
class PluginManifest(BaseModel):
|
||||
"""完整的 manifest.json 结构"""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True, extra="ignore")
|
||||
|
||||
foxpkg: str = Field(default="1.0", description="foxpkg 格式版本")
|
||||
key: str = Field(..., min_length=1, description="插件唯一标识")
|
||||
name: str = Field(..., min_length=1, description="插件名称")
|
||||
version: str = Field(default="1.0.0", description="插件版本")
|
||||
description: Optional[str] = Field(default=None, description="插件描述")
|
||||
i18n: Optional[Dict[str, Dict[str, str]]] = Field(
|
||||
default=None,
|
||||
description="多语言信息(name/description),例如:{'en': {'name': '...', 'description': '...'}}",
|
||||
)
|
||||
author: Optional[str] = Field(default=None, description="作者")
|
||||
website: Optional[str] = Field(default=None, description="网站")
|
||||
github: Optional[str] = Field(default=None, description="GitHub 地址")
|
||||
license: Optional[str] = Field(default=None, description="许可证")
|
||||
|
||||
frontend: Optional[ManifestFrontend] = Field(default=None, description="前端配置")
|
||||
backend: Optional[ManifestBackend] = Field(default=None, description="后端配置")
|
||||
dependencies: Optional[ManifestDependencies] = Field(default=None, description="依赖配置")
|
||||
|
||||
|
||||
# ========== API 请求/响应类型 ==========
|
||||
|
||||
|
||||
class PluginOut(BaseModel):
|
||||
"""插件输出模型"""
|
||||
|
||||
id: int
|
||||
key: str
|
||||
open_app: bool = False
|
||||
name: Optional[str] = None
|
||||
version: Optional[str] = None
|
||||
supported_exts: Optional[List[str]] = None
|
||||
default_bounds: Optional[Dict[str, Any]] = None
|
||||
default_maximized: Optional[bool] = None
|
||||
icon: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
author: Optional[str] = None
|
||||
website: Optional[str] = None
|
||||
github: Optional[str] = None
|
||||
license: Optional[str] = None
|
||||
|
||||
# 新增字段
|
||||
manifest: Optional[Dict[str, Any]] = None
|
||||
loaded_routes: Optional[List[str]] = None
|
||||
loaded_processors: Optional[List[str]] = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class PluginInstallResult(BaseModel):
|
||||
"""安装结果"""
|
||||
|
||||
success: bool
|
||||
plugin: Optional[PluginOut] = None
|
||||
message: Optional[str] = None
|
||||
errors: Optional[List[str]] = None
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user