mirror of
https://github.com/snailyp/gemini-balance.git
synced 2026-07-04 06:11:32 +08:00
Compare commits
84 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ebfa1d247c | ||
|
|
cdb85ef9b7 | ||
|
|
7006522c13 | ||
|
|
530c958afc | ||
|
|
57d861b578 | ||
|
|
99664298b9 | ||
|
|
a6fe5a7022 | ||
|
|
1918dad602 | ||
|
|
69399c291e | ||
|
|
9ec33ce320 | ||
|
|
c35d3aff7d | ||
|
|
2a5744d1c4 | ||
|
|
825511506b | ||
|
|
5a98a701cb | ||
|
|
dd1fa35c73 | ||
|
|
fb572fa849 | ||
|
|
c0a473ed19 | ||
|
|
030641adc6 | ||
|
|
445ef49dc8 | ||
|
|
32d4c60541 | ||
|
|
23f865be07 | ||
|
|
5d55325c12 | ||
|
|
900330509a | ||
|
|
cfb682ae3c | ||
|
|
abae90b16d | ||
|
|
470fc37f26 | ||
|
|
7a7caef1a6 | ||
|
|
a6aecb5d89 | ||
|
|
4a004f9aa1 | ||
|
|
1a6feae23b | ||
|
|
af5b2fa2c9 | ||
|
|
eeec45274b | ||
|
|
2b48c853fe | ||
|
|
c47f696691 | ||
|
|
9a8e4c8e15 | ||
|
|
24aab9a658 | ||
|
|
afdaaffac5 | ||
|
|
fe721116e2 | ||
|
|
8e0a834daa | ||
|
|
c9fca1561c | ||
|
|
5eb2dfd822 | ||
|
|
0b837c3f80 | ||
|
|
a6cfc12443 | ||
|
|
f6d64dd850 | ||
|
|
eed62caa78 | ||
|
|
204d41d6f3 | ||
|
|
858df0548e | ||
|
|
b3da021803 | ||
|
|
d234f826f4 | ||
|
|
231b69ecf8 | ||
|
|
0a08913677 | ||
|
|
49d32813ea | ||
|
|
c5d57e97b1 | ||
|
|
da8f7539a1 | ||
|
|
64a68f1176 | ||
|
|
1199d7cc3c | ||
|
|
8a827d2acb | ||
|
|
0e8a943d7f | ||
|
|
4f62658440 | ||
|
|
6e7c3d5f6a | ||
|
|
d5062db9b6 | ||
|
|
a6ad006a49 | ||
|
|
57d593fa17 | ||
|
|
f38b5ae870 | ||
|
|
418b3ca13c | ||
|
|
09bfa85e69 | ||
|
|
62b132208b | ||
|
|
fc28f4f74e | ||
|
|
f79a52f839 | ||
|
|
94d1041961 | ||
|
|
ada32d526a | ||
|
|
ef1e38aba1 | ||
|
|
60b2d59e25 | ||
|
|
e18aa73456 | ||
|
|
24747a5f09 | ||
|
|
621dac22dc | ||
|
|
23d7004b60 | ||
|
|
c3b3d34127 | ||
|
|
18a166afb0 | ||
|
|
a41447a96d | ||
|
|
df8d543539 | ||
|
|
5ecce8e0fe | ||
|
|
00f423a622 | ||
|
|
05ce04de69 |
17
.env.example
17
.env.example
@@ -20,6 +20,9 @@ THINKING_BUDGET_MAP={"gemini-2.5-flash-preview-04-17": 4000}
|
||||
IMAGE_MODELS=["gemini-2.0-flash-exp"]
|
||||
SEARCH_MODELS=["gemini-2.0-flash-exp","gemini-2.0-pro-exp"]
|
||||
FILTERED_MODELS=["gemini-1.0-pro-vision-latest", "gemini-pro-vision", "chat-bison-001", "text-bison-001", "embedding-gecko-001"]
|
||||
# 是否启用网址上下文,默认启用
|
||||
URL_CONTEXT_ENABLED=true
|
||||
URL_CONTEXT_MODELS=["gemini-2.5-pro","gemini-2.5-flash","gemini-2.5-flash-lite","gemini-2.0-flash","gemini-2.0-flash-live-001"]
|
||||
TOOLS_CODE_EXECUTION_ENABLED=false
|
||||
SHOW_SEARCH_LINK=true
|
||||
SHOW_THINKING_PROCESS=true
|
||||
@@ -43,6 +46,7 @@ SMMS_SECRET_TOKEN=XXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
|
||||
PICGO_API_KEY=xxxx
|
||||
CLOUDFLARE_IMGBED_URL=https://xxxxxxx.pages.dev/upload
|
||||
CLOUDFLARE_IMGBED_AUTH_CODE=xxxxxxxxx
|
||||
CLOUDFLARE_IMGBED_UPLOAD_FOLDER=
|
||||
##########################################################################
|
||||
#########################stream_optimizer 相关配置########################
|
||||
STREAM_OPTIMIZER_ENABLED=false
|
||||
@@ -74,3 +78,16 @@ FAKE_STREAM_EMPTY_DATA_INTERVAL_SECONDS=5
|
||||
# 安全设置 (JSON 字符串格式)
|
||||
# 注意:这里的示例值可能需要根据实际模型支持情况调整
|
||||
SAFETY_SETTINGS=[{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"}, {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"}, {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"}, {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"}, {"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"}]
|
||||
URL_NORMALIZATION_ENABLED=false
|
||||
# tts配置
|
||||
TTS_MODEL=gemini-2.5-flash-preview-tts
|
||||
TTS_VOICE_NAME=Zephyr
|
||||
TTS_SPEED=normal
|
||||
#########################Files API 相关配置########################
|
||||
# 是否启用文件过期自动清理
|
||||
FILES_CLEANUP_ENABLED=true
|
||||
# 文件过期清理间隔(小时)
|
||||
FILES_CLEANUP_INTERVAL_HOURS=1
|
||||
# 是否启用用户文件隔离(每个用户只能看到自己上传的文件)
|
||||
FILES_USER_ISOLATION_ENABLED=true
|
||||
##########################################################################
|
||||
22
.github/workflows/release.yml
vendored
22
.github/workflows/release.yml
vendored
@@ -3,7 +3,7 @@ name: Publish Release
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v*' # 当推送以 "v" 开头的标签时触发(如 v1.0.0, v2.1.0)
|
||||
- "v*" # 当推送以 "v" 开头的标签时触发(如 v1.0.0, v2.1.0)
|
||||
|
||||
jobs:
|
||||
update-release-draft:
|
||||
@@ -15,8 +15,17 @@ jobs:
|
||||
# Step 1: 检出代码库
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
# Step 2: 自动生成 Release
|
||||
# Step 2: 自动生成 Release Notes
|
||||
- name: Generate release notes
|
||||
id: changelog
|
||||
uses: mikepenz/release-changelog-builder-action@v4
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
# Step 3: 自动生成 Release
|
||||
- name: Create Release
|
||||
id: create_release
|
||||
uses: actions/create-release@v1
|
||||
@@ -25,15 +34,16 @@ jobs:
|
||||
with:
|
||||
tag_name: ${{ github.ref_name }}
|
||||
release_name: ${{ github.ref_name }}
|
||||
body: ${{ steps.changelog.outputs.changelog }}
|
||||
draft: false
|
||||
prerelease: false
|
||||
|
||||
# Step 3: 可选,构建zip文件
|
||||
|
||||
# Step 4: 可选,构建zip文件
|
||||
- name: Create ZIP file
|
||||
run: |
|
||||
zip -r gemini-balance.zip . -x "*.git*" "*.github*" "*.env*" "logs/*" "tests/*"
|
||||
|
||||
# Step 4: 可选,上传构建文件
|
||||
# Step 5: 可选,上传构建文件
|
||||
- name: Upload Release Asset
|
||||
uses: actions/upload-release-asset@v1
|
||||
env:
|
||||
@@ -41,5 +51,5 @@ jobs:
|
||||
with:
|
||||
upload_url: ${{ steps.create_release.outputs.upload_url }}
|
||||
asset_path: ./gemini-balance.zip # 替换为你的构建文件路径
|
||||
asset_name: gemini-balance.zip # 替换为你的文件名
|
||||
asset_name: gemini-balance.zip # 替换为你的文件名
|
||||
asset_content_type: application/zip
|
||||
|
||||
@@ -8,12 +8,6 @@ COPY ./VERSION /app
|
||||
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
COPY ./app /app/app
|
||||
ENV API_KEYS='["your_api_key_1"]'
|
||||
ENV ALLOWED_TOKENS='["your_token_1"]'
|
||||
ENV BASE_URL=https://generativelanguage.googleapis.com/v1beta
|
||||
ENV TOOLS_CODE_EXECUTION_ENABLED=false
|
||||
ENV IMAGE_MODELS='["gemini-2.0-flash-exp"]'
|
||||
ENV SEARCH_MODELS='["gemini-2.0-flash-exp","gemini-2.0-pro-exp"]'
|
||||
|
||||
# Expose port
|
||||
EXPOSE 8000
|
||||
|
||||
136
README.md
136
README.md
@@ -2,6 +2,12 @@
|
||||
|
||||
# Gemini Balance - Gemini API Proxy and Load Balancer
|
||||
|
||||
<p align="center">
|
||||
<a href="https://trendshift.io/repositories/13692" target="_blank">
|
||||
<img src="https://trendshift.io/api/badge/repositories/13692" alt="snailyp%2Fgemini-balance | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/>
|
||||
</a>
|
||||
</p>
|
||||
|
||||
> ⚠️ This project is licensed under the CC BY-NC 4.0 (Attribution-NonCommercial) license. Any form of commercial resale service is prohibited. See the LICENSE file for details.
|
||||
|
||||
> I have never sold this service on any platform. If you encounter someone selling this service, they are definitely a reseller. Please be careful not to be deceived.
|
||||
@@ -11,7 +17,7 @@
|
||||
[](https://www.uvicorn.org/)
|
||||
[](https://t.me/+soaHax5lyI0wZDVl)
|
||||
|
||||
> Telegram Group: https://t.me/+soaHax5lyI0wZDVl
|
||||
> Telegram Group: <https://t.me/+soaHax5lyI0wZDVl>
|
||||
|
||||
## Project Introduction
|
||||
|
||||
@@ -40,39 +46,39 @@ app/
|
||||
|
||||
## ✨ Feature Highlights
|
||||
|
||||
* **Multi-Key Load Balancing**: Supports configuring multiple Gemini API Keys (`API_KEYS`) for automatic sequential polling, improving availability and concurrency.
|
||||
* **Visual Configuration Takes Effect Immediately**: Configurations modified through the admin backend take effect without restarting the service. Remember to click save for changes to apply.
|
||||
* **Multi-Key Load Balancing**: Supports configuring multiple Gemini API Keys (`API_KEYS`) for automatic sequential polling, improving availability and concurrency.
|
||||
* **Visual Configuration Takes Effect Immediately**: Configurations modified through the admin backend take effect without restarting the service. Remember to click save for changes to apply.
|
||||

|
||||
* **Dual Protocol API Compatibility**: Supports forwarding CHAT API requests in both Gemini and OpenAI formats.
|
||||
* **Dual Protocol API Compatibility**: Supports forwarding CHAT API requests in both Gemini and OpenAI formats.
|
||||
|
||||
```plaintext
|
||||
openai baseurl `http://localhost:8000(/hf)/v1`
|
||||
gemini baseurl `http://localhost:8000(/gemini)/v1beta`
|
||||
```
|
||||
|
||||
* **Supports Image-Text Chat and Image Modification**: `IMAGE_MODELS` configures which models can perform image-text chat and image editing. When actually calling, use the `configured_model-image` model name to use this feature.
|
||||
* **Supports Image-Text Chat and Image Modification**: `IMAGE_MODELS` configures which models can perform image-text chat and image editing. When actually calling, use the `configured_model-image` model name to use this feature.
|
||||

|
||||

|
||||
* **Supports Web Search**: Supports web search. `SEARCH_MODELS` configures which models can perform web searches. When actually calling, use the `configured_model-search` model name to use this feature.
|
||||
* **Supports Web Search**: Supports web search. `SEARCH_MODELS` configures which models can perform web searches. When actually calling, use the `configured_model-search` model name to use this feature.
|
||||

|
||||
* **Key Status Monitoring**: Provides a `/keys_status` page (requires authentication) to view the status and usage of each Key in real-time.
|
||||
* **Key Status Monitoring**: Provides a `/keys_status` page (requires authentication) to view the status and usage of each Key in real-time.
|
||||

|
||||
* **Detailed Logging**: Provides detailed error logs for easy troubleshooting.
|
||||
* **Detailed Logging**: Provides detailed error logs for easy troubleshooting.
|
||||

|
||||

|
||||

|
||||
* **Support for Custom Gemini Proxy**: Supports custom Gemini proxies, such as those built on Deno or Cloudflare.
|
||||
* **OpenAI Image Generation API Compatibility**: Adapts the `imagen-3.0-generate-002` model interface to be compatible with the OpenAI image generation API, supporting client calls.
|
||||
* **Flexible Key Addition**: Flexible way to add keys using regex matching for `gemini_key`, with key deduplication.
|
||||
* **Support for Custom Gemini Proxy**: Supports custom Gemini proxies, such as those built on Deno or Cloudflare.
|
||||
* **OpenAI Image Generation API Compatibility**: Adapts the `imagen-3.0-generate-002` model interface to be compatible with the OpenAI image generation API, supporting client calls.
|
||||
* **Flexible Key Addition**: Flexible way to add keys using regex matching for `gemini_key`, with key deduplication.
|
||||

|
||||
* **OpenAI Format Embeddings API Compatibility**: Perfectly adapts to the OpenAI format `embeddings` interface, usable for local document vectorization.
|
||||
* **Streamlined Response Optimization**: Optional stream output optimizer (`STREAM_OPTIMIZER_ENABLED`) to improve the experience of long-text stream responses.
|
||||
* **Failure Retry and Key Management**: Automatically handles API request failures, retries (`MAX_RETRIES`), automatically disables Keys after too many failures (`MAX_FAILURES`), and periodically checks for recovery (`CHECK_INTERVAL_HOURS`).
|
||||
* **Docker Support**: Supports AMD and ARM architecture Docker deployments. You can also build your own Docker image.
|
||||
* **OpenAI Format Embeddings API Compatibility**: Perfectly adapts to the OpenAI format `embeddings` interface, usable for local document vectorization.
|
||||
* **Streamlined Response Optimization**: Optional stream output optimizer (`STREAM_OPTIMIZER_ENABLED`) to improve the experience of long-text stream responses.
|
||||
* **Failure Retry and Key Management**: Automatically handles API request failures, retries (`MAX_RETRIES`), automatically disables Keys after too many failures (`MAX_FAILURES`), and periodically checks for recovery (`CHECK_INTERVAL_HOURS`).
|
||||
* **Docker Support**: Supports AMD and ARM architecture Docker deployments. You can also build your own Docker image.
|
||||
> Image address: docker pull ghcr.io/snailyp/gemini-balance:latest
|
||||
* **Automatic Model List Maintenance**: Supports fetching OpenAI and Gemini model lists, perfectly compatible with NewAPI's automatic model list fetching, no manual entry required.
|
||||
* **Support for Removing Unused Models**: Too many default models are provided, many of which are not used. You can filter them out using `FILTERED_MODELS`.
|
||||
* **Proxy Support**: Supports configuring HTTP/SOCKS5 proxy servers (`PROXIES`) for accessing the Gemini API, convenient for use in special network environments. Supports batch adding proxies.
|
||||
* **Automatic Model List Maintenance**: Supports fetching OpenAI and Gemini model lists, perfectly compatible with NewAPI's automatic model list fetching, no manual entry required.
|
||||
* **Support for Removing Unused Models**: Too many default models are provided, many of which are not used. You can filter them out using `FILTERED_MODELS`.
|
||||
* **Proxy Support**: Supports configuring HTTP/SOCKS5 proxy servers (`PROXIES`) for accessing the Gemini API, convenient for use in special network environments. Supports batch adding proxies.
|
||||
|
||||
## 🚀 Quick Start
|
||||
|
||||
@@ -80,79 +86,83 @@ app/
|
||||
|
||||
#### a) Build with Dockerfile
|
||||
|
||||
1. **Build Image**:
|
||||
1. **Build Image**:
|
||||
|
||||
```bash
|
||||
docker build -t gemini-balance .
|
||||
```
|
||||
|
||||
2. **Run Container**:
|
||||
2. **Run Container**:
|
||||
|
||||
```bash
|
||||
docker run -d -p 8000:8000 --env-file .env gemini-balance
|
||||
```
|
||||
|
||||
* `-d`: Run in detached mode.
|
||||
* `-p 8000:8000`: Map port 8000 of the container to port 8000 of the host.
|
||||
* `--env-file .env`: Use the `.env` file to set environment variables.
|
||||
* `-d`: Run in detached mode.
|
||||
* `-p 8000:8000`: Map port 8000 of the container to port 8000 of the host.
|
||||
* `--env-file .env`: Use the `.env` file to set environment variables.
|
||||
|
||||
> Note: If using an SQLite database, you need to mount a data volume to persist
|
||||
> Note: If using an SQLite database, you need to mount a data volume to persist
|
||||
>
|
||||
> ```bash
|
||||
> docker run -d -p 8000:8000 --env-file .env -v /path/to/data:/app/data gemini-balance
|
||||
> ```
|
||||
>
|
||||
> Where `/path/to/data` is the data storage path on the host, and `/app/data` is the data directory inside the container.
|
||||
|
||||
#### b) Deploy with an Existing Docker Image
|
||||
|
||||
1. **Pull Image**:
|
||||
1. **Pull Image**:
|
||||
|
||||
```bash
|
||||
docker pull ghcr.io/snailyp/gemini-balance:latest
|
||||
```
|
||||
|
||||
2. **Run Container**:
|
||||
2. **Run Container**:
|
||||
|
||||
```bash
|
||||
docker run -d -p 8000:8000 --env-file .env ghcr.io/snailyp/gemini-balance:latest
|
||||
```
|
||||
|
||||
* `-d`: Run in detached mode.
|
||||
* `-p 8000:8000`: Map port 8000 of the container to port 8000 of the host (adjust as needed).
|
||||
* `--env-file .env`: Use the `.env` file to set environment variables (ensure the `.env` file exists in the directory where the command is executed).
|
||||
* `-d`: Run in detached mode.
|
||||
* `-p 8000:8000`: Map port 8000 of the container to port 8000 of the host (adjust as needed).
|
||||
* `--env-file .env`: Use the `.env` file to set environment variables (ensure the `.env` file exists in the directory where the command is executed).
|
||||
|
||||
> Note: If using an SQLite database, you need to mount a data volume to persist
|
||||
> Note: If using an SQLite database, you need to mount a data volume to persist
|
||||
>
|
||||
> ```bash
|
||||
> docker run -d -p 8000:8000 --env-file .env -v /path/to/data:/app/data ghcr.io/snailyp/gemini-balance:latest
|
||||
> ```
|
||||
>
|
||||
> Where `/path/to/data` is the data storage path on the host, and `/app/data` is the data directory inside the container.
|
||||
|
||||
### Run Locally (Suitable for Development and Testing)
|
||||
|
||||
If you want to run the source code directly locally for development or testing, follow these steps:
|
||||
|
||||
1. **Ensure Prerequisites are Met**:
|
||||
* Clone the repository locally.
|
||||
* Install Python 3.9 or higher.
|
||||
* Create and configure the `.env` file in the project root directory (refer to the "Configure Environment Variables" section above).
|
||||
* Install project dependencies:
|
||||
1. **Ensure Prerequisites are Met**:
|
||||
* Clone the repository locally.
|
||||
* Install Python 3.9 or higher.
|
||||
* Create and configure the `.env` file in the project root directory (refer to the "Configure Environment Variables" section above).
|
||||
* Install project dependencies:
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
2. **Start Application**:
|
||||
2. **Start Application**:
|
||||
Run the following command in the project root directory:
|
||||
|
||||
```bash
|
||||
uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload
|
||||
```
|
||||
|
||||
* `app.main:app`: Specifies the location of the FastAPI application instance (the `app` object in the `main.py` file within the `app` module).
|
||||
* `--host 0.0.0.0`: Makes the application accessible from any IP address on the local network.
|
||||
* `--port 8000`: Specifies the port number the application listens on (you can change this as needed).
|
||||
* `--reload`: Enables automatic reloading. When you modify the code, the service will automatically restart, which is very suitable for development environments (remove this option in production environments).
|
||||
* `app.main:app`: Specifies the location of the FastAPI application instance (the `app` object in the `main.py` file within the `app` module).
|
||||
* `--host 0.0.0.0`: Makes the application accessible from any IP address on the local network.
|
||||
* `--port 8000`: Specifies the port number the application listens on (you can change this as needed).
|
||||
* `--reload`: Enables automatic reloading. When you modify the code, the service will automatically restart, which is very suitable for development environments (remove this option in production environments).
|
||||
|
||||
3. **Access Application**:
|
||||
3. **Access Application**:
|
||||
After the application starts, you can access `http://localhost:8000` (or the host and port you specified) through a browser or API tool.
|
||||
|
||||
### Complete Configuration List
|
||||
@@ -181,6 +191,9 @@ If you want to run the source code directly locally for development or testing,
|
||||
| `SHOW_THINKING_PROCESS` | Optional, whether to display the model's thinking process | `true` |
|
||||
| `THINKING_MODELS` | Optional, list of models that support thinking functions | `[]` |
|
||||
| `THINKING_BUDGET_MAP` | Optional, thinking function budget mapping (model_name:budget_value) | `{}` |
|
||||
| `URL_NORMALIZATION_ENABLED` | Optional, whether to enable intelligent URL routing mapping | `false` |
|
||||
| `URL_CONTEXT_ENABLED` | Optional, whether to enable URL context understanding | `false` |
|
||||
| `URL_CONTEXT_MODELS` | Optional, list of models that support URL context understanding | `[]` |
|
||||
| `BASE_URL` | Optional, Gemini API base URL, no modification needed by default | `https://generativelanguage.googleapis.com/v1beta` |
|
||||
| `MAX_FAILURES` | Optional, number of times a single key is allowed to fail | `3` |
|
||||
| `MAX_RETRIES` | Optional, maximum number of retries for failed API requests | `3` |
|
||||
@@ -194,6 +207,10 @@ If you want to run the source code directly locally for development or testing,
|
||||
| `AUTO_DELETE_REQUEST_LOGS_ENABLED`| Optional, whether to enable automatic deletion of request logs | `false` |
|
||||
| `AUTO_DELETE_REQUEST_LOGS_DAYS` | Optional, automatically delete request logs older than this many days (e.g., 1, 7, 30) | `30` |
|
||||
| `SAFETY_SETTINGS` | Optional, safety settings (JSON string format), used to configure content safety thresholds. Example values may need adjustment based on actual model support. | `[{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"}, {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"}, {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"}, {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"}, {"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"}]` |
|
||||
| **TTS Related** | | |
|
||||
| `TTS_MODEL` | Optional, TTS model name | `gemini-2.5-flash-preview-tts` |
|
||||
| `TTS_VOICE_NAME` | Optional, TTS voice name | `Zephyr` |
|
||||
| `TTS_SPEED` | Optional, TTS speed | `normal` |
|
||||
| **Image Generation Related** | | |
|
||||
| `PAID_KEY` | Optional, paid API Key for advanced features like image generation | `your-paid-api-key` |
|
||||
| `CREATE_IMAGE_MODEL` | Optional, image generation model | `imagen-3.0-generate-002` |
|
||||
@@ -202,6 +219,7 @@ If you want to run the source code directly locally for development or testing,
|
||||
| `PICGO_API_KEY` | Optional, API Key for [PicoGo](https://www.picgo.net/) image hosting | `your-picogo-apikey` |
|
||||
| `CLOUDFLARE_IMGBED_URL` | Optional, [CloudFlare](https://github.com/MarSeventh/CloudFlare-ImgBed) image hosting upload address | `https://xxxxxxx.pages.dev/upload` |
|
||||
| `CLOUDFLARE_IMGBED_AUTH_CODE` | Optional, authentication key for CloudFlare image hosting | `your-cloudflare-imgber-auth-code` |
|
||||
| `CLOUDFLARE_IMGBED_UPLOAD_FOLDER` | Optional, upload folder path for CloudFlare image hosting | `""` |
|
||||
| **Stream Optimizer Related** | | |
|
||||
| `STREAM_OPTIMIZER_ENABLED` | Optional, whether to enable stream output optimization | `false` |
|
||||
| `STREAM_MIN_DELAY` | Optional, minimum delay for stream output | `0.016` |
|
||||
@@ -219,20 +237,20 @@ The following are the main API endpoints provided by the service:
|
||||
|
||||
### Gemini API Related (`(/gemini)/v1beta`)
|
||||
|
||||
* `GET /models`: List available Gemini models.
|
||||
* `POST /models/{model_name}:generateContent`: Generate content using the specified Gemini model.
|
||||
* `POST /models/{model_name}:streamGenerateContent`: Stream content generation using the specified Gemini model.
|
||||
* `GET /models`: List available Gemini models.
|
||||
* `POST /models/{model_name}:generateContent`: Generate content using the specified Gemini model.
|
||||
* `POST /models/{model_name}:streamGenerateContent`: Stream content generation using the specified Gemini model.
|
||||
|
||||
### OpenAI API Related
|
||||
|
||||
* `GET (/hf)/v1/models`: List available models (uses Gemini format underneath).
|
||||
* `POST (/hf)/v1/chat/completions`: Perform chat completion (uses Gemini format underneath, supports streaming).
|
||||
* `POST (/hf)/v1/embeddings`: Create text embeddings (uses Gemini format underneath).
|
||||
* `POST (/hf)/v1/images/generations`: Generate images (uses Gemini format underneath).
|
||||
* `GET /openai/v1/models`: List available models (uses OpenAI format underneath).
|
||||
* `POST /openai/v1/chat/completions`: Perform chat completion (uses OpenAI format underneath, supports streaming, can prevent truncation, and is faster).
|
||||
* `POST /openai/v1/embeddings`: Create text embeddings (uses OpenAI format underneath).
|
||||
* `POST /openai/v1/images/generations`: Generate images (uses OpenAI format underneath).
|
||||
* `GET (/hf)/v1/models`: List available models (uses Gemini format underneath).
|
||||
* `POST (/hf)/v1/chat/completions`: Perform chat completion (uses Gemini format underneath, supports streaming).
|
||||
* `POST (/hf)/v1/embeddings`: Create text embeddings (uses Gemini format underneath).
|
||||
* `POST (/hf)/v1/images/generations`: Generate images (uses Gemini format underneath).
|
||||
* `GET /openai/v1/models`: List available models (uses OpenAI format underneath).
|
||||
* `POST /openai/v1/chat/completions`: Perform chat completion (uses OpenAI format underneath, supports streaming, can prevent truncation, and is faster).
|
||||
* `POST /openai/v1/embeddings`: Create text embeddings (uses OpenAI format underneath).
|
||||
* `POST /openai/v1/images/generations`: Generate images (uses OpenAI format underneath).
|
||||
|
||||
## 🤝 Contributing
|
||||
|
||||
@@ -242,9 +260,9 @@ Pull Requests or Issues are welcome.
|
||||
|
||||
Special thanks to the following projects and platforms for providing image hosting services for this project:
|
||||
|
||||
* [PicGo](https://www.picgo.net/)
|
||||
* [SM.MS](https://smms.app/)
|
||||
* [CloudFlare-ImgBed](https://github.com/MarSeventh/CloudFlare-ImgBed) open source project
|
||||
* [PicGo](https://www.picgo.net/)
|
||||
* [SM.MS](https://smms.app/)
|
||||
* [CloudFlare-ImgBed](https://github.com/MarSeventh/CloudFlare-ImgBed) open source project
|
||||
|
||||
## 🙏 Thanks to Contributors
|
||||
|
||||
@@ -254,11 +272,11 @@ Thanks to all developers who contributed to this project!
|
||||
|
||||
## Thanks to Our Supporters
|
||||
|
||||
We extend our heartfelt gratitude to the following supporters for their invaluable contributions to this project:
|
||||
|
||||
A special shout-out to DigitalOcean for providing the rock-solid and dependable cloud infrastructure that keeps this project humming!
|
||||
[](https://m.do.co/c/b249dd7f3b4c)
|
||||
|
||||
A special shout-out to DigitalOcean for providing the rock-solid and dependable cloud infrastructure that keeps this project humming!
|
||||
CDN acceleration and security protection for this project are sponsored by Tencent EdgeOne.
|
||||
[](https://edgeone.ai/?from=github)
|
||||
|
||||
## ⭐ Star History
|
||||
|
||||
@@ -266,7 +284,7 @@ A special shout-out to DigitalOcean for providing the rock-solid and dependable
|
||||
|
||||
## 💖 Friendly Projects
|
||||
|
||||
* **[OneLine](https://github.com/chengtx809/OneLine)** by [chengtx809](https://github.com/chengtx809) - OneLine: AI-driven hot event timeline generation tool
|
||||
* **[OneLine](https://github.com/chengtx809/OneLine)** by [chengtx809](https://github.com/chengtx809) - OneLine: AI-driven hot event timeline generation tool
|
||||
|
||||
## 🎁 Project Support
|
||||
|
||||
|
||||
14
README_ZH.md
14
README_ZH.md
@@ -1,5 +1,11 @@
|
||||
# Gemini Balance - Gemini API 代理和负载均衡器
|
||||
|
||||
<p align="center">
|
||||
<a href="https://trendshift.io/repositories/13692" target="_blank">
|
||||
<img src="https://trendshift.io/api/badge/repositories/13692" alt="snailyp%2Fgemini-balance | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/>
|
||||
</a>
|
||||
</p>
|
||||
|
||||
> ⚠️ 本项目采用 CC BY-NC 4.0(署名-非商业性使用)协议,禁止任何形式的商业倒卖服务,详见 LICENSE 文件。
|
||||
|
||||
> 本人从未在各个平台售卖服务,如有遇到售卖此服务者,那一定是倒卖狗,大家切记不要上当受骗。
|
||||
@@ -178,6 +184,9 @@ app/
|
||||
| `SHOW_THINKING_PROCESS` | 可选,是否显示模型思考过程 | `true` |
|
||||
| `THINKING_MODELS` | 可选,支持思考功能的模型列表 | `[]` |
|
||||
| `THINKING_BUDGET_MAP` | 可选,思考功能预算映射 (模型名:预算值) | `{}` |
|
||||
| `URL_NORMALIZATION_ENABLED` | 可选,是否启用智能路由映射功能 | `false` |
|
||||
| `URL_CONTEXT_ENABLED` | 可选,是否启用URL上下文理解功能 | `false` |
|
||||
| `URL_CONTEXT_MODELS` | 可选,支持URL上下文理解功能的模型列表 | `[]` |
|
||||
| `BASE_URL` | 可选,Gemini API 基础 URL,默认无需修改 | `https://generativelanguage.googleapis.com/v1beta` |
|
||||
| `MAX_FAILURES` | 可选,允许单个key失败的次数 | `3` |
|
||||
| `MAX_RETRIES` | 可选,API 请求失败时的最大重试次数 | `3` |
|
||||
@@ -191,6 +200,10 @@ app/
|
||||
| `AUTO_DELETE_REQUEST_LOGS_ENABLED`| 可选,是否开启自动删除请求日志 | `false` |
|
||||
| `AUTO_DELETE_REQUEST_LOGS_DAYS` | 可选,自动删除多少天前的请求日志 (例如 1, 7, 30) | `30` |
|
||||
| `SAFETY_SETTINGS` | 可选,安全设置 (JSON 字符串格式),用于配置内容安全阈值。示例值可能需要根据实际模型支持情况调整。 | `[{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"}, {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"}, {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"}, {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"}, {"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"}]` |
|
||||
| **TTS 相关** | | |
|
||||
| `TTS_MODEL` | 可选,TTS 模型名称 | `gemini-2.5-flash-preview-tts` |
|
||||
| `TTS_VOICE_NAME` | 可选,TTS 语音名称 | `Zephyr` |
|
||||
| `TTS_SPEED` | 可选,TTS 语速 | `normal` |
|
||||
| **图像生成相关** | | |
|
||||
| `PAID_KEY` | 可选,付费版API Key,用于图片生成等高级功能 | `your-paid-api-key` |
|
||||
| `CREATE_IMAGE_MODEL` | 可选,图片生成模型 | `imagen-3.0-generate-002` |
|
||||
@@ -199,6 +212,7 @@ app/
|
||||
| `PICGO_API_KEY` | 可选,[PicoGo](https://www.picgo.net/)图床的API Key | `your-picogo-apikey` |
|
||||
| `CLOUDFLARE_IMGBED_URL` | 可选,[CloudFlare](https://github.com/MarSeventh/CloudFlare-ImgBed) 图床上传地址 | `https://xxxxxxx.pages.dev/upload` |
|
||||
| `CLOUDFLARE_IMGBED_AUTH_CODE`| 可选,CloudFlare图床的鉴权key | `your-cloudflare-imgber-auth-code` |
|
||||
| `CLOUDFLARE_IMGBED_UPLOAD_FOLDER`| 可选,CloudFlare图床的上传文件夹路径 | `""` |
|
||||
| **流式优化器相关** | | |
|
||||
| `STREAM_OPTIMIZER_ENABLED` | 可选,是否启用流式输出优化 | `false` |
|
||||
| `STREAM_MIN_DELAY` | 可选,流式输出最小延迟 | `0.016` |
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
|
||||
import datetime
|
||||
import json
|
||||
from typing import Any, Dict, List, Type
|
||||
from typing import Any, Dict, List, Type, get_args, get_origin
|
||||
|
||||
from pydantic import ValidationError, ValidationInfo, field_validator
|
||||
from pydantic_settings import BaseSettings
|
||||
@@ -51,8 +51,8 @@ class Settings(BaseSettings):
|
||||
return v
|
||||
|
||||
# API相关配置
|
||||
API_KEYS: List[str]
|
||||
ALLOWED_TOKENS: List[str]
|
||||
API_KEYS: List[str]=[]
|
||||
ALLOWED_TOKENS: List[str]=[]
|
||||
BASE_URL: str = f"https://generativelanguage.googleapis.com/{API_VERSION}"
|
||||
AUTH_TOKEN: str = ""
|
||||
MAX_FAILURES: int = 3
|
||||
@@ -63,17 +63,31 @@ class Settings(BaseSettings):
|
||||
PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY: bool = True # 是否使用一致性哈希来选择代理
|
||||
VERTEX_API_KEYS: List[str] = []
|
||||
VERTEX_EXPRESS_BASE_URL: str = "https://aiplatform.googleapis.com/v1beta1/publishers/google"
|
||||
|
||||
|
||||
# 智能路由配置
|
||||
URL_NORMALIZATION_ENABLED: bool = False # 是否启用智能路由映射功能
|
||||
|
||||
# 自定义 Headers
|
||||
CUSTOM_HEADERS: Dict[str, str] = {}
|
||||
|
||||
# 模型相关配置
|
||||
SEARCH_MODELS: List[str] = ["gemini-2.0-flash-exp"]
|
||||
IMAGE_MODELS: List[str] = ["gemini-2.0-flash-exp"]
|
||||
FILTERED_MODELS: List[str] = DEFAULT_FILTER_MODELS
|
||||
TOOLS_CODE_EXECUTION_ENABLED: bool = False
|
||||
# 是否启用网址上下文
|
||||
URL_CONTEXT_ENABLED: bool = True
|
||||
URL_CONTEXT_MODELS: List[str] = ["gemini-2.5-pro","gemini-2.5-flash","gemini-2.5-flash-lite","gemini-2.0-flash","gemini-2.0-flash-live-001"]
|
||||
SHOW_SEARCH_LINK: bool = True
|
||||
SHOW_THINKING_PROCESS: bool = True
|
||||
THINKING_MODELS: List[str] = []
|
||||
THINKING_BUDGET_MAP: Dict[str, float] = {}
|
||||
|
||||
# TTS相关配置
|
||||
TTS_MODEL: str = "gemini-2.5-flash-preview-tts"
|
||||
TTS_VOICE_NAME: str = "Zephyr"
|
||||
TTS_SPEED: str = "normal"
|
||||
|
||||
# 图像生成相关配置
|
||||
PAID_KEY: str = ""
|
||||
CREATE_IMAGE_MODEL: str = DEFAULT_CREATE_IMAGE_MODEL
|
||||
@@ -82,6 +96,7 @@ class Settings(BaseSettings):
|
||||
PICGO_API_KEY: str = ""
|
||||
CLOUDFLARE_IMGBED_URL: str = ""
|
||||
CLOUDFLARE_IMGBED_AUTH_CODE: str = ""
|
||||
CLOUDFLARE_IMGBED_UPLOAD_FOLDER: str = ""
|
||||
|
||||
# 流式输出优化器配置
|
||||
STREAM_OPTIMIZER_ENABLED: bool = False
|
||||
@@ -111,6 +126,11 @@ class Settings(BaseSettings):
|
||||
AUTO_DELETE_REQUEST_LOGS_DAYS: int = 30
|
||||
SAFETY_SETTINGS: List[Dict[str, str]] = DEFAULT_SAFETY_SETTINGS
|
||||
|
||||
# Files API
|
||||
FILES_CLEANUP_ENABLED: bool = True
|
||||
FILES_CLEANUP_INTERVAL_HOURS: int = 1
|
||||
FILES_USER_ISOLATION_ENABLED: bool = True
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
# 设置默认AUTH_TOKEN(如果未提供)
|
||||
@@ -128,86 +148,106 @@ def _parse_db_value(key: str, db_value: str, target_type: Type) -> Any:
|
||||
|
||||
logger = get_config_logger()
|
||||
try:
|
||||
# 处理 List[str]
|
||||
if target_type == List[str]:
|
||||
try:
|
||||
parsed = json.loads(db_value)
|
||||
if isinstance(parsed, list):
|
||||
return [str(item) for item in parsed]
|
||||
except json.JSONDecodeError:
|
||||
origin_type = get_origin(target_type)
|
||||
args = get_args(target_type)
|
||||
|
||||
# 处理 List 类型
|
||||
if origin_type is list:
|
||||
# 处理 List[str]
|
||||
if args and args[0] == str:
|
||||
try:
|
||||
parsed = json.loads(db_value)
|
||||
if isinstance(parsed, list):
|
||||
return [str(item) for item in parsed]
|
||||
except json.JSONDecodeError:
|
||||
return [item.strip() for item in db_value.split(",") if item.strip()]
|
||||
logger.warning(
|
||||
f"Could not parse '{db_value}' as List[str] for key '{key}', falling back to comma split or empty list."
|
||||
)
|
||||
return [item.strip() for item in db_value.split(",") if item.strip()]
|
||||
logger.warning(
|
||||
f"Could not parse '{db_value}' as List[str] for key '{key}', falling back to comma split or empty list."
|
||||
)
|
||||
return [item.strip() for item in db_value.split(",") if item.strip()]
|
||||
# 处理 Dict[str, float]
|
||||
elif target_type == Dict[str, float]:
|
||||
parsed_dict = {}
|
||||
try:
|
||||
parsed = json.loads(db_value)
|
||||
if isinstance(parsed, dict):
|
||||
parsed_dict = {str(k): float(v) for k, v in parsed.items()}
|
||||
else:
|
||||
logger.warning(
|
||||
f"Parsed DB value for key '{key}' is not a dictionary type. Value: {db_value}"
|
||||
)
|
||||
except (json.JSONDecodeError, ValueError, TypeError) as e1:
|
||||
if isinstance(e1, json.JSONDecodeError) and "'" in db_value:
|
||||
logger.warning(
|
||||
f"Failed initial JSON parse for key '{key}'. Attempting to replace single quotes. Error: {e1}"
|
||||
)
|
||||
try:
|
||||
corrected_db_value = db_value.replace("'", '"')
|
||||
parsed = json.loads(corrected_db_value)
|
||||
if isinstance(parsed, dict):
|
||||
parsed_dict = {str(k): float(v) for k, v in parsed.items()}
|
||||
# 处理 List[Dict[str, str]]
|
||||
elif args and get_origin(args[0]) is dict:
|
||||
try:
|
||||
parsed = json.loads(db_value)
|
||||
if isinstance(parsed, list):
|
||||
valid = all(
|
||||
isinstance(item, dict)
|
||||
and all(isinstance(k, str) for k in item.keys())
|
||||
and all(isinstance(v, str) for v in item.values())
|
||||
for item in parsed
|
||||
)
|
||||
if valid:
|
||||
return parsed
|
||||
else:
|
||||
logger.warning(
|
||||
f"Parsed DB value (after quote replacement) for key '{key}' is not a dictionary type. Value: {corrected_db_value}"
|
||||
f"Invalid structure in List[Dict[str, str]] for key '{key}'. Value: {db_value}"
|
||||
)
|
||||
except (json.JSONDecodeError, ValueError, TypeError) as e2:
|
||||
logger.error(
|
||||
f"Could not parse '{db_value}' as Dict[str, float] for key '{key}' even after replacing quotes: {e2}. Returning empty dict."
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"Could not parse '{db_value}' as Dict[str, float] for key '{key}': {e1}. Returning empty dict."
|
||||
)
|
||||
return parsed_dict
|
||||
# 处理 List[Dict[str, str]]
|
||||
elif target_type == List[Dict[str, str]]:
|
||||
try:
|
||||
parsed = json.loads(db_value)
|
||||
if isinstance(parsed, list):
|
||||
# 验证列表中的每个元素是否为字典,并且键和值都是字符串
|
||||
valid = all(
|
||||
isinstance(item, dict)
|
||||
and all(isinstance(k, str) for k in item.keys())
|
||||
and all(isinstance(v, str) for v in item.values())
|
||||
for item in parsed
|
||||
)
|
||||
if valid:
|
||||
return parsed
|
||||
return []
|
||||
else:
|
||||
logger.warning(
|
||||
f"Invalid structure in List[Dict[str, str]] for key '{key}'. Value: {db_value}"
|
||||
f"Parsed DB value for key '{key}' is not a list type. Value: {db_value}"
|
||||
)
|
||||
return []
|
||||
else:
|
||||
logger.warning(
|
||||
f"Parsed DB value for key '{key}' is not a list type. Value: {db_value}"
|
||||
except json.JSONDecodeError:
|
||||
logger.error(
|
||||
f"Could not parse '{db_value}' as JSON for List[Dict[str, str]] for key '{key}'. Returning empty list."
|
||||
)
|
||||
return []
|
||||
except json.JSONDecodeError:
|
||||
logger.error(
|
||||
f"Could not parse '{db_value}' as JSON for List[Dict[str, str]] for key '{key}'. Returning empty list."
|
||||
)
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error parsing List[Dict[str, str]] for key '{key}': {e}. Value: {db_value}. Returning empty list."
|
||||
)
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error parsing List[Dict[str, str]] for key '{key}': {e}. Value: {db_value}. Returning empty list."
|
||||
)
|
||||
return []
|
||||
# 处理 Dict 类型
|
||||
elif origin_type is dict:
|
||||
# 处理 Dict[str, str]
|
||||
if args and args == (str, str):
|
||||
parsed_dict = {}
|
||||
try:
|
||||
parsed = json.loads(db_value)
|
||||
if isinstance(parsed, dict):
|
||||
parsed_dict = {str(k): str(v) for k, v in parsed.items()}
|
||||
else:
|
||||
logger.warning(
|
||||
f"Parsed DB value for key '{key}' is not a dictionary type. Value: {db_value}"
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"Could not parse '{db_value}' as Dict[str, str] for key '{key}'. Returning empty dict.")
|
||||
return parsed_dict
|
||||
# 处理 Dict[str, float]
|
||||
elif args and args == (str, float):
|
||||
parsed_dict = {}
|
||||
try:
|
||||
parsed = json.loads(db_value)
|
||||
if isinstance(parsed, dict):
|
||||
parsed_dict = {str(k): float(v) for k, v in parsed.items()}
|
||||
else:
|
||||
logger.warning(
|
||||
f"Parsed DB value for key '{key}' is not a dictionary type. Value: {db_value}"
|
||||
)
|
||||
except (json.JSONDecodeError, ValueError, TypeError) as e1:
|
||||
if isinstance(e1, json.JSONDecodeError) and "'" in db_value:
|
||||
logger.warning(
|
||||
f"Failed initial JSON parse for key '{key}'. Attempting to replace single quotes. Error: {e1}"
|
||||
)
|
||||
try:
|
||||
corrected_db_value = db_value.replace("'", '"')
|
||||
parsed = json.loads(corrected_db_value)
|
||||
if isinstance(parsed, dict):
|
||||
parsed_dict = {str(k): float(v) for k, v in parsed.items()}
|
||||
else:
|
||||
logger.warning(
|
||||
f"Parsed DB value (after quote replacement) for key '{key}' is not a dictionary type. Value: {corrected_db_value}"
|
||||
)
|
||||
except (json.JSONDecodeError, ValueError, TypeError) as e2:
|
||||
logger.error(
|
||||
f"Could not parse '{db_value}' as Dict[str, float] for key '{key}' even after replacing quotes: {e2}. Returning empty dict."
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"Could not parse '{db_value}' as Dict[str, float] for key '{key}': {e1}. Returning empty dict."
|
||||
)
|
||||
return parsed_dict
|
||||
# 处理 bool
|
||||
elif target_type == bool:
|
||||
return db_value.lower() in ("true", "1", "yes", "on")
|
||||
@@ -296,18 +336,12 @@ async def sync_initial_settings():
|
||||
if parsed_db_value != memory_value:
|
||||
# 检查类型是否匹配,以防解析函数返回了不兼容的类型
|
||||
type_match = False
|
||||
if target_type == List[str] and isinstance(
|
||||
parsed_db_value, list
|
||||
):
|
||||
type_match = True
|
||||
elif target_type == Dict[str, float] and isinstance(
|
||||
parsed_db_value, dict
|
||||
):
|
||||
type_match = True
|
||||
elif target_type not in (
|
||||
List[str],
|
||||
Dict[str, float],
|
||||
) and isinstance(parsed_db_value, target_type):
|
||||
origin_type = get_origin(target_type)
|
||||
if origin_type: # It's a generic type
|
||||
if isinstance(parsed_db_value, origin_type):
|
||||
type_match = True
|
||||
# It's a non-generic type, or a specific generic we want to handle
|
||||
elif isinstance(parsed_db_value, target_type):
|
||||
type_match = True
|
||||
|
||||
if type_match:
|
||||
|
||||
@@ -15,12 +15,12 @@ DEFAULT_MAX_TOKENS = 8192
|
||||
DEFAULT_TOP_P = 0.9
|
||||
DEFAULT_TOP_K = 40
|
||||
DEFAULT_FILTER_MODELS = [
|
||||
"gemini-1.0-pro-vision-latest",
|
||||
"gemini-pro-vision",
|
||||
"chat-bison-001",
|
||||
"text-bison-001",
|
||||
"embedding-gecko-001"
|
||||
]
|
||||
"gemini-1.0-pro-vision-latest",
|
||||
"gemini-pro-vision",
|
||||
"chat-bison-001",
|
||||
"text-bison-001",
|
||||
"embedding-gecko-001",
|
||||
]
|
||||
DEFAULT_CREATE_IMAGE_MODEL = "imagen-3.0-generate-002"
|
||||
|
||||
# 图像生成相关常量
|
||||
@@ -38,14 +38,14 @@ DEFAULT_STREAM_LONG_TEXT_THRESHOLD = 50
|
||||
DEFAULT_STREAM_CHUNK_SIZE = 5
|
||||
|
||||
# 正则表达式模式
|
||||
IMAGE_URL_PATTERN = r'!\[(.*?)\]\((.*?)\)'
|
||||
DATA_URL_PATTERN = r'data:([^;]+);base64,(.+)'
|
||||
IMAGE_URL_PATTERN = r"!\[(.*?)\]\((.*?)\)"
|
||||
DATA_URL_PATTERN = r"data:([^;]+);base64,(.+)"
|
||||
|
||||
# Audio/Video Settings
|
||||
SUPPORTED_AUDIO_FORMATS = ["wav", "mp3", "flac", "ogg"]
|
||||
SUPPORTED_VIDEO_FORMATS = ["mp4", "mov", "avi", "webm"]
|
||||
MAX_AUDIO_SIZE_BYTES = 50 * 1024 * 1024 # Example: 50MB limit for Base64 payload
|
||||
MAX_VIDEO_SIZE_BYTES = 200 * 1024 * 1024 # Example: 200MB limit
|
||||
MAX_VIDEO_SIZE_BYTES = 200 * 1024 * 1024 # Example: 200MB limit
|
||||
|
||||
# Optional: Define MIME type mappings if needed, or handle directly in converter
|
||||
AUDIO_FORMAT_TO_MIMETYPE = {
|
||||
@@ -63,17 +63,50 @@ VIDEO_FORMAT_TO_MIMETYPE = {
|
||||
}
|
||||
|
||||
GEMINI_2_FLASH_EXP_SAFETY_SETTINGS = [
|
||||
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "OFF"},
|
||||
]
|
||||
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "OFF"},
|
||||
]
|
||||
|
||||
DEFAULT_SAFETY_SETTINGS = [
|
||||
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"},
|
||||
]
|
||||
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"},
|
||||
]
|
||||
|
||||
TTS_VOICE_NAMES = [
|
||||
"Zephyr",
|
||||
"Puck",
|
||||
"Charon",
|
||||
"Kore",
|
||||
"Fenrir",
|
||||
"Leda",
|
||||
"Orus",
|
||||
"Aoede",
|
||||
"Callirrhoe",
|
||||
"Autonoe",
|
||||
"Enceladus",
|
||||
"Iapetus",
|
||||
"Umbriel",
|
||||
"Algieba",
|
||||
"Despina",
|
||||
"Erinome",
|
||||
"Algenib",
|
||||
"Rasalgethi",
|
||||
"Laomedeia",
|
||||
"Achernar",
|
||||
"Alnilam",
|
||||
"Schedar",
|
||||
"Gacrux",
|
||||
"Pulcherrima",
|
||||
"Achird",
|
||||
"Zubenelgenubi",
|
||||
"Vindemiatrix",
|
||||
"Sadachbia",
|
||||
"Sadaltager",
|
||||
"Sulafat",
|
||||
]
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
数据库连接池模块
|
||||
"""
|
||||
from pathlib import Path
|
||||
from urllib.parse import quote_plus
|
||||
from databases import Database
|
||||
from sqlalchemy import create_engine, MetaData
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
@@ -20,9 +21,9 @@ if settings.DATABASE_TYPE == "sqlite":
|
||||
DATABASE_URL = f"sqlite:///{db_path}"
|
||||
elif settings.DATABASE_TYPE == "mysql":
|
||||
if settings.MYSQL_SOCKET:
|
||||
DATABASE_URL = f"mysql+pymysql://{settings.MYSQL_USER}:{settings.MYSQL_PASSWORD}@/{settings.MYSQL_DATABASE}?unix_socket={settings.MYSQL_SOCKET}"
|
||||
DATABASE_URL = f"mysql+pymysql://{settings.MYSQL_USER}:{quote_plus(settings.MYSQL_PASSWORD)}@/{settings.MYSQL_DATABASE}?unix_socket={settings.MYSQL_SOCKET}"
|
||||
else:
|
||||
DATABASE_URL = f"mysql+pymysql://{settings.MYSQL_USER}:{settings.MYSQL_PASSWORD}@{settings.MYSQL_HOST}:{settings.MYSQL_PORT}/{settings.MYSQL_DATABASE}"
|
||||
DATABASE_URL = f"mysql+pymysql://{settings.MYSQL_USER}:{quote_plus(settings.MYSQL_PASSWORD)}@{settings.MYSQL_HOST}:{settings.MYSQL_PORT}/{settings.MYSQL_DATABASE}"
|
||||
else:
|
||||
raise ValueError("Unsupported database type. Please set DATABASE_TYPE to 'sqlite' or 'mysql'.")
|
||||
|
||||
|
||||
@@ -2,7 +2,8 @@
|
||||
数据库模型模块
|
||||
"""
|
||||
import datetime
|
||||
from sqlalchemy import Column, Integer, String, Text, DateTime, JSON, Boolean
|
||||
from sqlalchemy import Column, Integer, String, Text, DateTime, JSON, Boolean, BigInteger, Enum
|
||||
import enum
|
||||
|
||||
from app.database.connection import Base
|
||||
|
||||
@@ -60,3 +61,69 @@ class RequestLog(Base):
|
||||
|
||||
def __repr__(self):
|
||||
return f"<RequestLog(id='{self.id}', key='{self.api_key[:4]}...', success='{self.is_success}')>"
|
||||
|
||||
|
||||
class FileState(enum.Enum):
|
||||
"""文件状态枚举"""
|
||||
PROCESSING = "PROCESSING"
|
||||
ACTIVE = "ACTIVE"
|
||||
FAILED = "FAILED"
|
||||
|
||||
|
||||
class FileRecord(Base):
|
||||
"""
|
||||
文件记录表,用于存储上传到 Gemini 的文件信息
|
||||
"""
|
||||
__tablename__ = "t_file_records"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
|
||||
# 文件基本信息
|
||||
name = Column(String(255), unique=True, nullable=False, comment="文件名称,格式: files/{file_id}")
|
||||
display_name = Column(String(255), nullable=True, comment="用户上传时的原始文件名")
|
||||
mime_type = Column(String(100), nullable=False, comment="MIME 类型")
|
||||
size_bytes = Column(BigInteger, nullable=False, comment="文件大小(字节)")
|
||||
sha256_hash = Column(String(255), nullable=True, comment="文件的 SHA256 哈希值")
|
||||
|
||||
# 状态信息
|
||||
state = Column(Enum(FileState), nullable=False, default=FileState.PROCESSING, comment="文件状态")
|
||||
|
||||
# 时间戳
|
||||
create_time = Column(DateTime, nullable=False, comment="创建时间")
|
||||
update_time = Column(DateTime, nullable=False, comment="更新时间")
|
||||
expiration_time = Column(DateTime, nullable=False, comment="过期时间")
|
||||
|
||||
# API 相关
|
||||
uri = Column(String(500), nullable=False, comment="文件访问 URI")
|
||||
api_key = Column(String(100), nullable=False, comment="上传时使用的 API Key")
|
||||
upload_url = Column(Text, nullable=True, comment="临时上传 URL(用于分块上传)")
|
||||
|
||||
# 额外信息
|
||||
user_token = Column(String(100), nullable=True, comment="上传用户的 token")
|
||||
upload_completed = Column(DateTime, nullable=True, comment="上传完成时间")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<FileRecord(name='{self.name}', state='{self.state.value if self.state else 'None'}', api_key='{self.api_key[:8]}...')>"
|
||||
|
||||
def to_dict(self):
|
||||
"""转换为字典格式,用于 API 响应"""
|
||||
return {
|
||||
"name": self.name,
|
||||
"displayName": self.display_name,
|
||||
"mimeType": self.mime_type,
|
||||
"sizeBytes": str(self.size_bytes),
|
||||
"createTime": self.create_time.isoformat() + "Z",
|
||||
"updateTime": self.update_time.isoformat() + "Z",
|
||||
"expirationTime": self.expiration_time.isoformat() + "Z",
|
||||
"sha256Hash": self.sha256_hash,
|
||||
"uri": self.uri,
|
||||
"state": self.state.value if self.state else "PROCESSING"
|
||||
}
|
||||
|
||||
def is_expired(self):
|
||||
"""检查文件是否已过期"""
|
||||
# 确保比较时都是 timezone-aware
|
||||
expiration_time = self.expiration_time
|
||||
if expiration_time.tzinfo is None:
|
||||
expiration_time = expiration_time.replace(tzinfo=datetime.timezone.utc)
|
||||
return datetime.datetime.now(datetime.timezone.utc) > expiration_time
|
||||
|
||||
@@ -2,11 +2,11 @@
|
||||
数据库服务模块
|
||||
"""
|
||||
from typing import List, Optional, Dict, Any, Union
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from sqlalchemy import func, desc, asc, select, insert, update, delete
|
||||
import json
|
||||
from app.database.connection import database
|
||||
from app.database.models import Settings, ErrorLog, RequestLog
|
||||
from app.database.models import Settings, ErrorLog, RequestLog, FileRecord, FileState
|
||||
from app.log.logger import get_database_logger
|
||||
|
||||
logger = get_database_logger()
|
||||
@@ -427,3 +427,264 @@ async def add_request_log(
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add request log: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
# ==================== 文件记录相关函数 ====================
|
||||
|
||||
async def create_file_record(
|
||||
name: str,
|
||||
mime_type: str,
|
||||
size_bytes: int,
|
||||
api_key: str,
|
||||
uri: str,
|
||||
create_time: datetime,
|
||||
update_time: datetime,
|
||||
expiration_time: datetime,
|
||||
state: FileState = FileState.PROCESSING,
|
||||
display_name: Optional[str] = None,
|
||||
sha256_hash: Optional[str] = None,
|
||||
upload_url: Optional[str] = None,
|
||||
user_token: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
创建文件记录
|
||||
|
||||
Args:
|
||||
name: 文件名称(格式: files/{file_id})
|
||||
mime_type: MIME 类型
|
||||
size_bytes: 文件大小(字节)
|
||||
api_key: 上传时使用的 API Key
|
||||
uri: 文件访问 URI
|
||||
create_time: 创建时间
|
||||
update_time: 更新时间
|
||||
expiration_time: 过期时间
|
||||
display_name: 显示名称
|
||||
sha256_hash: SHA256 哈希值
|
||||
upload_url: 临时上传 URL
|
||||
user_token: 上传用户的 token
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 创建的文件记录
|
||||
"""
|
||||
try:
|
||||
query = insert(FileRecord).values(
|
||||
name=name,
|
||||
display_name=display_name,
|
||||
mime_type=mime_type,
|
||||
size_bytes=size_bytes,
|
||||
sha256_hash=sha256_hash,
|
||||
state=state,
|
||||
create_time=create_time,
|
||||
update_time=update_time,
|
||||
expiration_time=expiration_time,
|
||||
uri=uri,
|
||||
api_key=api_key,
|
||||
upload_url=upload_url,
|
||||
user_token=user_token
|
||||
)
|
||||
await database.execute(query)
|
||||
|
||||
# 返回创建的记录
|
||||
return await get_file_record_by_name(name)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create file record: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
async def get_file_record_by_name(name: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
根据文件名获取文件记录
|
||||
|
||||
Args:
|
||||
name: 文件名称(格式: files/{file_id})
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: 文件记录,如果不存在则返回 None
|
||||
"""
|
||||
try:
|
||||
query = select(FileRecord).where(FileRecord.name == name)
|
||||
result = await database.fetch_one(query)
|
||||
return dict(result) if result else None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get file record by name {name}: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
|
||||
async def update_file_record_state(
|
||||
file_name: str,
|
||||
state: FileState,
|
||||
update_time: Optional[datetime] = None,
|
||||
upload_completed: Optional[datetime] = None,
|
||||
sha256_hash: Optional[str] = None
|
||||
) -> bool:
|
||||
"""
|
||||
更新文件记录状态
|
||||
|
||||
Args:
|
||||
file_name: 文件名
|
||||
state: 新状态
|
||||
update_time: 更新时间
|
||||
upload_completed: 上传完成时间
|
||||
sha256_hash: SHA256 哈希值
|
||||
|
||||
Returns:
|
||||
bool: 是否更新成功
|
||||
"""
|
||||
try:
|
||||
values = {"state": state}
|
||||
if update_time:
|
||||
values["update_time"] = update_time
|
||||
if upload_completed:
|
||||
values["upload_completed"] = upload_completed
|
||||
if sha256_hash:
|
||||
values["sha256_hash"] = sha256_hash
|
||||
|
||||
query = update(FileRecord).where(FileRecord.name == file_name).values(**values)
|
||||
result = await database.execute(query)
|
||||
|
||||
if result:
|
||||
logger.info(f"Updated file record state for {file_name} to {state}")
|
||||
return True
|
||||
|
||||
logger.warning(f"File record not found for update: {file_name}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update file record state: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
async def list_file_records(
|
||||
user_token: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
page_size: int = 10,
|
||||
page_token: Optional[str] = None
|
||||
) -> tuple[List[Dict[str, Any]], Optional[str]]:
|
||||
"""
|
||||
列出文件记录
|
||||
|
||||
Args:
|
||||
user_token: 用户 token(如果提供,只返回该用户的文件)
|
||||
api_key: API Key(如果提供,只返回使用该 key 的文件)
|
||||
page_size: 每页大小
|
||||
page_token: 分页标记(偏移量)
|
||||
|
||||
Returns:
|
||||
tuple[List[Dict[str, Any]], Optional[str]]: (文件列表, 下一页标记)
|
||||
"""
|
||||
try:
|
||||
logger.debug(f"list_file_records called with page_size={page_size}, page_token={page_token}")
|
||||
query = select(FileRecord).where(
|
||||
FileRecord.expiration_time > datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
if user_token:
|
||||
query = query.where(FileRecord.user_token == user_token)
|
||||
if api_key:
|
||||
query = query.where(FileRecord.api_key == api_key)
|
||||
|
||||
# 使用偏移量进行分页
|
||||
offset = 0
|
||||
if page_token:
|
||||
try:
|
||||
offset = int(page_token)
|
||||
except ValueError:
|
||||
logger.warning(f"Invalid page token: {page_token}")
|
||||
offset = 0
|
||||
|
||||
# 按ID升序排列,使用 OFFSET 和 LIMIT
|
||||
query = query.order_by(FileRecord.id).offset(offset).limit(page_size + 1)
|
||||
|
||||
results = await database.fetch_all(query)
|
||||
|
||||
logger.debug(f"Query returned {len(results)} records")
|
||||
if results:
|
||||
logger.debug(f"First record ID: {results[0]['id']}, Last record ID: {results[-1]['id']}")
|
||||
|
||||
# 处理分页
|
||||
has_next = len(results) > page_size
|
||||
if has_next:
|
||||
results = results[:page_size]
|
||||
# 下一页的偏移量是当前偏移量加上本页返回的记录数
|
||||
next_offset = offset + page_size
|
||||
next_page_token = str(next_offset)
|
||||
logger.debug(f"Has next page, offset={offset}, page_size={page_size}, next_page_token={next_page_token}")
|
||||
else:
|
||||
next_page_token = None
|
||||
logger.debug(f"No next page, returning {len(results)} results")
|
||||
|
||||
return [dict(row) for row in results], next_page_token
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list file records: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
async def delete_file_record(name: str) -> bool:
|
||||
"""
|
||||
删除文件记录
|
||||
|
||||
Args:
|
||||
name: 文件名称
|
||||
|
||||
Returns:
|
||||
bool: 是否删除成功
|
||||
"""
|
||||
try:
|
||||
query = delete(FileRecord).where(FileRecord.name == name)
|
||||
await database.execute(query)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete file record: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
async def delete_expired_file_records() -> List[Dict[str, Any]]:
|
||||
"""
|
||||
删除已过期的文件记录
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 删除的记录列表
|
||||
"""
|
||||
try:
|
||||
# 先获取要删除的记录
|
||||
query = select(FileRecord).where(
|
||||
FileRecord.expiration_time <= datetime.now(timezone.utc)
|
||||
)
|
||||
expired_records = await database.fetch_all(query)
|
||||
|
||||
if not expired_records:
|
||||
return []
|
||||
|
||||
# 执行删除
|
||||
delete_query = delete(FileRecord).where(
|
||||
FileRecord.expiration_time <= datetime.now(timezone.utc)
|
||||
)
|
||||
await database.execute(delete_query)
|
||||
|
||||
logger.info(f"Deleted {len(expired_records)} expired file records")
|
||||
return [dict(record) for record in expired_records]
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete expired file records: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
async def get_file_api_key(name: str) -> Optional[str]:
|
||||
"""
|
||||
获取文件对应的 API Key
|
||||
|
||||
Args:
|
||||
name: 文件名称
|
||||
|
||||
Returns:
|
||||
Optional[str]: API Key,如果文件不存在或已过期则返回 None
|
||||
"""
|
||||
try:
|
||||
query = select(FileRecord.api_key).where(
|
||||
(FileRecord.name == name) &
|
||||
(FileRecord.expiration_time > datetime.now(timezone.utc))
|
||||
)
|
||||
result = await database.fetch_one(query)
|
||||
return result["api_key"] if result else None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get file API key: {str(e)}")
|
||||
raise
|
||||
|
||||
69
app/domain/file_models.py
Normal file
69
app/domain/file_models.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""
|
||||
Files API 相关的领域模型
|
||||
"""
|
||||
from typing import Optional, Dict, Any, List
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class FileUploadConfig(BaseModel):
|
||||
"""文件上传配置"""
|
||||
mime_type: Optional[str] = Field(None, description="MIME 类型")
|
||||
display_name: Optional[str] = Field(None, description="显示名称,最多40个字符")
|
||||
|
||||
|
||||
class CreateFileRequest(BaseModel):
|
||||
"""创建文件请求(用于初始化上传)"""
|
||||
file: Optional[Dict[str, Any]] = Field(None, description="文件元数据")
|
||||
|
||||
|
||||
class FileMetadata(BaseModel):
|
||||
"""文件元数据响应"""
|
||||
name: str = Field(..., description="文件名称,格式: files/{file_id}")
|
||||
displayName: Optional[str] = Field(None, description="显示名称")
|
||||
mimeType: str = Field(..., description="MIME 类型")
|
||||
sizeBytes: str = Field(..., description="文件大小(字节)")
|
||||
createTime: str = Field(..., description="创建时间 (RFC3339)")
|
||||
updateTime: str = Field(..., description="更新时间 (RFC3339)")
|
||||
expirationTime: str = Field(..., description="过期时间 (RFC3339)")
|
||||
sha256Hash: Optional[str] = Field(None, description="SHA256 哈希值")
|
||||
uri: str = Field(..., description="文件访问 URI")
|
||||
state: str = Field(..., description="文件状态")
|
||||
|
||||
class Config:
|
||||
json_encoders = {
|
||||
datetime: lambda v: v.isoformat() + "Z"
|
||||
}
|
||||
|
||||
|
||||
class ListFilesRequest(BaseModel):
|
||||
"""列出文件请求参数"""
|
||||
pageSize: Optional[int] = Field(10, ge=1, le=100, description="每页大小")
|
||||
pageToken: Optional[str] = Field(None, description="分页标记")
|
||||
|
||||
|
||||
class ListFilesResponse(BaseModel):
|
||||
"""列出文件响应"""
|
||||
files: List[FileMetadata] = Field(default_factory=list, description="文件列表")
|
||||
nextPageToken: Optional[str] = Field(None, description="下一页标记")
|
||||
|
||||
|
||||
class UploadInitResponse(BaseModel):
|
||||
"""上传初始化响应(内部使用)"""
|
||||
file_metadata: FileMetadata
|
||||
upload_url: str
|
||||
|
||||
|
||||
class FileKeyMapping(BaseModel):
|
||||
"""文件与 API Key 的映射关系(内部使用)"""
|
||||
file_name: str
|
||||
api_key: str
|
||||
user_token: str
|
||||
created_at: datetime
|
||||
expires_at: datetime
|
||||
|
||||
|
||||
class DeleteFileResponse(BaseModel):
|
||||
"""删除文件响应"""
|
||||
success: bool = Field(..., description="是否删除成功")
|
||||
message: Optional[str] = Field(None, description="消息")
|
||||
@@ -41,15 +41,18 @@ class GenerationConfig(BaseModel):
|
||||
responseLogprobs: Optional[bool] = None
|
||||
logprobs: Optional[int] = None
|
||||
thinkingConfig: Optional[Dict[str, Any]] = None
|
||||
# TTS相关字段
|
||||
responseModalities: Optional[List[str]] = None
|
||||
speechConfig: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class SystemInstruction(BaseModel):
|
||||
role: str = "system"
|
||||
parts: List[Dict[str, Any]] | Dict[str, Any]
|
||||
role: Optional[str] = "system"
|
||||
parts: Union[List[Dict[str, Any]], Dict[str, Any]]
|
||||
|
||||
|
||||
class GeminiContent(BaseModel):
|
||||
role: str
|
||||
role: Optional[str] = None
|
||||
parts: List[Dict[str, Any]]
|
||||
|
||||
|
||||
|
||||
@@ -1,23 +1,20 @@
|
||||
from typing import Union
|
||||
|
||||
|
||||
class ImageMetadata:
|
||||
def __init__(self, width: int, height: int, filename: str, size: int, url: str, delete_url: str | None = None):
|
||||
def __init__(self, width: int, height: int, filename: str, size: int, url: str, delete_url: Union[str, None] = None):
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.filename = filename
|
||||
self.size = size
|
||||
self.url = url
|
||||
self.delete_url = delete_url
|
||||
|
||||
|
||||
class UploadResponse:
|
||||
def __init__(self, success: bool, code: str, message: str, data: ImageMetadata):
|
||||
self.success = success
|
||||
self.code = code
|
||||
self.message = message
|
||||
self.data = data
|
||||
|
||||
|
||||
class ImageUploader:
|
||||
def upload(self, file: bytes, filename: str) -> UploadResponse:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
||||
@@ -33,3 +33,10 @@ class ImageGenerationRequest(BaseModel):
|
||||
quality: Optional[str] = None
|
||||
style: Optional[str] = None
|
||||
response_format: Optional[str] = "url"
|
||||
|
||||
|
||||
class TTSRequest(BaseModel):
|
||||
model: str = "gemini-2.5-flash-preview-tts"
|
||||
input: str
|
||||
voice: str = "Kore"
|
||||
response_format: Optional[str] = "wav"
|
||||
|
||||
@@ -9,6 +9,9 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.config.config import settings
|
||||
from app.utils.uploader import ImageUploaderFactory
|
||||
from app.log.logger import get_openai_logger
|
||||
|
||||
logger = get_openai_logger()
|
||||
|
||||
|
||||
class ResponseHandler(ABC):
|
||||
@@ -39,13 +42,13 @@ class GeminiResponseHandler(ResponseHandler):
|
||||
def _handle_openai_stream_response(
|
||||
response: Dict[str, Any], model: str, finish_reason: str, usage_metadata: Optional[Dict[str, Any]]
|
||||
) -> Dict[str, Any]:
|
||||
text, tool_calls, _ = _extract_result(
|
||||
text, reasoning_content, tool_calls, _ = _extract_result(
|
||||
response, model, stream=True, gemini_format=False
|
||||
)
|
||||
if not text and not tool_calls:
|
||||
if not text and not tool_calls and not reasoning_content:
|
||||
delta = {}
|
||||
else:
|
||||
delta = {"content": text, "role": "assistant"}
|
||||
delta = {"content": text, "reasoning_content": reasoning_content, "role": "assistant"}
|
||||
if tool_calls:
|
||||
delta["tool_calls"] = tool_calls
|
||||
template_chunk = {
|
||||
@@ -63,7 +66,7 @@ def _handle_openai_stream_response(
|
||||
def _handle_openai_normal_response(
|
||||
response: Dict[str, Any], model: str, finish_reason: str, usage_metadata: Optional[Dict[str, Any]]
|
||||
) -> Dict[str, Any]:
|
||||
text, tool_calls, _ = _extract_result(
|
||||
text, reasoning_content, tool_calls, _ = _extract_result(
|
||||
response, model, stream=False, gemini_format=False
|
||||
)
|
||||
return {
|
||||
@@ -77,6 +80,7 @@ def _handle_openai_normal_response(
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": text,
|
||||
"reasoning_content": reasoning_content,
|
||||
"tool_calls": tool_calls,
|
||||
},
|
||||
"finish_reason": finish_reason,
|
||||
@@ -156,19 +160,24 @@ def _extract_result(
|
||||
model: str,
|
||||
stream: bool = False,
|
||||
gemini_format: bool = False,
|
||||
) -> tuple[str, List[Dict[str, Any]], Optional[bool]]:
|
||||
text, tool_calls = "", []
|
||||
thought = None
|
||||
) -> tuple[str, Optional[str], List[Dict[str, Any]], Optional[bool]]:
|
||||
text, reasoning_content, tool_calls, thought = "", "", [], None
|
||||
|
||||
if stream:
|
||||
if response.get("candidates"):
|
||||
candidate = response["candidates"][0]
|
||||
content = candidate.get("content", {})
|
||||
parts = content.get("parts", [])
|
||||
if not parts:
|
||||
return "", [], None
|
||||
logger.warning("No parts found in stream response")
|
||||
return "", None, [], None
|
||||
|
||||
if "text" in parts[0]:
|
||||
text = parts[0].get("text")
|
||||
if "thought" in parts[0]:
|
||||
if not gemini_format and settings.SHOW_THINKING_PROCESS:
|
||||
reasoning_content = text
|
||||
text = ""
|
||||
thought = parts[0].get("thought")
|
||||
elif "executableCode" in parts[0]:
|
||||
text = _format_code_block(parts[0]["executableCode"])
|
||||
@@ -187,40 +196,40 @@ def _extract_result(
|
||||
else:
|
||||
if response.get("candidates"):
|
||||
candidate = response["candidates"][0]
|
||||
if "thinking" in model:
|
||||
if settings.SHOW_THINKING_PROCESS:
|
||||
if len(candidate["content"]["parts"]) == 2:
|
||||
text = (
|
||||
"> thinking\n\n"
|
||||
+ candidate["content"]["parts"][0]["text"]
|
||||
+ "\n\n---\n> output\n\n"
|
||||
+ candidate["content"]["parts"][1]["text"]
|
||||
)
|
||||
else:
|
||||
text = candidate["content"]["parts"][0]["text"]
|
||||
else:
|
||||
if len(candidate["content"]["parts"]) == 2:
|
||||
text = candidate["content"]["parts"][1]["text"]
|
||||
else:
|
||||
text = candidate["content"]["parts"][0]["text"]
|
||||
else:
|
||||
text = ""
|
||||
if "parts" in candidate["content"]:
|
||||
for part in candidate["content"]["parts"]:
|
||||
text, reasoning_content = "", ""
|
||||
|
||||
# 使用安全的访问方式
|
||||
content = candidate.get("content", {})
|
||||
|
||||
if content and isinstance(content, dict):
|
||||
parts = content.get("parts", [])
|
||||
|
||||
if parts:
|
||||
for part in parts:
|
||||
if "text" in part:
|
||||
text += part["text"]
|
||||
if "thought" in part and settings.SHOW_THINKING_PROCESS:
|
||||
reasoning_content += part["text"]
|
||||
else:
|
||||
text += part["text"]
|
||||
if "thought" in part and thought is None:
|
||||
thought = part.get("thought")
|
||||
elif "inlineData" in part:
|
||||
text += _extract_image_data(part)
|
||||
else:
|
||||
logger.warning(f"No parts found in content for model: {model}")
|
||||
else:
|
||||
logger.error(f"Invalid content structure for model: {model}")
|
||||
|
||||
text = _add_search_link_text(model, candidate, text)
|
||||
tool_calls = _extract_tool_calls(
|
||||
candidate["content"]["parts"], gemini_format
|
||||
)
|
||||
|
||||
# 安全地获取 parts 用于工具调用提取
|
||||
parts = candidate.get("content", {}).get("parts", [])
|
||||
tool_calls = _extract_tool_calls(parts, gemini_format)
|
||||
else:
|
||||
logger.warning(f"No candidates found in response for model: {model}")
|
||||
text = "暂无返回"
|
||||
return text, tool_calls, thought
|
||||
|
||||
return text, reasoning_content, tool_calls, thought
|
||||
|
||||
|
||||
def _extract_image_data(part: dict) -> str:
|
||||
@@ -238,6 +247,7 @@ def _extract_image_data(part: dict) -> str:
|
||||
provider=settings.UPLOAD_PROVIDER,
|
||||
base_url=settings.CLOUDFLARE_IMGBED_URL,
|
||||
auth_code=settings.CLOUDFLARE_IMGBED_AUTH_CODE,
|
||||
upload_folder=settings.CLOUDFLARE_IMGBED_UPLOAD_FOLDER,
|
||||
)
|
||||
current_date = time.strftime("%Y/%m/%d")
|
||||
filename = f"{current_date}/{uuid.uuid4().hex[:8]}.png"
|
||||
@@ -260,8 +270,8 @@ def _extract_tool_calls(
|
||||
return []
|
||||
|
||||
letters = string.ascii_lowercase + string.digits
|
||||
|
||||
tool_calls = list()
|
||||
|
||||
for i in range(len(parts)):
|
||||
part = parts[i]
|
||||
if not part or not isinstance(part, dict):
|
||||
@@ -270,7 +280,7 @@ def _extract_tool_calls(
|
||||
item = part.get("functionCall", {})
|
||||
if not item or not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
|
||||
if gemini_format:
|
||||
tool_calls.append(part)
|
||||
else:
|
||||
@@ -293,7 +303,7 @@ def _extract_tool_calls(
|
||||
def _handle_gemini_stream_response(
|
||||
response: Dict[str, Any], model: str, stream: bool
|
||||
) -> Dict[str, Any]:
|
||||
text, tool_calls, thought = _extract_result(
|
||||
text, reasoning_content, tool_calls, thought = _extract_result(
|
||||
response, model, stream=stream, gemini_format=True
|
||||
)
|
||||
if tool_calls:
|
||||
@@ -310,16 +320,18 @@ def _handle_gemini_stream_response(
|
||||
def _handle_gemini_normal_response(
|
||||
response: Dict[str, Any], model: str, stream: bool
|
||||
) -> Dict[str, Any]:
|
||||
text, tool_calls, thought = _extract_result(
|
||||
text, reasoning_content, tool_calls, thought = _extract_result(
|
||||
response, model, stream=stream, gemini_format=True
|
||||
)
|
||||
parts = []
|
||||
if tool_calls:
|
||||
content = {"parts": tool_calls, "role": "model"}
|
||||
parts = tool_calls
|
||||
else:
|
||||
part = {"text": text}
|
||||
if thought is not None:
|
||||
part["thought"] = thought
|
||||
content = {"parts": [part], "role": "model"}
|
||||
parts.append({"text": reasoning_content,"thought": thought})
|
||||
part = {"text": text}
|
||||
parts.append(part)
|
||||
content = {"parts": parts, "role": "model"}
|
||||
response["candidates"][0]["content"] = content
|
||||
return response
|
||||
|
||||
|
||||
@@ -228,6 +228,10 @@ def get_request_log_logger():
|
||||
return Logger.setup_logger("request_log")
|
||||
|
||||
|
||||
def get_files_logger():
|
||||
return Logger.setup_logger("files")
|
||||
|
||||
|
||||
def get_vertex_express_logger():
|
||||
return Logger.setup_logger("vertex_express")
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ from fastapi.responses import RedirectResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
# from app.middleware.request_logging_middleware import RequestLoggingMiddleware
|
||||
from app.middleware.smart_routing_middleware import SmartRoutingMiddleware
|
||||
from app.core.constants import API_VERSION
|
||||
from app.core.security import verify_auth_token
|
||||
from app.log.logger import get_middleware_logger
|
||||
@@ -33,6 +34,7 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
and not request.url.path.startswith("/openai")
|
||||
and not request.url.path.startswith("/api/version/check")
|
||||
and not request.url.path.startswith("/vertex-express")
|
||||
and not request.url.path.startswith("/upload")
|
||||
):
|
||||
|
||||
auth_token = request.cookies.get("auth_token")
|
||||
@@ -52,6 +54,9 @@ def setup_middlewares(app: FastAPI) -> None:
|
||||
Args:
|
||||
app: FastAPI应用程序实例
|
||||
"""
|
||||
# 添加智能路由中间件(必须在认证中间件之前)
|
||||
app.add_middleware(SmartRoutingMiddleware)
|
||||
|
||||
# 添加认证中间件
|
||||
app.add_middleware(AuthMiddleware)
|
||||
|
||||
|
||||
210
app/middleware/smart_routing_middleware.py
Normal file
210
app/middleware/smart_routing_middleware.py
Normal file
@@ -0,0 +1,210 @@
|
||||
from fastapi import Request
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from app.config.config import settings
|
||||
from app.log.logger import get_main_logger
|
||||
import re
|
||||
|
||||
logger = get_main_logger()
|
||||
|
||||
class SmartRoutingMiddleware(BaseHTTPMiddleware):
|
||||
def __init__(self, app):
|
||||
super().__init__(app)
|
||||
# 简化的路由规则 - 直接根据检测结果路由
|
||||
pass
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
if not settings.URL_NORMALIZATION_ENABLED:
|
||||
return await call_next(request)
|
||||
logger.debug(f"request: {request}")
|
||||
original_path = str(request.url.path)
|
||||
method = request.method
|
||||
|
||||
# 尝试修复URL
|
||||
fixed_path, fix_info = self.fix_request_url(original_path, method, request)
|
||||
|
||||
if fixed_path != original_path:
|
||||
logger.info(f"URL fixed: {method} {original_path} → {fixed_path}")
|
||||
if fix_info:
|
||||
logger.debug(f"Fix details: {fix_info}")
|
||||
|
||||
# 重写请求路径
|
||||
request.scope["path"] = fixed_path
|
||||
request.scope["raw_path"] = fixed_path.encode()
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
def fix_request_url(self, path: str, method: str, request: Request) -> tuple:
|
||||
"""简化的URL修复逻辑"""
|
||||
|
||||
# 首先检查是否已经是正确的格式,如果是则不处理
|
||||
if self.is_already_correct_format(path):
|
||||
return path, None
|
||||
|
||||
# 1. 最高优先级:包含generateContent → Gemini格式
|
||||
if "generatecontent" in path.lower() or "v1beta/models" in path.lower():
|
||||
return self.fix_gemini_by_operation(path, method, request)
|
||||
|
||||
# 2. 第二优先级:包含/openai/ → OpenAI格式
|
||||
if "/openai/" in path.lower():
|
||||
return self.fix_openai_by_operation(path, method)
|
||||
|
||||
# 3. 第三优先级:包含/v1/ → v1格式
|
||||
if "/v1/" in path.lower():
|
||||
return self.fix_v1_by_operation(path, method)
|
||||
|
||||
# 4. 第四优先级:包含/chat/completions → chat功能
|
||||
if "/chat/completions" in path.lower():
|
||||
return "/v1/chat/completions", {"type": "v1_chat"}
|
||||
|
||||
# 5. 默认:原样传递
|
||||
return path, None
|
||||
|
||||
def is_already_correct_format(self, path: str) -> bool:
|
||||
"""检查是否已经是正确的API格式"""
|
||||
# 检查是否已经是正确的端点格式
|
||||
correct_patterns = [
|
||||
r"^/v1beta/models/[^/:]+:(generate|streamGenerate)Content$", # Gemini原生
|
||||
r"^/gemini/v1beta/models/[^/:]+:(generate|streamGenerate)Content$", # Gemini带前缀
|
||||
r"^/v1beta/models$", # Gemini模型列表
|
||||
r"^/gemini/v1beta/models$", # Gemini带前缀的模型列表
|
||||
r"^/v1/(chat/completions|models|embeddings|images/generations|audio/speech)$", # v1格式
|
||||
r"^/openai/v1/(chat/completions|models|embeddings|images/generations|audio/speech)$", # OpenAI格式
|
||||
r"^/hf/v1/(chat/completions|models|embeddings|images/generations|audio/speech)$", # HF格式
|
||||
r"^/vertex-express/v1beta/models/[^/:]+:(generate|streamGenerate)Content$", # Vertex Express Gemini格式
|
||||
r"^/vertex-express/v1beta/models$", # Vertex Express模型列表
|
||||
r"^/vertex-express/v1/(chat/completions|models|embeddings|images/generations)$", # Vertex Express OpenAI格式
|
||||
]
|
||||
|
||||
for pattern in correct_patterns:
|
||||
if re.match(pattern, path):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def fix_gemini_by_operation(
|
||||
self, path: str, method: str, request: Request
|
||||
) -> tuple:
|
||||
"""根据Gemini操作修复,考虑端点偏好"""
|
||||
if method == "GET":
|
||||
return "/v1beta/models", {
|
||||
"role": "gemini_models",
|
||||
}
|
||||
|
||||
# 提取模型名称
|
||||
try:
|
||||
model_name = self.extract_model_name(path, request)
|
||||
except ValueError:
|
||||
# 无法提取模型名称,返回原路径不做处理
|
||||
return path, None
|
||||
|
||||
# 检测是否为流式请求
|
||||
is_stream = self.detect_stream_request(path, request)
|
||||
|
||||
# 检查是否有vertex-express偏好
|
||||
if "/vertex-express/" in path.lower():
|
||||
if is_stream:
|
||||
target_url = (
|
||||
f"/vertex-express/v1beta/models/{model_name}:streamGenerateContent"
|
||||
)
|
||||
else:
|
||||
target_url = (
|
||||
f"/vertex-express/v1beta/models/{model_name}:generateContent"
|
||||
)
|
||||
|
||||
fix_info = {
|
||||
"rule": (
|
||||
"vertex_express_generate"
|
||||
if not is_stream
|
||||
else "vertex_express_stream"
|
||||
),
|
||||
"preference": "vertex_express_format",
|
||||
"is_stream": is_stream,
|
||||
"model": model_name,
|
||||
}
|
||||
else:
|
||||
# 标准Gemini端点
|
||||
if is_stream:
|
||||
target_url = f"/v1beta/models/{model_name}:streamGenerateContent"
|
||||
else:
|
||||
target_url = f"/v1beta/models/{model_name}:generateContent"
|
||||
|
||||
fix_info = {
|
||||
"rule": "gemini_generate" if not is_stream else "gemini_stream",
|
||||
"preference": "gemini_format",
|
||||
"is_stream": is_stream,
|
||||
"model": model_name,
|
||||
}
|
||||
|
||||
return target_url, fix_info
|
||||
|
||||
def fix_openai_by_operation(self, path: str, method: str) -> tuple:
|
||||
"""根据操作类型修复OpenAI格式"""
|
||||
if method == "POST":
|
||||
if "chat" in path.lower() or "completion" in path.lower():
|
||||
return "/openai/v1/chat/completions", {"type": "openai_chat"}
|
||||
elif "embedding" in path.lower():
|
||||
return "/openai/v1/embeddings", {"type": "openai_embeddings"}
|
||||
elif "image" in path.lower():
|
||||
return "/openai/v1/images/generations", {"type": "openai_images"}
|
||||
elif "audio" in path.lower():
|
||||
return "/openai/v1/audio/speech", {"type": "openai_audio"}
|
||||
elif method == "GET":
|
||||
if "model" in path.lower():
|
||||
return "/openai/v1/models", {"type": "openai_models"}
|
||||
|
||||
return path, None
|
||||
|
||||
def fix_v1_by_operation(self, path: str, method: str) -> tuple:
|
||||
"""根据操作类型修复v1格式"""
|
||||
if method == "POST":
|
||||
if "chat" in path.lower() or "completion" in path.lower():
|
||||
return "/v1/chat/completions", {"type": "v1_chat"}
|
||||
elif "embedding" in path.lower():
|
||||
return "/v1/embeddings", {"type": "v1_embeddings"}
|
||||
elif "image" in path.lower():
|
||||
return "/v1/images/generations", {"type": "v1_images"}
|
||||
elif "audio" in path.lower():
|
||||
return "/v1/audio/speech", {"type": "v1_audio"}
|
||||
elif method == "GET":
|
||||
if "model" in path.lower():
|
||||
return "/v1/models", {"type": "v1_models"}
|
||||
|
||||
return path, None
|
||||
|
||||
def detect_stream_request(self, path: str, request: Request) -> bool:
|
||||
"""检测是否为流式请求"""
|
||||
# 1. 路径中包含stream关键词
|
||||
if "stream" in path.lower():
|
||||
return True
|
||||
|
||||
# 2. 查询参数
|
||||
if request.query_params.get("stream") == "true":
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def extract_model_name(self, path: str, request: Request) -> str:
|
||||
"""从请求中提取模型名称,用于构建Gemini API URL"""
|
||||
# 1. 从请求体中提取
|
||||
try:
|
||||
if hasattr(request, "_body") and request._body:
|
||||
import json
|
||||
|
||||
body = json.loads(request._body.decode())
|
||||
if "model" in body and body["model"]:
|
||||
return body["model"]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 2. 从查询参数中提取
|
||||
model_param = request.query_params.get("model")
|
||||
if model_param:
|
||||
return model_param
|
||||
|
||||
# 3. 从路径中提取(用于已包含模型名称的路径)
|
||||
match = re.search(r"/models/([^/:]+)", path, re.IGNORECASE)
|
||||
if match:
|
||||
return match.group(1)
|
||||
|
||||
# 4. 如果无法提取模型名称,抛出异常
|
||||
raise ValueError("Unable to extract model name from request")
|
||||
295
app/router/files_routes.py
Normal file
295
app/router/files_routes.py
Normal file
@@ -0,0 +1,295 @@
|
||||
"""
|
||||
Files API 路由
|
||||
"""
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, Request, Query, Depends, Header, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from app.config.config import settings
|
||||
from app.domain.file_models import (
|
||||
FileMetadata,
|
||||
ListFilesResponse,
|
||||
DeleteFileResponse
|
||||
)
|
||||
from app.log.logger import get_files_logger
|
||||
from app.core.security import SecurityService
|
||||
from app.service.files.files_service import get_files_service
|
||||
from app.service.files.file_upload_handler import get_upload_handler
|
||||
|
||||
logger = get_files_logger()
|
||||
|
||||
router = APIRouter()
|
||||
security_service = SecurityService()
|
||||
|
||||
|
||||
@router.post("/upload/v1beta/files")
|
||||
async def upload_file_init(
|
||||
request: Request,
|
||||
auth_token: str = Depends(security_service.verify_key_or_goog_api_key),
|
||||
x_goog_upload_protocol: Optional[str] = Header(None),
|
||||
x_goog_upload_command: Optional[str] = Header(None),
|
||||
x_goog_upload_header_content_length: Optional[str] = Header(None),
|
||||
x_goog_upload_header_content_type: Optional[str] = Header(None),
|
||||
):
|
||||
"""初始化文件上传"""
|
||||
logger.debug(f"Upload file request: {request.method=}, {request.url=}, {auth_token=}, {x_goog_upload_protocol=}, {x_goog_upload_command=}, {x_goog_upload_header_content_length=}, {x_goog_upload_header_content_type=}")
|
||||
|
||||
# 檢查是否是實際的上傳請求(有 upload_id)
|
||||
if request.query_params.get("upload_id") and x_goog_upload_command in ["upload", "upload, finalize"]:
|
||||
logger.debug("This is an upload request, not initialization. Redirecting to handle_upload.")
|
||||
return await handle_upload(
|
||||
upload_path="v1beta/files",
|
||||
request=request,
|
||||
key=request.query_params.get("key"),
|
||||
auth_token=auth_token
|
||||
)
|
||||
|
||||
try:
|
||||
# 使用认证 token 作为 user_token
|
||||
user_token = auth_token
|
||||
# 获取请求体
|
||||
body = await request.body()
|
||||
|
||||
# 构建请求主机 URL
|
||||
request_host = f"{request.url.scheme}://{request.url.netloc}"
|
||||
logger.info(f"Request host: {request_host}")
|
||||
|
||||
# 准备请求头
|
||||
headers = {
|
||||
"x-goog-upload-protocol": x_goog_upload_protocol or "resumable",
|
||||
"x-goog-upload-command": x_goog_upload_command or "start",
|
||||
}
|
||||
|
||||
if x_goog_upload_header_content_length:
|
||||
headers["x-goog-upload-header-content-length"] = x_goog_upload_header_content_length
|
||||
if x_goog_upload_header_content_type:
|
||||
headers["x-goog-upload-header-content-type"] = x_goog_upload_header_content_type
|
||||
|
||||
# 调用服务
|
||||
files_service = await get_files_service()
|
||||
response_data, response_headers = await files_service.initialize_upload(
|
||||
headers=headers,
|
||||
body=body,
|
||||
user_token=user_token,
|
||||
request_host=request_host # 傳遞請求主機
|
||||
)
|
||||
|
||||
logger.info(f"Upload initialization response: {response_data}")
|
||||
logger.info(f"Upload initialization response headers: {response_headers}")
|
||||
|
||||
logger.info(f"Upload initialization response headers: {response_data}")
|
||||
# 返回响应
|
||||
return JSONResponse(
|
||||
content=response_data,
|
||||
headers=response_headers
|
||||
)
|
||||
|
||||
except HTTPException as e:
|
||||
logger.error(f"Upload initialization failed: {e.detail}")
|
||||
return JSONResponse(
|
||||
content={"error": {"message": e.detail}},
|
||||
status_code=e.status_code
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in upload initialization: {str(e)}")
|
||||
return JSONResponse(
|
||||
content={"error": {"message": "Internal server error"}},
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.get("/v1beta/files")
|
||||
async def list_files(
|
||||
page_size: int = Query(10, ge=1, le=100, description="每页大小", alias="pageSize"),
|
||||
page_token: Optional[str] = Query(None, description="分页标记", alias="pageToken"),
|
||||
auth_token: str = Depends(security_service.verify_key_or_goog_api_key)
|
||||
) -> ListFilesResponse:
|
||||
"""列出文件"""
|
||||
logger.debug(f"List files: {page_size=}, {page_token=}, {auth_token=}")
|
||||
try:
|
||||
# 使用认证 token 作为 user_token(如果启用用户隔离)
|
||||
user_token = auth_token if settings.FILES_USER_ISOLATION_ENABLED else None
|
||||
# 调用服务
|
||||
files_service = await get_files_service()
|
||||
return await files_service.list_files(
|
||||
page_size=page_size,
|
||||
page_token=page_token,
|
||||
user_token=user_token
|
||||
)
|
||||
|
||||
except HTTPException as e:
|
||||
logger.error(f"List files failed: {e.detail}")
|
||||
return JSONResponse(
|
||||
content={"error": {"message": e.detail}},
|
||||
status_code=e.status_code
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in list files: {str(e)}")
|
||||
return JSONResponse(
|
||||
content={"error": {"message": "Internal server error"}},
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.get("/v1beta/files/{file_id:path}")
|
||||
async def get_file(
|
||||
file_id: str,
|
||||
auth_token: str = Depends(security_service.verify_key_or_goog_api_key)
|
||||
) -> FileMetadata:
|
||||
"""获取文件信息"""
|
||||
logger.debug(f"Get file request: {file_id=}, {auth_token=}")
|
||||
try:
|
||||
# 使用认证 token 作为 user_token
|
||||
user_token = auth_token
|
||||
# 调用服务
|
||||
files_service = await get_files_service()
|
||||
return await files_service.get_file(f"files/{file_id}", user_token)
|
||||
|
||||
except HTTPException as e:
|
||||
logger.error(f"Get file failed: {e.detail}")
|
||||
return JSONResponse(
|
||||
content={"error": {"message": e.detail}},
|
||||
status_code=e.status_code
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in get file: {str(e)}")
|
||||
return JSONResponse(
|
||||
content={"error": {"message": "Internal server error"}},
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/v1beta/files/{file_id:path}")
|
||||
async def delete_file(
|
||||
file_id: str,
|
||||
auth_token: str = Depends(security_service.verify_key_or_goog_api_key)
|
||||
) -> DeleteFileResponse:
|
||||
"""删除文件"""
|
||||
logger.info(f"Delete file: {file_id=}, {auth_token=}")
|
||||
try:
|
||||
# 使用认证 token 作为 user_token
|
||||
user_token = auth_token
|
||||
# 调用服务
|
||||
files_service = await get_files_service()
|
||||
success = await files_service.delete_file(f"files/{file_id}", user_token)
|
||||
|
||||
return DeleteFileResponse(
|
||||
success=success,
|
||||
message="File deleted successfully" if success else "Failed to delete file"
|
||||
)
|
||||
|
||||
except HTTPException as e:
|
||||
logger.error(f"Delete file failed: {e.detail}")
|
||||
return JSONResponse(
|
||||
content={"error": {"message": e.detail}},
|
||||
status_code=e.status_code
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in delete file: {str(e)}")
|
||||
return JSONResponse(
|
||||
content={"error": {"message": "Internal server error"}},
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
# 处理上传请求的通配符路由
|
||||
@router.api_route("/upload/{upload_path:path}", methods=["GET", "POST", "PUT"])
|
||||
async def handle_upload(
|
||||
upload_path: str,
|
||||
request: Request,
|
||||
key: Optional[str] = Query(None), # 從查詢參數獲取 key
|
||||
auth_token: str = Depends(security_service.verify_key_or_goog_api_key)
|
||||
):
|
||||
"""处理文件上传请求"""
|
||||
try:
|
||||
logger.info(f"Handling upload request: {request.method} {upload_path}, key={key}")
|
||||
|
||||
# 從查詢參數獲取 upload_id
|
||||
upload_id = request.query_params.get("upload_id")
|
||||
if not upload_id:
|
||||
raise HTTPException(status_code=400, detail="Missing upload_id")
|
||||
|
||||
# 從 session 獲取真實的 API key
|
||||
files_service = await get_files_service()
|
||||
session_info = await files_service.get_upload_session(upload_id)
|
||||
if not session_info:
|
||||
logger.error(f"No session found for upload_id: {upload_id}")
|
||||
raise HTTPException(status_code=404, detail="Upload session not found")
|
||||
|
||||
real_api_key = session_info["api_key"]
|
||||
original_upload_url = session_info["upload_url"]
|
||||
|
||||
# 使用真實的 API key 構建完整的 Google 上傳 URL
|
||||
# 保留原始 URL 的所有參數,但使用真實的 API key
|
||||
upload_url = original_upload_url
|
||||
logger.info(f"Using real API key for upload: {real_api_key[:8]}...{real_api_key[-4:]}")
|
||||
|
||||
# 代理上传请求
|
||||
upload_handler = get_upload_handler()
|
||||
return await upload_handler.proxy_upload_request(
|
||||
request=request,
|
||||
upload_url=upload_url,
|
||||
files_service=files_service
|
||||
)
|
||||
|
||||
except HTTPException as e:
|
||||
logger.error(f"Upload handling failed: {e.detail}")
|
||||
return JSONResponse(
|
||||
content={"error": {"message": e.detail}},
|
||||
status_code=e.status_code
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in upload handling: {str(e)}")
|
||||
return JSONResponse(
|
||||
content={"error": {"message": "Internal server error"}},
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
# 为兼容性添加 /gemini 前缀的路由
|
||||
@router.post("/gemini/upload/v1beta/files")
|
||||
async def gemini_upload_file_init(
|
||||
request: Request,
|
||||
auth_token: str = Depends(security_service.verify_key_or_goog_api_key),
|
||||
x_goog_upload_protocol: Optional[str] = Header(None),
|
||||
x_goog_upload_command: Optional[str] = Header(None),
|
||||
x_goog_upload_header_content_length: Optional[str] = Header(None),
|
||||
x_goog_upload_header_content_type: Optional[str] = Header(None),
|
||||
):
|
||||
"""初始化文件上传(Gemini 前缀)"""
|
||||
return await upload_file_init(
|
||||
request,
|
||||
auth_token,
|
||||
x_goog_upload_protocol,
|
||||
x_goog_upload_command,
|
||||
x_goog_upload_header_content_length,
|
||||
x_goog_upload_header_content_type
|
||||
)
|
||||
|
||||
|
||||
@router.get("/gemini/v1beta/files")
|
||||
async def gemini_list_files(
|
||||
page_size: int = Query(10, ge=1, le=100, alias="pageSize"),
|
||||
page_token: Optional[str] = Query(None, alias="pageToken"),
|
||||
auth_token: str = Depends(security_service.verify_key_or_goog_api_key)
|
||||
) -> ListFilesResponse:
|
||||
"""列出文件(Gemini 前缀)"""
|
||||
return await list_files(page_size, page_token, auth_token)
|
||||
|
||||
|
||||
@router.get("/gemini/v1beta/files/{file_id:path}")
|
||||
async def gemini_get_file(
|
||||
file_id: str,
|
||||
auth_token: str = Depends(security_service.verify_key_or_goog_api_key)
|
||||
) -> FileMetadata:
|
||||
"""获取文件信息(Gemini 前缀)"""
|
||||
return await get_file(file_id, auth_token)
|
||||
|
||||
|
||||
@router.delete("/gemini/v1beta/files/{file_id:path}")
|
||||
async def gemini_delete_file(
|
||||
file_id: str,
|
||||
auth_token: str = Depends(security_service.verify_key_or_goog_api_key)
|
||||
) -> DeleteFileResponse:
|
||||
"""删除文件(Gemini 前缀)"""
|
||||
return await delete_file(file_id, auth_token)
|
||||
@@ -8,6 +8,7 @@ from app.core.security import SecurityService
|
||||
from app.domain.gemini_models import GeminiContent, GeminiRequest, ResetSelectedKeysRequest, VerifySelectedKeysRequest
|
||||
from app.service.chat.gemini_chat_service import GeminiChatService
|
||||
from app.service.key.key_manager import KeyManager, get_key_manager_instance
|
||||
from app.service.tts.native.tts_routes import get_tts_chat_service
|
||||
from app.service.model.model_service import ModelService
|
||||
from app.handler.retry_handler import RetryHandler
|
||||
from app.handler.error_handler import handle_route_errors
|
||||
@@ -109,11 +110,41 @@ async def generate_content(
|
||||
async with handle_route_errors(logger, operation_name, failure_message="Content generation failed"):
|
||||
logger.info(f"Handling Gemini content generation request for model: {model_name}")
|
||||
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||
|
||||
# 检测是否为原生Gemini TTS请求
|
||||
is_native_tts = False
|
||||
if "tts" in model_name.lower() and request.generationConfig:
|
||||
# 直接从解析后的request对象获取TTS配置
|
||||
response_modalities = request.generationConfig.responseModalities or []
|
||||
speech_config = request.generationConfig.speechConfig or {}
|
||||
|
||||
# 如果包含AUDIO模态和语音配置,则认为是原生TTS请求
|
||||
if "AUDIO" in response_modalities and speech_config:
|
||||
is_native_tts = True
|
||||
logger.info("Detected native Gemini TTS request")
|
||||
logger.info(f"TTS responseModalities: {response_modalities}")
|
||||
logger.info(f"TTS speechConfig: {speech_config}")
|
||||
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
|
||||
if not await model_service.check_model_support(model_name):
|
||||
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
|
||||
|
||||
# 所有原生TTS请求都使用TTS增强服务
|
||||
if is_native_tts:
|
||||
try:
|
||||
logger.info("Using native TTS enhanced service")
|
||||
tts_service = await get_tts_chat_service(key_manager)
|
||||
response = await tts_service.generate_content(
|
||||
model=model_name,
|
||||
request=request,
|
||||
api_key=api_key
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.warning(f"Native TTS processing failed, falling back to standard service: {e}")
|
||||
|
||||
# 使用标准服务处理所有其他请求(非TTS)
|
||||
response = await chat_service.generate_content(
|
||||
model=model_name,
|
||||
request=request,
|
||||
@@ -151,6 +182,35 @@ async def stream_generate_content(
|
||||
return StreamingResponse(response_stream, media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.post("/models/{model_name}:countTokens")
|
||||
@router_v1beta.post("/models/{model_name}:countTokens")
|
||||
@RetryHandler(key_arg="api_key")
|
||||
async def count_tokens(
|
||||
model_name: str,
|
||||
request: GeminiRequest,
|
||||
_=Depends(security_service.verify_key_or_goog_api_key),
|
||||
api_key: str = Depends(get_next_working_key),
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
chat_service: GeminiChatService = Depends(get_chat_service)
|
||||
):
|
||||
"""处理 Gemini token 计数请求。"""
|
||||
operation_name = "gemini_count_tokens"
|
||||
async with handle_route_errors(logger, operation_name, failure_message="Token counting failed"):
|
||||
logger.info(f"Handling Gemini token count request for model: {model_name}")
|
||||
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
|
||||
if not await model_service.check_model_support(model_name):
|
||||
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
|
||||
|
||||
response = await chat_service.count_tokens(
|
||||
model=model_name,
|
||||
request=request,
|
||||
api_key=api_key
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@router.post("/reset-all-fail-counts")
|
||||
async def reset_all_key_fail_counts(key_type: str = None, key_manager: KeyManager = Depends(get_key_manager)):
|
||||
"""批量重置Gemini API密钥的失败计数,可选择性地仅重置有效或无效密钥"""
|
||||
@@ -269,7 +329,7 @@ async def verify_key(api_key: str, chat_service: GeminiChatService = Depends(get
|
||||
parts=[{"text": "hi"}],
|
||||
)
|
||||
],
|
||||
generation_config={"temperature": 0.7, "top_p": 1.0, "max_output_tokens": 10}
|
||||
generation_config={"temperature": 0.7, "topP": 1.0, "maxOutputTokens": 10}
|
||||
)
|
||||
|
||||
response = await chat_service.generate_content(
|
||||
@@ -279,7 +339,9 @@ async def verify_key(api_key: str, chat_service: GeminiChatService = Depends(get
|
||||
)
|
||||
|
||||
if response:
|
||||
return JSONResponse({"status": "valid"})
|
||||
# 如果密钥验证成功,则重置其失败计数
|
||||
await key_manager.reset_key_failure_count(api_key)
|
||||
return JSONResponse({"status": "valid"})
|
||||
except Exception as e:
|
||||
logger.error(f"Key verification failed: {str(e)}")
|
||||
|
||||
@@ -314,7 +376,7 @@ async def verify_selected_keys(
|
||||
try:
|
||||
gemini_request = GeminiRequest(
|
||||
contents=[GeminiContent(role="user", parts=[{"text": "hi"}])],
|
||||
generation_config={"temperature": 0.7, "top_p": 1.0, "max_output_tokens": 10}
|
||||
generation_config={"temperature": 0.7, "topP": 1.0, "maxOutputTokens": 10}
|
||||
)
|
||||
await chat_service.generate_content(
|
||||
settings.TEST_MODEL,
|
||||
@@ -322,6 +384,8 @@ async def verify_selected_keys(
|
||||
api_key
|
||||
)
|
||||
successful_keys.append(api_key)
|
||||
# 如果密钥验证成功,则重置其失败计数
|
||||
await key_manager.reset_key_failure_count(api_key)
|
||||
return api_key, "valid", None
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException, Response
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from app.config.config import settings
|
||||
@@ -7,6 +7,7 @@ from app.domain.openai_models import (
|
||||
ChatRequest,
|
||||
EmbeddingRequest,
|
||||
ImageGenerationRequest,
|
||||
TTSRequest,
|
||||
)
|
||||
from app.handler.retry_handler import RetryHandler
|
||||
from app.handler.error_handler import handle_route_errors
|
||||
@@ -14,6 +15,7 @@ from app.log.logger import get_openai_logger
|
||||
from app.service.chat.openai_chat_service import OpenAIChatService
|
||||
from app.service.embedding.embedding_service import EmbeddingService
|
||||
from app.service.image.image_create_service import ImageCreateService
|
||||
from app.service.tts.tts_service import TTSService
|
||||
from app.service.key.key_manager import KeyManager, get_key_manager_instance
|
||||
from app.service.model.model_service import ModelService
|
||||
|
||||
@@ -24,6 +26,7 @@ security_service = SecurityService()
|
||||
model_service = ModelService()
|
||||
embedding_service = EmbeddingService()
|
||||
image_create_service = ImageCreateService()
|
||||
tts_service = TTSService()
|
||||
|
||||
|
||||
async def get_key_manager():
|
||||
@@ -41,6 +44,11 @@ async def get_openai_chat_service(key_manager: KeyManager = Depends(get_key_mana
|
||||
return OpenAIChatService(settings.BASE_URL, key_manager)
|
||||
|
||||
|
||||
async def get_tts_service():
|
||||
"""获取TTS服务实例"""
|
||||
return tts_service
|
||||
|
||||
|
||||
@router.get("/v1/models")
|
||||
@router.get("/hf/v1/models")
|
||||
async def list_models(
|
||||
@@ -147,3 +155,21 @@ async def get_keys_list(
|
||||
},
|
||||
"total": len(keys_status["valid_keys"]) + len(keys_status["invalid_keys"]),
|
||||
}
|
||||
|
||||
|
||||
@router.post("/v1/audio/speech")
|
||||
@router.post("/hf/v1/audio/speech")
|
||||
async def text_to_speech(
|
||||
request: TTSRequest,
|
||||
_=Depends(security_service.verify_authorization),
|
||||
api_key: str = Depends(get_next_working_key_wrapper),
|
||||
tts_service: TTSService = Depends(get_tts_service),
|
||||
):
|
||||
"""处理 OpenAI TTS 请求。"""
|
||||
operation_name = "text_to_speech"
|
||||
async with handle_route_errors(logger, operation_name):
|
||||
logger.info(f"Handling TTS request for model: {request.model}")
|
||||
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
audio_data = await tts_service.create_tts(request, api_key)
|
||||
return Response(content=audio_data, media_type="audio/wav")
|
||||
|
||||
@@ -8,7 +8,7 @@ from fastapi.templating import Jinja2Templates
|
||||
|
||||
from app.core.security import verify_auth_token
|
||||
from app.log.logger import get_routes_logger
|
||||
from app.router import error_log_routes, gemini_routes, openai_routes, config_routes, scheduler_routes, stats_routes, version_routes, openai_compatiable_routes, vertex_express_routes
|
||||
from app.router import error_log_routes, gemini_routes, openai_routes, config_routes, scheduler_routes, stats_routes, version_routes, openai_compatiable_routes, vertex_express_routes, files_routes
|
||||
from app.service.key.key_manager import get_key_manager_instance
|
||||
from app.service.stats.stats_service import StatsService
|
||||
|
||||
@@ -34,6 +34,7 @@ def setup_routers(app: FastAPI) -> None:
|
||||
app.include_router(version_routes.router)
|
||||
app.include_router(openai_compatiable_routes.router)
|
||||
app.include_router(vertex_express_routes.router)
|
||||
app.include_router(files_routes.router)
|
||||
|
||||
setup_page_routes(app)
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ from app.service.chat.gemini_chat_service import GeminiChatService
|
||||
from app.service.error_log.error_log_service import delete_old_error_logs
|
||||
from app.service.key.key_manager import get_key_manager_instance
|
||||
from app.service.request_log.request_log_service import delete_old_request_logs_task
|
||||
from app.service.files.files_service import get_files_service
|
||||
|
||||
logger = Logger.setup_logger("scheduler")
|
||||
|
||||
@@ -96,6 +97,26 @@ async def check_failed_keys():
|
||||
)
|
||||
|
||||
|
||||
async def cleanup_expired_files():
|
||||
"""
|
||||
定时清理过期的文件记录
|
||||
"""
|
||||
logger.info("Starting scheduled cleanup for expired files...")
|
||||
try:
|
||||
files_service = await get_files_service()
|
||||
deleted_count = await files_service.cleanup_expired_files()
|
||||
|
||||
if deleted_count > 0:
|
||||
logger.info(f"Successfully cleaned up {deleted_count} expired files.")
|
||||
else:
|
||||
logger.info("No expired files to clean up.")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"An error occurred during the scheduled file cleanup: {str(e)}", exc_info=True
|
||||
)
|
||||
|
||||
|
||||
def setup_scheduler():
|
||||
"""设置并启动 APScheduler"""
|
||||
scheduler = AsyncIOScheduler(timezone=str(settings.TIMEZONE)) # 从配置读取时区
|
||||
@@ -134,6 +155,20 @@ def setup_scheduler():
|
||||
logger.info(
|
||||
f"Auto-delete request logs job scheduled to run daily at 3:05 AM, if enabled and AUTO_DELETE_REQUEST_LOGS_DAYS is set to {settings.AUTO_DELETE_REQUEST_LOGS_DAYS} days."
|
||||
)
|
||||
|
||||
# 新增:添加文件过期清理的定时任务,每小时执行一次
|
||||
if getattr(settings, 'FILES_CLEANUP_ENABLED', True):
|
||||
cleanup_interval = getattr(settings, 'FILES_CLEANUP_INTERVAL_HOURS', 1)
|
||||
scheduler.add_job(
|
||||
cleanup_expired_files,
|
||||
"interval",
|
||||
hours=cleanup_interval,
|
||||
id="cleanup_expired_files_job",
|
||||
name="Cleanup Expired Files",
|
||||
)
|
||||
logger.info(
|
||||
f"File cleanup job scheduled to run every {cleanup_interval} hour(s)."
|
||||
)
|
||||
|
||||
scheduler.start()
|
||||
logger.info("Scheduler started with all jobs.")
|
||||
|
||||
@@ -13,7 +13,7 @@ from app.handler.stream_optimizer import gemini_optimizer
|
||||
from app.log.logger import get_gemini_logger
|
||||
from app.service.client.api_client import GeminiApiClient
|
||||
from app.service.key.key_manager import KeyManager
|
||||
from app.database.services import add_error_log, add_request_log
|
||||
from app.database.services import add_error_log, add_request_log, get_file_api_key
|
||||
|
||||
logger = get_gemini_logger()
|
||||
|
||||
@@ -27,10 +27,74 @@ def _has_image_parts(contents: List[Dict[str, Any]]) -> bool:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _extract_file_references(contents: List[Dict[str, Any]]) -> List[str]:
|
||||
"""從內容中提取文件引用"""
|
||||
file_names = []
|
||||
for content in contents:
|
||||
if "parts" in content:
|
||||
for part in content["parts"]:
|
||||
if not isinstance(part, dict) or "fileData" not in part:
|
||||
continue
|
||||
file_data = part["fileData"]
|
||||
if "fileUri" not in file_data:
|
||||
continue
|
||||
file_uri = file_data["fileUri"]
|
||||
# 從 URI 中提取文件名
|
||||
# 1. https://generativelanguage.googleapis.com/v1beta/files/{file_id}
|
||||
match = re.match(rf"{re.escape(settings.BASE_URL)}/(files/.*)", file_uri)
|
||||
if not match:
|
||||
logger.warning(f"Invalid file URI: {file_uri}")
|
||||
continue
|
||||
file_id = match.group(1)
|
||||
file_names.append(file_id)
|
||||
logger.info(f"Found file reference: {file_id}")
|
||||
return file_names
|
||||
|
||||
def _clean_json_schema_properties(obj: Any) -> Any:
|
||||
"""清理JSON Schema中Gemini API不支持的字段"""
|
||||
if not isinstance(obj, dict):
|
||||
return obj
|
||||
|
||||
# Gemini API不支持的JSON Schema字段
|
||||
unsupported_fields = {
|
||||
"exclusiveMaximum", "exclusiveMinimum", "const", "examples",
|
||||
"contentEncoding", "contentMediaType", "if", "then", "else",
|
||||
"allOf", "anyOf", "oneOf", "not", "definitions", "$schema",
|
||||
"$id", "$ref", "$comment", "readOnly", "writeOnly"
|
||||
}
|
||||
|
||||
cleaned = {}
|
||||
for key, value in obj.items():
|
||||
if key in unsupported_fields:
|
||||
continue
|
||||
if isinstance(value, dict):
|
||||
cleaned[key] = _clean_json_schema_properties(value)
|
||||
elif isinstance(value, list):
|
||||
cleaned[key] = [_clean_json_schema_properties(item) for item in value]
|
||||
else:
|
||||
cleaned[key] = value
|
||||
|
||||
return cleaned
|
||||
|
||||
|
||||
def _build_tools(model: str, payload: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""构建工具"""
|
||||
|
||||
def _has_function_call(contents: List[Dict[str, Any]]) -> bool:
|
||||
"""检查内容中是否包含 functionCall"""
|
||||
if not contents or not isinstance(contents, list):
|
||||
return False
|
||||
for content in contents:
|
||||
if not content or not isinstance(content, dict) or "parts" not in content:
|
||||
continue
|
||||
parts = content.get("parts", [])
|
||||
if not parts or not isinstance(parts, list):
|
||||
continue
|
||||
for part in parts:
|
||||
if isinstance(part, dict) and "functionCall" in part:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _merge_tools(tools: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
record = dict()
|
||||
for item in tools:
|
||||
@@ -40,7 +104,15 @@ def _build_tools(model: str, payload: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
for k, v in item.items():
|
||||
if k == "functionDeclarations" and v and isinstance(v, list):
|
||||
functions = record.get("functionDeclarations", [])
|
||||
functions.extend(v)
|
||||
# 清理每个函数声明中的不支持字段
|
||||
cleaned_functions = []
|
||||
for func in v:
|
||||
if isinstance(func, dict):
|
||||
cleaned_func = _clean_json_schema_properties(func)
|
||||
cleaned_functions.append(cleaned_func)
|
||||
else:
|
||||
cleaned_functions.append(func)
|
||||
functions.extend(cleaned_functions)
|
||||
record["functionDeclarations"] = functions
|
||||
else:
|
||||
record[k] = v
|
||||
@@ -62,15 +134,32 @@ def _build_tools(model: str, payload: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
tool["codeExecution"] = {}
|
||||
if model.endswith("-search"):
|
||||
tool["googleSearch"] = {}
|
||||
|
||||
|
||||
real_model = _get_real_model(model)
|
||||
if real_model in settings.URL_CONTEXT_MODELS and settings.URL_CONTEXT_ENABLED:
|
||||
tool["urlContext"] = {}
|
||||
|
||||
# 解决 "Tool use with function calling is unsupported" 问题
|
||||
if tool.get("functionDeclarations"):
|
||||
if tool.get("functionDeclarations") or _has_function_call(payload.get("contents", [])):
|
||||
tool.pop("googleSearch", None)
|
||||
tool.pop("codeExecution", None)
|
||||
tool.pop("urlContext", None)
|
||||
|
||||
return [tool] if tool else []
|
||||
|
||||
|
||||
def _get_real_model(model: str) -> str:
|
||||
if model.endswith("-search"):
|
||||
model = model[:-7]
|
||||
if model.endswith("-image"):
|
||||
model = model[:-6]
|
||||
if model.endswith("-non-thinking"):
|
||||
model = model[:-13]
|
||||
if "-search" in model and "-non-thinking" in model:
|
||||
model = model[:-20]
|
||||
return model
|
||||
|
||||
|
||||
def _get_safety_settings(model: str) -> List[Dict[str, str]]:
|
||||
"""获取安全设置"""
|
||||
if model == "gemini-2.0-flash-exp":
|
||||
@@ -78,21 +167,61 @@ def _get_safety_settings(model: str) -> List[Dict[str, str]]:
|
||||
return settings.SAFETY_SETTINGS
|
||||
|
||||
|
||||
def _filter_empty_parts(contents: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Filters out contents with empty or invalid parts."""
|
||||
if not contents:
|
||||
return []
|
||||
|
||||
filtered_contents = []
|
||||
for content in contents:
|
||||
if not content or "parts" not in content or not isinstance(content.get("parts"), list):
|
||||
continue
|
||||
|
||||
valid_parts = [part for part in content["parts"] if isinstance(part, dict) and part]
|
||||
|
||||
if valid_parts:
|
||||
new_content = content.copy()
|
||||
new_content["parts"] = valid_parts
|
||||
filtered_contents.append(new_content)
|
||||
|
||||
return filtered_contents
|
||||
|
||||
|
||||
def _build_payload(model: str, request: GeminiRequest) -> Dict[str, Any]:
|
||||
"""构建请求payload"""
|
||||
request_dict = request.model_dump()
|
||||
request_dict = request.model_dump(exclude_none=False)
|
||||
if request.generationConfig:
|
||||
if request.generationConfig.maxOutputTokens is None:
|
||||
# 如果未指定最大输出长度,则不传递该字段,解决截断的问题
|
||||
request_dict["generationConfig"].pop("maxOutputTokens")
|
||||
|
||||
payload = {
|
||||
"contents": request_dict.get("contents", []),
|
||||
"tools": _build_tools(model, request_dict),
|
||||
"safetySettings": _get_safety_settings(model),
|
||||
"generationConfig": request_dict.get("generationConfig"),
|
||||
"systemInstruction": request_dict.get("systemInstruction"),
|
||||
}
|
||||
if "maxOutputTokens" in request_dict["generationConfig"]:
|
||||
request_dict["generationConfig"].pop("maxOutputTokens")
|
||||
|
||||
# 检查是否为TTS模型
|
||||
is_tts_model = "tts" in model.lower()
|
||||
|
||||
if is_tts_model:
|
||||
# TTS模型使用简化的payload,不包含tools和safetySettings
|
||||
payload = {
|
||||
"contents": _filter_empty_parts(request_dict.get("contents", [])),
|
||||
"generationConfig": request_dict.get("generationConfig"),
|
||||
}
|
||||
|
||||
# 只在有systemInstruction时才添加
|
||||
if request_dict.get("systemInstruction"):
|
||||
payload["systemInstruction"] = request_dict.get("systemInstruction")
|
||||
else:
|
||||
# 非TTS模型使用完整的payload
|
||||
payload = {
|
||||
"contents": _filter_empty_parts(request_dict.get("contents", [])),
|
||||
"tools": _build_tools(model, request_dict),
|
||||
"safetySettings": _get_safety_settings(model),
|
||||
"generationConfig": request_dict.get("generationConfig"),
|
||||
"systemInstruction": request_dict.get("systemInstruction"),
|
||||
}
|
||||
|
||||
# 确保 generationConfig 不为 None
|
||||
if payload["generationConfig"] is None:
|
||||
payload["generationConfig"] = {}
|
||||
|
||||
if model.endswith("-image") or model.endswith("-image-generation"):
|
||||
payload.pop("systemInstruction")
|
||||
@@ -109,9 +238,18 @@ def _build_payload(model: str, request: GeminiRequest) -> Dict[str, Any]:
|
||||
else:
|
||||
# 客户端没有提供思考配置,使用默认配置
|
||||
if model.endswith("-non-thinking"):
|
||||
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": 0}
|
||||
elif model in settings.THINKING_BUDGET_MAP:
|
||||
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": settings.THINKING_BUDGET_MAP.get(model,1000)}
|
||||
if "gemini-2.5-pro" in model:
|
||||
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": 128}
|
||||
else:
|
||||
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": 0}
|
||||
elif _get_real_model(model) in settings.THINKING_BUDGET_MAP:
|
||||
if settings.SHOW_THINKING_PROCESS:
|
||||
payload["generationConfig"]["thinkingConfig"] = {
|
||||
"thinkingBudget": settings.THINKING_BUDGET_MAP.get(model,1000),
|
||||
"includeThoughts": True
|
||||
}
|
||||
else:
|
||||
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": settings.THINKING_BUDGET_MAP.get(model,1000)}
|
||||
|
||||
return payload
|
||||
|
||||
@@ -152,6 +290,17 @@ class GeminiChatService:
|
||||
self, model: str, request: GeminiRequest, api_key: str
|
||||
) -> Dict[str, Any]:
|
||||
"""生成内容"""
|
||||
# 檢查並獲取文件專用的 API key(如果有文件)
|
||||
file_names = _extract_file_references(request.model_dump().get("contents", []))
|
||||
if file_names:
|
||||
logger.info(f"Request contains file references: {file_names}")
|
||||
file_api_key = await get_file_api_key(file_names[0])
|
||||
if file_api_key:
|
||||
logger.info(f"Found API key for file {file_names[0]}: {file_api_key[:8]}...{file_api_key[-4:]}")
|
||||
api_key = file_api_key # 使用文件的 API key
|
||||
else:
|
||||
logger.warning(f"No API key found for file {file_names[0]}, using default key: {api_key[:8]}...{api_key[-4:]}")
|
||||
|
||||
payload = _build_payload(model, request)
|
||||
start_time = time.perf_counter()
|
||||
request_datetime = datetime.datetime.now()
|
||||
@@ -195,10 +344,69 @@ class GeminiChatService:
|
||||
request_time=request_datetime
|
||||
)
|
||||
|
||||
async def count_tokens(
|
||||
self, model: str, request: GeminiRequest, api_key: str
|
||||
) -> Dict[str, Any]:
|
||||
"""计算token数量"""
|
||||
# countTokens API只需要contents
|
||||
payload = {"contents": _filter_empty_parts(request.model_dump().get("contents", []))}
|
||||
start_time = time.perf_counter()
|
||||
request_datetime = datetime.datetime.now()
|
||||
is_success = False
|
||||
status_code = None
|
||||
response = None
|
||||
|
||||
try:
|
||||
response = await self.api_client.count_tokens(payload, model, api_key)
|
||||
is_success = True
|
||||
status_code = 200
|
||||
return response
|
||||
except Exception as e:
|
||||
is_success = False
|
||||
error_log_msg = str(e)
|
||||
logger.error(f"Count tokens API call failed with error: {error_log_msg}")
|
||||
match = re.search(r"status code (\d+)", error_log_msg)
|
||||
if match:
|
||||
status_code = int(match.group(1))
|
||||
else:
|
||||
status_code = 500
|
||||
|
||||
await add_error_log(
|
||||
gemini_key=api_key,
|
||||
model_name=model,
|
||||
error_type="gemini-count-tokens",
|
||||
error_log=error_log_msg,
|
||||
error_code=status_code,
|
||||
request_msg=payload
|
||||
)
|
||||
raise e
|
||||
finally:
|
||||
end_time = time.perf_counter()
|
||||
latency_ms = int((end_time - start_time) * 1000)
|
||||
await add_request_log(
|
||||
model_name=model,
|
||||
api_key=api_key,
|
||||
is_success=is_success,
|
||||
status_code=status_code,
|
||||
latency_ms=latency_ms,
|
||||
request_time=request_datetime
|
||||
)
|
||||
|
||||
async def stream_generate_content(
|
||||
self, model: str, request: GeminiRequest, api_key: str
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""流式生成内容"""
|
||||
# 檢查並獲取文件專用的 API key(如果有文件)
|
||||
file_names = _extract_file_references(request.model_dump().get("contents", []))
|
||||
if file_names:
|
||||
logger.info(f"Request contains file references: {file_names}")
|
||||
file_api_key = await get_file_api_key(file_names[0])
|
||||
if file_api_key:
|
||||
logger.info(f"Found API key for file {file_names[0]}: {file_api_key[:8]}...{file_api_key[-4:]}")
|
||||
api_key = file_api_key # 使用文件的 API key
|
||||
else:
|
||||
logger.warning(f"No API key found for file {file_names[0]}, using default key: {api_key[:8]}...{api_key[-4:]}")
|
||||
|
||||
retries = 0
|
||||
max_retries = settings.MAX_RETRIES
|
||||
payload = _build_payload(model, request)
|
||||
|
||||
@@ -26,16 +26,43 @@ from app.service.key.key_manager import KeyManager
|
||||
logger = get_openai_logger()
|
||||
|
||||
|
||||
def _has_media_parts(contents: List[Dict[str, Any]]) -> bool:
|
||||
"""判断消息是否包含图片、音频或视频部分 (inline_data)"""
|
||||
for content in contents:
|
||||
if content and "parts" in content and isinstance(content["parts"], list):
|
||||
for part in content["parts"]:
|
||||
if isinstance(part, dict) and "inline_data" in part:
|
||||
def _has_media_parts(messages: List[Dict[str, Any]]) -> bool:
|
||||
"""判断消息是否包含多媒体部分"""
|
||||
for message in messages:
|
||||
if "parts" in message:
|
||||
for part in message["parts"]:
|
||||
if "image_url" in part or "inline_data" in part:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _clean_json_schema_properties(obj: Any) -> Any:
|
||||
"""清理JSON Schema中Gemini API不支持的字段"""
|
||||
if not isinstance(obj, dict):
|
||||
return obj
|
||||
|
||||
# Gemini API不支持的JSON Schema字段
|
||||
unsupported_fields = {
|
||||
"exclusiveMaximum", "exclusiveMinimum", "const", "examples",
|
||||
"contentEncoding", "contentMediaType", "if", "then", "else",
|
||||
"allOf", "anyOf", "oneOf", "not", "definitions", "$schema",
|
||||
"$id", "$ref", "$comment", "readOnly", "writeOnly"
|
||||
}
|
||||
|
||||
cleaned = {}
|
||||
for key, value in obj.items():
|
||||
if key in unsupported_fields:
|
||||
continue
|
||||
if isinstance(value, dict):
|
||||
cleaned[key] = _clean_json_schema_properties(value)
|
||||
elif isinstance(value, list):
|
||||
cleaned[key] = [_clean_json_schema_properties(item) for item in value]
|
||||
else:
|
||||
cleaned[key] = value
|
||||
|
||||
return cleaned
|
||||
|
||||
|
||||
def _build_tools(
|
||||
request: ChatRequest, messages: List[Dict[str, Any]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
@@ -60,6 +87,10 @@ def _build_tools(
|
||||
|
||||
if model.endswith("-search"):
|
||||
tool["googleSearch"] = {}
|
||||
|
||||
real_model = _get_real_model(model)
|
||||
if real_model in settings.URL_CONTEXT_MODELS and settings.URL_CONTEXT_ENABLED:
|
||||
tool["urlContext"] = {}
|
||||
|
||||
# 将 request 中的 tools 合并到 tools 中
|
||||
if request.tools:
|
||||
@@ -76,6 +107,8 @@ def _build_tools(
|
||||
):
|
||||
function.pop("parameters", None)
|
||||
|
||||
# 清理函数中的不支持字段
|
||||
function = _clean_json_schema_properties(function)
|
||||
function_declarations.append(function)
|
||||
|
||||
if function_declarations:
|
||||
@@ -97,10 +130,23 @@ def _build_tools(
|
||||
if tool.get("functionDeclarations"):
|
||||
tool.pop("googleSearch", None)
|
||||
tool.pop("codeExecution", None)
|
||||
tool.pop("urlContext",None)
|
||||
|
||||
return [tool] if tool else []
|
||||
|
||||
|
||||
def _get_real_model(model: str) -> str:
|
||||
if model.endswith("-search"):
|
||||
model = model[:-7]
|
||||
if model.endswith("-image"):
|
||||
model = model[:-6]
|
||||
if model.endswith("-non-thinking"):
|
||||
model = model[:-13]
|
||||
if "-search" in model and "-non-thinking" in model:
|
||||
model = model[:-20]
|
||||
return model
|
||||
|
||||
|
||||
def _get_safety_settings(model: str) -> List[Dict[str, str]]:
|
||||
"""获取安全设置"""
|
||||
# if (
|
||||
@@ -113,6 +159,23 @@ def _get_safety_settings(model: str) -> List[Dict[str, str]]:
|
||||
return settings.SAFETY_SETTINGS
|
||||
|
||||
|
||||
def _validate_and_set_max_tokens(
|
||||
payload: Dict[str, Any],
|
||||
max_tokens: Optional[int],
|
||||
logger_instance
|
||||
) -> None:
|
||||
"""验证并设置 max_tokens 参数"""
|
||||
if max_tokens is None:
|
||||
return
|
||||
|
||||
# 参数验证和处理
|
||||
if max_tokens <= 0:
|
||||
logger_instance.warning(f"Invalid max_tokens value: {max_tokens}, will not set maxOutputTokens")
|
||||
# 不设置 maxOutputTokens,让 Gemini API 使用默认值
|
||||
else:
|
||||
payload["generationConfig"]["maxOutputTokens"] = max_tokens
|
||||
|
||||
|
||||
def _build_payload(
|
||||
request: ChatRequest,
|
||||
messages: List[Dict[str, Any]],
|
||||
@@ -130,16 +193,27 @@ def _build_payload(
|
||||
"tools": _build_tools(request, messages),
|
||||
"safetySettings": _get_safety_settings(request.model),
|
||||
}
|
||||
if request.max_tokens is not None:
|
||||
payload["generationConfig"]["maxOutputTokens"] = request.max_tokens
|
||||
|
||||
# 处理 max_tokens 参数
|
||||
_validate_and_set_max_tokens(payload, request.max_tokens, logger)
|
||||
|
||||
if request.model.endswith("-image") or request.model.endswith("-image-generation"):
|
||||
payload["generationConfig"]["responseModalities"] = ["Text", "Image"]
|
||||
|
||||
if request.model.endswith("-non-thinking"):
|
||||
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": 0}
|
||||
if request.model in settings.THINKING_BUDGET_MAP:
|
||||
payload["generationConfig"]["thinkingConfig"] = {
|
||||
"thinkingBudget": settings.THINKING_BUDGET_MAP.get(request.model, 1000)
|
||||
}
|
||||
if "gemini-2.5-pro" in request.model:
|
||||
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": 128}
|
||||
else:
|
||||
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": 0}
|
||||
|
||||
if _get_real_model(request.model) in settings.THINKING_BUDGET_MAP:
|
||||
if settings.SHOW_THINKING_PROCESS:
|
||||
payload["generationConfig"]["thinkingConfig"] = {
|
||||
"thinkingBudget": settings.THINKING_BUDGET_MAP.get(request.model, 1000),
|
||||
"includeThoughts": True
|
||||
}
|
||||
else:
|
||||
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": settings.THINKING_BUDGET_MAP.get(request.model, 1000)}
|
||||
|
||||
if (
|
||||
instruction
|
||||
@@ -206,27 +280,53 @@ class OpenAIChatService:
|
||||
is_success = False
|
||||
status_code = None
|
||||
response = None
|
||||
|
||||
try:
|
||||
response = await self.api_client.generate_content(payload, model, api_key)
|
||||
usage_metadata = response.get("usageMetadata", {})
|
||||
is_success = True
|
||||
status_code = 200
|
||||
return self.response_handler.handle_response(
|
||||
response,
|
||||
model,
|
||||
stream=False,
|
||||
finish_reason="stop",
|
||||
usage_metadata=usage_metadata,
|
||||
)
|
||||
|
||||
# 尝试处理响应,捕获可能的响应处理异常
|
||||
try:
|
||||
result = self.response_handler.handle_response(
|
||||
response,
|
||||
model,
|
||||
stream=False,
|
||||
finish_reason="stop",
|
||||
usage_metadata=usage_metadata,
|
||||
)
|
||||
return result
|
||||
except Exception as response_error:
|
||||
logger.error(f"Response processing failed for model {model}: {str(response_error)}")
|
||||
|
||||
# 记录详细的错误信息
|
||||
if "parts" in str(response_error):
|
||||
logger.error("Response structure issue - missing or invalid parts")
|
||||
if response.get("candidates"):
|
||||
candidate = response["candidates"][0]
|
||||
content = candidate.get("content", {})
|
||||
logger.error(f"Content structure: {content}")
|
||||
|
||||
# 重新抛出异常
|
||||
raise response_error
|
||||
|
||||
except Exception as e:
|
||||
is_success = False
|
||||
error_log_msg = str(e)
|
||||
logger.error(f"Normal API call failed with error: {error_log_msg}")
|
||||
logger.error(f"API call failed for model {model}: {error_log_msg}")
|
||||
|
||||
# 特别记录 max_tokens 相关的错误
|
||||
gen_config = payload.get('generationConfig', {})
|
||||
if "maxOutputTokens" in gen_config:
|
||||
logger.error(f"Request had maxOutputTokens: {gen_config['maxOutputTokens']}")
|
||||
|
||||
# 如果是响应处理错误,记录更多信息
|
||||
if "parts" in error_log_msg:
|
||||
logger.error("This is likely a response processing error")
|
||||
|
||||
match = re.search(r"status code (\d+)", error_log_msg)
|
||||
if match:
|
||||
status_code = int(match.group(1))
|
||||
else:
|
||||
status_code = 500
|
||||
status_code = int(match.group(1)) if match else 500
|
||||
|
||||
await add_error_log(
|
||||
gemini_key=api_key,
|
||||
@@ -240,6 +340,8 @@ class OpenAIChatService:
|
||||
finally:
|
||||
end_time = time.perf_counter()
|
||||
latency_ms = int((end_time - start_time) * 1000)
|
||||
logger.info(f"Normal completion finished - Success: {is_success}, Latency: {latency_ms}ms")
|
||||
|
||||
await add_request_log(
|
||||
model_name=model,
|
||||
api_key=api_key,
|
||||
|
||||
@@ -28,9 +28,51 @@ def _has_image_parts(contents: List[Dict[str, Any]]) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _clean_json_schema_properties(obj: Any) -> Any:
|
||||
"""清理JSON Schema中Gemini API不支持的字段"""
|
||||
if not isinstance(obj, dict):
|
||||
return obj
|
||||
|
||||
# Gemini API不支持的JSON Schema字段
|
||||
unsupported_fields = {
|
||||
"exclusiveMaximum", "exclusiveMinimum", "const", "examples",
|
||||
"contentEncoding", "contentMediaType", "if", "then", "else",
|
||||
"allOf", "anyOf", "oneOf", "not", "definitions", "$schema",
|
||||
"$id", "$ref", "$comment", "readOnly", "writeOnly"
|
||||
}
|
||||
|
||||
cleaned = {}
|
||||
for key, value in obj.items():
|
||||
if key in unsupported_fields:
|
||||
continue
|
||||
if isinstance(value, dict):
|
||||
cleaned[key] = _clean_json_schema_properties(value)
|
||||
elif isinstance(value, list):
|
||||
cleaned[key] = [_clean_json_schema_properties(item) for item in value]
|
||||
else:
|
||||
cleaned[key] = value
|
||||
|
||||
return cleaned
|
||||
|
||||
|
||||
def _build_tools(model: str, payload: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""构建工具"""
|
||||
|
||||
def _has_function_call(contents: List[Dict[str, Any]]) -> bool:
|
||||
"""检查内容中是否包含 functionCall"""
|
||||
if not contents or not isinstance(contents, list):
|
||||
return False
|
||||
for content in contents:
|
||||
if not content or not isinstance(content, dict) or "parts" not in content:
|
||||
continue
|
||||
parts = content.get("parts", [])
|
||||
if not parts or not isinstance(parts, list):
|
||||
continue
|
||||
for part in parts:
|
||||
if isinstance(part, dict) and "functionCall" in part:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _merge_tools(tools: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
record = dict()
|
||||
for item in tools:
|
||||
@@ -40,7 +82,15 @@ def _build_tools(model: str, payload: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
for k, v in item.items():
|
||||
if k == "functionDeclarations" and v and isinstance(v, list):
|
||||
functions = record.get("functionDeclarations", [])
|
||||
functions.extend(v)
|
||||
# 清理每个函数声明中的不支持字段
|
||||
cleaned_functions = []
|
||||
for func in v:
|
||||
if isinstance(func, dict):
|
||||
cleaned_func = _clean_json_schema_properties(func)
|
||||
cleaned_functions.append(cleaned_func)
|
||||
else:
|
||||
cleaned_functions.append(func)
|
||||
functions.extend(cleaned_functions)
|
||||
record["functionDeclarations"] = functions
|
||||
else:
|
||||
record[k] = v
|
||||
@@ -62,15 +112,32 @@ def _build_tools(model: str, payload: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
tool["codeExecution"] = {}
|
||||
if model.endswith("-search"):
|
||||
tool["googleSearch"] = {}
|
||||
|
||||
real_model = _get_real_model(model)
|
||||
if real_model in settings.URL_CONTEXT_MODELS and settings.URL_CONTEXT_ENABLED:
|
||||
tool["urlContext"] = {}
|
||||
|
||||
# 解决 "Tool use with function calling is unsupported" 问题
|
||||
if tool.get("functionDeclarations"):
|
||||
if tool.get("functionDeclarations") or _has_function_call(payload.get("contents", [])):
|
||||
tool.pop("googleSearch", None)
|
||||
tool.pop("codeExecution", None)
|
||||
tool.pop("urlContext", None)
|
||||
|
||||
return [tool] if tool else []
|
||||
|
||||
|
||||
def _get_real_model(model: str) -> str:
|
||||
if model.endswith("-search"):
|
||||
model = model[:-7]
|
||||
if model.endswith("-image"):
|
||||
model = model[:-6]
|
||||
if model.endswith("-non-thinking"):
|
||||
model = model[:-13]
|
||||
if "-search" in model and "-non-thinking" in model:
|
||||
model = model[:-20]
|
||||
return model
|
||||
|
||||
|
||||
def _get_safety_settings(model: str) -> List[Dict[str, str]]:
|
||||
"""获取安全设置"""
|
||||
if model == "gemini-2.0-flash-exp":
|
||||
@@ -80,7 +147,7 @@ def _get_safety_settings(model: str) -> List[Dict[str, str]]:
|
||||
|
||||
def _build_payload(model: str, request: GeminiRequest) -> Dict[str, Any]:
|
||||
"""构建请求payload"""
|
||||
request_dict = request.model_dump()
|
||||
request_dict = request.model_dump(exclude_none=False)
|
||||
if request.generationConfig:
|
||||
if request.generationConfig.maxOutputTokens is None:
|
||||
# 如果未指定最大输出长度,则不传递该字段,解决截断的问题
|
||||
@@ -98,10 +165,29 @@ def _build_payload(model: str, request: GeminiRequest) -> Dict[str, Any]:
|
||||
payload.pop("systemInstruction")
|
||||
payload["generationConfig"]["responseModalities"] = ["Text", "Image"]
|
||||
|
||||
if model.endswith("-non-thinking"):
|
||||
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": 0}
|
||||
if model in settings.THINKING_BUDGET_MAP:
|
||||
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": settings.THINKING_BUDGET_MAP.get(model,1000)}
|
||||
# 处理思考配置:优先使用客户端提供的配置,否则使用默认配置
|
||||
client_thinking_config = None
|
||||
if request.generationConfig and request.generationConfig.thinkingConfig:
|
||||
client_thinking_config = request.generationConfig.thinkingConfig
|
||||
|
||||
if client_thinking_config is not None:
|
||||
# 客户端提供了思考配置,直接使用
|
||||
payload["generationConfig"]["thinkingConfig"] = client_thinking_config
|
||||
else:
|
||||
# 客户端没有提供思考配置,使用默认配置
|
||||
if model.endswith("-non-thinking"):
|
||||
if "gemini-2.5-pro" in model:
|
||||
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": 128}
|
||||
else:
|
||||
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": 0}
|
||||
elif _get_real_model(model) in settings.THINKING_BUDGET_MAP:
|
||||
if settings.SHOW_THINKING_PROCESS:
|
||||
payload["generationConfig"]["thinkingConfig"] = {
|
||||
"thinkingBudget": settings.THINKING_BUDGET_MAP.get(model,1000),
|
||||
"includeThoughts": True
|
||||
}
|
||||
else:
|
||||
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": settings.THINKING_BUDGET_MAP.get(model,1000)}
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
@@ -40,6 +40,13 @@ class GeminiApiClient(ApiClient):
|
||||
model = model[:-20]
|
||||
return model
|
||||
|
||||
def _prepare_headers(self) -> Dict[str, str]:
|
||||
headers = {}
|
||||
if settings.CUSTOM_HEADERS:
|
||||
headers.update(settings.CUSTOM_HEADERS)
|
||||
logger.info(f"Using custom headers: {settings.CUSTOM_HEADERS}")
|
||||
return headers
|
||||
|
||||
async def get_models(self, api_key: str) -> Optional[Dict[str, Any]]:
|
||||
"""获取可用的 Gemini 模型列表"""
|
||||
timeout = httpx.Timeout(timeout=5)
|
||||
@@ -52,10 +59,11 @@ class GeminiApiClient(ApiClient):
|
||||
proxy_to_use = random.choice(settings.PROXIES)
|
||||
logger.info(f"Using proxy for getting models: {proxy_to_use}")
|
||||
|
||||
headers = self._prepare_headers()
|
||||
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
|
||||
url = f"{self.base_url}/models?key={api_key}&pageSize=1000"
|
||||
try:
|
||||
response = await client.get(url)
|
||||
response = await client.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except httpx.HTTPStatusError as e:
|
||||
@@ -69,7 +77,7 @@ class GeminiApiClient(ApiClient):
|
||||
async def generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]:
|
||||
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
||||
model = self._get_real_model(model)
|
||||
|
||||
|
||||
proxy_to_use = None
|
||||
if settings.PROXIES:
|
||||
if settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY:
|
||||
@@ -78,13 +86,36 @@ class GeminiApiClient(ApiClient):
|
||||
proxy_to_use = random.choice(settings.PROXIES)
|
||||
logger.info(f"Using proxy for getting models: {proxy_to_use}")
|
||||
|
||||
headers = self._prepare_headers()
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
|
||||
url = f"{self.base_url}/models/{model}:generateContent?key={api_key}"
|
||||
response = await client.post(url, json=payload)
|
||||
if response.status_code != 200:
|
||||
error_content = response.text
|
||||
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
|
||||
return response.json()
|
||||
|
||||
try:
|
||||
response = await client.post(url, json=payload, headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_content = response.text
|
||||
logger.error(f"API call failed - Status: {response.status_code}, Content: {error_content}")
|
||||
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
|
||||
|
||||
response_data = response.json()
|
||||
|
||||
# 检查响应结构的基本信息
|
||||
if not response_data.get("candidates"):
|
||||
logger.warning("No candidates found in API response")
|
||||
|
||||
return response_data
|
||||
|
||||
except httpx.TimeoutException as e:
|
||||
logger.error(f"Request timeout: {e}")
|
||||
raise Exception(f"Request timeout: {e}")
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Request error: {e}")
|
||||
raise Exception(f"Request error: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error: {e}")
|
||||
raise
|
||||
|
||||
async def stream_generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> AsyncGenerator[str, None]:
|
||||
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
||||
@@ -98,9 +129,10 @@ class GeminiApiClient(ApiClient):
|
||||
proxy_to_use = random.choice(settings.PROXIES)
|
||||
logger.info(f"Using proxy for getting models: {proxy_to_use}")
|
||||
|
||||
headers = self._prepare_headers()
|
||||
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
|
||||
url = f"{self.base_url}/models/{model}:streamGenerateContent?alt=sse&key={api_key}"
|
||||
async with client.stream(method="POST", url=url, json=payload) as response:
|
||||
async with client.stream(method="POST", url=url, json=payload, headers=headers) as response:
|
||||
if response.status_code != 200:
|
||||
error_content = await response.aread()
|
||||
error_msg = error_content.decode("utf-8")
|
||||
@@ -108,6 +140,27 @@ class GeminiApiClient(ApiClient):
|
||||
async for line in response.aiter_lines():
|
||||
yield line
|
||||
|
||||
async def count_tokens(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]:
|
||||
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
||||
model = self._get_real_model(model)
|
||||
|
||||
proxy_to_use = None
|
||||
if settings.PROXIES:
|
||||
if settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY:
|
||||
proxy_to_use = settings.PROXIES[hash(api_key) % len(settings.PROXIES)]
|
||||
else:
|
||||
proxy_to_use = random.choice(settings.PROXIES)
|
||||
logger.info(f"Using proxy for counting tokens: {proxy_to_use}")
|
||||
|
||||
headers = self._prepare_headers()
|
||||
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
|
||||
url = f"{self.base_url}/models/{model}:countTokens?key={api_key}"
|
||||
response = await client.post(url, json=payload, headers=headers)
|
||||
if response.status_code != 200:
|
||||
error_content = response.text
|
||||
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
|
||||
return response.json()
|
||||
|
||||
|
||||
class OpenaiApiClient(ApiClient):
|
||||
"""OpenAI API客户端"""
|
||||
@@ -116,6 +169,13 @@ class OpenaiApiClient(ApiClient):
|
||||
self.base_url = base_url
|
||||
self.timeout = timeout
|
||||
|
||||
def _prepare_headers(self, api_key: str) -> Dict[str, str]:
|
||||
headers = {"Authorization": f"Bearer {api_key}"}
|
||||
if settings.CUSTOM_HEADERS:
|
||||
headers.update(settings.CUSTOM_HEADERS)
|
||||
logger.info(f"Using custom headers: {settings.CUSTOM_HEADERS}")
|
||||
return headers
|
||||
|
||||
async def get_models(self, api_key: str) -> Dict[str, Any]:
|
||||
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
||||
|
||||
@@ -127,9 +187,9 @@ class OpenaiApiClient(ApiClient):
|
||||
proxy_to_use = random.choice(settings.PROXIES)
|
||||
logger.info(f"Using proxy for getting models: {proxy_to_use}")
|
||||
|
||||
headers = self._prepare_headers(api_key)
|
||||
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
|
||||
url = f"{self.base_url}/openai/models"
|
||||
headers = {"Authorization": f"Bearer {api_key}"}
|
||||
response = await client.get(url, headers=headers)
|
||||
if response.status_code != 200:
|
||||
error_content = response.text
|
||||
@@ -147,9 +207,9 @@ class OpenaiApiClient(ApiClient):
|
||||
proxy_to_use = random.choice(settings.PROXIES)
|
||||
logger.info(f"Using proxy for getting models: {proxy_to_use}")
|
||||
|
||||
headers = self._prepare_headers(api_key)
|
||||
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
|
||||
url = f"{self.base_url}/openai/chat/completions"
|
||||
headers = {"Authorization": f"Bearer {api_key}"}
|
||||
response = await client.post(url, json=payload, headers=headers)
|
||||
if response.status_code != 200:
|
||||
error_content = response.text
|
||||
@@ -166,9 +226,9 @@ class OpenaiApiClient(ApiClient):
|
||||
proxy_to_use = random.choice(settings.PROXIES)
|
||||
logger.info(f"Using proxy for getting models: {proxy_to_use}")
|
||||
|
||||
headers = self._prepare_headers(api_key)
|
||||
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
|
||||
url = f"{self.base_url}/openai/chat/completions"
|
||||
headers = {"Authorization": f"Bearer {api_key}"}
|
||||
async with client.stream(method="POST", url=url, json=payload, headers=headers) as response:
|
||||
if response.status_code != 200:
|
||||
error_content = await response.aread()
|
||||
@@ -188,9 +248,9 @@ class OpenaiApiClient(ApiClient):
|
||||
proxy_to_use = random.choice(settings.PROXIES)
|
||||
logger.info(f"Using proxy for getting models: {proxy_to_use}")
|
||||
|
||||
headers = self._prepare_headers(api_key)
|
||||
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
|
||||
url = f"{self.base_url}/openai/embeddings"
|
||||
headers = {"Authorization": f"Bearer {api_key}"}
|
||||
payload = {
|
||||
"input": input,
|
||||
"model": model,
|
||||
@@ -212,9 +272,9 @@ class OpenaiApiClient(ApiClient):
|
||||
proxy_to_use = random.choice(settings.PROXIES)
|
||||
logger.info(f"Using proxy for getting models: {proxy_to_use}")
|
||||
|
||||
headers = self._prepare_headers(api_key)
|
||||
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
|
||||
url = f"{self.base_url}/openai/images/generations"
|
||||
headers = {"Authorization": f"Bearer {api_key}"}
|
||||
response = await client.post(url, json=payload, headers=headers)
|
||||
if response.status_code != 200:
|
||||
error_content = response.text
|
||||
|
||||
1
app/service/files/__init__.py
Normal file
1
app/service/files/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Intentionally empty __init__.py file
|
||||
247
app/service/files/file_upload_handler.py
Normal file
247
app/service/files/file_upload_handler.py
Normal file
@@ -0,0 +1,247 @@
|
||||
"""
|
||||
文件上传处理器
|
||||
处理 Google 的可恢复上传协议
|
||||
"""
|
||||
from typing import Optional
|
||||
from datetime import datetime, timezone, timedelta
|
||||
|
||||
from httpx import AsyncClient
|
||||
from fastapi import Request, Response, HTTPException
|
||||
|
||||
from app.config.config import settings
|
||||
from app.database import services as db_services
|
||||
from app.database.models import FileState
|
||||
from app.log.logger import get_files_logger
|
||||
|
||||
logger = get_files_logger()
|
||||
|
||||
|
||||
class FileUploadHandler:
|
||||
"""处理文件分块上传"""
|
||||
|
||||
def __init__(self):
|
||||
self.chunk_size = 8 * 1024 * 1024 # 8MB
|
||||
|
||||
async def handle_upload_chunk(
|
||||
self,
|
||||
upload_url: str,
|
||||
request: Request,
|
||||
files_service=None # 添加 files_service 參數
|
||||
) -> Response:
|
||||
"""
|
||||
处理上传分块
|
||||
|
||||
Args:
|
||||
upload_url: 上传 URL
|
||||
request: FastAPI 请求对象
|
||||
files_service: 文件服務實例
|
||||
|
||||
Returns:
|
||||
Response: 响应对象
|
||||
"""
|
||||
try:
|
||||
# 获取请求头
|
||||
headers = {}
|
||||
|
||||
# 复制必要的上传头
|
||||
upload_headers = [
|
||||
"x-goog-upload-command",
|
||||
"x-goog-upload-offset",
|
||||
"content-type",
|
||||
"content-length"
|
||||
]
|
||||
|
||||
for header in upload_headers:
|
||||
if header in request.headers:
|
||||
# 转换为正确的格式
|
||||
key = "-".join(word.capitalize() for word in header.split("-"))
|
||||
headers[key] = request.headers[header]
|
||||
|
||||
# 读取请求体
|
||||
body = await request.body()
|
||||
|
||||
# 检查是否是最后一块
|
||||
is_final = "finalize" in headers.get("X-Goog-Upload-Command", "")
|
||||
logger.debug(f"Upload command: {headers.get('X-Goog-Upload-Command', '')}, is_final: {is_final}")
|
||||
|
||||
# 转发到真实的上传 URL
|
||||
async with AsyncClient() as client:
|
||||
response = await client.post(
|
||||
upload_url,
|
||||
headers=headers,
|
||||
content=body,
|
||||
timeout=300.0 # 5分钟超时
|
||||
)
|
||||
|
||||
if response.status_code not in [200, 201, 308]:
|
||||
logger.error(f"Upload chunk failed: {response.status_code} - {response.text}")
|
||||
raise HTTPException(status_code=response.status_code, detail="Upload failed")
|
||||
|
||||
# 如果是最后一块,更新文件状态
|
||||
if is_final and response.status_code in [200, 201]:
|
||||
logger.debug(f"Upload finalized with status {response.status_code}")
|
||||
try:
|
||||
# 解析響應獲取文件信息
|
||||
response_data = response.json()
|
||||
logger.debug(f"Upload complete response data: {response_data}")
|
||||
file_data = response_data.get("file", {})
|
||||
|
||||
# 獲取真實的文件名
|
||||
real_file_name = file_data.get("name")
|
||||
logger.debug(f"Upload response: {response_data}")
|
||||
if real_file_name and files_service:
|
||||
logger.info(f"Upload completed, file name: {real_file_name}")
|
||||
|
||||
# 從會話中獲取信息
|
||||
session_info = await files_service.get_upload_session(upload_url)
|
||||
logger.debug(f"Retrieved session info for {upload_url}: {session_info}")
|
||||
if session_info:
|
||||
# 創建文件記錄
|
||||
now = datetime.now(timezone.utc)
|
||||
expiration_time = now + timedelta(hours=48)
|
||||
|
||||
# 處理過期時間格式(Google 可能返回納秒級精度)
|
||||
expiration_time_str = file_data.get("expirationTime", expiration_time.isoformat() + "Z")
|
||||
# 處理納秒格式:2025-07-11T02:02:52.531916141Z -> 2025-07-11T02:02:52.531916Z
|
||||
if expiration_time_str.endswith("Z"):
|
||||
# 移除 Z
|
||||
expiration_time_str = expiration_time_str[:-1]
|
||||
# 如果有納秒(超過6位小數),截斷到微秒
|
||||
if "." in expiration_time_str:
|
||||
date_part, frac_part = expiration_time_str.rsplit(".", 1)
|
||||
if len(frac_part) > 6:
|
||||
frac_part = frac_part[:6]
|
||||
expiration_time_str = f"{date_part}.{frac_part}"
|
||||
# 添加時區
|
||||
expiration_time_str += "+00:00"
|
||||
|
||||
# 獲取文件狀態(Google 可能返回 PROCESSING)
|
||||
file_state = file_data.get("state", "PROCESSING")
|
||||
logger.debug(f"File state from Google: {file_state}")
|
||||
|
||||
# 將字符串狀態轉換為枚舉
|
||||
if file_state == "ACTIVE":
|
||||
state_enum = FileState.ACTIVE
|
||||
elif file_state == "PROCESSING":
|
||||
state_enum = FileState.PROCESSING
|
||||
elif file_state == "FAILED":
|
||||
state_enum = FileState.FAILED
|
||||
else:
|
||||
logger.warning(f"Unknown file state: {file_state}, defaulting to PROCESSING")
|
||||
state_enum = FileState.PROCESSING
|
||||
|
||||
await db_services.create_file_record(
|
||||
name=real_file_name,
|
||||
mime_type=file_data.get("mimeType", session_info["mime_type"]),
|
||||
size_bytes=int(file_data.get("sizeBytes", session_info["size_bytes"])),
|
||||
api_key=session_info["api_key"],
|
||||
uri=file_data.get("uri", f"{settings.BASE_URL}/{real_file_name}"),
|
||||
create_time=now,
|
||||
update_time=now,
|
||||
expiration_time=datetime.fromisoformat(expiration_time_str),
|
||||
state=state_enum,
|
||||
display_name=file_data.get("displayName", session_info.get("display_name", "")),
|
||||
sha256_hash=file_data.get("sha256Hash"),
|
||||
user_token=session_info["user_token"]
|
||||
)
|
||||
logger.info(f"Created file record: name={real_file_name}, api_key={session_info['api_key'][:8]}...{session_info['api_key'][-4:]}")
|
||||
else:
|
||||
logger.warning(f"No upload session found for URL: {upload_url}")
|
||||
else:
|
||||
logger.warning(f"Missing real_file_name or files_service: real_file_name={real_file_name}, files_service={files_service}")
|
||||
|
||||
# 返回完整的文件信息
|
||||
return Response(
|
||||
content=response.content,
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create file record: {str(e)}", exc_info=True)
|
||||
else:
|
||||
logger.debug(f"Upload chunk processed: is_final={is_final}, status={response.status_code}")
|
||||
|
||||
# 返回响应
|
||||
response_headers = dict(response.headers)
|
||||
|
||||
# 确保包含必要的头
|
||||
if response.status_code == 308: # Resume Incomplete
|
||||
if "x-goog-upload-status" not in response_headers:
|
||||
response_headers["x-goog-upload-status"] = "active"
|
||||
|
||||
return Response(
|
||||
content=response.content,
|
||||
status_code=response.status_code,
|
||||
headers=response_headers
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to handle upload chunk: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}")
|
||||
|
||||
async def proxy_upload_request(
|
||||
self,
|
||||
request: Request,
|
||||
upload_url: str,
|
||||
files_service=None
|
||||
) -> Response:
|
||||
"""
|
||||
代理上传请求
|
||||
|
||||
Args:
|
||||
request: FastAPI 请求对象
|
||||
upload_url: 目标上传 URL
|
||||
files_service: 文件服務實例
|
||||
|
||||
Returns:
|
||||
Response: 代理响应
|
||||
"""
|
||||
logger.debug(f"Proxy upload request: {request.method}, {upload_url}")
|
||||
try:
|
||||
# 如果是 GET 请求,返回上传状态
|
||||
if request.method == "GET":
|
||||
return await self._get_upload_status(upload_url)
|
||||
|
||||
# 处理 POST/PUT 请求
|
||||
return await self.handle_upload_chunk(upload_url, request, files_service)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to proxy upload request: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}")
|
||||
|
||||
async def _get_upload_status(self, upload_url: str) -> Response:
|
||||
"""
|
||||
获取上传状态
|
||||
|
||||
Args:
|
||||
upload_url: 上传 URL
|
||||
|
||||
Returns:
|
||||
Response: 状态响应
|
||||
"""
|
||||
try:
|
||||
async with AsyncClient() as client:
|
||||
response = await client.get(upload_url)
|
||||
|
||||
return Response(
|
||||
content=response.content,
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get upload status: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}")
|
||||
|
||||
|
||||
# 单例实例
|
||||
_upload_handler_instance: Optional[FileUploadHandler] = None
|
||||
|
||||
|
||||
def get_upload_handler() -> FileUploadHandler:
|
||||
"""获取上传处理器单例实例"""
|
||||
global _upload_handler_instance
|
||||
if _upload_handler_instance is None:
|
||||
_upload_handler_instance = FileUploadHandler()
|
||||
return _upload_handler_instance
|
||||
498
app/service/files/files_service.py
Normal file
498
app/service/files/files_service.py
Normal file
@@ -0,0 +1,498 @@
|
||||
"""
|
||||
文件管理服务
|
||||
"""
|
||||
import json
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional, Dict, Any, Tuple
|
||||
from httpx import AsyncClient
|
||||
import asyncio
|
||||
|
||||
from app.config.config import settings
|
||||
from app.database import services as db_services
|
||||
from app.database.models import FileState
|
||||
from app.domain.file_models import FileMetadata, ListFilesResponse
|
||||
from fastapi import HTTPException
|
||||
from app.log.logger import get_files_logger
|
||||
from app.service.client.api_client import GeminiApiClient
|
||||
from app.service.key.key_manager import get_key_manager_instance
|
||||
|
||||
logger = get_files_logger()
|
||||
|
||||
# 全局上傳會話存儲
|
||||
_upload_sessions: Dict[str, Dict[str, Any]] = {}
|
||||
_upload_sessions_lock = asyncio.Lock()
|
||||
|
||||
|
||||
class FilesService:
|
||||
"""文件管理服务类"""
|
||||
|
||||
def __init__(self):
|
||||
self.api_client = GeminiApiClient(base_url=settings.BASE_URL)
|
||||
self.key_manager = None
|
||||
|
||||
async def _get_key_manager(self):
|
||||
"""获取 KeyManager 实例"""
|
||||
if not self.key_manager:
|
||||
self.key_manager = await get_key_manager_instance(
|
||||
settings.API_KEYS,
|
||||
settings.VERTEX_API_KEYS
|
||||
)
|
||||
return self.key_manager
|
||||
|
||||
async def initialize_upload(
|
||||
self,
|
||||
headers: Dict[str, str],
|
||||
body: Optional[bytes],
|
||||
user_token: str,
|
||||
request_host: str = None # 添加請求主機參數
|
||||
) -> Tuple[Dict[str, Any], Dict[str, str]]:
|
||||
"""
|
||||
初始化文件上传
|
||||
|
||||
Args:
|
||||
headers: 请求头
|
||||
body: 请求体
|
||||
user_token: 用户令牌
|
||||
|
||||
Returns:
|
||||
Tuple[Dict[str, Any], Dict[str, str]]: (响应体, 响应头)
|
||||
"""
|
||||
try:
|
||||
# 获取可用的 API key
|
||||
key_manager = await self._get_key_manager()
|
||||
api_key = await key_manager.get_next_key()
|
||||
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=503, detail="No available API keys")
|
||||
|
||||
# 转发请求到真实的 Gemini API
|
||||
async with AsyncClient() as client:
|
||||
# 准备请求头
|
||||
forward_headers = {
|
||||
"X-Goog-Upload-Protocol": headers.get("x-goog-upload-protocol", "resumable"),
|
||||
"X-Goog-Upload-Command": headers.get("x-goog-upload-command", "start"),
|
||||
"Content-Type": headers.get("content-type", "application/json"),
|
||||
}
|
||||
|
||||
# 添加其他必要的头
|
||||
if "x-goog-upload-header-content-length" in headers:
|
||||
forward_headers["X-Goog-Upload-Header-Content-Length"] = headers["x-goog-upload-header-content-length"]
|
||||
if "x-goog-upload-header-content-type" in headers:
|
||||
forward_headers["X-Goog-Upload-Header-Content-Type"] = headers["x-goog-upload-header-content-type"]
|
||||
|
||||
# 发送请求
|
||||
response = await client.post(
|
||||
"https://generativelanguage.googleapis.com/upload/v1beta/files",
|
||||
headers=forward_headers,
|
||||
content=body,
|
||||
params={"key": api_key}
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"Upload initialization failed: {response.status_code} - {response.text}")
|
||||
raise HTTPException(status_code=response.status_code, detail="Upload initialization failed")
|
||||
|
||||
# 获取上传 URL
|
||||
upload_url = response.headers.get("x-goog-upload-url")
|
||||
if not upload_url:
|
||||
raise HTTPException(status_code=500, detail="No upload URL in response")
|
||||
|
||||
logger.info(f"Original upload URL from Google: {upload_url}")
|
||||
|
||||
|
||||
# 儲存上傳資訊到 headers 中,供後續使用
|
||||
# 不在這裡創建數據庫記錄,等到上傳完成後再創建
|
||||
logger.info(f"Upload initialized with API key: {api_key[:8]}...{api_key[-4:]}")
|
||||
|
||||
# 解析响应 - 初始化响应可能是空的
|
||||
response_data = {}
|
||||
|
||||
# 從請求體中解析文件信息(如果有)
|
||||
display_name = ""
|
||||
if body:
|
||||
try:
|
||||
request_data = json.loads(body)
|
||||
display_name = request_data.get("displayName", "")
|
||||
except Exception:
|
||||
pass
|
||||
# 從 upload URL 中提取 upload_id
|
||||
import urllib.parse
|
||||
parsed_url = urllib.parse.urlparse(upload_url)
|
||||
query_params = urllib.parse.parse_qs(parsed_url.query)
|
||||
upload_id = query_params.get('upload_id', [None])[0]
|
||||
|
||||
if upload_id:
|
||||
# 儲存上傳會話信息,使用 upload_id 作為 key
|
||||
async with _upload_sessions_lock:
|
||||
_upload_sessions[upload_id] = {
|
||||
"api_key": api_key,
|
||||
"user_token": user_token,
|
||||
"display_name": display_name,
|
||||
"mime_type": headers.get("x-goog-upload-header-content-type", "application/octet-stream"),
|
||||
"size_bytes": int(headers.get("x-goog-upload-header-content-length", "0")),
|
||||
"created_at": datetime.now(timezone.utc),
|
||||
"upload_url": upload_url
|
||||
}
|
||||
logger.info(f"Stored upload session for upload_id={upload_id}: api_key={api_key[:8]}...{api_key[-4:]}")
|
||||
logger.debug(f"Total active sessions: {len(_upload_sessions)}")
|
||||
else:
|
||||
logger.warning(f"No upload_id found in upload URL: {upload_url}")
|
||||
|
||||
# 定期清理過期的會話(超過1小時)
|
||||
asyncio.create_task(self._cleanup_expired_sessions())
|
||||
|
||||
# 替換 Google 的 URL 為我們的代理 URL
|
||||
proxy_upload_url = upload_url
|
||||
if request_host:
|
||||
# 原始: https://generativelanguage.googleapis.com/upload/v1beta/files?key=AIzaSyDc...&upload_id=xxx&upload_protocol=resumable
|
||||
# 替換為: http://request-host/upload/v1beta/files?key=sk-123456&upload_id=xxx&upload_protocol=resumable
|
||||
|
||||
# 先替換域名
|
||||
proxy_upload_url = upload_url.replace(
|
||||
"https://generativelanguage.googleapis.com",
|
||||
request_host.rstrip('/')
|
||||
)
|
||||
|
||||
# 再替換 key 參數
|
||||
import re
|
||||
# 匹配 key=xxx 參數
|
||||
key_pattern = r'(\?|&)key=([^&]+)'
|
||||
match = re.search(key_pattern, proxy_upload_url)
|
||||
if match:
|
||||
# 替換為我們的 token
|
||||
proxy_upload_url = proxy_upload_url.replace(
|
||||
f"{match.group(1)}key={match.group(2)}",
|
||||
f"{match.group(1)}key={user_token}"
|
||||
)
|
||||
|
||||
logger.info(f"Replaced upload URL: {upload_url} -> {proxy_upload_url}")
|
||||
|
||||
return response_data, {
|
||||
"X-Goog-Upload-URL": proxy_upload_url,
|
||||
"X-Goog-Upload-Status": "active"
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize upload: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}")
|
||||
|
||||
async def _cleanup_expired_sessions(self):
|
||||
"""清理過期的上傳會話"""
|
||||
try:
|
||||
async with _upload_sessions_lock:
|
||||
now = datetime.now(timezone.utc)
|
||||
expired_keys = []
|
||||
for key, session in _upload_sessions.items():
|
||||
if now - session["created_at"] > timedelta(hours=1):
|
||||
expired_keys.append(key)
|
||||
|
||||
for key in expired_keys:
|
||||
del _upload_sessions[key]
|
||||
|
||||
if expired_keys:
|
||||
logger.info(f"Cleaned up {len(expired_keys)} expired upload sessions")
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up upload sessions: {str(e)}")
|
||||
|
||||
async def get_upload_session(self, key: str) -> Optional[Dict[str, Any]]:
|
||||
"""獲取上傳會話信息(支持 upload_id 或完整 URL)"""
|
||||
async with _upload_sessions_lock:
|
||||
# 先嘗試直接查找
|
||||
session = _upload_sessions.get(key)
|
||||
if session:
|
||||
logger.debug(f"Found session by direct key {key}")
|
||||
return session
|
||||
|
||||
# 如果是 URL,嘗試提取 upload_id
|
||||
if key.startswith("http"):
|
||||
import urllib.parse
|
||||
parsed_url = urllib.parse.urlparse(key)
|
||||
query_params = urllib.parse.parse_qs(parsed_url.query)
|
||||
upload_id = query_params.get('upload_id', [None])[0]
|
||||
if upload_id:
|
||||
session = _upload_sessions.get(upload_id)
|
||||
if session:
|
||||
logger.debug(f"Found session by upload_id {upload_id} from URL")
|
||||
return session
|
||||
|
||||
logger.debug(f"No session found for key: {key}")
|
||||
return None
|
||||
|
||||
async def get_file(self, file_name: str, user_token: str) -> FileMetadata:
|
||||
"""
|
||||
获取文件信息
|
||||
|
||||
Args:
|
||||
file_name: 文件名称 (格式: files/{file_id})
|
||||
user_token: 用户令牌
|
||||
|
||||
Returns:
|
||||
FileMetadata: 文件元数据
|
||||
"""
|
||||
try:
|
||||
# 查询文件记录
|
||||
file_record = await db_services.get_file_record_by_name(file_name)
|
||||
|
||||
if not file_record:
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
# 检查是否过期
|
||||
expiration_time = datetime.fromisoformat(str(file_record["expiration_time"]))
|
||||
# 如果是 naive datetime,假设为 UTC
|
||||
if expiration_time.tzinfo is None:
|
||||
expiration_time = expiration_time.replace(tzinfo=timezone.utc)
|
||||
if expiration_time <= datetime.now(timezone.utc):
|
||||
raise HTTPException(status_code=404, detail="File has expired")
|
||||
|
||||
# 使用原始 API key 获取文件信息
|
||||
api_key = file_record["api_key"]
|
||||
|
||||
async with AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{settings.BASE_URL}/{file_name}",
|
||||
params={"key": api_key}
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"Failed to get file: {response.status_code} - {response.text}")
|
||||
raise HTTPException(status_code=response.status_code, detail="Failed to get file")
|
||||
|
||||
file_data = response.json()
|
||||
|
||||
# 檢查並更新文件狀態
|
||||
google_state = file_data.get("state", "PROCESSING")
|
||||
if google_state != file_record.get("state", "").value if file_record.get("state") else None:
|
||||
logger.info(f"File state changed from {file_record.get('state')} to {google_state}")
|
||||
# 更新數據庫中的狀態
|
||||
if google_state == "ACTIVE":
|
||||
await db_services.update_file_record_state(
|
||||
file_name=file_name,
|
||||
state=FileState.ACTIVE,
|
||||
update_time=datetime.now(timezone.utc)
|
||||
)
|
||||
elif google_state == "FAILED":
|
||||
await db_services.update_file_record_state(
|
||||
file_name=file_name,
|
||||
state=FileState.FAILED,
|
||||
update_time=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
# 构建响应
|
||||
return FileMetadata(
|
||||
name=file_data["name"],
|
||||
displayName=file_data.get("displayName"),
|
||||
mimeType=file_data["mimeType"],
|
||||
sizeBytes=str(file_data["sizeBytes"]),
|
||||
createTime=file_data["createTime"],
|
||||
updateTime=file_data["updateTime"],
|
||||
expirationTime=file_data["expirationTime"],
|
||||
sha256Hash=file_data.get("sha256Hash"),
|
||||
uri=file_data["uri"],
|
||||
state=google_state
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get file {file_name}: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}")
|
||||
|
||||
async def list_files(
|
||||
self,
|
||||
page_size: int = 10,
|
||||
page_token: Optional[str] = None,
|
||||
user_token: Optional[str] = None
|
||||
) -> ListFilesResponse:
|
||||
"""
|
||||
列出文件
|
||||
|
||||
Args:
|
||||
page_size: 每页大小
|
||||
page_token: 分页标记
|
||||
user_token: 用户令牌(可选,如果提供则只返回该用户的文件)
|
||||
|
||||
Returns:
|
||||
ListFilesResponse: 文件列表响应
|
||||
"""
|
||||
try:
|
||||
logger.debug(f"list_files called with page_size={page_size}, page_token={page_token}")
|
||||
|
||||
# 从数据库获取文件列表
|
||||
files, next_page_token = await db_services.list_file_records(
|
||||
user_token=user_token,
|
||||
page_size=page_size,
|
||||
page_token=page_token
|
||||
)
|
||||
|
||||
logger.debug(f"Database returned {len(files)} files, next_page_token={next_page_token}")
|
||||
|
||||
# 转换为响应格式
|
||||
file_list = []
|
||||
for file_record in files:
|
||||
file_list.append(FileMetadata(
|
||||
name=file_record["name"],
|
||||
displayName=file_record.get("display_name"),
|
||||
mimeType=file_record["mime_type"],
|
||||
sizeBytes=str(file_record["size_bytes"]),
|
||||
createTime=file_record["create_time"].isoformat() + "Z",
|
||||
updateTime=file_record["update_time"].isoformat() + "Z",
|
||||
expirationTime=file_record["expiration_time"].isoformat() + "Z",
|
||||
sha256Hash=file_record.get("sha256_hash"),
|
||||
uri=file_record["uri"],
|
||||
state=file_record["state"].value if file_record.get("state") else "ACTIVE"
|
||||
))
|
||||
|
||||
response = ListFilesResponse(
|
||||
files=file_list,
|
||||
nextPageToken=next_page_token
|
||||
)
|
||||
|
||||
logger.debug(f"Returning response with {len(response.files)} files, nextPageToken={response.nextPageToken}")
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list files: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}")
|
||||
|
||||
async def delete_file(self, file_name: str, user_token: str) -> bool:
|
||||
"""
|
||||
删除文件
|
||||
|
||||
Args:
|
||||
file_name: 文件名称
|
||||
user_token: 用户令牌
|
||||
|
||||
Returns:
|
||||
bool: 是否删除成功
|
||||
"""
|
||||
try:
|
||||
# 查询文件记录
|
||||
file_record = await db_services.get_file_record_by_name(file_name)
|
||||
|
||||
if not file_record:
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
# 使用原始 API key 删除文件
|
||||
api_key = file_record["api_key"]
|
||||
|
||||
async with AsyncClient() as client:
|
||||
response = await client.delete(
|
||||
f"{settings.BASE_URL}/{file_name}",
|
||||
params={"key": api_key}
|
||||
)
|
||||
|
||||
if response.status_code not in [200, 204]:
|
||||
logger.error(f"Failed to delete file: {response.status_code} - {response.text}")
|
||||
# 如果 API 删除失败,但文件已过期,仍然删除数据库记录
|
||||
expiration_time = datetime.fromisoformat(str(file_record["expiration_time"]))
|
||||
if expiration_time.tzinfo is None:
|
||||
expiration_time = expiration_time.replace(tzinfo=timezone.utc)
|
||||
if expiration_time <= datetime.now(timezone.utc):
|
||||
await db_services.delete_file_record(file_name)
|
||||
return True
|
||||
raise HTTPException(status_code=response.status_code, detail="Failed to delete file")
|
||||
|
||||
# 删除数据库记录
|
||||
await db_services.delete_file_record(file_name)
|
||||
return True
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete file {file_name}: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}")
|
||||
|
||||
async def check_file_state(self, file_name: str, api_key: str) -> str:
|
||||
"""
|
||||
檢查並更新文件狀態
|
||||
|
||||
Args:
|
||||
file_name: 文件名稱
|
||||
api_key: API密鑰
|
||||
|
||||
Returns:
|
||||
str: 當前狀態
|
||||
"""
|
||||
try:
|
||||
async with AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{settings.BASE_URL}/{file_name}",
|
||||
params={"key": api_key}
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"Failed to check file state: {response.status_code}")
|
||||
return "UNKNOWN"
|
||||
|
||||
file_data = response.json()
|
||||
google_state = file_data.get("state", "PROCESSING")
|
||||
|
||||
# 更新數據庫狀態
|
||||
if google_state == "ACTIVE":
|
||||
await db_services.update_file_record_state(
|
||||
file_name=file_name,
|
||||
state=FileState.ACTIVE,
|
||||
update_time=datetime.now(timezone.utc)
|
||||
)
|
||||
elif google_state == "FAILED":
|
||||
await db_services.update_file_record_state(
|
||||
file_name=file_name,
|
||||
state=FileState.FAILED,
|
||||
update_time=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
return google_state
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check file state: {str(e)}")
|
||||
return "UNKNOWN"
|
||||
|
||||
async def cleanup_expired_files(self) -> int:
|
||||
"""
|
||||
清理过期文件
|
||||
|
||||
Returns:
|
||||
int: 清理的文件数量
|
||||
"""
|
||||
try:
|
||||
# 获取过期文件
|
||||
expired_files = await db_services.delete_expired_file_records()
|
||||
|
||||
if not expired_files:
|
||||
return 0
|
||||
|
||||
# 尝试从 Gemini API 删除文件
|
||||
for file_record in expired_files:
|
||||
try:
|
||||
api_key = file_record["api_key"]
|
||||
file_name = file_record["name"]
|
||||
|
||||
async with AsyncClient() as client:
|
||||
await client.delete(
|
||||
f"{settings.BASE_URL}/{file_name}",
|
||||
params={"key": api_key}
|
||||
)
|
||||
except Exception as e:
|
||||
# 记录错误但继续处理其他文件
|
||||
logger.error(f"Failed to delete file {file_record['name']} from API: {str(e)}")
|
||||
|
||||
return len(expired_files)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cleanup expired files: {str(e)}")
|
||||
return 0
|
||||
|
||||
|
||||
# 单例实例
|
||||
_files_service_instance: Optional[FilesService] = None
|
||||
|
||||
|
||||
async def get_files_service() -> FilesService:
|
||||
"""获取文件服务单例实例"""
|
||||
global _files_service_instance
|
||||
if _files_service_instance is None:
|
||||
_files_service_instance = FilesService()
|
||||
return _files_service_instance
|
||||
@@ -121,6 +121,7 @@ class ImageCreateService:
|
||||
provider=settings.UPLOAD_PROVIDER,
|
||||
base_url=settings.CLOUDFLARE_IMGBED_URL,
|
||||
auth_code=settings.CLOUDFLARE_IMGBED_AUTH_CODE,
|
||||
upload_folder=settings.CLOUDFLARE_IMGBED_UPLOAD_FOLDER,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import asyncio
|
||||
from itertools import cycle
|
||||
from typing import Dict
|
||||
from typing import Dict, Union
|
||||
|
||||
from app.config.config import settings
|
||||
from app.log.logger import get_key_manager_logger
|
||||
@@ -34,7 +34,7 @@ class KeyManager:
|
||||
return next(self.key_cycle)
|
||||
|
||||
async def get_next_vertex_key(self) -> str:
|
||||
"""获取下一个 Vertex API key"""
|
||||
"""获取下一个 Vertex Express API key"""
|
||||
async with self.vertex_key_cycle_lock:
|
||||
return next(self.vertex_key_cycle)
|
||||
|
||||
@@ -98,7 +98,7 @@ class KeyManager:
|
||||
return current_key
|
||||
|
||||
async def get_next_working_vertex_key(self) -> str:
|
||||
"""获取下一可用的 Vertex API key"""
|
||||
"""获取下一可用的 Vertex Express API key"""
|
||||
initial_key = await self.get_next_vertex_key()
|
||||
current_key = initial_key
|
||||
|
||||
@@ -124,12 +124,12 @@ class KeyManager:
|
||||
return ""
|
||||
|
||||
async def handle_vertex_api_failure(self, api_key: str, retries: int) -> str:
|
||||
"""处理 Vertex API 调用失败"""
|
||||
"""处理 Vertex Express API 调用失败"""
|
||||
async with self.vertex_failure_count_lock:
|
||||
self.vertex_key_failure_counts[api_key] += 1
|
||||
if self.vertex_key_failure_counts[api_key] >= self.MAX_FAILURES:
|
||||
logger.warning(
|
||||
f"Vertex API key {api_key} has failed {self.MAX_FAILURES} times"
|
||||
f"Vertex Express API key {api_key} has failed {self.MAX_FAILURES} times"
|
||||
)
|
||||
|
||||
def get_fail_count(self, key: str) -> int:
|
||||
@@ -156,7 +156,7 @@ class KeyManager:
|
||||
return {"valid_keys": valid_keys, "invalid_keys": invalid_keys}
|
||||
|
||||
async def get_vertex_keys_by_status(self) -> dict:
|
||||
"""获取分类后的 Vertex API key 列表,包括失败次数"""
|
||||
"""获取分类后的 Vertex Express API key 列表,包括失败次数"""
|
||||
valid_keys = {}
|
||||
invalid_keys = {}
|
||||
|
||||
@@ -185,12 +185,12 @@ class KeyManager:
|
||||
|
||||
_singleton_instance = None
|
||||
_singleton_lock = asyncio.Lock()
|
||||
_preserved_failure_counts: Dict[str, int] | None = None
|
||||
_preserved_vertex_failure_counts: Dict[str, int] | None = None
|
||||
_preserved_old_api_keys_for_reset: list | None = None
|
||||
_preserved_vertex_old_api_keys_for_reset: list | None = None
|
||||
_preserved_next_key_in_cycle: str | None = None
|
||||
_preserved_vertex_next_key_in_cycle: str | None = None
|
||||
_preserved_failure_counts: Union[Dict[str, int], None] = None
|
||||
_preserved_vertex_failure_counts: Union[Dict[str, int], None] = None
|
||||
_preserved_old_api_keys_for_reset: Union[list, None] = None
|
||||
_preserved_vertex_old_api_keys_for_reset: Union[list, None] = None
|
||||
_preserved_next_key_in_cycle: Union[str, None] = None
|
||||
_preserved_vertex_next_key_in_cycle: Union[str, None] = None
|
||||
|
||||
|
||||
async def get_key_manager_instance(
|
||||
@@ -213,7 +213,7 @@ async def get_key_manager_instance(
|
||||
)
|
||||
if vertex_api_keys is None:
|
||||
raise ValueError(
|
||||
"Vertex API keys are required to initialize or re-initialize the KeyManager instance."
|
||||
"Vertex Express API keys are required to initialize or re-initialize the KeyManager instance."
|
||||
)
|
||||
|
||||
if not api_keys:
|
||||
@@ -222,12 +222,12 @@ async def get_key_manager_instance(
|
||||
)
|
||||
if not vertex_api_keys:
|
||||
logger.warning(
|
||||
"Initializing KeyManager with an empty list of Vertex API keys."
|
||||
"Initializing KeyManager with an empty list of Vertex Express API keys."
|
||||
)
|
||||
|
||||
_singleton_instance = KeyManager(api_keys, vertex_api_keys)
|
||||
logger.info(
|
||||
f"KeyManager instance created/re-created with {len(api_keys)} API keys and {len(vertex_api_keys)} Vertex API keys."
|
||||
f"KeyManager instance created/re-created with {len(api_keys)} API keys and {len(vertex_api_keys)} Vertex Express API keys."
|
||||
)
|
||||
|
||||
# 1. 恢复失败计数
|
||||
@@ -349,7 +349,7 @@ async def get_key_manager_instance(
|
||||
break
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
f"Preserved next key '{_preserved_vertex_next_key_in_cycle}' not found in preserved old Vertex API keys. "
|
||||
f"Preserved next key '{_preserved_vertex_next_key_in_cycle}' not found in preserved old Vertex Express API keys. "
|
||||
"New cycle will start from the beginning of the new list."
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -357,7 +357,7 @@ async def get_key_manager_instance(
|
||||
f"Error determining start key for new Vertex key cycle from preserved state: {e}. "
|
||||
"New cycle will start from the beginning."
|
||||
)
|
||||
|
||||
|
||||
if start_key_for_new_vertex_cycle and _singleton_instance.vertex_api_keys:
|
||||
try:
|
||||
target_idx = _singleton_instance.vertex_api_keys.index(
|
||||
@@ -370,25 +370,25 @@ async def get_key_manager_instance(
|
||||
)
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
f"Determined start key '{start_key_for_new_vertex_cycle}' not found in new Vertex API keys during cycle advancement. "
|
||||
f"Determined start key '{start_key_for_new_vertex_cycle}' not found in new Vertex Express API keys during cycle advancement. "
|
||||
"New cycle will start from the beginning."
|
||||
)
|
||||
except StopIteration:
|
||||
logger.error(
|
||||
"StopIteration while advancing Vertex key cycle, implies empty new Vertex API key list previously missed."
|
||||
"StopIteration while advancing Vertex key cycle, implies empty new Vertex Express API key list previously missed."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error advancing new Vertex key cycle: {e}. Cycle will start from beginning."
|
||||
)
|
||||
)
|
||||
else:
|
||||
if _singleton_instance.vertex_api_keys:
|
||||
logger.info(
|
||||
"New Vertex key cycle will start from the beginning of the new Vertex API key list (no specific start key determined or needed)."
|
||||
"New Vertex key cycle will start from the beginning of the new Vertex Express API key list (no specific start key determined or needed)."
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"New Vertex key cycle not applicable as the new Vertex API key list is empty."
|
||||
"New Vertex key cycle not applicable as the new Vertex Express API key list is empty."
|
||||
)
|
||||
|
||||
# 清理所有保存的状态
|
||||
@@ -409,16 +409,20 @@ async def reset_key_manager_instance():
|
||||
if _singleton_instance:
|
||||
# 1. 保存失败计数
|
||||
_preserved_failure_counts = _singleton_instance.key_failure_counts.copy()
|
||||
_preserved_vertex_failure_counts = _singleton_instance.vertex_key_failure_counts.copy()
|
||||
_preserved_vertex_failure_counts = (
|
||||
_singleton_instance.vertex_key_failure_counts.copy()
|
||||
)
|
||||
|
||||
# 2. 保存旧的 API keys 列表
|
||||
_preserved_old_api_keys_for_reset = _singleton_instance.api_keys.copy()
|
||||
_preserved_vertex_old_api_keys_for_reset = _singleton_instance.vertex_api_keys.copy()
|
||||
_preserved_vertex_old_api_keys_for_reset = (
|
||||
_singleton_instance.vertex_api_keys.copy()
|
||||
)
|
||||
|
||||
# 3. 保存 key_cycle 的下一个 key 提示
|
||||
try:
|
||||
if _singleton_instance.api_keys:
|
||||
_preserved_next_key_in_cycle = (
|
||||
_preserved_next_key_in_cycle = (
|
||||
await _singleton_instance.get_next_key()
|
||||
)
|
||||
else:
|
||||
@@ -427,7 +431,7 @@ async def reset_key_manager_instance():
|
||||
logger.warning(
|
||||
"Could not preserve next key hint: key cycle was empty or exhausted in old instance."
|
||||
)
|
||||
_preserved_next_key_in_cycle = None
|
||||
_preserved_next_key_in_cycle = None
|
||||
except Exception as e:
|
||||
logger.error(f"Error preserving next key hint during reset: {e}")
|
||||
_preserved_next_key_in_cycle = None
|
||||
@@ -443,12 +447,11 @@ async def reset_key_manager_instance():
|
||||
except StopIteration:
|
||||
logger.warning(
|
||||
"Could not preserve next key hint: Vertex key cycle was empty or exhausted in old instance."
|
||||
)
|
||||
)
|
||||
_preserved_vertex_next_key_in_cycle = None
|
||||
except Exception as e:
|
||||
logger.error(f"Error preserving next key hint during reset: {e}")
|
||||
_preserved_vertex_next_key_in_cycle = None
|
||||
|
||||
|
||||
_singleton_instance = None
|
||||
logger.info(
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# app/service/stats_service.py
|
||||
|
||||
import datetime
|
||||
from typing import Union
|
||||
|
||||
from sqlalchemy import and_, case, func, or_, select
|
||||
|
||||
@@ -195,10 +196,11 @@ class StatsService:
|
||||
return details
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get API call details for period '{period}': {e}")
|
||||
logger.error(
|
||||
f"Failed to get API call details for period '{period}': {e}")
|
||||
raise
|
||||
|
||||
async def get_key_usage_details_last_24h(self, key: str) -> dict | None:
|
||||
async def get_key_usage_details_last_24h(self, key: str) -> Union[dict, None]:
|
||||
"""
|
||||
获取指定 API 密钥在过去 24 小时内按模型统计的调用次数。
|
||||
|
||||
@@ -218,7 +220,8 @@ class StatsService:
|
||||
try:
|
||||
query = (
|
||||
select(
|
||||
RequestLog.model_name, func.count(RequestLog.id).label("call_count")
|
||||
RequestLog.model_name, func.count(
|
||||
RequestLog.id).label("call_count")
|
||||
)
|
||||
.where(
|
||||
RequestLog.api_key == key,
|
||||
@@ -237,7 +240,8 @@ class StatsService:
|
||||
)
|
||||
return {}
|
||||
|
||||
usage_details = {row["model_name"]: row["call_count"] for row in results}
|
||||
usage_details = {row["model_name"]: row["call_count"]
|
||||
for row in results}
|
||||
logger.info(
|
||||
f"Successfully fetched usage details for key ending in ...{key[-4:]}: {usage_details}"
|
||||
)
|
||||
|
||||
363
app/service/tts/native/README.md
Normal file
363
app/service/tts/native/README.md
Normal file
@@ -0,0 +1,363 @@
|
||||
# 原生Gemini TTS功能
|
||||
|
||||
这个模块为Gemini Balance项目添加了原生Gemini TTS(Text-to-Speech)功能,支持单人和多人语音合成,采用智能检测和继承模式设计,保持与原始代码的完全兼容性。
|
||||
|
||||
## 🎯 设计原则
|
||||
|
||||
- **智能检测**:自动检测所有原生Gemini TTS格式的请求(包含responseModalities和speechConfig)
|
||||
- **继承而非修改**:所有扩展都继承自原始类,不修改源码
|
||||
- **完全兼容**:原有TTS功能(OpenAI兼容TTS)完全不受影响
|
||||
- **动态模型选择**:支持用户在请求URL中指定不同的TTS模型
|
||||
- **自动回退**:原生TTS处理失败时自动回退到标准服务
|
||||
- **完整日志记录**:包含请求日志、错误日志和性能监控
|
||||
- **易于维护**:更新原始代码时不会产生冲突
|
||||
|
||||
## 📁 文件结构
|
||||
|
||||
```
|
||||
app/service/tts/
|
||||
├── tts_service.py # 原有的OpenAI兼容TTS服务
|
||||
└── native/ # 原生Gemini TTS扩展
|
||||
├── __init__.py # 模块初始化
|
||||
├── README.md # 使用说明(本文件)
|
||||
├── tts_models.py # TTS数据模型(继承自原始模型)
|
||||
├── tts_response_handler.py # TTS响应处理器(继承自原始处理器)
|
||||
├── tts_chat_service.py # TTS聊天服务(继承自原始服务)
|
||||
└── tts_routes.py # TTS路由扩展和依赖注入
|
||||
```
|
||||
|
||||
## 🚀 原生Gemini TTS功能
|
||||
|
||||
### 智能检测机制(当前实现)
|
||||
|
||||
原生Gemini TTS功能通过智能检测自动启用,无需任何配置:
|
||||
|
||||
1. **自动启用**:
|
||||
```bash
|
||||
# 直接启动服务,原生TTS功能自动可用
|
||||
python -m uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload
|
||||
```
|
||||
|
||||
2. **无需配置**:
|
||||
- 不需要环境变量
|
||||
- 不需要修改配置文件
|
||||
- 完全基于请求内容智能判断
|
||||
|
||||
### 工作原理
|
||||
|
||||
系统会智能检测请求内容:
|
||||
- **原生TTS请求**:包含 `responseModalities: ["AUDIO"]` 和 `speechConfig` → 使用TTS增强服务
|
||||
- **单人TTS**:包含 `voiceConfig.prebuiltVoiceConfig`
|
||||
- **多人TTS**:包含 `multiSpeakerVoiceConfig`
|
||||
- **普通请求**:非TTS模型 → 使用原有Gemini聊天服务
|
||||
|
||||
```python
|
||||
# app/router/gemini_routes.py 中的智能检测逻辑
|
||||
if "tts" in model_name.lower() and request.generationConfig:
|
||||
# 直接从解析后的request对象获取TTS配置
|
||||
response_modalities = request.generationConfig.responseModalities or []
|
||||
speech_config = request.generationConfig.speechConfig or {}
|
||||
|
||||
# 如果包含AUDIO模态和语音配置,则认为是原生TTS请求
|
||||
if "AUDIO" in response_modalities and speech_config:
|
||||
# 使用TTS增强服务
|
||||
tts_service = await get_tts_chat_service(key_manager)
|
||||
return await tts_service.generate_content(...)
|
||||
# 否则使用原有服务
|
||||
```
|
||||
|
||||
## 📝 使用示例
|
||||
|
||||
### 1. 原生Gemini单人TTS请求(使用TTS增强服务)
|
||||
|
||||
包含 `voiceConfig.prebuiltVoiceConfig` 的原生Gemini格式请求会自动使用TTS增强服务:
|
||||
|
||||
```bash
|
||||
curl -X POST "https://your-domain.com/v1beta/models/gemini-2.5-flash-preview-tts:generateContent" \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "x-goog-api-key: your-token" \
|
||||
-d '{
|
||||
"contents": [{
|
||||
"parts": [{
|
||||
"text": "Hello, this is a single speaker test."
|
||||
}]
|
||||
}],
|
||||
"generationConfig": {
|
||||
"responseModalities": ["AUDIO"],
|
||||
"speechConfig": {
|
||||
"voiceConfig": {
|
||||
"prebuiltVoiceConfig": {
|
||||
"voiceName": "Kore"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
### 2. 原生Gemini多人TTS请求(使用TTS增强服务)
|
||||
|
||||
包含 `multiSpeakerVoiceConfig` 的原生Gemini格式请求会自动使用TTS增强服务:
|
||||
|
||||
```bash
|
||||
curl -X POST "https://your-domain.com/v1beta/models/gemini-2.5-flash-preview-tts:generateContent" \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "x-goog-api-key: your-token" \
|
||||
-d '{
|
||||
"contents": [{
|
||||
"parts": [{
|
||||
"text": "Alice: Hello everyone, welcome to our show today.\nBob: Hi Alice, and hello to all our listeners! Today we are talking about AI development."
|
||||
}]
|
||||
}],
|
||||
"generationConfig": {
|
||||
"responseModalities": ["AUDIO"],
|
||||
"speechConfig": {
|
||||
"multiSpeakerVoiceConfig": {
|
||||
"speakerVoiceConfigs": [
|
||||
{
|
||||
"speaker": "Alice",
|
||||
"voiceConfig": {
|
||||
"prebuiltVoiceConfig": {
|
||||
"voiceName": "Puck"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"speaker": "Bob",
|
||||
"voiceConfig": {
|
||||
"prebuiltVoiceConfig": {
|
||||
"voiceName": "Kore"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
### 3. OpenAI兼容TTS请求(使用原有服务)
|
||||
|
||||
OpenAI兼容格式的TTS请求使用不同的API路径,不受本模块影响:
|
||||
|
||||
```bash
|
||||
curl -X POST "https://your-domain.com/v1/audio/speech" \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer your-token" \
|
||||
-d '{
|
||||
"model": "tts-1",
|
||||
"input": "这是一个OpenAI兼容格式的TTS测试。",
|
||||
"voice": "alloy"
|
||||
}' \
|
||||
--output openai_tts_test.wav
|
||||
```
|
||||
|
||||
**注意**:OpenAI兼容TTS请求:
|
||||
- 使用路径:`/v1/audio/speech`
|
||||
- 使用Authorization头而不是x-goog-api-key
|
||||
- 返回音频文件而不是JSON响应
|
||||
- 不受本模块的TTS增强服务影响
|
||||
|
||||
### 普通文本生成(使用原有服务)
|
||||
|
||||
非TTS模型的请求会使用原有的Gemini聊天服务,完全不受影响:
|
||||
|
||||
```bash
|
||||
curl -X POST "https://your-domain.com/v1beta/models/gemini-2.5-flash:generateContent" \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "x-goog-api-key: your-token" \
|
||||
-d '{
|
||||
"contents": [{
|
||||
"parts": [{
|
||||
"text": "请简单介绍一下人工智能的发展历程。"
|
||||
}]
|
||||
}],
|
||||
"generationConfig": {
|
||||
"maxOutputTokens": 200,
|
||||
"temperature": 0.7
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
## 🔧 技术实现
|
||||
|
||||
### 继承关系
|
||||
|
||||
```
|
||||
GeminiChatService
|
||||
↓ (继承)
|
||||
TTSGeminiChatService
|
||||
├── 重写 generate_content() 方法
|
||||
├── 添加 _handle_tts_request() 方法
|
||||
└── 集成完整的日志记录功能
|
||||
|
||||
GeminiResponseHandler
|
||||
↓ (继承)
|
||||
TTSResponseHandler
|
||||
└── 重写 handle_response() 方法
|
||||
|
||||
GenerationConfig (Pydantic模型)
|
||||
↓ (扩展)
|
||||
TTSGenerationConfig
|
||||
├── responseModalities: List[str]
|
||||
└── speechConfig: Dict[str, Any]
|
||||
```
|
||||
|
||||
### 工作流程
|
||||
|
||||
1. **请求接收**:系统接收到API请求
|
||||
2. **智能检测**:
|
||||
- 检查模型名称是否包含 "tts"
|
||||
- 如果是TTS模型,从 `request.generationConfig` 检查是否包含 `responseModalities: ["AUDIO"]` 和 `speechConfig`
|
||||
3. **服务选择**:
|
||||
- **原生TTS请求**:使用 `TTSGeminiChatService` 增强服务
|
||||
- **普通请求**:使用原有 `GeminiChatService`
|
||||
4. **请求处理**:
|
||||
- **原生TTS**:使用 `_handle_tts_request()` 特殊处理
|
||||
- **其他请求**:使用标准 `generate_content()` 方法
|
||||
5. **字段处理**:从 `request.generationConfig` 直接获取TTS字段(`responseModalities`, `speechConfig`)
|
||||
6. **API调用**:构建优化的payload并调用Gemini API
|
||||
7. **自动回退**:如果原生TTS处理失败,自动回退到标准服务
|
||||
8. **响应处理**:
|
||||
- **TTS响应**:检测音频数据,直接返回原始响应
|
||||
- **普通响应**:使用标准处理方法
|
||||
9. **日志记录**:记录请求时间、成功状态、错误信息到数据库
|
||||
|
||||
## 📊 功能特性
|
||||
|
||||
### ✅ 已实现功能
|
||||
|
||||
- **智能原生TTS支持**:支持单人和多人语音合成
|
||||
- **单人TTS**:支持 `voiceConfig.prebuiltVoiceConfig` 配置
|
||||
- **多人TTS**:支持 `multiSpeakerVoiceConfig` 配置
|
||||
- **智能检测机制**:自动检测所有原生Gemini TTS格式的请求
|
||||
- **动态模型选择**:支持用户在URL中指定不同TTS模型
|
||||
- **完全向后兼容**:原有TTS功能(OpenAI兼容TTS)完全不受影响
|
||||
- **自动回退机制**:原生TTS处理失败时自动使用标准服务
|
||||
- **完整日志记录**:请求日志、错误日志、性能监控
|
||||
- **API配额管理**:自动重试和密钥轮换
|
||||
- **零配置启用**:无需环境变量或配置文件修改
|
||||
- **错误处理**:完整的异常捕获和错误记录
|
||||
|
||||
### 🎵 支持的语音配置
|
||||
|
||||
#### 单人语音配置
|
||||
|
||||
```json
|
||||
{
|
||||
"responseModalities": ["AUDIO"],
|
||||
"speechConfig": {
|
||||
"voiceConfig": {
|
||||
"prebuiltVoiceConfig": {
|
||||
"voiceName": "Kore|Puck|其他预设语音"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
#### 多人语音配置
|
||||
|
||||
```json
|
||||
{
|
||||
"responseModalities": ["AUDIO"],
|
||||
"speechConfig": {
|
||||
"multiSpeakerVoiceConfig": {
|
||||
"speakerVoiceConfigs": [
|
||||
{
|
||||
"speaker": "角色名称",
|
||||
"voiceConfig": {
|
||||
"prebuiltVoiceConfig": {
|
||||
"voiceName": "Kore|Puck|其他预设语音"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## ⚠️ 注意事项
|
||||
|
||||
### API要求
|
||||
- 确保API密钥有TTS权限
|
||||
- TTS功能需要 `gemini-2.5-flash-preview-tts` 模型
|
||||
- 注意API配额限制(免费版每天15次)
|
||||
|
||||
### 性能考虑
|
||||
- TTS响应通常比文本响应更大(音频数据)
|
||||
- 建议监控API调用频率和成功率
|
||||
- 扩展功能不影响原始功能的性能和稳定性
|
||||
|
||||
### 部署建议
|
||||
- 生产环境建议先测试普通功能
|
||||
- 逐步启用TTS功能并监控日志
|
||||
- 定期检查API配额使用情况
|
||||
|
||||
## 📈 监控和调试
|
||||
|
||||
### 日志查看
|
||||
- **服务器日志**:查看TTS请求处理过程
|
||||
- **管理界面**:在"API 调用详情"中查看请求记录
|
||||
- **错误日志**:查看失败请求的详细信息
|
||||
|
||||
### 调试技巧
|
||||
```bash
|
||||
# 启用详细日志
|
||||
export LOG_LEVEL=DEBUG
|
||||
|
||||
# 查看实时日志
|
||||
tail -f logs/app.log
|
||||
|
||||
# 多人TTS功能无需配置,自动启用
|
||||
# 可通过请求内容智能检测
|
||||
```
|
||||
|
||||
## 🔄 TTS系统对比
|
||||
|
||||
项目中现在有三套TTS系统,各自服务不同的用途:
|
||||
|
||||
| TTS类型 | 路径 | 模型选择 | 语音配置 | 使用场景 | 我们的影响 |
|
||||
|---------|------|----------|----------|----------|------------|
|
||||
| **OpenAI兼容TTS** | `/v1/audio/speech` | 固定配置文件 | 单人语音 | OpenAI API兼容 | ✅ 无影响 |
|
||||
| **Gemini单人TTS** | `/v1beta/models/{model}:generateContent` | 用户指定 | 单人语音 | 原生Gemini TTS | ✅ 我们的增强 |
|
||||
| **Gemini多人TTS** | `/v1beta/models/{model}:generateContent` | 用户指定 | 多人语音 | 对话场景 | ✅ 我们的增强 |
|
||||
|
||||
### 智能路由机制
|
||||
|
||||
```mermaid
|
||||
flowchart TD
|
||||
A[API请求] --> B{路径检查}
|
||||
B -->|/v1/audio/speech| C[OpenAI兼容TTS服务]
|
||||
B -->|/v1beta/models/{model}:generateContent| D{模型名包含'tts'?}
|
||||
D -->|否| E[标准Gemini聊天服务]
|
||||
D -->|是| F{包含responseModalities和speechConfig?}
|
||||
F -->|否| G[标准Gemini聊天服务]
|
||||
F -->|是| H[原生TTS增强服务]
|
||||
H --> I{处理成功?}
|
||||
I -->|是| J[返回原生TTS响应]
|
||||
I -->|否| K[自动回退到标准服务]
|
||||
C --> L[完成]
|
||||
E --> L
|
||||
G --> L
|
||||
J --> L
|
||||
K --> L
|
||||
```
|
||||
|
||||
## 🎉 成功案例
|
||||
|
||||
基于智能检测的原生Gemini TTS解决方案已经成功实现:
|
||||
|
||||
- ✅ **零配置启用**:无需任何环境变量或配置修改
|
||||
- ✅ **智能检测**:自动检测所有原生Gemini TTS格式的请求
|
||||
- ✅ **完全向后兼容**:所有原有TTS功能零影响
|
||||
- ✅ **动态模型选择**:支持用户指定不同TTS模型
|
||||
- ✅ **自动回退机制**:处理失败时自动使用标准服务
|
||||
- ✅ **单人和多人语音合成**:支持所有原生Gemini TTS场景
|
||||
- ✅ **完整日志记录**:可在管理界面查看所有请求
|
||||
- ✅ **错误处理完善**:API配额和重试机制
|
||||
- ✅ **易于维护**:更新原始代码无冲突
|
||||
|
||||
这个实现展示了如何在不修改原始代码的情况下,优雅地扩展复杂系统的功能,同时保持完美的向后兼容性。
|
||||
19
app/service/tts/native/__init__.py
Normal file
19
app/service/tts/native/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""
|
||||
原生Gemini TTS功能模块
|
||||
Native Gemini TTS functionality for both single and multi-speaker scenarios
|
||||
"""
|
||||
|
||||
from .tts_chat_service import TTSGeminiChatService
|
||||
from .tts_models import TTSGenerationConfig, MultiSpeakerVoiceConfig, SpeechConfig, TTSRequest
|
||||
from .tts_response_handler import TTSResponseHandler
|
||||
from .tts_routes import get_tts_chat_service
|
||||
|
||||
__all__ = [
|
||||
"TTSGeminiChatService",
|
||||
"TTSGenerationConfig",
|
||||
"MultiSpeakerVoiceConfig",
|
||||
"SpeechConfig",
|
||||
"TTSRequest",
|
||||
"TTSResponseHandler",
|
||||
"get_tts_chat_service"
|
||||
]
|
||||
151
app/service/tts/native/tts_chat_service.py
Normal file
151
app/service/tts/native/tts_chat_service.py
Normal file
@@ -0,0 +1,151 @@
|
||||
"""
|
||||
原生Gemini TTS聊天服务扩展
|
||||
继承自原始聊天服务,添加原生Gemini TTS支持(单人和多人),保持向后兼容
|
||||
"""
|
||||
|
||||
import time
|
||||
import datetime
|
||||
from typing import Any, Dict
|
||||
from app.service.chat.gemini_chat_service import GeminiChatService
|
||||
from app.service.tts.native.tts_response_handler import TTSResponseHandler
|
||||
from app.domain.gemini_models import GeminiRequest
|
||||
from app.log.logger import get_gemini_logger
|
||||
from app.database.services import add_request_log, add_error_log
|
||||
|
||||
logger = get_gemini_logger()
|
||||
|
||||
|
||||
class TTSGeminiChatService(GeminiChatService):
|
||||
"""
|
||||
支持TTS的Gemini聊天服务
|
||||
继承自原始的GeminiChatService,添加TTS功能
|
||||
"""
|
||||
|
||||
def __init__(self, base_url: str, key_manager):
|
||||
"""
|
||||
初始化TTS聊天服务
|
||||
"""
|
||||
super().__init__(base_url, key_manager)
|
||||
# 使用TTS响应处理器替换原始处理器
|
||||
self.response_handler = TTSResponseHandler()
|
||||
logger.info("TTS Gemini Chat Service initialized with multi-speaker TTS support")
|
||||
|
||||
async def generate_content(
|
||||
self, model: str, request: GeminiRequest, api_key: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
生成内容,支持TTS
|
||||
"""
|
||||
try:
|
||||
# 添加调试日志
|
||||
logger.info(f"TTS request model: {model}")
|
||||
logger.info(f"TTS request generationConfig: {request.generationConfig}")
|
||||
|
||||
# 检查是否是TTS模型,如果是,需要特殊处理
|
||||
if "tts" in model.lower():
|
||||
logger.info("Detected TTS model, applying TTS-specific processing")
|
||||
# 对于TTS模型,我们需要确保正确的字段被传递
|
||||
response = await self._handle_tts_request(model, request, api_key)
|
||||
return response
|
||||
else:
|
||||
# 对于非TTS模型,使用父类的方法
|
||||
response = await super().generate_content(model, request, api_key)
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"TTS API call failed with error: {e}")
|
||||
raise
|
||||
|
||||
async def _handle_tts_request(self, model: str, request: GeminiRequest, api_key: str) -> Dict[str, Any]:
|
||||
"""
|
||||
处理TTS特定的请求,包含完整的日志记录功能
|
||||
"""
|
||||
# 记录开始时间和请求时间
|
||||
start_time = time.perf_counter()
|
||||
request_datetime = datetime.datetime.now()
|
||||
is_success = False
|
||||
status_code = None
|
||||
|
||||
try:
|
||||
# 构建TTS专用的payload - 不包含tools和safetySettings
|
||||
from app.service.chat.gemini_chat_service import _filter_empty_parts
|
||||
|
||||
request_dict = request.model_dump(exclude_none=False)
|
||||
|
||||
# 构建TTS专用的简化payload
|
||||
payload = {
|
||||
"contents": _filter_empty_parts(request_dict.get("contents", [])),
|
||||
"generationConfig": request_dict.get("generationConfig", {}),
|
||||
}
|
||||
|
||||
# 只在有systemInstruction时才添加
|
||||
if request_dict.get("systemInstruction"):
|
||||
payload["systemInstruction"] = request_dict.get("systemInstruction")
|
||||
|
||||
# 确保 generationConfig 不为 None
|
||||
if payload["generationConfig"] is None:
|
||||
payload["generationConfig"] = {}
|
||||
|
||||
# 从request.generationConfig直接获取TTS相关字段
|
||||
if request.generationConfig:
|
||||
# 添加TTS特定字段
|
||||
if request.generationConfig.responseModalities:
|
||||
payload["generationConfig"]["responseModalities"] = request.generationConfig.responseModalities
|
||||
logger.info(f"Added responseModalities: {request.generationConfig.responseModalities}")
|
||||
|
||||
if request.generationConfig.speechConfig:
|
||||
payload["generationConfig"]["speechConfig"] = request.generationConfig.speechConfig
|
||||
logger.info(f"Added speechConfig: {request.generationConfig.speechConfig}")
|
||||
else:
|
||||
logger.warning("No generationConfig found in request, TTS fields may be missing")
|
||||
|
||||
logger.info(f"TTS payload before API call: {payload}")
|
||||
|
||||
# 调用API
|
||||
response = await self.api_client.generate_content(payload, model, api_key)
|
||||
|
||||
# 如果到达这里,说明API调用成功
|
||||
is_success = True
|
||||
status_code = 200
|
||||
|
||||
# 使用TTS响应处理器处理响应
|
||||
return self.response_handler.handle_response(response, model, False, None)
|
||||
|
||||
except Exception as e:
|
||||
# 记录错误
|
||||
is_success = False
|
||||
error_msg = str(e)
|
||||
|
||||
# 尝试从错误消息中提取状态码
|
||||
import re
|
||||
match = re.search(r"status code (\d+)", error_msg)
|
||||
if match:
|
||||
status_code = int(match.group(1))
|
||||
else:
|
||||
status_code = 500
|
||||
|
||||
# 添加错误日志
|
||||
await add_error_log(
|
||||
gemini_key=api_key,
|
||||
model_name=model,
|
||||
error_type="tts-api-error",
|
||||
error_log=error_msg,
|
||||
error_code=status_code,
|
||||
request_msg=request.model_dump(exclude_none=False)
|
||||
)
|
||||
|
||||
logger.error(f"TTS API call failed: {error_msg}")
|
||||
raise
|
||||
|
||||
finally:
|
||||
# 记录请求日志
|
||||
end_time = time.perf_counter()
|
||||
latency_ms = int((end_time - start_time) * 1000)
|
||||
|
||||
await add_request_log(
|
||||
model_name=model,
|
||||
api_key=api_key,
|
||||
is_success=is_success,
|
||||
status_code=status_code,
|
||||
latency_ms=latency_ms,
|
||||
request_time=request_datetime
|
||||
)
|
||||
37
app/service/tts/native/tts_config.py
Normal file
37
app/service/tts/native/tts_config.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""
|
||||
TTS扩展配置
|
||||
控制是否启用TTS功能
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Union
|
||||
from app.service.chat.gemini_chat_service import GeminiChatService
|
||||
from app.service.tts.native.tts_chat_service import TTSGeminiChatService
|
||||
|
||||
|
||||
class TTSConfig:
|
||||
"""TTS配置管理"""
|
||||
|
||||
@staticmethod
|
||||
def is_tts_enabled() -> bool:
|
||||
"""
|
||||
检查是否启用TTS功能
|
||||
通过环境变量 ENABLE_TTS 控制,默认为 False
|
||||
"""
|
||||
return os.getenv("ENABLE_TTS", "false").lower() in ("true", "1", "yes", "on")
|
||||
|
||||
@staticmethod
|
||||
def get_chat_service(base_url: str, key_manager) -> Union[GeminiChatService, TTSGeminiChatService]:
|
||||
"""
|
||||
工厂方法:根据配置返回合适的聊天服务
|
||||
"""
|
||||
if TTSConfig.is_tts_enabled():
|
||||
return TTSGeminiChatService(base_url, key_manager)
|
||||
else:
|
||||
return GeminiChatService(base_url, key_manager)
|
||||
|
||||
|
||||
# 便捷函数
|
||||
def create_chat_service(base_url: str, key_manager) -> Union[GeminiChatService, TTSGeminiChatService]:
|
||||
"""创建聊天服务实例"""
|
||||
return TTSConfig.get_chat_service(base_url, key_manager)
|
||||
36
app/service/tts/native/tts_models.py
Normal file
36
app/service/tts/native/tts_models.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""
|
||||
原生Gemini TTS扩展数据模型
|
||||
继承自原始模型,添加原生Gemini TTS相关字段,保持向后兼容
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.domain.gemini_models import GenerationConfig as BaseGenerationConfig
|
||||
|
||||
|
||||
class TTSGenerationConfig(BaseGenerationConfig):
|
||||
"""
|
||||
支持TTS的生成配置类
|
||||
继承自原始的GenerationConfig,添加TTS相关字段
|
||||
"""
|
||||
# TTS 相关配置
|
||||
responseModalities: Optional[List[str]] = None
|
||||
speechConfig: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class MultiSpeakerVoiceConfig(BaseModel):
|
||||
"""多人语音配置"""
|
||||
speakerVoiceConfigs: List[Dict[str, Any]]
|
||||
|
||||
|
||||
class SpeechConfig(BaseModel):
|
||||
"""语音配置"""
|
||||
multiSpeakerVoiceConfig: Optional[MultiSpeakerVoiceConfig] = None
|
||||
voiceConfig: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class TTSRequest(BaseModel):
|
||||
"""TTS请求模型"""
|
||||
contents: List[Dict[str, Any]]
|
||||
generationConfig: TTSGenerationConfig
|
||||
53
app/service/tts/native/tts_response_handler.py
Normal file
53
app/service/tts/native/tts_response_handler.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""
|
||||
原生Gemini TTS响应处理器扩展
|
||||
继承自原始响应处理器,添加原生Gemini TTS支持,保持向后兼容
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
from app.handler.response_handler import GeminiResponseHandler
|
||||
from app.log.logger import get_gemini_logger
|
||||
|
||||
logger = get_gemini_logger()
|
||||
|
||||
|
||||
class TTSResponseHandler(GeminiResponseHandler):
|
||||
"""
|
||||
支持TTS的响应处理器
|
||||
继承自原始的GeminiResponseHandler,添加TTS响应处理
|
||||
"""
|
||||
|
||||
def handle_response(
|
||||
self, response: Dict[str, Any], model: str, stream: bool = False, usage_metadata: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
处理响应,支持TTS音频数据
|
||||
"""
|
||||
# 检查是否是TTS响应(包含音频数据)
|
||||
if self._is_tts_response(response):
|
||||
logger.info("Detected TTS response with audio data, returning original response")
|
||||
return response
|
||||
|
||||
# 对于非TTS响应,使用父类的处理方法
|
||||
return super().handle_response(response, model, stream, usage_metadata)
|
||||
|
||||
def _is_tts_response(self, response: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
检查是否是TTS响应
|
||||
"""
|
||||
try:
|
||||
if (response.get("candidates") and
|
||||
len(response["candidates"]) > 0 and
|
||||
response["candidates"][0].get("content") and
|
||||
response["candidates"][0]["content"].get("parts") and
|
||||
len(response["candidates"][0]["content"]["parts"]) > 0):
|
||||
|
||||
parts = response["candidates"][0]["content"]["parts"]
|
||||
for part in parts:
|
||||
if "inlineData" in part:
|
||||
mime_type = part["inlineData"].get("mimeType", "")
|
||||
if mime_type.startswith("audio/"):
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning(f"Error checking TTS response: {e}")
|
||||
return False
|
||||
24
app/service/tts/native/tts_routes.py
Normal file
24
app/service/tts/native/tts_routes.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""
|
||||
TTS路由扩展
|
||||
提供原生Gemini TTS增强服务,支持单人和多人语音
|
||||
"""
|
||||
|
||||
from fastapi import Depends
|
||||
|
||||
from app.config.config import settings
|
||||
from app.service.key.key_manager import KeyManager, get_key_manager_instance
|
||||
from app.service.tts.native.tts_chat_service import TTSGeminiChatService
|
||||
|
||||
|
||||
async def get_key_manager():
|
||||
"""获取密钥管理器实例"""
|
||||
return get_key_manager_instance()
|
||||
|
||||
|
||||
async def get_tts_chat_service(key_manager: KeyManager = Depends(get_key_manager)) -> TTSGeminiChatService:
|
||||
"""
|
||||
获取原生Gemini TTS增强聊天服务实例,支持单人和多人语音
|
||||
"""
|
||||
return TTSGeminiChatService(settings.BASE_URL, key_manager)
|
||||
|
||||
|
||||
95
app/service/tts/tts_service.py
Normal file
95
app/service/tts/tts_service.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import datetime
|
||||
import io
|
||||
import re
|
||||
import time
|
||||
import wave
|
||||
from typing import Optional
|
||||
|
||||
from google import genai
|
||||
|
||||
from app.config.config import settings
|
||||
from app.core.constants import TTS_VOICE_NAMES
|
||||
from app.database.services import add_error_log, add_request_log
|
||||
from app.domain.openai_models import TTSRequest
|
||||
from app.log.logger import get_openai_logger
|
||||
|
||||
logger = get_openai_logger()
|
||||
|
||||
|
||||
def _create_wav_file(audio_data: bytes) -> bytes:
|
||||
"""Creates a WAV file in memory from raw audio data."""
|
||||
with io.BytesIO() as wav_file:
|
||||
with wave.open(wav_file, "wb") as wf:
|
||||
wf.setnchannels(1) # Mono
|
||||
wf.setsampwidth(2) # 16-bit
|
||||
wf.setframerate(24000) # 24kHz sample rate
|
||||
wf.writeframes(audio_data)
|
||||
return wav_file.getvalue()
|
||||
|
||||
|
||||
class TTSService:
|
||||
async def create_tts(self, request: TTSRequest, api_key: str) -> Optional[bytes]:
|
||||
"""
|
||||
使用 Google Gemini SDK 创建音频。
|
||||
"""
|
||||
start_time = time.perf_counter()
|
||||
request_datetime = datetime.datetime.now()
|
||||
is_success = False
|
||||
status_code = None
|
||||
response = None
|
||||
error_log_msg = ""
|
||||
try:
|
||||
client = genai.Client(api_key=api_key)
|
||||
response =await client.aio.models.generate_content(
|
||||
model=settings.TTS_MODEL,
|
||||
contents=f"Speak in a {settings.TTS_SPEED} speed voice: {request.input}",
|
||||
config={
|
||||
"response_modalities": ["Audio"],
|
||||
"speech_config": {
|
||||
"voice_config": {
|
||||
"prebuilt_voice_config": {
|
||||
"voice_name": request.voice if request.voice in TTS_VOICE_NAMES else settings.TTS_VOICE_NAME
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
)
|
||||
if (
|
||||
response.candidates
|
||||
and response.candidates[0].content.parts
|
||||
and response.candidates[0].content.parts[0].inline_data
|
||||
):
|
||||
raw_audio_data = response.candidates[0].content.parts[0].inline_data.data
|
||||
is_success = True
|
||||
status_code = 200
|
||||
return _create_wav_file(raw_audio_data)
|
||||
except Exception as e:
|
||||
is_success = False
|
||||
error_log_msg = f"Generic error: {e}"
|
||||
logger.error(f"An error occurred in TTSService: {error_log_msg}")
|
||||
match = re.search(r"status code (\d+)", str(e))
|
||||
if match:
|
||||
status_code = int(match.group(1))
|
||||
else:
|
||||
status_code = 500
|
||||
raise
|
||||
finally:
|
||||
end_time = time.perf_counter()
|
||||
latency_ms = int((end_time - start_time) * 1000)
|
||||
if not is_success:
|
||||
await add_error_log(
|
||||
gemini_key=api_key,
|
||||
model_name=settings.TTS_MODEL,
|
||||
error_type="google-tts",
|
||||
error_log=error_log_msg,
|
||||
error_code=status_code,
|
||||
request_msg=request.input
|
||||
)
|
||||
await add_request_log(
|
||||
model_name=settings.TTS_MODEL,
|
||||
api_key=api_key,
|
||||
is_success=is_success,
|
||||
status_code=status_code,
|
||||
latency_ms=latency_ms,
|
||||
request_time=request_datetime
|
||||
)
|
||||
@@ -5,14 +5,17 @@ const ARRAY_INPUT_CLASS = "array-input";
|
||||
const MAP_ITEM_CLASS = "map-item";
|
||||
const MAP_KEY_INPUT_CLASS = "map-key-input";
|
||||
const MAP_VALUE_INPUT_CLASS = "map-value-input";
|
||||
const CUSTOM_HEADER_ITEM_CLASS = "custom-header-item";
|
||||
const CUSTOM_HEADER_KEY_INPUT_CLASS = "custom-header-key-input";
|
||||
const CUSTOM_HEADER_VALUE_INPUT_CLASS = "custom-header-value-input";
|
||||
const SAFETY_SETTING_ITEM_CLASS = "safety-setting-item";
|
||||
const SHOW_CLASS = "show"; // For modals
|
||||
const API_KEY_REGEX = /AIzaSy\S{33}/g;
|
||||
const PROXY_REGEX =
|
||||
/(?:https?|socks5):\/\/(?:[^:@\/]+(?::[^@\/]+)?@)?(?:[^:\/\s]+)(?::\d+)?/g;
|
||||
const VERTEX_API_KEY_REGEX = /AQ\.[a-zA-Z0-9_]{50}/g; // 新增 Vertex API Key 正则
|
||||
const VERTEX_API_KEY_REGEX = /AQ\.[a-zA-Z0-9_\-]{50}/g; // 新增 Vertex Express API Key 正则
|
||||
const MASKED_VALUE = "••••••••";
|
||||
|
||||
|
||||
// DOM Elements - Global Scope for frequently accessed elements
|
||||
const safetySettingsContainer = document.getElementById(
|
||||
"SAFETY_SETTINGS_container"
|
||||
@@ -31,8 +34,8 @@ const bulkDeleteProxyModal = document.getElementById("bulkDeleteProxyModal");
|
||||
const bulkDeleteProxyInput = document.getElementById("bulkDeleteProxyInput");
|
||||
const resetConfirmModal = document.getElementById("resetConfirmModal");
|
||||
const configForm = document.getElementById("configForm"); // Added for frequent use
|
||||
|
||||
// Vertex API Key Modal Elements
|
||||
|
||||
// Vertex Express API Key Modal Elements
|
||||
const vertexApiKeyModal = document.getElementById("vertexApiKeyModal");
|
||||
const vertexApiKeyBulkInput = document.getElementById("vertexApiKeyBulkInput");
|
||||
const bulkDeleteVertexApiKeyModal = document.getElementById(
|
||||
@@ -41,7 +44,7 @@ const bulkDeleteVertexApiKeyModal = document.getElementById(
|
||||
const bulkDeleteVertexApiKeyInput = document.getElementById(
|
||||
"bulkDeleteVertexApiKeyInput"
|
||||
);
|
||||
|
||||
|
||||
// Model Helper Modal Elements
|
||||
const modelHelperModal = document.getElementById("modelHelperModal");
|
||||
const modelHelperTitleElement = document.getElementById("modelHelperTitle");
|
||||
@@ -383,9 +386,15 @@ document.addEventListener("DOMContentLoaded", function () {
|
||||
addSafetySettingBtn.addEventListener("click", () => addSafetySettingItem());
|
||||
}
|
||||
|
||||
// Add Custom Header button
|
||||
const addCustomHeaderBtn = document.getElementById("addCustomHeaderBtn");
|
||||
if (addCustomHeaderBtn) {
|
||||
addCustomHeaderBtn.addEventListener("click", () => addCustomHeaderItem());
|
||||
}
|
||||
|
||||
initializeSensitiveFields(); // Initialize sensitive field handling
|
||||
|
||||
// Vertex API Key Modal Elements and Events
|
||||
|
||||
// Vertex Express API Key Modal Elements and Events
|
||||
const addVertexApiKeyBtn = document.getElementById("addVertexApiKeyBtn");
|
||||
const closeVertexApiKeyModalBtn = document.getElementById(
|
||||
"closeVertexApiKeyModalBtn"
|
||||
@@ -408,7 +417,7 @@ document.addEventListener("DOMContentLoaded", function () {
|
||||
const confirmBulkDeleteVertexApiKeyBtn = document.getElementById(
|
||||
"confirmBulkDeleteVertexApiKeyBtn"
|
||||
);
|
||||
|
||||
|
||||
if (addVertexApiKeyBtn) {
|
||||
addVertexApiKeyBtn.addEventListener("click", () => {
|
||||
openModal(vertexApiKeyModal);
|
||||
@@ -428,7 +437,7 @@ document.addEventListener("DOMContentLoaded", function () {
|
||||
"click",
|
||||
handleBulkAddVertexApiKeys
|
||||
);
|
||||
|
||||
|
||||
if (bulkDeleteVertexApiKeyBtn) {
|
||||
bulkDeleteVertexApiKeyBtn.addEventListener("click", () => {
|
||||
openModal(bulkDeleteVertexApiKeyModal);
|
||||
@@ -448,7 +457,7 @@ document.addEventListener("DOMContentLoaded", function () {
|
||||
"click",
|
||||
handleBulkDeleteVertexApiKeys
|
||||
);
|
||||
|
||||
|
||||
// Model Helper Modal Event Listeners
|
||||
if (closeModelHelperModalBtn) {
|
||||
closeModelHelperModalBtn.addEventListener("click", () =>
|
||||
@@ -691,12 +700,26 @@ async function initConfig() {
|
||||
) {
|
||||
config.THINKING_BUDGET_MAP = {}; // 默认为空对象
|
||||
}
|
||||
// --- 新增:处理 CUSTOM_HEADERS 默认值 ---
|
||||
if (
|
||||
!config.CUSTOM_HEADERS ||
|
||||
typeof config.CUSTOM_HEADERS !== "object" ||
|
||||
config.CUSTOM_HEADERS === null
|
||||
) {
|
||||
config.CUSTOM_HEADERS = {}; // 默认为空对象
|
||||
}
|
||||
// --- 新增:处理 SAFETY_SETTINGS 默认值 ---
|
||||
if (!config.SAFETY_SETTINGS || !Array.isArray(config.SAFETY_SETTINGS)) {
|
||||
config.SAFETY_SETTINGS = []; // 默认为空数组
|
||||
}
|
||||
// --- 结束:处理 SAFETY_SETTINGS 默认值 ---
|
||||
|
||||
if (typeof config.URL_CONTEXT_ENABLED === "undefined") {
|
||||
config.URL_CONTEXT_ENABLED = true;
|
||||
}
|
||||
if (!config.URL_CONTEXT_MODELS || !Array.isArray(config.URL_CONTEXT_MODELS)) {
|
||||
config.URL_CONTEXT_MODELS = [];
|
||||
}
|
||||
|
||||
// --- 新增:处理自动删除错误日志配置的默认值 ---
|
||||
if (typeof config.AUTO_DELETE_ERROR_LOGS_ENABLED === "undefined") {
|
||||
config.AUTO_DELETE_ERROR_LOGS_ENABLED = false;
|
||||
@@ -756,6 +779,7 @@ async function initConfig() {
|
||||
VERTEX_EXPRESS_BASE_URL: "", // 确保默认值存在
|
||||
THINKING_MODELS: [],
|
||||
THINKING_BUDGET_MAP: {},
|
||||
CUSTOM_HEADERS: {},
|
||||
AUTO_DELETE_ERROR_LOGS_ENABLED: false,
|
||||
AUTO_DELETE_ERROR_LOGS_DAYS: 7, // 新增默认值
|
||||
AUTO_DELETE_REQUEST_LOGS_ENABLED: false, // 新增默认值
|
||||
@@ -765,7 +789,7 @@ async function initConfig() {
|
||||
FAKE_STREAM_EMPTY_DATA_INTERVAL_SECONDS: 5,
|
||||
// --- 结束:处理假流式配置的默认值 ---
|
||||
};
|
||||
|
||||
|
||||
populateForm(defaultConfig);
|
||||
if (configForm) {
|
||||
// Ensure form exists
|
||||
@@ -854,6 +878,26 @@ function populateForm(config) {
|
||||
'<div class="text-gray-500 text-sm italic">请在上方添加思考模型,预算将自动关联。</div>';
|
||||
}
|
||||
|
||||
// Populate CUSTOM_HEADERS
|
||||
const customHeadersContainer = document.getElementById(
|
||||
"CUSTOM_HEADERS_container"
|
||||
);
|
||||
let customHeadersAdded = false;
|
||||
if (
|
||||
customHeadersContainer &&
|
||||
config.CUSTOM_HEADERS &&
|
||||
typeof config.CUSTOM_HEADERS === "object"
|
||||
) {
|
||||
for (const [key, value] of Object.entries(config.CUSTOM_HEADERS)) {
|
||||
createAndAppendCustomHeaderItem(key, value);
|
||||
customHeadersAdded = true;
|
||||
}
|
||||
}
|
||||
if (!customHeadersAdded && customHeadersContainer) {
|
||||
customHeadersContainer.innerHTML =
|
||||
'<div class="text-gray-500 text-sm italic">添加自定义请求头,例如 X-Api-Key: your-key</div>';
|
||||
}
|
||||
|
||||
// 4. Populate other array fields (excluding THINKING_MODELS)
|
||||
for (const [key, value] of Object.entries(config)) {
|
||||
if (Array.isArray(value) && key !== "THINKING_MODELS") {
|
||||
@@ -1177,25 +1221,21 @@ function handleBulkDeleteProxies() {
|
||||
}
|
||||
bulkDeleteProxyInput.value = "";
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Handles the bulk addition of Vertex API keys from the modal input.
|
||||
* Handles the bulk addition of Vertex Express API keys from the modal input.
|
||||
*/
|
||||
function handleBulkAddVertexApiKeys() {
|
||||
const vertexApiKeyContainer = document.getElementById(
|
||||
"VERTEX_API_KEYS_container"
|
||||
);
|
||||
if (
|
||||
!vertexApiKeyBulkInput ||
|
||||
!vertexApiKeyContainer ||
|
||||
!vertexApiKeyModal
|
||||
) {
|
||||
if (!vertexApiKeyBulkInput || !vertexApiKeyContainer || !vertexApiKeyModal) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
const bulkText = vertexApiKeyBulkInput.value;
|
||||
const extractedKeys = bulkText.match(VERTEX_API_KEY_REGEX) || [];
|
||||
|
||||
|
||||
const currentKeyInputs = vertexApiKeyContainer.querySelectorAll(
|
||||
`.${ARRAY_INPUT_CLASS}.${SENSITIVE_INPUT_CLASS}`
|
||||
);
|
||||
@@ -1206,16 +1246,16 @@ function handleBulkAddVertexApiKeys() {
|
||||
: input.value;
|
||||
})
|
||||
.filter((key) => key && key.trim() !== "" && key !== MASKED_VALUE);
|
||||
|
||||
|
||||
const combinedKeys = new Set([...currentKeys, ...extractedKeys]);
|
||||
const uniqueKeys = Array.from(combinedKeys);
|
||||
|
||||
|
||||
vertexApiKeyContainer.innerHTML = ""; // Clear existing items
|
||||
|
||||
|
||||
uniqueKeys.forEach((key) => {
|
||||
addArrayItemWithValue("VERTEX_API_KEYS", key); // VERTEX_API_KEYS are sensitive
|
||||
});
|
||||
|
||||
|
||||
// Ensure new sensitive inputs are masked
|
||||
const newKeyInputs = vertexApiKeyContainer.querySelectorAll(
|
||||
`.${ARRAY_INPUT_CLASS}.${SENSITIVE_INPUT_CLASS}`
|
||||
@@ -1229,7 +1269,7 @@ function handleBulkAddVertexApiKeys() {
|
||||
input.dispatchEvent(focusoutEvent);
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
closeModal(vertexApiKeyModal);
|
||||
showNotification(
|
||||
`添加/更新了 ${uniqueKeys.length} 个唯一 Vertex 密钥`,
|
||||
@@ -1237,9 +1277,9 @@ function handleBulkAddVertexApiKeys() {
|
||||
);
|
||||
vertexApiKeyBulkInput.value = "";
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Handles the bulk deletion of Vertex API keys based on input from the modal.
|
||||
* Handles the bulk deletion of Vertex Express API keys based on input from the modal.
|
||||
*/
|
||||
function handleBulkDeleteVertexApiKeys() {
|
||||
const vertexApiKeyContainer = document.getElementById(
|
||||
@@ -1252,26 +1292,28 @@ function handleBulkDeleteVertexApiKeys() {
|
||||
) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
const bulkText = bulkDeleteVertexApiKeyInput.value;
|
||||
if (!bulkText.trim()) {
|
||||
showNotification("请粘贴需要删除的 Vertex API 密钥", "warning");
|
||||
showNotification("请粘贴需要删除的 Vertex Express API 密钥", "warning");
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
const keysToDelete = new Set(bulkText.match(VERTEX_API_KEY_REGEX) || []);
|
||||
|
||||
|
||||
if (keysToDelete.size === 0) {
|
||||
showNotification(
|
||||
"未在输入内容中提取到有效的 Vertex API 密钥格式",
|
||||
"未在输入内容中提取到有效的 Vertex Express API 密钥格式",
|
||||
"warning"
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const keyItems = vertexApiKeyContainer.querySelectorAll(`.${ARRAY_ITEM_CLASS}`);
|
||||
|
||||
const keyItems = vertexApiKeyContainer.querySelectorAll(
|
||||
`.${ARRAY_ITEM_CLASS}`
|
||||
);
|
||||
let deleteCount = 0;
|
||||
|
||||
|
||||
keyItems.forEach((item) => {
|
||||
const input = item.querySelector(
|
||||
`.${ARRAY_INPUT_CLASS}.${SENSITIVE_INPUT_CLASS}`
|
||||
@@ -1286,17 +1328,20 @@ function handleBulkDeleteVertexApiKeys() {
|
||||
deleteCount++;
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
closeModal(bulkDeleteVertexApiKeyModal);
|
||||
|
||||
|
||||
if (deleteCount > 0) {
|
||||
showNotification(`成功删除了 ${deleteCount} 个匹配的 Vertex 密钥`, "success");
|
||||
showNotification(
|
||||
`成功删除了 ${deleteCount} 个匹配的 Vertex 密钥`,
|
||||
"success"
|
||||
);
|
||||
} else {
|
||||
showNotification("列表中未找到您输入的任何 Vertex 密钥进行删除", "info");
|
||||
}
|
||||
bulkDeleteVertexApiKeyInput.value = "";
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Switches the active configuration tab.
|
||||
* @param {string} tabId - The ID of the tab to switch to.
|
||||
@@ -1305,8 +1350,10 @@ function switchTab(tabId) {
|
||||
console.log(`Switching to tab: ${tabId}`);
|
||||
|
||||
// 定义选中态和未选中态的样式
|
||||
const activeStyle = "background-color: #3b82f6 !important; color: #ffffff !important; border: 2px solid #2563eb !important; box-shadow: 0 4px 12px -2px rgba(59, 130, 246, 0.4), 0 2px 6px -1px rgba(59, 130, 246, 0.2) !important; transform: translateY(-2px) !important; font-weight: 600 !important;";
|
||||
const inactiveStyle = "background-color: #f8fafc !important; color: #64748b !important; border: 2px solid #e2e8f0 !important; box-shadow: 0 1px 3px 0 rgba(0, 0, 0, 0.1) !important; font-weight: 500 !important; transform: none !important;";
|
||||
const activeStyle =
|
||||
"background-color: #3b82f6 !important; color: #ffffff !important; border: 2px solid #2563eb !important; box-shadow: 0 4px 12px -2px rgba(59, 130, 246, 0.4), 0 2px 6px -1px rgba(59, 130, 246, 0.2) !important; transform: translateY(-2px) !important; font-weight: 600 !important;";
|
||||
const inactiveStyle =
|
||||
"background-color: #f8fafc !important; color: #64748b !important; border: 2px solid #e2e8f0 !important; box-shadow: 0 1px 3px 0 rgba(0, 0, 0, 0.1) !important; font-weight: 500 !important; transform: none !important;";
|
||||
|
||||
// 更新标签按钮状态
|
||||
const tabButtons = document.querySelectorAll(".tab-btn");
|
||||
@@ -1421,7 +1468,7 @@ function addArrayItem(key) {
|
||||
const modelId = addArrayItemWithValue(key, newItemValue); // This adds the DOM element
|
||||
|
||||
if (key === "THINKING_MODELS" && modelId) {
|
||||
createAndAppendBudgetMapItem(newItemValue, 0, modelId); // Default budget 0
|
||||
createAndAppendBudgetMapItem(newItemValue, -1, modelId); // Default budget -1
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1439,10 +1486,9 @@ function addArrayItemWithValue(key, value) {
|
||||
const isThinkingModel = key === "THINKING_MODELS";
|
||||
const isAllowedToken = key === "ALLOWED_TOKENS";
|
||||
const isVertexApiKey = key === "VERTEX_API_KEYS"; // 新增判断
|
||||
const isSensitive =
|
||||
key === "API_KEYS" || isAllowedToken || isVertexApiKey; // 更新敏感判断
|
||||
const isSensitive = key === "API_KEYS" || isAllowedToken || isVertexApiKey; // 更新敏感判断
|
||||
const modelId = isThinkingModel ? generateUUID() : null;
|
||||
|
||||
|
||||
const arrayItem = document.createElement("div");
|
||||
arrayItem.className = `${ARRAY_ITEM_CLASS} flex items-center mb-2 gap-2`;
|
||||
if (isThinkingModel) {
|
||||
@@ -1532,17 +1578,17 @@ function createAndAppendBudgetMapItem(mapKey, mapValue, modelId) {
|
||||
const valueInput = document.createElement("input");
|
||||
valueInput.type = "number";
|
||||
const intValue = parseInt(mapValue, 10);
|
||||
valueInput.value = isNaN(intValue) ? 0 : intValue;
|
||||
valueInput.value = isNaN(intValue) ? -1 : intValue;
|
||||
valueInput.placeholder = "预算 (整数)";
|
||||
valueInput.className = `${MAP_VALUE_INPUT_CLASS} w-24 px-3 py-2 border border-gray-300 rounded-md focus:outline-none focus:border-primary-500 focus:ring focus:ring-primary-200 focus:ring-opacity-50`;
|
||||
valueInput.min = 0;
|
||||
valueInput.max = 24576;
|
||||
valueInput.min = -1;
|
||||
valueInput.max = 32767;
|
||||
valueInput.addEventListener("input", function () {
|
||||
let val = this.value.replace(/[^0-9]/g, "");
|
||||
let val = this.value.replace(/[^0-9-]/g, "");
|
||||
if (val !== "") {
|
||||
val = parseInt(val, 10);
|
||||
if (val < 0) val = 0;
|
||||
if (val > 24576) val = 24576;
|
||||
if (val < -1) val = -1;
|
||||
if (val > 32767) val = 32767;
|
||||
}
|
||||
this.value = val; // Corrected variable name
|
||||
});
|
||||
@@ -1562,6 +1608,67 @@ function createAndAppendBudgetMapItem(mapKey, mapValue, modelId) {
|
||||
container.appendChild(mapItem);
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds a new custom header item to the DOM.
|
||||
*/
|
||||
function addCustomHeaderItem() {
|
||||
createAndAppendCustomHeaderItem("", "");
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates and appends a DOM element for a custom header.
|
||||
* @param {string} key - The header key.
|
||||
* @param {string} value - The header value.
|
||||
*/
|
||||
function createAndAppendCustomHeaderItem(key, value) {
|
||||
const container = document.getElementById("CUSTOM_HEADERS_container");
|
||||
if (!container) {
|
||||
console.error(
|
||||
"Cannot add custom header: CUSTOM_HEADERS_container not found!"
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const placeholder = container.querySelector(".text-gray-500.italic");
|
||||
if (
|
||||
placeholder &&
|
||||
container.children.length === 1 &&
|
||||
container.firstChild === placeholder
|
||||
) {
|
||||
container.innerHTML = "";
|
||||
}
|
||||
|
||||
const headerItem = document.createElement("div");
|
||||
headerItem.className = `${CUSTOM_HEADER_ITEM_CLASS} flex items-center mb-2 gap-2`;
|
||||
|
||||
const keyInput = document.createElement("input");
|
||||
keyInput.type = "text";
|
||||
keyInput.value = key;
|
||||
keyInput.placeholder = "Header Name";
|
||||
keyInput.className = `${CUSTOM_HEADER_KEY_INPUT_CLASS} flex-grow px-3 py-2 border border-gray-300 rounded-md focus:outline-none bg-gray-100 text-gray-500`;
|
||||
|
||||
const valueInput = document.createElement("input");
|
||||
valueInput.type = "text";
|
||||
valueInput.value = value;
|
||||
valueInput.placeholder = "Header Value";
|
||||
valueInput.className = `${CUSTOM_HEADER_VALUE_INPUT_CLASS} flex-grow px-3 py-2 border border-gray-300 rounded-md focus:outline-none focus:border-primary-500 focus:ring focus:ring-primary-200 focus:ring-opacity-50`;
|
||||
|
||||
const removeBtn = createRemoveButton();
|
||||
removeBtn.addEventListener("click", () => {
|
||||
headerItem.remove();
|
||||
if (container.children.length === 0) {
|
||||
container.innerHTML =
|
||||
'<div class="text-gray-500 text-sm italic">添加自定义请求头,例如 X-Api-Key: your-key</div>';
|
||||
}
|
||||
});
|
||||
|
||||
headerItem.appendChild(keyInput);
|
||||
headerItem.appendChild(valueInput);
|
||||
headerItem.appendChild(removeBtn);
|
||||
|
||||
container.appendChild(headerItem);
|
||||
}
|
||||
|
||||
/**
|
||||
* Collects all data from the configuration form.
|
||||
* @returns {object} An object containing all configuration data.
|
||||
@@ -1632,12 +1739,32 @@ function collectFormData() {
|
||||
formData["THINKING_BUDGET_MAP"][keyInput.value.trim()] = isNaN(
|
||||
budgetValue
|
||||
)
|
||||
? 0
|
||||
? -1
|
||||
: budgetValue;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
const customHeadersContainer = document.getElementById(
|
||||
"CUSTOM_HEADERS_container"
|
||||
);
|
||||
if (customHeadersContainer) {
|
||||
formData["CUSTOM_HEADERS"] = {};
|
||||
const customHeaderItems = customHeadersContainer.querySelectorAll(
|
||||
`.${CUSTOM_HEADER_ITEM_CLASS}`
|
||||
);
|
||||
customHeaderItems.forEach((item) => {
|
||||
const keyInput = item.querySelector(`.${CUSTOM_HEADER_KEY_INPUT_CLASS}`);
|
||||
const valueInput = item.querySelector(
|
||||
`.${CUSTOM_HEADER_VALUE_INPUT_CLASS}`
|
||||
);
|
||||
if (keyInput && valueInput && keyInput.value.trim() !== "") {
|
||||
formData["CUSTOM_HEADERS"][keyInput.value.trim()] =
|
||||
valueInput.value.trim();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
if (safetySettingsContainer) {
|
||||
formData["SAFETY_SETTINGS"] = [];
|
||||
const settingItems = safetySettingsContainer.querySelectorAll(
|
||||
@@ -2163,7 +2290,7 @@ function handleModelSelection(selectedModelId) {
|
||||
);
|
||||
if (currentModelHelperTarget.targetKey === "THINKING_MODELS" && modelId) {
|
||||
// Automatically add corresponding budget map item with default budget 0
|
||||
createAndAppendBudgetMapItem(selectedModelId, 0, modelId);
|
||||
createAndAppendBudgetMapItem(selectedModelId, -1, modelId);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -817,58 +817,11 @@ function toggleSection(header, sectionId) {
|
||||
}
|
||||
}
|
||||
|
||||
// 筛选有效密钥(根据失败次数阈值)并更新批量操作状态
|
||||
// filterValidKeys 函数已被 filterAndSearchValidKeys 替代,此函数保留为空或可移除
|
||||
function filterValidKeys() {
|
||||
const thresholdInput = document.getElementById("failCountThreshold");
|
||||
const validKeysList = document.getElementById("validKeys"); // Get the UL element
|
||||
if (!validKeysList) return; // Exit if the list doesn't exist
|
||||
|
||||
const validKeyItems = validKeysList.querySelectorAll("li[data-key]"); // Select li elements within the list
|
||||
// 读取阈值,如果输入无效或为空,则默认为0(不过滤)
|
||||
const threshold = parseInt(thresholdInput.value, 10);
|
||||
const filterThreshold = isNaN(threshold) || threshold < 0 ? 0 : threshold;
|
||||
let hasVisibleItems = false;
|
||||
|
||||
validKeyItems.forEach((item) => {
|
||||
// 确保只处理包含 data-fail-count 的 li 元素
|
||||
if (item.dataset.failCount !== undefined) {
|
||||
const failCount = parseInt(item.dataset.failCount, 10);
|
||||
// 如果失败次数大于等于阈值,则显示,否则隐藏
|
||||
if (failCount >= filterThreshold) {
|
||||
item.style.display = "flex"; // 使用 flex 因为 li 现在是 flex 容器
|
||||
hasVisibleItems = true;
|
||||
} else {
|
||||
item.style.display = "none"; // 隐藏
|
||||
// 如果隐藏了一个项,取消其选中状态
|
||||
const checkbox = item.querySelector(".key-checkbox");
|
||||
if (checkbox && checkbox.checked) {
|
||||
checkbox.checked = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// 更新有效密钥的批量操作状态和全选复选框
|
||||
updateBatchActions("valid");
|
||||
|
||||
// 处理"暂无有效密钥"消息
|
||||
const noMatchMsgId = "no-valid-keys-msg";
|
||||
let noMatchMsg = validKeysList.querySelector(`#${noMatchMsgId}`);
|
||||
const initialKeyCount = validKeysList.querySelectorAll("li[data-key]").length; // 获取初始密钥数量
|
||||
|
||||
if (!hasVisibleItems && initialKeyCount > 0) {
|
||||
// 仅当初始有密钥但现在都不可见时显示
|
||||
if (!noMatchMsg) {
|
||||
noMatchMsg = document.createElement("li");
|
||||
noMatchMsg.id = noMatchMsgId;
|
||||
noMatchMsg.className = "text-center text-gray-500 py-4 col-span-full";
|
||||
noMatchMsg.textContent = "没有符合条件的有效密钥";
|
||||
validKeysList.appendChild(noMatchMsg);
|
||||
}
|
||||
noMatchMsg.style.display = "";
|
||||
} else if (noMatchMsg) {
|
||||
noMatchMsg.style.display = "none";
|
||||
}
|
||||
// This function is now handled by filterAndSearchValidKeys
|
||||
// Kept for now to avoid breaking any potential legacy calls, but should be removed later.
|
||||
filterAndSearchValidKeys();
|
||||
}
|
||||
|
||||
// --- Initialization Helper Functions ---
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1245,6 +1245,7 @@ endblock %} {% block head_extra_styles %}
|
||||
<option value="20">20</option>
|
||||
<option value="50">50</option>
|
||||
<option value="100">100</option>
|
||||
<option value="500">500</option>
|
||||
</select>
|
||||
<span class="text-sm select-none font-semibold" style="color: #1f2937 !important;">项</span>
|
||||
</div>
|
||||
|
||||
@@ -261,18 +261,20 @@ class PicGoUploader(ImageUploader):
|
||||
|
||||
class CloudFlareImgBedUploader(ImageUploader):
|
||||
"""CloudFlare图床上传器"""
|
||||
|
||||
def __init__(self, auth_code: str, api_url: str):
|
||||
|
||||
def __init__(self, auth_code: str, api_url: str, upload_folder: str = ""):
|
||||
"""
|
||||
初始化CloudFlare图床上传器
|
||||
|
||||
Args:
|
||||
auth_code: 认证码
|
||||
api_url: 上传API地址
|
||||
upload_folder: 上传文件夹路径(可选)
|
||||
"""
|
||||
self.auth_code = auth_code
|
||||
self.api_url = api_url
|
||||
|
||||
self.upload_folder = upload_folder
|
||||
|
||||
def upload(self, file: bytes, filename: str) -> UploadResponse:
|
||||
"""
|
||||
上传图片到CloudFlare图床
|
||||
@@ -288,12 +290,16 @@ class CloudFlareImgBedUploader(ImageUploader):
|
||||
UploadError: 上传失败时抛出异常
|
||||
"""
|
||||
try:
|
||||
# 准备请求URL(添加认证码参数,如果存在)
|
||||
# 准备请求URL参数
|
||||
params = []
|
||||
if self.upload_folder:
|
||||
params.append(f"uploadFolder={self.upload_folder}")
|
||||
if self.auth_code:
|
||||
request_url = f"{self.api_url}?authCode={self.auth_code}&uploadNameType=origin"
|
||||
else:
|
||||
request_url = f"{self.api_url}?uploadNameType=origin"
|
||||
|
||||
params.append(f"authCode={self.auth_code}")
|
||||
params.append("uploadNameType=origin")
|
||||
|
||||
request_url = f"{self.api_url}?{'&'.join(params)}"
|
||||
|
||||
# 准备文件数据
|
||||
files = {
|
||||
"file": (filename, file)
|
||||
@@ -388,6 +394,7 @@ class ImageUploaderFactory:
|
||||
elif provider == "cloudflare_imgbed":
|
||||
return CloudFlareImgBedUploader(
|
||||
credentials["auth_code"],
|
||||
credentials["base_url"]
|
||||
credentials["base_url"],
|
||||
credentials.get("upload_folder", ""),
|
||||
)
|
||||
raise ValueError(f"Unknown provider: {provider}")
|
||||
|
||||
Reference in New Issue
Block a user