mirror of
https://github.com/snailyp/gemini-balance.git
synced 2026-07-03 13:51:30 +08:00
Compare commits
138 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5512f7ff21 | ||
|
|
150824938c | ||
|
|
ccaea40281 | ||
|
|
9d8e77c9f7 | ||
|
|
19941f7f50 | ||
|
|
d6981c204a | ||
|
|
d386cc7180 | ||
|
|
bed3647424 | ||
|
|
95b5acad66 | ||
|
|
68b65814bc | ||
|
|
88f5b33018 | ||
|
|
8c62c8121d | ||
|
|
05762cb6a5 | ||
|
|
78f38cc981 | ||
|
|
79f47c315e | ||
|
|
708fb1604b | ||
|
|
7dbd3ad693 | ||
|
|
67dd1af583 | ||
|
|
e104a50cf4 | ||
|
|
6b9647813b | ||
|
|
f863e3065b | ||
|
|
1314e0ee09 | ||
|
|
81d92370ad | ||
|
|
5f6eba62cc | ||
|
|
a8a265c2a7 | ||
|
|
ee21e50305 | ||
|
|
611559d298 | ||
|
|
b0127e6fc2 | ||
|
|
1d15a21ce5 | ||
|
|
c206aa8e4a | ||
|
|
3f040b7075 | ||
|
|
1771555fe9 | ||
|
|
8711088ebc | ||
|
|
bb6c629aef | ||
|
|
4af17ce55d | ||
|
|
2001bfdcd9 | ||
|
|
669123f348 | ||
|
|
d06e418a61 | ||
|
|
fa6745454e | ||
|
|
1aa3d267bb | ||
|
|
e9601ca76c | ||
|
|
01312317a1 | ||
|
|
7827283d0a | ||
|
|
96c4b4fa50 | ||
|
|
892392742d | ||
|
|
380e6426ed | ||
|
|
d2906d89a6 | ||
|
|
13e1db7d69 | ||
|
|
40c9689eae | ||
|
|
548dcccf2f | ||
|
|
b52092a72b | ||
|
|
67efd067c6 | ||
|
|
fd39c2c9cb | ||
|
|
f58ae2b340 | ||
|
|
f51a4d20ad | ||
|
|
b89d3ea144 | ||
|
|
3d6b5063d5 | ||
|
|
a6558b4668 | ||
|
|
6f714649a7 | ||
|
|
ae775760dd | ||
|
|
d475ccdece | ||
|
|
4eed3a48db | ||
|
|
26f3dbd12b | ||
|
|
7af53de782 | ||
|
|
2270f6d998 | ||
|
|
9f5892a987 | ||
|
|
ccd4722a77 | ||
|
|
feb57d7cf2 | ||
|
|
e7394776af | ||
|
|
0fa9638dd5 | ||
|
|
9d4d6464bf | ||
|
|
f3d9cb2b85 | ||
|
|
6abda7d902 | ||
|
|
b25cf7d978 | ||
|
|
07481ca972 | ||
|
|
9c285e38ef | ||
|
|
ebfa1d247c | ||
|
|
cdb85ef9b7 | ||
|
|
7006522c13 | ||
|
|
530c958afc | ||
|
|
57d861b578 | ||
|
|
99664298b9 | ||
|
|
a6fe5a7022 | ||
|
|
1918dad602 | ||
|
|
69399c291e | ||
|
|
9ec33ce320 | ||
|
|
c35d3aff7d | ||
|
|
2a5744d1c4 | ||
|
|
825511506b | ||
|
|
5a98a701cb | ||
|
|
dd1fa35c73 | ||
|
|
fb572fa849 | ||
|
|
c0a473ed19 | ||
|
|
030641adc6 | ||
|
|
445ef49dc8 | ||
|
|
32d4c60541 | ||
|
|
23f865be07 | ||
|
|
5d55325c12 | ||
|
|
900330509a | ||
|
|
cfb682ae3c | ||
|
|
abae90b16d | ||
|
|
470fc37f26 | ||
|
|
7a7caef1a6 | ||
|
|
a6aecb5d89 | ||
|
|
4a004f9aa1 | ||
|
|
1a6feae23b | ||
|
|
af5b2fa2c9 | ||
|
|
eeec45274b | ||
|
|
2b48c853fe | ||
|
|
c47f696691 | ||
|
|
9a8e4c8e15 | ||
|
|
24aab9a658 | ||
|
|
afdaaffac5 | ||
|
|
fe721116e2 | ||
|
|
8e0a834daa | ||
|
|
c9fca1561c | ||
|
|
5eb2dfd822 | ||
|
|
0b837c3f80 | ||
|
|
a6cfc12443 | ||
|
|
f6d64dd850 | ||
|
|
eed62caa78 | ||
|
|
204d41d6f3 | ||
|
|
858df0548e | ||
|
|
b3da021803 | ||
|
|
d234f826f4 | ||
|
|
231b69ecf8 | ||
|
|
0a08913677 | ||
|
|
49d32813ea | ||
|
|
c5d57e97b1 | ||
|
|
da8f7539a1 | ||
|
|
64a68f1176 | ||
|
|
1199d7cc3c | ||
|
|
8a827d2acb | ||
|
|
0e8a943d7f | ||
|
|
4f62658440 | ||
|
|
6e7c3d5f6a | ||
|
|
d5062db9b6 | ||
|
|
a6ad006a49 |
34
.env.example
34
.env.example
@@ -14,12 +14,15 @@ AUTH_TOKEN=sk-123456
|
||||
VERTEX_API_KEYS=["AQ.Abxxxxxxxxxxxxxxxxxxx"]
|
||||
# For Vertex AI Platform Express API Base URL
|
||||
VERTEX_EXPRESS_BASE_URL=https://aiplatform.googleapis.com/v1beta1/publishers/google
|
||||
TEST_MODEL=gemini-1.5-flash
|
||||
THINKING_MODELS=["gemini-2.5-flash-preview-04-17"]
|
||||
THINKING_BUDGET_MAP={"gemini-2.5-flash-preview-04-17": 4000}
|
||||
IMAGE_MODELS=["gemini-2.0-flash-exp"]
|
||||
SEARCH_MODELS=["gemini-2.0-flash-exp","gemini-2.0-pro-exp"]
|
||||
TEST_MODEL=gemini-2.5-flash-lite
|
||||
THINKING_MODELS=["gemini-2.5-flash","gemini-2.5-pro"]
|
||||
THINKING_BUDGET_MAP={"gemini-2.5-flash": -1}
|
||||
IMAGE_MODELS=["gemini-2.0-flash-exp", "gemini-2.5-flash-image-preview"]
|
||||
SEARCH_MODELS=["gemini-2.5-flash","gemini-2.5-pro"]
|
||||
FILTERED_MODELS=["gemini-1.0-pro-vision-latest", "gemini-pro-vision", "chat-bison-001", "text-bison-001", "embedding-gecko-001"]
|
||||
# 是否启用网址上下文,默认启用
|
||||
URL_CONTEXT_ENABLED=false
|
||||
URL_CONTEXT_MODELS=["gemini-2.5-pro","gemini-2.5-flash","gemini-2.5-flash-lite","gemini-2.0-flash","gemini-2.0-flash-live-001"]
|
||||
TOOLS_CODE_EXECUTION_ENABLED=false
|
||||
SHOW_SEARCH_LINK=true
|
||||
SHOW_THINKING_PROCESS=true
|
||||
@@ -41,8 +44,17 @@ CREATE_IMAGE_MODEL=imagen-3.0-generate-002
|
||||
UPLOAD_PROVIDER=smms
|
||||
SMMS_SECRET_TOKEN=XXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
|
||||
PICGO_API_KEY=xxxx
|
||||
PICGO_API_URL=https://www.picgo.net/api/1/upload
|
||||
CLOUDFLARE_IMGBED_URL=https://xxxxxxx.pages.dev/upload
|
||||
CLOUDFLARE_IMGBED_AUTH_CODE=xxxxxxxxx
|
||||
CLOUDFLARE_IMGBED_UPLOAD_FOLDER=
|
||||
# 阿里云OSS配置
|
||||
OSS_ENDPOINT=oss-cn-shanghai.aliyuncs.com
|
||||
OSS_ENDPOINT_INNER=oss-cn-shanghai-internal.aliyuncs.com
|
||||
OSS_ACCESS_KEY=LTAI5txxxxxxxxxxxxxxxx
|
||||
OSS_ACCESS_KEY_SECRET=yXxxxxxxxxxxxxxxxxxxxxx
|
||||
OSS_BUCKET_NAME=your-bucket-name
|
||||
OSS_REGION=cn-shanghai
|
||||
##########################################################################
|
||||
#########################stream_optimizer 相关配置########################
|
||||
STREAM_OPTIMIZER_ENABLED=false
|
||||
@@ -55,6 +67,8 @@ STREAM_CHUNK_SIZE=5
|
||||
######################### 日志配置 #######################################
|
||||
# 日志级别 (debug, info, warning, error, critical),默认为 info
|
||||
LOG_LEVEL=info
|
||||
# 是否记录错误日志的请求体(可能包含敏感信息),默认 false
|
||||
ERROR_LOG_RECORD_REQUEST_BODY=false
|
||||
# 是否开启自动删除错误日志
|
||||
AUTO_DELETE_ERROR_LOGS_ENABLED=true
|
||||
# 自动删除多少天前的错误日志 (1, 7, 30)
|
||||
@@ -78,4 +92,12 @@ URL_NORMALIZATION_ENABLED=false
|
||||
# tts配置
|
||||
TTS_MODEL=gemini-2.5-flash-preview-tts
|
||||
TTS_VOICE_NAME=Zephyr
|
||||
TTS_SPEED=normal
|
||||
TTS_SPEED=normal
|
||||
#########################Files API 相关配置########################
|
||||
# 是否启用文件过期自动清理
|
||||
FILES_CLEANUP_ENABLED=true
|
||||
# 文件过期清理间隔(小时)
|
||||
FILES_CLEANUP_INTERVAL_HOURS=1
|
||||
# 是否启用用户文件隔离(每个用户只能看到自己上传的文件)
|
||||
FILES_USER_ISOLATION_ENABLED=true
|
||||
##########################################################################
|
||||
22
.github/workflows/release.yml
vendored
22
.github/workflows/release.yml
vendored
@@ -3,7 +3,7 @@ name: Publish Release
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v*' # 当推送以 "v" 开头的标签时触发(如 v1.0.0, v2.1.0)
|
||||
- "v*" # 当推送以 "v" 开头的标签时触发(如 v1.0.0, v2.1.0)
|
||||
|
||||
jobs:
|
||||
update-release-draft:
|
||||
@@ -15,8 +15,17 @@ jobs:
|
||||
# Step 1: 检出代码库
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
# Step 2: 自动生成 Release
|
||||
# Step 2: 自动生成 Release Notes
|
||||
- name: Generate release notes
|
||||
id: changelog
|
||||
uses: mikepenz/release-changelog-builder-action@v4
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
# Step 3: 自动生成 Release
|
||||
- name: Create Release
|
||||
id: create_release
|
||||
uses: actions/create-release@v1
|
||||
@@ -25,15 +34,16 @@ jobs:
|
||||
with:
|
||||
tag_name: ${{ github.ref_name }}
|
||||
release_name: ${{ github.ref_name }}
|
||||
body: ${{ steps.changelog.outputs.changelog }}
|
||||
draft: false
|
||||
prerelease: false
|
||||
|
||||
# Step 3: 可选,构建zip文件
|
||||
|
||||
# Step 4: 可选,构建zip文件
|
||||
- name: Create ZIP file
|
||||
run: |
|
||||
zip -r gemini-balance.zip . -x "*.git*" "*.github*" "*.env*" "logs/*" "tests/*"
|
||||
|
||||
# Step 4: 可选,上传构建文件
|
||||
# Step 5: 可选,上传构建文件
|
||||
- name: Upload Release Asset
|
||||
uses: actions/upload-release-asset@v1
|
||||
env:
|
||||
@@ -41,5 +51,5 @@ jobs:
|
||||
with:
|
||||
upload_url: ${{ steps.create_release.outputs.upload_url }}
|
||||
asset_path: ./gemini-balance.zip # 替换为你的构建文件路径
|
||||
asset_name: gemini-balance.zip # 替换为你的文件名
|
||||
asset_name: gemini-balance.zip # 替换为你的文件名
|
||||
asset_content_type: application/zip
|
||||
|
||||
@@ -10,11 +10,7 @@ RUN pip install --no-cache-dir -r requirements.txt
|
||||
COPY ./app /app/app
|
||||
ENV API_KEYS='["your_api_key_1"]'
|
||||
ENV ALLOWED_TOKENS='["your_token_1"]'
|
||||
ENV BASE_URL=https://generativelanguage.googleapis.com/v1beta
|
||||
ENV TOOLS_CODE_EXECUTION_ENABLED=false
|
||||
ENV IMAGE_MODELS='["gemini-2.0-flash-exp"]'
|
||||
ENV SEARCH_MODELS='["gemini-2.0-flash-exp","gemini-2.0-pro-exp"]'
|
||||
ENV URL_NORMALIZATION_ENABLED=false
|
||||
ENV TZ='Asia/Shanghai'
|
||||
|
||||
# Expose port
|
||||
EXPOSE 8000
|
||||
|
||||
407
README.md
407
README.md
@@ -2,31 +2,39 @@
|
||||
|
||||
# Gemini Balance - Gemini API Proxy and Load Balancer
|
||||
|
||||
> ⚠️ This project is licensed under the CC BY-NC 4.0 (Attribution-NonCommercial) license. Any form of commercial resale service is prohibited. See the LICENSE file for details.
|
||||
<p align="center">
|
||||
<a href="https://trendshift.io/repositories/13692" target="_blank">
|
||||
<img src="https://trendshift.io/api/badge/repositories/13692" alt="snailyp%2Fgemini-balance | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/>
|
||||
</a>
|
||||
</p>
|
||||
|
||||
> I have never sold this service on any platform. If you encounter someone selling this service, they are definitely a reseller. Please be careful not to be deceived.
|
||||
<p align="center">
|
||||
<a href="https://www.python.org/"><img src="https://img.shields.io/badge/Python-3.9%2B-blue.svg" alt="Python"></a>
|
||||
<a href="https://fastapi.tiangolo.com/"><img src="https://img.shields.io/badge/FastAPI-0.100%2B-green.svg" alt="FastAPI"></a>
|
||||
<a href="https://www.uvicorn.org/"><img src="https://img.shields.io/badge/Uvicorn-running-purple.svg" alt="Uvicorn"></a>
|
||||
<a href="https://t.me/+soaHax5lyI0wZDVl"><img src="https://img.shields.io/badge/Telegram-Group-blue.svg?logo=telegram" alt="Telegram Group"></a>
|
||||
</p>
|
||||
|
||||
[](https://www.python.org/)
|
||||
[](https://fastapi.tiangolo.com/)
|
||||
[](https://www.uvicorn.org/)
|
||||
[](https://t.me/+soaHax5lyI0wZDVl)
|
||||
> ⚠️ **Important**: This project is licensed under the [CC BY-NC 4.0](LICENSE) license. **Any form of commercial resale service is prohibited**.
|
||||
> I have never sold this service on any platform. If you encounter someone selling this service, they are a reseller. Please do not be deceived.
|
||||
|
||||
> Telegram Group: <https://t.me/+soaHax5lyI0wZDVl>
|
||||
---
|
||||
|
||||
## Project Introduction
|
||||
## 📖 Project Introduction
|
||||
|
||||
Gemini Balance is an application built with Python FastAPI, designed to provide proxy and load balancing functions for the Google Gemini API. It allows you to manage multiple Gemini API Keys and implement key rotation, authentication, model filtering, and status monitoring through simple configuration. Additionally, the project integrates image generation and multiple image hosting upload functions, and supports proxying in the OpenAI API format.
|
||||
**Gemini Balance** is an application built with Python FastAPI, designed to provide proxy and load balancing functions for the Google Gemini API. It allows you to manage multiple Gemini API Keys and implement key rotation, authentication, model filtering, and status monitoring through simple configuration. Additionally, the project integrates image generation and multiple image hosting upload functions, and supports proxying in the OpenAI API format.
|
||||
|
||||
**Project Structure:**
|
||||
<details>
|
||||
<summary>📂 View Project Structure</summary>
|
||||
|
||||
```plaintext
|
||||
app/
|
||||
├── config/ # Configuration management
|
||||
├── core/ # Core application logic (FastAPI instance creation, middleware, etc.)
|
||||
├── database/ # Database models and connections
|
||||
├── domain/ # Business domain objects (optional)
|
||||
├── domain/ # Business domain objects
|
||||
├── exception/ # Custom exceptions
|
||||
├── handler/ # Request handlers (optional, or handled in router)
|
||||
├── handler/ # Request handlers
|
||||
├── log/ # Logging configuration
|
||||
├── main.py # Application entry point
|
||||
├── middleware/ # FastAPI middleware
|
||||
@@ -35,247 +43,242 @@ app/
|
||||
├── service/ # Business logic services (chat, Key management, statistics, etc.)
|
||||
├── static/ # Static files (CSS, JS)
|
||||
├── templates/ # HTML templates (e.g., Key status page)
|
||||
├── utils/ # Utility functions
|
||||
└── utils/ # Utility functions
|
||||
```
|
||||
</details>
|
||||
|
||||
---
|
||||
|
||||
## ✨ Feature Highlights
|
||||
|
||||
* **Multi-Key Load Balancing**: Supports configuring multiple Gemini API Keys (`API_KEYS`) for automatic sequential polling, improving availability and concurrency.
|
||||
* **Visual Configuration Takes Effect Immediately**: Configurations modified through the admin backend take effect without restarting the service. Remember to click save for changes to apply.
|
||||
* **Multi-Key Load Balancing**: Supports configuring multiple Gemini API Keys (`API_KEYS`) for automatic sequential polling.
|
||||
* **Visual Configuration**: Configurations modified through the admin backend take effect immediately without restarting.
|
||||

|
||||
* **Dual Protocol API Compatibility**: Supports forwarding CHAT API requests in both Gemini and OpenAI formats.
|
||||
|
||||
```plaintext
|
||||
openai baseurl `http://localhost:8000(/hf)/v1`
|
||||
gemini baseurl `http://localhost:8000(/gemini)/v1beta`
|
||||
```
|
||||
|
||||
* **Supports Image-Text Chat and Image Modification**: `IMAGE_MODELS` configures which models can perform image-text chat and image editing. When actually calling, use the `configured_model-image` model name to use this feature.
|
||||
* **Dual Protocol API Compatibility**: Supports both Gemini and OpenAI CHAT API formats.
|
||||
* OpenAI Base URL: `http://localhost:8000(/hf)/v1`
|
||||
* Gemini Base URL: `http://localhost:8000(/gemini)/v1beta`
|
||||
* **Image-Text Chat & Modification**: Configure models with `IMAGE_MODELS` to support image-text chat and editing. Use the `configured_model-image` model name to invoke.
|
||||

|
||||

|
||||
* **Supports Web Search**: Supports web search. `SEARCH_MODELS` configures which models can perform web searches. When actually calling, use the `configured_model-search` model name to use this feature.
|
||||
* **Web Search**: Configure models with `SEARCH_MODELS` to support web search. Use the `configured_model-search` model name to invoke.
|
||||

|
||||
* **Key Status Monitoring**: Provides a `/keys_status` page (requires authentication) to view the status and usage of each Key in real-time.
|
||||
* **Key Status Monitoring**: Provides a `/keys_status` page (authentication required) for real-time monitoring.
|
||||

|
||||
* **Detailed Logging**: Provides detailed error logs for easy troubleshooting.
|
||||
* **Detailed Logging**: Provides detailed error logs for easy troubleshooting.
|
||||

|
||||

|
||||

|
||||
* **Support for Custom Gemini Proxy**: Supports custom Gemini proxies, such as those built on Deno or Cloudflare.
|
||||
* **OpenAI Image Generation API Compatibility**: Adapts the `imagen-3.0-generate-002` model interface to be compatible with the OpenAI image generation API, supporting client calls.
|
||||
* **Flexible Key Addition**: Flexible way to add keys using regex matching for `gemini_key`, with key deduplication.
|
||||
* **Flexible Key Addition**: Add keys in batches using the `gemini_key` regex, with automatic deduplication.
|
||||

|
||||
* **OpenAI Format Embeddings API Compatibility**: Perfectly adapts to the OpenAI format `embeddings` interface, usable for local document vectorization.
|
||||
* **Streamlined Response Optimization**: Optional stream output optimizer (`STREAM_OPTIMIZER_ENABLED`) to improve the experience of long-text stream responses.
|
||||
* **Failure Retry and Key Management**: Automatically handles API request failures, retries (`MAX_RETRIES`), automatically disables Keys after too many failures (`MAX_FAILURES`), and periodically checks for recovery (`CHECK_INTERVAL_HOURS`).
|
||||
* **Docker Support**: Supports AMD and ARM architecture Docker deployments. You can also build your own Docker image.
|
||||
> Image address: docker pull ghcr.io/snailyp/gemini-balance:latest
|
||||
* **Automatic Model List Maintenance**: Supports fetching OpenAI and Gemini model lists, perfectly compatible with NewAPI's automatic model list fetching, no manual entry required.
|
||||
* **Support for Removing Unused Models**: Too many default models are provided, many of which are not used. You can filter them out using `FILTERED_MODELS`.
|
||||
* **Proxy Support**: Supports configuring HTTP/SOCKS5 proxy servers (`PROXIES`) for accessing the Gemini API, convenient for use in special network environments. Supports batch adding proxies.
|
||||
* **Failure Retry & Auto-Disable**: Automatically retries failed API requests (`MAX_RETRIES`) and disables keys after excessive failures (`MAX_FAILURES`).
|
||||
* **Comprehensive API Compatibility**:
|
||||
* **Embeddings API**: Fully compatible with the OpenAI `embeddings` API format.
|
||||
* **Image Generation API**: Adapts the `imagen-3.0-generate-002` model to the OpenAI image generation API format.
|
||||
* **Automatic Model List Maintenance**: Automatically fetches and syncs the latest model lists from Gemini and OpenAI.
|
||||
* **Proxy Support**: Supports HTTP/SOCKS5 proxies (`PROXIES`).
|
||||
* **Docker Support**: Provides Docker images for both AMD and ARM architectures.
|
||||
* Image Address: `ghcr.io/snailyp/gemini-balance:latest`
|
||||
|
||||
---
|
||||
|
||||
## 🚀 Quick Start
|
||||
|
||||
### Build Docker Yourself (Recommended)
|
||||
|
||||
#### a) Build with Dockerfile
|
||||
|
||||
1. **Build Image**:
|
||||
### Option 1: Docker Compose (Recommended)
|
||||
|
||||
1. **Get `docker-compose.yml`**:
|
||||
Download the `docker-compose.yml` file from the project repository.
|
||||
2. **Prepare `.env` file**:
|
||||
Copy `.env.example` to `.env` and configure it. Ensure `DATABASE_TYPE` is set to `mysql` and fill in the `MYSQL_*` details.
|
||||
3. **Start Services**:
|
||||
In the directory containing `docker-compose.yml` and `.env`, run:
|
||||
```bash
|
||||
docker build -t gemini-balance .
|
||||
docker-compose up -d
|
||||
```
|
||||
|
||||
2. **Run Container**:
|
||||
|
||||
```bash
|
||||
docker run -d -p 8000:8000 --env-file .env gemini-balance
|
||||
```
|
||||
|
||||
* `-d`: Run in detached mode.
|
||||
* `-p 8000:8000`: Map port 8000 of the container to port 8000 of the host.
|
||||
* `--env-file .env`: Use the `.env` file to set environment variables.
|
||||
|
||||
> Note: If using an SQLite database, you need to mount a data volume to persist
|
||||
>
|
||||
> ```bash
|
||||
> docker run -d -p 8000:8000 --env-file .env -v /path/to/data:/app/data gemini-balance
|
||||
> ```
|
||||
>
|
||||
> Where `/path/to/data` is the data storage path on the host, and `/app/data` is the data directory inside the container.
|
||||
|
||||
#### b) Deploy with an Existing Docker Image
|
||||
|
||||
1. **Pull Image**:
|
||||
### Option 2: Docker Command
|
||||
|
||||
1. **Pull Image**:
|
||||
```bash
|
||||
docker pull ghcr.io/snailyp/gemini-balance:latest
|
||||
```
|
||||
|
||||
2. **Run Container**:
|
||||
|
||||
2. **Prepare `.env` file**:
|
||||
Copy `.env.example` to `.env` and configure it.
|
||||
3. **Run Container**:
|
||||
```bash
|
||||
docker run -d -p 8000:8000 --env-file .env ghcr.io/snailyp/gemini-balance:latest
|
||||
docker run -d -p 8000:8000 --name gemini-balance \
|
||||
-v ./data:/app/data \
|
||||
--env-file .env \
|
||||
ghcr.io/snailyp/gemini-balance:latest
|
||||
```
|
||||
* `-d`: Detached mode.
|
||||
* `-p 8000:8000`: Map container port 8000 to host.
|
||||
* `-v ./data:/app/data`: Mount volume for persistent data.
|
||||
* `--env-file .env`: Load environment variables.
|
||||
|
||||
* `-d`: Run in detached mode.
|
||||
* `-p 8000:8000`: Map port 8000 of the container to port 8000 of the host (adjust as needed).
|
||||
* `--env-file .env`: Use the `.env` file to set environment variables (ensure the `.env` file exists in the directory where the command is executed).
|
||||
|
||||
> Note: If using an SQLite database, you need to mount a data volume to persist
|
||||
>
|
||||
> ```bash
|
||||
> docker run -d -p 8000:8000 --env-file .env -v /path/to/data:/app/data ghcr.io/snailyp/gemini-balance:latest
|
||||
> ```
|
||||
>
|
||||
> Where `/path/to/data` is the data storage path on the host, and `/app/data` is the data directory inside the container.
|
||||
|
||||
### Run Locally (Suitable for Development and Testing)
|
||||
|
||||
If you want to run the source code directly locally for development or testing, follow these steps:
|
||||
|
||||
1. **Ensure Prerequisites are Met**:
|
||||
* Clone the repository locally.
|
||||
* Install Python 3.9 or higher.
|
||||
* Create and configure the `.env` file in the project root directory (refer to the "Configure Environment Variables" section above).
|
||||
* Install project dependencies:
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
2. **Start Application**:
|
||||
Run the following command in the project root directory:
|
||||
### Option 3: Local Development
|
||||
|
||||
1. **Clone and Install**:
|
||||
```bash
|
||||
git clone https://github.com/snailyp/gemini-balance.git
|
||||
cd gemini-balance
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
2. **Configure Environment**:
|
||||
Copy `.env.example` to `.env` and configure it.
|
||||
3. **Start Application**:
|
||||
```bash
|
||||
uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload
|
||||
```
|
||||
Access the application at `http://localhost:8000`.
|
||||
|
||||
* `app.main:app`: Specifies the location of the FastAPI application instance (the `app` object in the `main.py` file within the `app` module).
|
||||
* `--host 0.0.0.0`: Makes the application accessible from any IP address on the local network.
|
||||
* `--port 8000`: Specifies the port number the application listens on (you can change this as needed).
|
||||
* `--reload`: Enables automatic reloading. When you modify the code, the service will automatically restart, which is very suitable for development environments (remove this option in production environments).
|
||||
|
||||
3. **Access Application**:
|
||||
After the application starts, you can access `http://localhost:8000` (or the host and port you specified) through a browser or API tool.
|
||||
|
||||
### Complete Configuration List
|
||||
|
||||
| Configuration Item | Description | Default Value |
|
||||
| :----------------------------- | :-------------------------------------------------------------------------- | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| **Database Configuration** | | |
|
||||
| `DATABASE_TYPE` | Optional, database type, supports `mysql` or `sqlite` | `mysql` |
|
||||
| `SQLITE_DATABASE` | Optional, required when using `sqlite`, SQLite database file path | `default_db` |
|
||||
| `MYSQL_HOST` | Required when using `mysql`, MySQL database host address | `localhost` |
|
||||
| `MYSQL_SOCKET` | Optional, MySQL database socket address | `/var/run/mysqld/mysqld.sock` |
|
||||
| `MYSQL_PORT` | Required when using `mysql`, MySQL database port | `3306` |
|
||||
| `MYSQL_USER` | Required when using `mysql`, MySQL database username | `your_db_user` |
|
||||
| `MYSQL_PASSWORD` | Required when using `mysql`, MySQL database password | `your_db_password` |
|
||||
| `MYSQL_DATABASE` | Required when using `mysql`, MySQL database name | `defaultdb` |
|
||||
| **API Related Configuration** | | |
|
||||
| `API_KEYS` | Required, list of Gemini API keys for load balancing | `["your-gemini-api-key-1", "your-gemini-api-key-2"]` |
|
||||
| `ALLOWED_TOKENS` | Required, list of tokens allowed to access | `["your-access-token-1", "your-access-token-2"]` |
|
||||
| `AUTH_TOKEN` | Optional, super admin token with all permissions, defaults to the first of `ALLOWED_TOKENS` if not set | `sk-123456` |
|
||||
| `TEST_MODEL` | Optional, model name used to test if a key is usable | `gemini-1.5-flash` |
|
||||
| `IMAGE_MODELS` | Optional, list of models that support drawing functions | `["gemini-2.0-flash-exp"]` |
|
||||
| `SEARCH_MODELS` | Optional, list of models that support search functions | `["gemini-2.0-flash-exp"]` |
|
||||
| `FILTERED_MODELS` | Optional, list of disabled models | `["gemini-1.0-pro-vision-latest", ...]` |
|
||||
| `TOOLS_CODE_EXECUTION_ENABLED` | Optional, whether to enable the code execution tool | `false` |
|
||||
| `SHOW_SEARCH_LINK` | Optional, whether to display search result links in the response | `true` |
|
||||
| `SHOW_THINKING_PROCESS` | Optional, whether to display the model's thinking process | `true` |
|
||||
| `THINKING_MODELS` | Optional, list of models that support thinking functions | `[]` |
|
||||
| `THINKING_BUDGET_MAP` | Optional, thinking function budget mapping (model_name:budget_value) | `{}` |
|
||||
| `URL_NORMALIZATION_ENABLED` | Optional, whether to enable intelligent URL routing mapping | `false` |
|
||||
| `BASE_URL` | Optional, Gemini API base URL, no modification needed by default | `https://generativelanguage.googleapis.com/v1beta` |
|
||||
| `MAX_FAILURES` | Optional, number of times a single key is allowed to fail | `3` |
|
||||
| `MAX_RETRIES` | Optional, maximum number of retries for failed API requests | `3` |
|
||||
| `CHECK_INTERVAL_HOURS` | Optional, time interval (hours) to check if a disabled Key has recovered | `1` |
|
||||
| `TIMEZONE` | Optional, timezone used by the application | `Asia/Shanghai` |
|
||||
| `TIME_OUT` | Optional, request timeout (seconds) | `300` |
|
||||
| `PROXIES` | Optional, list of proxy servers (e.g., `http://user:pass@host:port`, `socks5://host:port`) | `[]` |
|
||||
| `LOG_LEVEL` | Optional, log level, e.g., DEBUG, INFO, WARNING, ERROR, CRITICAL | `INFO` |
|
||||
| `AUTO_DELETE_ERROR_LOGS_ENABLED` | Optional, whether to enable automatic deletion of error logs | `true` |
|
||||
| `AUTO_DELETE_ERROR_LOGS_DAYS` | Optional, automatically delete error logs older than this many days (e.g., 1, 7, 30) | `7` |
|
||||
| `AUTO_DELETE_REQUEST_LOGS_ENABLED`| Optional, whether to enable automatic deletion of request logs | `false` |
|
||||
| `AUTO_DELETE_REQUEST_LOGS_DAYS` | Optional, automatically delete request logs older than this many days (e.g., 1, 7, 30) | `30` |
|
||||
| `SAFETY_SETTINGS` | Optional, safety settings (JSON string format), used to configure content safety thresholds. Example values may need adjustment based on actual model support. | `[{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"}, {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"}, {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"}, {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"}, {"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"}]` |
|
||||
| **TTS Related** | | |
|
||||
| `TTS_MODEL` | Optional, TTS model name | `gemini-2.5-flash-preview-tts` |
|
||||
| `TTS_VOICE_NAME` | Optional, TTS voice name | `Zephyr` |
|
||||
| `TTS_SPEED` | Optional, TTS speed | `normal` |
|
||||
| **Image Generation Related** | | |
|
||||
| `PAID_KEY` | Optional, paid API Key for advanced features like image generation | `your-paid-api-key` |
|
||||
| `CREATE_IMAGE_MODEL` | Optional, image generation model | `imagen-3.0-generate-002` |
|
||||
| `UPLOAD_PROVIDER` | Optional, image upload provider: `smms`, `picgo`, `cloudflare_imgbed` | `smms` |
|
||||
| `SMMS_SECRET_TOKEN` | Optional, API Token for SM.MS image hosting | `your-smms-token` |
|
||||
| `PICGO_API_KEY` | Optional, API Key for [PicoGo](https://www.picgo.net/) image hosting | `your-picogo-apikey` |
|
||||
| `CLOUDFLARE_IMGBED_URL` | Optional, [CloudFlare](https://github.com/MarSeventh/CloudFlare-ImgBed) image hosting upload address | `https://xxxxxxx.pages.dev/upload` |
|
||||
| `CLOUDFLARE_IMGBED_AUTH_CODE` | Optional, authentication key for CloudFlare image hosting | `your-cloudflare-imgber-auth-code` |
|
||||
| **Stream Optimizer Related** | | |
|
||||
| `STREAM_OPTIMIZER_ENABLED` | Optional, whether to enable stream output optimization | `false` |
|
||||
| `STREAM_MIN_DELAY` | Optional, minimum delay for stream output | `0.016` |
|
||||
| `STREAM_MAX_DELAY` | Optional, maximum delay for stream output | `0.024` |
|
||||
| `STREAM_SHORT_TEXT_THRESHOLD` | Optional, short text threshold | `10` |
|
||||
| `STREAM_LONG_TEXT_THRESHOLD` | Optional, long text threshold | `50` |
|
||||
| `STREAM_CHUNK_SIZE` | Optional, stream output chunk size | `5` |
|
||||
| **Fake Stream Related** | | |
|
||||
| `FAKE_STREAM_ENABLED` | Optional, whether to enable fake streaming for models or scenarios that don't support streaming | `false` |
|
||||
| `FAKE_STREAM_EMPTY_DATA_INTERVAL_SECONDS` | Optional, interval in seconds for sending heartbeat empty data during fake streaming | `5` |
|
||||
---
|
||||
|
||||
## ⚙️ API Endpoints
|
||||
|
||||
The following are the main API endpoints provided by the service:
|
||||
### Gemini API Format (`/gemini/v1beta`)
|
||||
|
||||
### Gemini API Related (`(/gemini)/v1beta`)
|
||||
This endpoint is directly forwarded to official Gemini API format endpoint, without advanced features.
|
||||
|
||||
* `GET /models`: List available Gemini models.
|
||||
* `POST /models/{model_name}:generateContent`: Generate content using the specified Gemini model.
|
||||
* `POST /models/{model_name}:streamGenerateContent`: Stream content generation using the specified Gemini model.
|
||||
* `GET /models`: List available Gemini models.
|
||||
* `POST /models/{model_name}:generateContent`: Generate content.
|
||||
* `POST /models/{model_name}:streamGenerateContent`: Stream content generation.
|
||||
|
||||
### OpenAI API Related
|
||||
### OpenAI API Format
|
||||
|
||||
* `GET (/hf)/v1/models`: List available models (uses Gemini format underneath).
|
||||
* `POST (/hf)/v1/chat/completions`: Perform chat completion (uses Gemini format underneath, supports streaming).
|
||||
* `POST (/hf)/v1/embeddings`: Create text embeddings (uses Gemini format underneath).
|
||||
* `POST (/hf)/v1/images/generations`: Generate images (uses Gemini format underneath).
|
||||
* `GET /openai/v1/models`: List available models (uses OpenAI format underneath).
|
||||
* `POST /openai/v1/chat/completions`: Perform chat completion (uses OpenAI format underneath, supports streaming, can prevent truncation, and is faster).
|
||||
* `POST /openai/v1/embeddings`: Create text embeddings (uses OpenAI format underneath).
|
||||
* `POST /openai/v1/images/generations`: Generate images (uses OpenAI format underneath).
|
||||
#### Hugging Face (HF) Compatible
|
||||
|
||||
If you want to use advanced features, like fake streaming, please use this endpoint.
|
||||
|
||||
* `GET /hf/v1/models`: List models.
|
||||
* `POST /hf/v1/chat/completions`: Chat completion.
|
||||
* `POST /hf/v1/embeddings`: Create text embeddings.
|
||||
* `POST /hf/v1/images/generations`: Generate images.
|
||||
|
||||
#### Standard OpenAI
|
||||
|
||||
This endpoint is directly forwarded to official OpenAI Compatible API format endpoint, without advanced features.
|
||||
|
||||
* `GET /openai/v1/models`: List models.
|
||||
* `POST /openai/v1/chat/completions`: Chat completion (Recommended).
|
||||
* `POST /openai/v1/embeddings`: Create text embeddings.
|
||||
* `POST /openai/v1/images/generations`: Generate images.
|
||||
|
||||
---
|
||||
|
||||
<details>
|
||||
<summary>📋 View Full Configuration List</summary>
|
||||
|
||||
| Configuration Item | Description | Default Value |
|
||||
| :--- | :--- | :--- |
|
||||
| **Database** | | |
|
||||
| `DATABASE_TYPE` | `mysql` or `sqlite` | `mysql` |
|
||||
| `SQLITE_DATABASE` | Path for SQLite database file | `default_db` |
|
||||
| `MYSQL_HOST` | MySQL host address | `localhost` |
|
||||
| `MYSQL_SOCKET` | MySQL socket address | `/var/run/mysqld/mysqld.sock` |
|
||||
| `MYSQL_PORT` | MySQL port | `3306` |
|
||||
| `MYSQL_USER` | MySQL username | `your_db_user` |
|
||||
| `MYSQL_PASSWORD` | MySQL password | `your_db_password` |
|
||||
| `MYSQL_DATABASE` | MySQL database name | `defaultdb` |
|
||||
| **API** | | |
|
||||
| `API_KEYS` | **Required**, list of Gemini API keys | `[]` |
|
||||
| `ALLOWED_TOKENS` | **Required**, list of access tokens | `[]` |
|
||||
| `AUTH_TOKEN` | Super admin token, defaults to the first of `ALLOWED_TOKENS` | `sk-123456` |
|
||||
| `ADMIN_SESSION_EXPIRE` | Admin session expiration time in seconds (5 minutes to 24 hours) | `3600` |
|
||||
| `TEST_MODEL` | Model for testing key validity | `gemini-2.5-flash-lite` |
|
||||
| `IMAGE_MODELS` | Models supporting image generation | `["gemini-2.0-flash-exp", "gemini-2.5-flash-image-preview"]` |
|
||||
| `SEARCH_MODELS` | Models supporting web search | `["gemini-2.5-flash","gemini-2.5-pro"]` |
|
||||
| `FILTERED_MODELS` | Disabled models | `[]` |
|
||||
| `TOOLS_CODE_EXECUTION_ENABLED` | Enable code execution tool | `false` |
|
||||
| `SHOW_SEARCH_LINK` | Display search result links in response | `true` |
|
||||
| `SHOW_THINKING_PROCESS` | Display model's thinking process | `true` |
|
||||
| `THINKING_MODELS` | Models supporting thinking process | `[]` |
|
||||
| `THINKING_BUDGET_MAP` | Budget map for thinking function (model:budget) | `{}` |
|
||||
| `URL_NORMALIZATION_ENABLED` | Enable smart URL routing | `false` |
|
||||
| `URL_CONTEXT_ENABLED` | Enable URL context understanding | `false` |
|
||||
| `URL_CONTEXT_MODELS` | Models supporting URL context | `[]` |
|
||||
| `BASE_URL` | Gemini API base URL | `https://generativelanguage.googleapis.com/v1beta` |
|
||||
| `MAX_FAILURES` | Max failures allowed per key | `3` |
|
||||
| `MAX_RETRIES` | Max retries for failed API requests | `3` |
|
||||
| `CHECK_INTERVAL_HOURS` | Interval (hours) to re-check disabled keys | `1` |
|
||||
| `TIMEZONE` | Application timezone | `Asia/Shanghai` |
|
||||
| `TIME_OUT` | Request timeout (seconds) | `300` |
|
||||
| `PROXIES` | List of proxy servers | `[]` |
|
||||
| **Logging & Security** | | |
|
||||
| `LOG_LEVEL` | Log level: `DEBUG`, `INFO`, `WARNING`, `ERROR` | `INFO` |
|
||||
| `ERROR_LOG_RECORD_REQUEST_BODY` | Record request body in error logs (may contain sensitive information) | `false` |
|
||||
| `AUTO_DELETE_ERROR_LOGS_ENABLED` | Auto-delete error logs | `true` |
|
||||
| `AUTO_DELETE_ERROR_LOGS_DAYS` | Error log retention period (days) | `7` |
|
||||
| `AUTO_DELETE_REQUEST_LOGS_ENABLED`| Auto-delete request logs | `false` |
|
||||
| `AUTO_DELETE_REQUEST_LOGS_DAYS` | Request log retention period (days) | `30` |
|
||||
| `SAFETY_SETTINGS` | Content safety thresholds (JSON string) | `[{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"}, ...]` |
|
||||
| **TTS** | | |
|
||||
| `TTS_MODEL` | TTS model name | `gemini-2.5-flash-preview-tts` |
|
||||
| `TTS_VOICE_NAME` | TTS voice name | `Zephyr` |
|
||||
| `TTS_SPEED` | TTS speed | `normal` |
|
||||
| **Image Generation** | | |
|
||||
| `PAID_KEY` | Paid API Key for advanced features | `your-paid-api-key` |
|
||||
| `CREATE_IMAGE_MODEL` | Image generation model | `imagen-3.0-generate-002` |
|
||||
| `UPLOAD_PROVIDER` | Image upload provider: `smms`, `picgo`, `cloudflare_imgbed`, `aliyun_oss` | `smms` |
|
||||
| `OSS_ENDPOINT` | Aliyun OSS public endpoint | `oss-cn-shanghai.aliyuncs.com` |
|
||||
| `OSS_ENDPOINT_INNER` | Aliyun OSS internal endpoint (intra-VPC) | `oss-cn-shanghai-internal.aliyuncs.com` |
|
||||
| `OSS_ACCESS_KEY` | Aliyun AccessKey ID | `LTAI5txxxxxxxxxxxxxxxx` |
|
||||
| `OSS_ACCESS_KEY_SECRET` | Aliyun AccessKey Secret | `yXxxxxxxxxxxxxxxxxxxxxx` |
|
||||
| `OSS_BUCKET_NAME` | Aliyun OSS bucket name | `your-bucket-name` |
|
||||
| `OSS_REGION` | Aliyun OSS region | `cn-shanghai` |
|
||||
| `SMMS_SECRET_TOKEN` | SM.MS API Token | `your-smms-token` |
|
||||
| `PICGO_API_KEY` | PicoGo API Key | `your-picogo-apikey` |
|
||||
| `PICGO_API_URL` | PicoGo API Server URL | `https://www.picgo.net/api/1/upload` |
|
||||
| `CLOUDFLARE_IMGBED_URL` | CloudFlare ImgBed upload URL | `https://xxxxxxx.pages.dev/upload` |
|
||||
| `CLOUDFLARE_IMGBED_AUTH_CODE`| CloudFlare ImgBed auth key | `your-cloudflare-imgber-auth-code` |
|
||||
| `CLOUDFLARE_IMGBED_UPLOAD_FOLDER`| CloudFlare ImgBed upload folder | `""` |
|
||||
| **Stream Optimizer** | | |
|
||||
| `STREAM_OPTIMIZER_ENABLED` | Enable stream output optimization | `false` |
|
||||
| `STREAM_MIN_DELAY` | Minimum stream output delay | `0.016` |
|
||||
| `STREAM_MAX_DELAY` | Maximum stream output delay | `0.024` |
|
||||
| `STREAM_SHORT_TEXT_THRESHOLD`| Short text threshold | `10` |
|
||||
| `STREAM_LONG_TEXT_THRESHOLD` | Long text threshold | `50` |
|
||||
| `STREAM_CHUNK_SIZE` | Stream output chunk size | `5` |
|
||||
| **Fake Stream** | | |
|
||||
| `FAKE_STREAM_ENABLED` | Enable fake streaming | `false` |
|
||||
| `FAKE_STREAM_EMPTY_DATA_INTERVAL_SECONDS` | Heartbeat interval for fake streaming (seconds) | `5` |
|
||||
|
||||
</details>
|
||||
|
||||
---
|
||||
|
||||
## 🤝 Contributing
|
||||
|
||||
Pull Requests or Issues are welcome.
|
||||
|
||||
## 🎉 Special Thanks
|
||||
|
||||
Special thanks to the following projects and platforms for providing image hosting services for this project:
|
||||
|
||||
* [PicGo](https://www.picgo.net/)
|
||||
* [SM.MS](https://smms.app/)
|
||||
* [CloudFlare-ImgBed](https://github.com/MarSeventh/CloudFlare-ImgBed) open source project
|
||||
|
||||
## 🙏 Thanks to Contributors
|
||||
|
||||
Thanks to all developers who contributed to this project!
|
||||
|
||||
[](https://github.com/snailyp/gemini-balance/graphs/contributors)
|
||||
|
||||
## Thanks to Our Supporters
|
||||
|
||||
A special shout-out to DigitalOcean for providing the rock-solid and dependable cloud infrastructure that keeps this project humming!
|
||||
[](https://m.do.co/c/b249dd7f3b4c)
|
||||
|
||||
CDN acceleration and security protection for this project are sponsored by Tencent EdgeOne.
|
||||
[](https://edgeone.ai/?from=github)
|
||||
|
||||
## ⭐ Star History
|
||||
|
||||
[](https://star-history.com/#snailyp/gemini-balance&Date)
|
||||
|
||||
## 🎉 Special Thanks
|
||||
|
||||
* [PicGo](https://www.picgo.net/)
|
||||
* [SM.MS](https://smms.app/)
|
||||
* [CloudFlare-ImgBed](https://github.com/MarSeventh/CloudFlare-ImgBed)
|
||||
|
||||
## 🙏 Our Supporters
|
||||
|
||||
A special shout-out to [DigitalOcean](https://m.do.co/c/b249dd7f3b4c) for providing the rock-solid and dependable cloud infrastructure that keeps this project humming!
|
||||
|
||||
<a href="https://m.do.co/c/b249dd7f3b4c">
|
||||
<img src="files/dataocean.svg" alt="DigitalOcean Logo" width="200"/>
|
||||
</a>
|
||||
|
||||
CDN acceleration and security protection for this project are sponsored by [Tencent EdgeOne](https://edgeone.ai/?from=github).
|
||||
|
||||
<a href="https://edgeone.ai/?from=github">
|
||||
<img src="https://edgeone.ai/media/34fe3a45-492d-4ea4-ae5d-ea1087ca7b4b.png" alt="EdgeOne Logo" width="200"/>
|
||||
</a>
|
||||
|
||||
## 💖 Friendly Projects
|
||||
|
||||
* **[OneLine](https://github.com/chengtx809/OneLine)** by [chengtx809](https://github.com/chengtx809) - OneLine: AI-driven hot event timeline generation tool
|
||||
* **[OneLine](https://github.com/chengtx809/OneLine)** by [chengtx809](https://github.com/chengtx809) - AI-driven hot event timeline generation tool.
|
||||
|
||||
## 🎁 Project Support
|
||||
|
||||
@@ -283,4 +286,4 @@ If you find this project helpful, consider supporting me via [Afdian](https://af
|
||||
|
||||
## License
|
||||
|
||||
This project is licensed under the CC BY-NC 4.0 (Attribution-NonCommercial) license. Any form of commercial resale service is prohibited. See the LICENSE file for details.
|
||||
This project is licensed under the [CC BY-NC 4.0](LICENSE) (Attribution-NonCommercial) license.
|
||||
|
||||
421
README_ZH.md
421
README_ZH.md
@@ -1,29 +1,38 @@
|
||||
# Gemini Balance - Gemini API 代理和负载均衡器
|
||||
|
||||
> ⚠️ 本项目采用 CC BY-NC 4.0(署名-非商业性使用)协议,禁止任何形式的商业倒卖服务,详见 LICENSE 文件。
|
||||
<p align="center">
|
||||
<a href="https://trendshift.io/repositories/13692" target="_blank">
|
||||
<img src="https://trendshift.io/api/badge/repositories/13692" alt="snailyp%2Fgemini-balance | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/>
|
||||
</a>
|
||||
</p>
|
||||
|
||||
> 本人从未在各个平台售卖服务,如有遇到售卖此服务者,那一定是倒卖狗,大家切记不要上当受骗。
|
||||
<p align="center">
|
||||
<a href="https://www.python.org/"><img src="https://img.shields.io/badge/Python-3.9%2B-blue.svg" alt="Python"></a>
|
||||
<a href="https://fastapi.tiangolo.com/"><img src="https://img.shields.io/badge/FastAPI-0.100%2B-green.svg" alt="FastAPI"></a>
|
||||
<a href="https://www.uvicorn.org/"><img src="https://img.shields.io/badge/Uvicorn-running-purple.svg" alt="Uvicorn"></a>
|
||||
<a href="https://t.me/+soaHax5lyI0wZDVl"><img src="https://img.shields.io/badge/Telegram-Group-blue.svg?logo=telegram" alt="Telegram Group"></a>
|
||||
</p>
|
||||
|
||||
[](https://www.python.org/)
|
||||
[](https://fastapi.tiangolo.com/)
|
||||
[](https://www.uvicorn.org/)
|
||||
[](https://t.me/+soaHax5lyI0wZDVl)
|
||||
> 交流群:https://t.me/+soaHax5lyI0wZDVl
|
||||
> ⚠️ **重要声明**: 本项目采用 [CC BY-NC 4.0](LICENSE) 协议,**禁止任何形式的商业倒卖服务**。
|
||||
> 本人从未在任何平台售卖服务,如遇售卖,均为倒卖行为,请勿上当受骗。
|
||||
|
||||
## 项目简介
|
||||
---
|
||||
|
||||
Gemini Balance 是一个基于 Python FastAPI 构建的应用程序,旨在提供 Google Gemini API 的代理和负载均衡功能。它允许您管理多个 Gemini API Key,并通过简单的配置实现 Key 的轮询、认证、模型过滤和状态监控。此外,项目还集成了图像生成和多种图床上传功能,并支持 OpenAI API 格式的代理。
|
||||
## 📖 项目简介
|
||||
|
||||
**项目结构:**
|
||||
**Gemini Balance** 是一个基于 Python FastAPI 构建的应用程序,旨在提供 Google Gemini API 的代理和负载均衡功能。它允许您管理多个 Gemini API Key,并通过简单的配置实现 Key 的轮询、认证、模型过滤和状态监控。此外,项目还集成了图像生成和多种图床上传功能,并支持 OpenAI API 格式的代理。
|
||||
|
||||
<details>
|
||||
<summary>📂 查看项目结构</summary>
|
||||
|
||||
```plaintext
|
||||
app/
|
||||
├── config/ # 配置管理
|
||||
├── core/ # 核心应用逻辑 (FastAPI 实例创建, 中间件等)
|
||||
├── database/ # 数据库模型和连接
|
||||
├── domain/ # 业务领域对象 (可选)
|
||||
├── domain/ # 业务领域对象
|
||||
├── exception/ # 自定义异常
|
||||
├── handler/ # 请求处理器 (可选, 或在 router 中处理)
|
||||
├── handler/ # 请求处理器
|
||||
├── log/ # 日志配置
|
||||
├── main.py # 应用入口
|
||||
├── middleware/ # FastAPI 中间件
|
||||
@@ -32,225 +41,214 @@ app/
|
||||
├── service/ # 业务逻辑服务 (聊天, Key 管理, 统计等)
|
||||
├── static/ # 静态文件 (CSS, JS)
|
||||
├── templates/ # HTML 模板 (如 Key 状态页)
|
||||
├── utils/ # 工具函数
|
||||
└── utils/ # 工具函数
|
||||
```
|
||||
</details>
|
||||
|
||||
---
|
||||
|
||||
## ✨ 功能亮点
|
||||
|
||||
* **多 Key 负载均衡**: 支持配置多个 Gemini API Key (`API_KEYS`),自动按顺序轮询使用,提高可用性和并发能力。
|
||||
* **可视化配置即时生效**: 通过管理后台修改配置后,无需重启服务即可生效,切记要点击保存才会生效。
|
||||

|
||||
* **双协议API 兼容**: 同时支持 Gemini 和 OpenAI 格式的 CHAT API 请求转发。
|
||||
* **多 Key 负载均衡**: 支持配置多个 Gemini API Key (`API_KEYS`),自动按顺序轮询使用,提高可用性和并发能力。
|
||||
* **可视化配置即时生效**: 通过管理后台修改配置后,无需重启服务即可生效。
|
||||

|
||||
* **双协议 API 兼容**: 同时支持 Gemini 和 OpenAI 格式的 CHAT API 请求转发。
|
||||
* OpenAI Base URL: `http://localhost:8000(/hf)/v1`
|
||||
* Gemini Base URL: `http://localhost:8000(/gemini)/v1beta`
|
||||
* **图文对话与修图**: 通过 `IMAGE_MODELS` 配置支持图文对话和修图功能的模型,调用时使用 `配置模型-image` 模型名。
|
||||

|
||||

|
||||
* **联网搜索**: 通过 `SEARCH_MODELS` 配置支持联网搜索的模型,调用时使用 `配置模型-search` 模型名。
|
||||

|
||||
* **Key 状态监控**: 提供 `/keys_status` 页面(需要认证),实时查看各 Key 的状态和使用情况。
|
||||

|
||||
* **详细日志记录**: 提供详细的错误日志,方便排查问题。
|
||||

|
||||

|
||||

|
||||
* **灵活的密钥添加**: 支持通过正则表达式 `gemini_key` 批量添加密钥,并自动去重。
|
||||

|
||||
* **失败重试与自动禁用**: 自动处理 API 请求失败,进行重试 (`MAX_RETRIES`),并在 Key 失效次数过多时自动禁用 (`MAX_FAILURES`),定时检查恢复 (`CHECK_INTERVAL_HOURS`)。
|
||||
* **全面的 API 兼容**:
|
||||
* **Embeddings 接口**: 完美适配 OpenAI 格式的 `embeddings` 接口。
|
||||
* **画图接口**: 将 `imagen-3.0-generate-002` 模型接口改造为 OpenAI 画图接口格式。
|
||||
* **模型列表自动维护**: 自动获取并同步 Gemini 和 OpenAI 的最新模型列表,兼容 New API。
|
||||
* **代理支持**: 支持配置 HTTP/SOCKS5 代理 (`PROXIES`),方便在特殊网络环境下使用。
|
||||
* **Docker 支持**: 提供 AMD 和 ARM 架构的 Docker 镜像,方便快速部署。
|
||||
* 镜像地址: `ghcr.io/snailyp/gemini-balance:latest`
|
||||
|
||||
```palintext
|
||||
openai baseurl `http://localhost:8000(/hf)/v1`
|
||||
gemini baseurl `http://localhost:8000(/gemini)/v1beta`
|
||||
```
|
||||
|
||||
* **支持图文对话和修改图片**: `IMAGE_MODELS`配置哪个模型可以图文对话和修图的功能,实际调用的时候,用 `配置模型-image`这个模型名对话使用该功能。
|
||||

|
||||

|
||||
* **支持联网搜索**: 支持联网搜索,`SEARCH_MODELS`配置哪些模型可以联网搜索,实际调用的时候,用 `配置模型-search`这个模型名对话使用该功能
|
||||

|
||||
* **Key 状态监控**: 提供 `/keys_status` 页面(需要认证),实时查看各 Key 的状态和使用情况。
|
||||

|
||||
* **详细的日志记录**: 提供详细的错误日志,方便排查。
|
||||

|
||||

|
||||

|
||||
* **支持自定义gemini代理**: 支持自定义gemini代理,比如自行在deno或者cloudflare上搭建gemini代理
|
||||
* **openai画图接口兼容**: 将`imagen-3.0-generate-002`模型接口改造成openai画图接口,支持客户端调用。
|
||||
* **灵活的添加密钥方式**: 灵活的添加密钥方式,采用正则匹配`gemini_key`,密钥去重
|
||||

|
||||
* **兼容openai格式embeddings接口**:完美适配openai格式的`embeddings`接口,可用于本地文档向量化。
|
||||
* **流式响应优化**: 可选的流式输出优化器 (`STREAM_OPTIMIZER_ENABLED`),改善长文本流式响应的体验。
|
||||
* **失败重试与 Key 管理**: 自动处理 API 请求失败,进行重试 (`MAX_RETRIES`),并在 Key 失效次数过多时自动禁用 (`MAX_FAILURES`),定时检查恢复 (`CHECK_INTERVAL_HOURS`)。
|
||||
* **Docker 支持**: 支持AMD,ARM架构的docker部署,也可自行构建docker镜像。
|
||||
>镜像地址: docker pull ghcr.io/snailyp/gemini-balance:latest
|
||||
* **模型列表自动维护**: 支持openai和gemini模型列表获取,与newapi自动获取模型列表完美兼容,无需手动填写。
|
||||
* **支持移除不使用的模型**: 默认提供的模型太多,很多用不上,可以通过`FILTERED_MODELS`过滤掉。
|
||||
* **代理支持**: 支持配置 HTTP/SOCKS5 代理服务器 (`PROXIES`),用于访问 Gemini API,方便在特殊网络环境下使用。支持批量添加代理。
|
||||
---
|
||||
|
||||
## 🚀 快速开始
|
||||
|
||||
### 自行构建 Docker (推荐)
|
||||
### 方式一:使用 Docker Compose (推荐)
|
||||
|
||||
#### a) dockerfile构建
|
||||
|
||||
1. **构建镜像**:
|
||||
这是最推荐的部署方式,可以一键启动应用和数据库。
|
||||
|
||||
1. **下载 `docker-compose.yml`**:
|
||||
从项目仓库获取 `docker-compose.yml` 文件。
|
||||
2. **准备 `.env` 文件**:
|
||||
从 `.env.example` 复制一份并重命名为 `.env`,然后根据需求修改配置。特别注意,`DATABASE_TYPE` 应设置为 `mysql`,并填写 `MYSQL_*` 相关配置。
|
||||
3. **启动服务**:
|
||||
在 `docker-compose.yml` 和 `.env` 文件所在的目录下,运行以下命令:
|
||||
```bash
|
||||
docker build -t gemini-balance .
|
||||
docker-compose up -d
|
||||
```
|
||||
该命令会以后台模式启动 `gemini-balance` 应用和 `mysql` 数据库。
|
||||
|
||||
2. **运行容器**:
|
||||
### 方式二:使用 Docker 命令
|
||||
|
||||
1. **拉取镜像**:
|
||||
```bash
|
||||
docker run -d -p 8000:8000 --env-file .env gemini-balance
|
||||
docker pull ghcr.io/snailyp/gemini-balance:latest
|
||||
```
|
||||
2. **准备 `.env` 文件**:
|
||||
从 `.env.example` 复制一份并重命名为 `.env`,然后根据需求修改配置。
|
||||
3. **运行容器**:
|
||||
```bash
|
||||
docker run -d -p 8000:8000 --name gemini-balance \
|
||||
-v ./data:/app/data \
|
||||
--env-file .env \
|
||||
ghcr.io/snailyp/gemini-balance:latest
|
||||
```
|
||||
* `-d`: 后台运行。
|
||||
* `-p 8000:8000`: 将容器的 8000 端口映射到主机。
|
||||
* `-v ./data:/app/data`: 挂载数据卷以持久化 SQLite 数据和日志。
|
||||
* `--env-file .env`: 加载环境变量配置文件。
|
||||
|
||||
* `-d`: 后台运行。
|
||||
* `-p 8000:8000`: 将容器的 8000 端口映射到主机的 8000 端口。
|
||||
* `--env-file .env`: 使用 `.env` 文件设置环境变量。
|
||||
|
||||
> 注意:如果使用 SQLite 数据库,需要挂载数据卷以持久化数据:
|
||||
> ```bash
|
||||
> docker run -d -p 8000:8000 --env-file .env -v /path/to/data:/app/data gemini-balance
|
||||
> ```
|
||||
> 其中 `/path/to/data` 是主机上的数据存储路径,`/app/data` 是容器内的数据目录。
|
||||
|
||||
#### b) 用现有的docker镜像部署
|
||||
|
||||
1. **拉取镜像**:
|
||||
|
||||
```bash
|
||||
docker pull ghcr.io/snailyp/gemini-balance:latest
|
||||
```
|
||||
|
||||
2. **运行容器**:
|
||||
|
||||
```bash
|
||||
docker run -d -p 8000:8000 --env-file .env ghcr.io/snailyp/gemini-balance:latest
|
||||
```
|
||||
|
||||
* `-d`: 后台运行。
|
||||
* `-p 8000:8000`: 将容器的 8000 端口映射到主机的 8000 端口 (根据需要调整)。
|
||||
* `--env-file .env`: 使用 `.env` 文件设置环境变量 (确保 `.env` 文件存在于执行命令的目录)。
|
||||
|
||||
> 注意:如果使用 SQLite 数据库,需要挂载数据卷以持久化数据:
|
||||
> ```bash
|
||||
> docker run -d -p 8000:8000 --env-file .env -v /path/to/data:/app/data ghcr.io/snailyp/gemini-balance:latest
|
||||
> ```
|
||||
> 其中 `/path/to/data` 是主机上的数据存储路径,`/app/data` 是容器内的数据目录。
|
||||
|
||||
### 本地运行 (适用于开发和测试)
|
||||
|
||||
如果您想在本地直接运行源代码进行开发或测试,请按照以下步骤操作:
|
||||
|
||||
1. **确保已完成准备工作**:
|
||||
* 克隆仓库到本地。
|
||||
* 安装 Python 3.9 或更高版本。
|
||||
* 在项目根目录下创建并配置好 `.env` 文件 (参考前面的"配置环境变量"部分)。
|
||||
* 安装项目依赖:
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
2. **启动应用**:
|
||||
在项目根目录下运行以下命令:
|
||||
### 方式三:本地运行 (适用于开发)
|
||||
|
||||
1. **克隆仓库并安装依赖**:
|
||||
```bash
|
||||
git clone https://github.com/snailyp/gemini-balance.git
|
||||
cd gemini-balance
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
2. **配置环境变量**:
|
||||
从 `.env.example` 复制一份并重命名为 `.env`,然后根据需求修改配置。
|
||||
3. **启动应用**:
|
||||
```bash
|
||||
uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload
|
||||
```
|
||||
应用启动后,访问 `http://localhost:8000`。
|
||||
|
||||
* `app.main:app`: 指定 FastAPI 应用实例的位置 (`app` 模块中的 `main.py` 文件里的 `app` 对象)。
|
||||
* `--host 0.0.0.0`: 使应用可以从本地网络中的任何 IP 地址访问。
|
||||
* `--port 8000`: 指定应用监听的端口号 (您可以根据需要修改)。
|
||||
* `--reload`: 启用自动重载功能。当您修改代码时,服务会自动重启,非常适合开发环境 (生产环境请移除此选项)。
|
||||
|
||||
3. **访问应用**:
|
||||
应用启动后,您可以通过浏览器或 API 工具访问 `http://localhost:8000` (或您指定的主机和端口)。
|
||||
|
||||
### 完整配置项列表
|
||||
|
||||
| 配置项 | 说明 | 默认值 |
|
||||
| :--------------------------- | :------------------------------------------------------- | :---------------------------------------------------- |
|
||||
| **数据库配置** | | |
|
||||
| `DATABASE_TYPE` | 可选,数据库类型,支持 `mysql` 或 `sqlite` | `mysql` |
|
||||
| `SQLITE_DATABASE` | 可选,当使用 `sqlite` 时必填,SQLite 数据库文件路径 | `default_db` |
|
||||
| `MYSQL_HOST` | 当使用 `mysql` 时必填,MySQL 数据库主机地址 | `localhost` |
|
||||
| `MYSQL_SOCKET` | 可选,MySQL 数据库 socket 地址 | `/var/run/mysqld/mysqld.sock` |
|
||||
| `MYSQL_PORT` | 当使用 `mysql` 时必填,MySQL 数据库端口 | `3306` |
|
||||
| `MYSQL_USER` | 当使用 `mysql` 时必填,MySQL 数据库用户名 | `your_db_user` |
|
||||
| `MYSQL_PASSWORD` | 当使用 `mysql` 时必填,MySQL 数据库密码 | `your_db_password` |
|
||||
| `MYSQL_DATABASE` | 当使用 `mysql` 时必填,MySQL 数据库名称 | `defaultdb` |
|
||||
| **API 相关配置** | | |
|
||||
| `API_KEYS` | 必填,Gemini API 密钥列表,用于负载均衡 | `["your-gemini-api-key-1", "your-gemini-api-key-2"]` |
|
||||
| `ALLOWED_TOKENS` | 必填,允许访问的 Token 列表 | `["your-access-token-1", "your-access-token-2"]` |
|
||||
| `AUTH_TOKEN` | 可选,超级管理员token,具有所有权限,不填默认使用 ALLOWED_TOKENS 的第一个 | `sk-123456` |
|
||||
| `TEST_MODEL` | 可选,用于测试密钥是否可用的模型名 | `gemini-1.5-flash` |
|
||||
| `IMAGE_MODELS` | 可选,支持绘图功能的模型列表 | `["gemini-2.0-flash-exp"]` |
|
||||
| `SEARCH_MODELS` | 可选,支持搜索功能的模型列表 | `["gemini-2.0-flash-exp"]` |
|
||||
| `FILTERED_MODELS` | 可选,被禁用的模型列表 | `["gemini-1.0-pro-vision-latest", ...]` |
|
||||
| `TOOLS_CODE_EXECUTION_ENABLED` | 可选,是否启用代码执行工具 | `false` |
|
||||
| `SHOW_SEARCH_LINK` | 可选,是否在响应中显示搜索结果链接 | `true` |
|
||||
| `SHOW_THINKING_PROCESS` | 可选,是否显示模型思考过程 | `true` |
|
||||
| `THINKING_MODELS` | 可选,支持思考功能的模型列表 | `[]` |
|
||||
| `THINKING_BUDGET_MAP` | 可选,思考功能预算映射 (模型名:预算值) | `{}` |
|
||||
| `URL_NORMALIZATION_ENABLED` | 可选,是否启用智能路由映射功能 | `false` |
|
||||
| `BASE_URL` | 可选,Gemini API 基础 URL,默认无需修改 | `https://generativelanguage.googleapis.com/v1beta` |
|
||||
| `MAX_FAILURES` | 可选,允许单个key失败的次数 | `3` |
|
||||
| `MAX_RETRIES` | 可选,API 请求失败时的最大重试次数 | `3` |
|
||||
| `CHECK_INTERVAL_HOURS` | 可选,检查禁用 Key 是否恢复的时间间隔 (小时) | `1` |
|
||||
| `TIMEZONE` | 可选,应用程序使用的时区 | `Asia/Shanghai` |
|
||||
| `TIME_OUT` | 可选,请求超时时间 (秒) | `300` |
|
||||
| `PROXIES` | 可选,代理服务器列表 (例如 `http://user:pass@host:port`, `socks5://host:port`) | `[]` |
|
||||
| `LOG_LEVEL` | 可选,日志级别,例如 DEBUG, INFO, WARNING, ERROR, CRITICAL | `INFO` |
|
||||
| `AUTO_DELETE_ERROR_LOGS_ENABLED` | 可选,是否开启自动删除错误日志 | `true` |
|
||||
| `AUTO_DELETE_ERROR_LOGS_DAYS` | 可选,自动删除多少天前的错误日志 (例如 1, 7, 30) | `7` |
|
||||
| `AUTO_DELETE_REQUEST_LOGS_ENABLED`| 可选,是否开启自动删除请求日志 | `false` |
|
||||
| `AUTO_DELETE_REQUEST_LOGS_DAYS` | 可选,自动删除多少天前的请求日志 (例如 1, 7, 30) | `30` |
|
||||
| `SAFETY_SETTINGS` | 可选,安全设置 (JSON 字符串格式),用于配置内容安全阈值。示例值可能需要根据实际模型支持情况调整。 | `[{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"}, {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"}, {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"}, {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"}, {"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"}]` |
|
||||
| **TTS 相关** | | |
|
||||
| `TTS_MODEL` | 可选,TTS 模型名称 | `gemini-2.5-flash-preview-tts` |
|
||||
| `TTS_VOICE_NAME` | 可选,TTS 语音名称 | `Zephyr` |
|
||||
| `TTS_SPEED` | 可选,TTS 语速 | `normal` |
|
||||
| **图像生成相关** | | |
|
||||
| `PAID_KEY` | 可选,付费版API Key,用于图片生成等高级功能 | `your-paid-api-key` |
|
||||
| `CREATE_IMAGE_MODEL` | 可选,图片生成模型 | `imagen-3.0-generate-002` |
|
||||
| `UPLOAD_PROVIDER` | 可选,图片上传提供商: `smms`, `picgo`, `cloudflare_imgbed` | `smms` |
|
||||
| `SMMS_SECRET_TOKEN` | 可选,SM.MS图床的API Token | `your-smms-token` |
|
||||
| `PICGO_API_KEY` | 可选,[PicoGo](https://www.picgo.net/)图床的API Key | `your-picogo-apikey` |
|
||||
| `CLOUDFLARE_IMGBED_URL` | 可选,[CloudFlare](https://github.com/MarSeventh/CloudFlare-ImgBed) 图床上传地址 | `https://xxxxxxx.pages.dev/upload` |
|
||||
| `CLOUDFLARE_IMGBED_AUTH_CODE`| 可选,CloudFlare图床的鉴权key | `your-cloudflare-imgber-auth-code` |
|
||||
| **流式优化器相关** | | |
|
||||
| `STREAM_OPTIMIZER_ENABLED` | 可选,是否启用流式输出优化 | `false` |
|
||||
| `STREAM_MIN_DELAY` | 可选,流式输出最小延迟 | `0.016` |
|
||||
| `STREAM_MAX_DELAY` | 可选,流式输出最大延迟 | `0.024` |
|
||||
| `STREAM_SHORT_TEXT_THRESHOLD`| 可选,短文本阈值 | `10` |
|
||||
| `STREAM_LONG_TEXT_THRESHOLD` | 可选,长文本阈值 | `50` |
|
||||
| `STREAM_CHUNK_SIZE` | 可选,流式输出块大小 | `5` |
|
||||
| **伪流式 (Fake Stream) 相关** | | |
|
||||
| `FAKE_STREAM_ENABLED` | 可选,是否启用伪流式传输,用于不支持流式的模型或场景 | `false` |
|
||||
| `FAKE_STREAM_EMPTY_DATA_INTERVAL_SECONDS` | 可选,伪流式传输时发送心跳空数据的间隔秒数 | `5` |
|
||||
---
|
||||
|
||||
## ⚙️ API 端点
|
||||
|
||||
以下是服务提供的主要 API 端点:
|
||||
### Gemini API 格式 (`/gemini/v1beta`)
|
||||
|
||||
### Gemini API 相关 (`(/gemini)/v1beta`)
|
||||
此端点将请求直接转发到官方 Gemini API 格式的端点,不包含高级功能。
|
||||
|
||||
* `GET /models`: 列出可用的 Gemini 模型。
|
||||
* `POST /models/{model_name}:generateContent`: 使用指定的 Gemini 模型生成内容。
|
||||
* `POST /models/{model_name}:streamGenerateContent`: 使用指定的 Gemini 模型流式生成内容。
|
||||
* `GET /models`: 列出可用的 Gemini 模型。
|
||||
* `POST /models/{model_name}:generateContent`: 生成内容。
|
||||
* `POST /models/{model_name}:streamGenerateContent`: 流式生成内容。
|
||||
|
||||
### OpenAI API 相关
|
||||
### OpenAI API 格式
|
||||
|
||||
* `GET (/hf)/v1/models`: 列出可用的模型 (底层用的gemini格式)。
|
||||
* `POST (/hf)/v1/chat/completions`: 进行聊天补全 (底层用的gemini格式, 支持流式传输)。
|
||||
* `POST (/hf)/v1/embeddings`: 创建文本嵌入 (底层用的gemini格式)。
|
||||
* `POST (/hf)/v1/images/generations`: 生成图像 (底层用的gemini格式)。
|
||||
* `GET /openai/v1/models`: 列出可用的模型 (底层用的openai格式)。
|
||||
* `POST /openai/v1/chat/completions`: 进行聊天补全 (底层用的openai格式, 支持流式传输, 可防止截断,速度也快)。
|
||||
* `POST /openai/v1/embeddings`: 创建文本嵌入 (底层用的openai格式)。
|
||||
* `POST /openai/v1/images/generations`: 生成图像 (底层用的openai格式)。
|
||||
#### 兼容 huggingface (HF) 格式
|
||||
|
||||
如果您需要使用高级功能(例如假流式输出),请使用此端点。
|
||||
|
||||
* `GET /hf/v1/models`: 列出模型。
|
||||
* `POST /hf/v1/chat/completions`: 聊天补全。
|
||||
* `POST /hf/v1/embeddings`: 创建文本嵌入。
|
||||
* `POST /hf/v1/images/generations`: 生成图像。
|
||||
|
||||
#### 标准 OpenAI 格式
|
||||
|
||||
此端点直接转发至官方的 OpenAI 兼容 API 格式端点,不包含高级功能。
|
||||
|
||||
* `GET /openai/v1/models`: 列出模型。
|
||||
* `POST /openai/v1/chat/completions`: 聊天补全 (推荐,速度更快,防截断)。
|
||||
* `POST /openai/v1/embeddings`: 创建文本嵌入。
|
||||
* `POST /openai/v1/images/generations`: 生成图像。
|
||||
|
||||
---
|
||||
|
||||
<details>
|
||||
<summary>📋 查看完整配置项列表</summary>
|
||||
|
||||
| 配置项 | 说明 | 默认值 |
|
||||
| :--- | :--- | :--- |
|
||||
| **数据库配置** | | |
|
||||
| `DATABASE_TYPE` | 数据库类型: `mysql` 或 `sqlite` | `mysql` |
|
||||
| `SQLITE_DATABASE` | 当使用 `sqlite` 时必填,SQLite 数据库文件路径 | `default_db` |
|
||||
| `MYSQL_HOST` | 当使用 `mysql` 时必填,MySQL 数据库主机地址 | `localhost` |
|
||||
| `MYSQL_SOCKET` | 可选,MySQL 数据库 socket 地址 | `/var/run/mysqld/mysqld.sock` |
|
||||
| `MYSQL_PORT` | 当使用 `mysql` 时必填,MySQL 数据库端口 | `3306` |
|
||||
| `MYSQL_USER` | 当使用 `mysql` 时必填,MySQL 数据库用户名 | `your_db_user` |
|
||||
| `MYSQL_PASSWORD` | 当使用 `mysql` 时必填,MySQL 数据库密码 | `your_db_password` |
|
||||
| `MYSQL_DATABASE` | 当使用 `mysql` 时必填,MySQL 数据库名称 | `defaultdb` |
|
||||
| **API 相关配置** | | |
|
||||
| `API_KEYS` | **必填**, Gemini API 密钥列表,用于负载均衡 | `[]` |
|
||||
| `ALLOWED_TOKENS` | **必填**, 允许访问的 Token 列表 | `[]` |
|
||||
| `AUTH_TOKEN` | 超级管理员 Token,不填则使用 `ALLOWED_TOKENS` 的第一个 | `sk-123456` |
|
||||
| `TEST_MODEL` | 用于测试密钥可用性的模型 | `gemini-2.5-flash-lite` |
|
||||
| `IMAGE_MODELS` | 支持绘图功能的模型列表 | `["gemini-2.0-flash-exp", "gemini-2.5-flash-image-preview"]` |
|
||||
| `SEARCH_MODELS` | 支持搜索功能的模型列表 | `["gemini-2.5-flash","gemini-2.5-pro"]` |
|
||||
| `FILTERED_MODELS` | 被禁用的模型列表 | `[]` |
|
||||
| `TOOLS_CODE_EXECUTION_ENABLED` | 是否启用代码执行工具 | `false` |
|
||||
| `SHOW_SEARCH_LINK` | 是否在响应中显示搜索结果链接 | `true` |
|
||||
| `SHOW_THINKING_PROCESS` | 是否显示模型思考过程 | `true` |
|
||||
| `THINKING_MODELS` | 支持思考功能的模型列表 | `[]` |
|
||||
| `THINKING_BUDGET_MAP` | 思考功能预算映射 (模型名:预算值) | `{}` |
|
||||
| `URL_NORMALIZATION_ENABLED` | 是否启用智能路由映射功能 | `false` |
|
||||
| `URL_CONTEXT_ENABLED` | 是否启用URL上下文理解功能 | `false` |
|
||||
| `URL_CONTEXT_MODELS` | 支持URL上下文理解功能的模型列表 | `[]` |
|
||||
| `BASE_URL` | Gemini API 基础 URL | `https://generativelanguage.googleapis.com/v1beta` |
|
||||
| `MAX_FAILURES` | 单个 Key 允许的最大失败次数 | `3` |
|
||||
| `MAX_RETRIES` | API 请求失败时的最大重试次数 | `3` |
|
||||
| `CHECK_INTERVAL_HOURS` | 禁用 Key 恢复检查间隔 (小时) | `1` |
|
||||
| `TIMEZONE` | 应用程序使用的时区 | `Asia/Shanghai` |
|
||||
| `TIME_OUT` | 请求超时时间 (秒) | `300` |
|
||||
| `PROXIES` | 代理服务器列表 (例如 `http://user:pass@host:port`) | `[]` |
|
||||
| **日志与安全** | | |
|
||||
| `LOG_LEVEL` | 日志级别: `DEBUG`, `INFO`, `WARNING`, `ERROR` | `INFO` |
|
||||
| `ERROR_LOG_RECORD_REQUEST_BODY` | 是否记录错误日志的请求体(可能包含敏感信息) | `false` |
|
||||
| `AUTO_DELETE_ERROR_LOGS_ENABLED` | 是否自动删除错误日志 | `true` |
|
||||
| `AUTO_DELETE_ERROR_LOGS_DAYS` | 错误日志保留天数 | `7` |
|
||||
| `AUTO_DELETE_REQUEST_LOGS_ENABLED`| 是否自动删除请求日志 | `false` |
|
||||
| `AUTO_DELETE_REQUEST_LOGS_DAYS` | 请求日志保留天数 | `30` |
|
||||
| `SAFETY_SETTINGS` | 内容安全阈值 (JSON 字符串) | `[{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"}, ...]` |
|
||||
| **TTS 相关** | | |
|
||||
| `TTS_MODEL` | TTS 模型名称 | `gemini-2.5-flash-preview-tts` |
|
||||
| `TTS_VOICE_NAME` | TTS 语音名称 | `Zephyr` |
|
||||
| `TTS_SPEED` | TTS 语速 | `normal` |
|
||||
| **图像生成相关** | | |
|
||||
| `PAID_KEY` | 付费版API Key,用于图片生成等高级功能 | `your-paid-api-key` |
|
||||
| `CREATE_IMAGE_MODEL` | 图片生成模型 | `imagen-3.0-generate-002` |
|
||||
| `UPLOAD_PROVIDER` | 图片上传提供商: `smms`, `picgo`, `cloudflare_imgbed`, `aliyun_oss` | `smms` |
|
||||
| `OSS_ENDPOINT` | 阿里云 OSS 公网 Endpoint | `oss-cn-shanghai.aliyuncs.com` |
|
||||
| `OSS_ENDPOINT_INNER` | 阿里云 OSS 内网 Endpoint(同 VPC 内网访问) | `oss-cn-shanghai-internal.aliyuncs.com` |
|
||||
| `OSS_ACCESS_KEY` | 阿里云 AccessKey ID | `LTAI5txxxxxxxxxxxxxxxx` |
|
||||
| `OSS_ACCESS_KEY_SECRET` | 阿里云 AccessKey Secret | `yXxxxxxxxxxxxxxxxxxxxxx` |
|
||||
| `OSS_BUCKET_NAME` | 阿里云 OSS Bucket 名称 | `your-bucket-name` |
|
||||
| `OSS_REGION` | 阿里云 OSS 区域 Region | `cn-shanghai` |
|
||||
| `SMMS_SECRET_TOKEN` | SM.MS图床的API Token | `your-smms-token` |
|
||||
| `PICGO_API_KEY` | [PicoGo](https://www.picgo.net/)图床的API Key | `your-picogo-apikey` |
|
||||
| `PICGO_API_URL` | [PicoGo](https://www.picgo.net/)图床的API服务器地址 | `https://www.picgo.net/api/1/upload` |
|
||||
| `CLOUDFLARE_IMGBED_URL` | [CloudFlare](https://github.com/MarSeventh/CloudFlare-ImgBed) 图床上传地址 | `https://xxxxxxx.pages.dev/upload` |
|
||||
| `CLOUDFLARE_IMGBED_AUTH_CODE`| CloudFlare图床的鉴权key | `your-cloudflare-imgber-auth-code` |
|
||||
| `CLOUDFLARE_IMGBED_UPLOAD_FOLDER`| CloudFlare图床的上传文件夹路径 | `""` |
|
||||
| **流式优化器相关** | | |
|
||||
| `STREAM_OPTIMIZER_ENABLED` | 是否启用流式输出优化 | `false` |
|
||||
| `STREAM_MIN_DELAY` | 流式输出最小延迟 | `0.016` |
|
||||
| `STREAM_MAX_DELAY` | 流式输出最大延迟 | `0.024` |
|
||||
| `STREAM_SHORT_TEXT_THRESHOLD`| 短文本阈值 | `10` |
|
||||
| `STREAM_LONG_TEXT_THRESHOLD` | 长文本阈值 | `50` |
|
||||
| `STREAM_CHUNK_SIZE` | 流式输出块大小 | `5` |
|
||||
| **伪流式 (Fake Stream) 相关** | | |
|
||||
| `FAKE_STREAM_ENABLED` | 是否启用伪流式传输 | `false` |
|
||||
| `FAKE_STREAM_EMPTY_DATA_INTERVAL_SECONDS` | 伪流式传输时发送心跳空数据的间隔秒数 | `5` |
|
||||
|
||||
</details>
|
||||
|
||||
---
|
||||
|
||||
## 🤝 贡献
|
||||
|
||||
欢迎提交 Pull Request 或 Issue。
|
||||
|
||||
## 🎉 特别鸣谢
|
||||
|
||||
特别鸣谢以下项目和平台为本项目提供图床服务:
|
||||
|
||||
* [PicGo](https://www.picgo.net/)
|
||||
* [SM.MS](https://smms.app/)
|
||||
* [CloudFlare-ImgBed](https://github.com/MarSeventh/CloudFlare-ImgBed) 开源项目
|
||||
|
||||
## 🙏 感谢贡献者
|
||||
|
||||
感谢所有为本项目做出贡献的开发者!
|
||||
欢迎通过提交 Pull Request 或 Issue 来为项目做出贡献。
|
||||
|
||||
[](https://github.com/snailyp/gemini-balance/graphs/contributors)
|
||||
|
||||
@@ -258,9 +256,15 @@ app/
|
||||
|
||||
[](https://star-history.com/#snailyp/gemini-balance&Date)
|
||||
|
||||
## 🎉 特别鸣谢
|
||||
|
||||
* [PicGo](https://www.picgo.net/)
|
||||
* [SM.MS](https://smms.app/)
|
||||
* [CloudFlare-ImgBed](https://github.com/MarSeventh/CloudFlare-ImgBed)
|
||||
|
||||
## 💖 友情项目
|
||||
|
||||
* **[OneLine](https://github.com/chengtx809/OneLine)** by [chengtx809](https://github.com/chengtx809) - OneLine一线:AI驱动的热点事件时间轴生成工具
|
||||
* **[OneLine](https://github.com/chengtx809/OneLine)** by [chengtx809](https://github.com/chengtx809) - AI 驱动的热点事件时间轴生成工具。
|
||||
|
||||
## 🎁 项目支持
|
||||
|
||||
@@ -268,4 +272,19 @@ app/
|
||||
|
||||
## 许可证
|
||||
|
||||
本项目采用 CC BY-NC 4.0(署名-非商业性使用)协议,禁止任何形式的商业倒卖服务,详见 LICENSE 文件。
|
||||
本项目采用 [CC BY-NC 4.0](LICENSE)(署名-非商业性使用)协议。
|
||||
|
||||
|
||||
## 赞助商
|
||||
|
||||
特别感谢 [DigitalOcean](https://m.do.co/c/b249dd7f3b4c) 为本项目提供稳定可靠的云基础设施支持。
|
||||
|
||||
<a href="https://m.do.co/c/b249dd7f3b4c">
|
||||
<img src="files/dataocean.svg" alt="DigitalOcean Logo" width="200"/>
|
||||
</a>
|
||||
|
||||
本项目的 CDN 加速和安全防护由 [Tencent EdgeOne](https://edgeone.ai/?from=github) 赞助。
|
||||
|
||||
<a href="https://edgeone.ai/?from=github">
|
||||
<img src="https://edgeone.ai/media/34fe3a45-492d-4ea4-ae5d-ea1087ca7b4b.png" alt="EdgeOne Logo" width="200"/>
|
||||
</a>
|
||||
|
||||
@@ -4,9 +4,9 @@
|
||||
|
||||
import datetime
|
||||
import json
|
||||
from typing import Any, Dict, List, Type
|
||||
from typing import Any, Dict, List, Type, get_args, get_origin
|
||||
|
||||
from pydantic import ValidationError, ValidationInfo, field_validator
|
||||
from pydantic import Field, ValidationError, ValidationInfo, field_validator
|
||||
from pydantic_settings import BaseSettings
|
||||
from sqlalchemy import insert, select, update
|
||||
|
||||
@@ -51,8 +51,8 @@ class Settings(BaseSettings):
|
||||
return v
|
||||
|
||||
# API相关配置
|
||||
API_KEYS: List[str]
|
||||
ALLOWED_TOKENS: List[str]
|
||||
API_KEYS: List[str] = []
|
||||
ALLOWED_TOKENS: List[str] = []
|
||||
BASE_URL: str = f"https://generativelanguage.googleapis.com/{API_VERSION}"
|
||||
AUTH_TOKEN: str = ""
|
||||
MAX_FAILURES: int = 3
|
||||
@@ -62,16 +62,30 @@ class Settings(BaseSettings):
|
||||
PROXIES: List[str] = []
|
||||
PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY: bool = True # 是否使用一致性哈希来选择代理
|
||||
VERTEX_API_KEYS: List[str] = []
|
||||
VERTEX_EXPRESS_BASE_URL: str = "https://aiplatform.googleapis.com/v1beta1/publishers/google"
|
||||
VERTEX_EXPRESS_BASE_URL: str = (
|
||||
"https://aiplatform.googleapis.com/v1beta1/publishers/google"
|
||||
)
|
||||
|
||||
# 智能路由配置
|
||||
URL_NORMALIZATION_ENABLED: bool = False # 是否启用智能路由映射功能
|
||||
|
||||
# 自定义 Headers
|
||||
CUSTOM_HEADERS: Dict[str, str] = {}
|
||||
|
||||
# 模型相关配置
|
||||
SEARCH_MODELS: List[str] = ["gemini-2.0-flash-exp"]
|
||||
IMAGE_MODELS: List[str] = ["gemini-2.0-flash-exp"]
|
||||
SEARCH_MODELS: List[str] = ["gemini-2.5-flash", "gemini-2.5-pro"]
|
||||
IMAGE_MODELS: List[str] = ["gemini-2.0-flash-exp", "gemini-2.5-flash-image-preview"]
|
||||
FILTERED_MODELS: List[str] = DEFAULT_FILTER_MODELS
|
||||
TOOLS_CODE_EXECUTION_ENABLED: bool = False
|
||||
# 是否启用网址上下文
|
||||
URL_CONTEXT_ENABLED: bool = False
|
||||
URL_CONTEXT_MODELS: List[str] = [
|
||||
"gemini-2.5-pro",
|
||||
"gemini-2.5-flash",
|
||||
"gemini-2.5-flash-lite",
|
||||
"gemini-2.0-flash",
|
||||
"gemini-2.0-flash-live-001",
|
||||
]
|
||||
SHOW_SEARCH_LINK: bool = True
|
||||
SHOW_THINKING_PROCESS: bool = True
|
||||
THINKING_MODELS: List[str] = []
|
||||
@@ -88,8 +102,17 @@ class Settings(BaseSettings):
|
||||
UPLOAD_PROVIDER: str = "smms"
|
||||
SMMS_SECRET_TOKEN: str = ""
|
||||
PICGO_API_KEY: str = ""
|
||||
PICGO_API_URL: str = "https://www.picgo.net/api/1/upload"
|
||||
CLOUDFLARE_IMGBED_URL: str = ""
|
||||
CLOUDFLARE_IMGBED_AUTH_CODE: str = ""
|
||||
CLOUDFLARE_IMGBED_UPLOAD_FOLDER: str = ""
|
||||
# 阿里云OSS配置
|
||||
OSS_ENDPOINT: str = ""
|
||||
OSS_ENDPOINT_INNER: str = ""
|
||||
OSS_ACCESS_KEY: str = ""
|
||||
OSS_ACCESS_KEY_SECRET: str = ""
|
||||
OSS_BUCKET_NAME: str = ""
|
||||
OSS_REGION: str = ""
|
||||
|
||||
# 流式输出优化器配置
|
||||
STREAM_OPTIMIZER_ENABLED: bool = False
|
||||
@@ -113,12 +136,25 @@ class Settings(BaseSettings):
|
||||
|
||||
# 日志配置
|
||||
LOG_LEVEL: str = "INFO"
|
||||
ERROR_LOG_RECORD_REQUEST_BODY: bool = False
|
||||
AUTO_DELETE_ERROR_LOGS_ENABLED: bool = True
|
||||
AUTO_DELETE_ERROR_LOGS_DAYS: int = 7
|
||||
AUTO_DELETE_REQUEST_LOGS_ENABLED: bool = False
|
||||
AUTO_DELETE_REQUEST_LOGS_DAYS: int = 30
|
||||
SAFETY_SETTINGS: List[Dict[str, str]] = DEFAULT_SAFETY_SETTINGS
|
||||
|
||||
# Files API
|
||||
FILES_CLEANUP_ENABLED: bool = True
|
||||
FILES_CLEANUP_INTERVAL_HOURS: int = 1
|
||||
FILES_USER_ISOLATION_ENABLED: bool = True
|
||||
|
||||
# Admin Session Configuration
|
||||
ADMIN_SESSION_EXPIRE: int = Field(
|
||||
default=3600,
|
||||
ge=300,
|
||||
le=86400,
|
||||
description="Admin session expiration time in seconds (5 minutes to 24 hours)",
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
@@ -137,86 +173,112 @@ def _parse_db_value(key: str, db_value: str, target_type: Type) -> Any:
|
||||
|
||||
logger = get_config_logger()
|
||||
try:
|
||||
# 处理 List[str]
|
||||
if target_type == List[str]:
|
||||
try:
|
||||
parsed = json.loads(db_value)
|
||||
if isinstance(parsed, list):
|
||||
return [str(item) for item in parsed]
|
||||
except json.JSONDecodeError:
|
||||
origin_type = get_origin(target_type)
|
||||
args = get_args(target_type)
|
||||
|
||||
# 处理 List 类型
|
||||
if origin_type is list:
|
||||
# 处理 List[str]
|
||||
if args and args[0] == str:
|
||||
try:
|
||||
parsed = json.loads(db_value)
|
||||
if isinstance(parsed, list):
|
||||
return [str(item) for item in parsed]
|
||||
except json.JSONDecodeError:
|
||||
return [
|
||||
item.strip() for item in db_value.split(",") if item.strip()
|
||||
]
|
||||
logger.warning(
|
||||
f"Could not parse '{db_value}' as List[str] for key '{key}', falling back to comma split or empty list."
|
||||
)
|
||||
return [item.strip() for item in db_value.split(",") if item.strip()]
|
||||
logger.warning(
|
||||
f"Could not parse '{db_value}' as List[str] for key '{key}', falling back to comma split or empty list."
|
||||
)
|
||||
return [item.strip() for item in db_value.split(",") if item.strip()]
|
||||
# 处理 Dict[str, float]
|
||||
elif target_type == Dict[str, float]:
|
||||
parsed_dict = {}
|
||||
try:
|
||||
parsed = json.loads(db_value)
|
||||
if isinstance(parsed, dict):
|
||||
parsed_dict = {str(k): float(v) for k, v in parsed.items()}
|
||||
else:
|
||||
logger.warning(
|
||||
f"Parsed DB value for key '{key}' is not a dictionary type. Value: {db_value}"
|
||||
)
|
||||
except (json.JSONDecodeError, ValueError, TypeError) as e1:
|
||||
if isinstance(e1, json.JSONDecodeError) and "'" in db_value:
|
||||
logger.warning(
|
||||
f"Failed initial JSON parse for key '{key}'. Attempting to replace single quotes. Error: {e1}"
|
||||
)
|
||||
try:
|
||||
corrected_db_value = db_value.replace("'", '"')
|
||||
parsed = json.loads(corrected_db_value)
|
||||
if isinstance(parsed, dict):
|
||||
parsed_dict = {str(k): float(v) for k, v in parsed.items()}
|
||||
# 处理 List[Dict[str, str]]
|
||||
elif args and get_origin(args[0]) is dict:
|
||||
try:
|
||||
parsed = json.loads(db_value)
|
||||
if isinstance(parsed, list):
|
||||
valid = all(
|
||||
isinstance(item, dict)
|
||||
and all(isinstance(k, str) for k in item.keys())
|
||||
and all(isinstance(v, str) for v in item.values())
|
||||
for item in parsed
|
||||
)
|
||||
if valid:
|
||||
return parsed
|
||||
else:
|
||||
logger.warning(
|
||||
f"Parsed DB value (after quote replacement) for key '{key}' is not a dictionary type. Value: {corrected_db_value}"
|
||||
f"Invalid structure in List[Dict[str, str]] for key '{key}'. Value: {db_value}"
|
||||
)
|
||||
except (json.JSONDecodeError, ValueError, TypeError) as e2:
|
||||
logger.error(
|
||||
f"Could not parse '{db_value}' as Dict[str, float] for key '{key}' even after replacing quotes: {e2}. Returning empty dict."
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"Could not parse '{db_value}' as Dict[str, float] for key '{key}': {e1}. Returning empty dict."
|
||||
)
|
||||
return parsed_dict
|
||||
# 处理 List[Dict[str, str]]
|
||||
elif target_type == List[Dict[str, str]]:
|
||||
try:
|
||||
parsed = json.loads(db_value)
|
||||
if isinstance(parsed, list):
|
||||
# 验证列表中的每个元素是否为字典,并且键和值都是字符串
|
||||
valid = all(
|
||||
isinstance(item, dict)
|
||||
and all(isinstance(k, str) for k in item.keys())
|
||||
and all(isinstance(v, str) for v in item.values())
|
||||
for item in parsed
|
||||
)
|
||||
if valid:
|
||||
return parsed
|
||||
return []
|
||||
else:
|
||||
logger.warning(
|
||||
f"Invalid structure in List[Dict[str, str]] for key '{key}'. Value: {db_value}"
|
||||
f"Parsed DB value for key '{key}' is not a list type. Value: {db_value}"
|
||||
)
|
||||
return []
|
||||
else:
|
||||
logger.warning(
|
||||
f"Parsed DB value for key '{key}' is not a list type. Value: {db_value}"
|
||||
except json.JSONDecodeError:
|
||||
logger.error(
|
||||
f"Could not parse '{db_value}' as JSON for List[Dict[str, str]] for key '{key}'. Returning empty list."
|
||||
)
|
||||
return []
|
||||
except json.JSONDecodeError:
|
||||
logger.error(
|
||||
f"Could not parse '{db_value}' as JSON for List[Dict[str, str]] for key '{key}'. Returning empty list."
|
||||
)
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error parsing List[Dict[str, str]] for key '{key}': {e}. Value: {db_value}. Returning empty list."
|
||||
)
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error parsing List[Dict[str, str]] for key '{key}': {e}. Value: {db_value}. Returning empty list."
|
||||
)
|
||||
return []
|
||||
# 处理 Dict 类型
|
||||
elif origin_type is dict:
|
||||
# 处理 Dict[str, str]
|
||||
if args and args == (str, str):
|
||||
parsed_dict = {}
|
||||
try:
|
||||
parsed = json.loads(db_value)
|
||||
if isinstance(parsed, dict):
|
||||
parsed_dict = {str(k): str(v) for k, v in parsed.items()}
|
||||
else:
|
||||
logger.warning(
|
||||
f"Parsed DB value for key '{key}' is not a dictionary type. Value: {db_value}"
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
logger.error(
|
||||
f"Could not parse '{db_value}' as Dict[str, str] for key '{key}'. Returning empty dict."
|
||||
)
|
||||
return parsed_dict
|
||||
# 处理 Dict[str, float]
|
||||
elif args and args == (str, float):
|
||||
parsed_dict = {}
|
||||
try:
|
||||
parsed = json.loads(db_value)
|
||||
if isinstance(parsed, dict):
|
||||
parsed_dict = {str(k): float(v) for k, v in parsed.items()}
|
||||
else:
|
||||
logger.warning(
|
||||
f"Parsed DB value for key '{key}' is not a dictionary type. Value: {db_value}"
|
||||
)
|
||||
except (json.JSONDecodeError, ValueError, TypeError) as e1:
|
||||
if isinstance(e1, json.JSONDecodeError) and "'" in db_value:
|
||||
logger.warning(
|
||||
f"Failed initial JSON parse for key '{key}'. Attempting to replace single quotes. Error: {e1}"
|
||||
)
|
||||
try:
|
||||
corrected_db_value = db_value.replace("'", '"')
|
||||
parsed = json.loads(corrected_db_value)
|
||||
if isinstance(parsed, dict):
|
||||
parsed_dict = {
|
||||
str(k): float(v) for k, v in parsed.items()
|
||||
}
|
||||
else:
|
||||
logger.warning(
|
||||
f"Parsed DB value (after quote replacement) for key '{key}' is not a dictionary type. Value: {corrected_db_value}"
|
||||
)
|
||||
except (json.JSONDecodeError, ValueError, TypeError) as e2:
|
||||
logger.error(
|
||||
f"Could not parse '{db_value}' as Dict[str, float] for key '{key}' even after replacing quotes: {e2}. Returning empty dict."
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"Could not parse '{db_value}' as Dict[str, float] for key '{key}': {e1}. Returning empty dict."
|
||||
)
|
||||
return parsed_dict
|
||||
# 处理 bool
|
||||
elif target_type == bool:
|
||||
return db_value.lower() in ("true", "1", "yes", "on")
|
||||
@@ -305,18 +367,12 @@ async def sync_initial_settings():
|
||||
if parsed_db_value != memory_value:
|
||||
# 检查类型是否匹配,以防解析函数返回了不兼容的类型
|
||||
type_match = False
|
||||
if target_type == List[str] and isinstance(
|
||||
parsed_db_value, list
|
||||
):
|
||||
type_match = True
|
||||
elif target_type == Dict[str, float] and isinstance(
|
||||
parsed_db_value, dict
|
||||
):
|
||||
type_match = True
|
||||
elif target_type not in (
|
||||
List[str],
|
||||
Dict[str, float],
|
||||
) and isinstance(parsed_db_value, target_type):
|
||||
origin_type = get_origin(target_type)
|
||||
if origin_type: # It's a generic type
|
||||
if isinstance(parsed_db_value, origin_type):
|
||||
type_match = True
|
||||
# It's a non-generic type, or a specific generic we want to handle
|
||||
elif isinstance(parsed_db_value, target_type):
|
||||
type_match = True
|
||||
|
||||
if type_match:
|
||||
@@ -370,9 +426,7 @@ async def sync_initial_settings():
|
||||
|
||||
# 序列化值为字符串或 JSON 字符串
|
||||
if isinstance(value, (list, dict)):
|
||||
db_value = json.dumps(
|
||||
value, ensure_ascii=False
|
||||
)
|
||||
db_value = json.dumps(value, ensure_ascii=False)
|
||||
elif isinstance(value, bool):
|
||||
db_value = str(value).lower()
|
||||
elif value is None:
|
||||
|
||||
@@ -9,7 +9,7 @@ from app.config.config import settings, sync_initial_settings
|
||||
from app.database.connection import connect_to_db, disconnect_from_db
|
||||
from app.database.initialization import initialize_database
|
||||
from app.exception.exceptions import setup_exception_handlers
|
||||
from app.log.logger import get_application_logger
|
||||
from app.log.logger import get_application_logger, setup_access_logging
|
||||
from app.middleware.middleware import setup_middlewares
|
||||
from app.router.routes import setup_routers
|
||||
from app.scheduler.scheduled_tasks import start_scheduler, stop_scheduler
|
||||
@@ -150,4 +150,7 @@ def create_app() -> FastAPI:
|
||||
# 配置路由
|
||||
setup_routers(app)
|
||||
|
||||
# 配置访问日志API密钥隐藏
|
||||
setup_access_logging()
|
||||
|
||||
return app
|
||||
|
||||
@@ -9,25 +9,25 @@ MAX_RETRIES = 3 # 最大重试次数
|
||||
|
||||
# 模型相关常量
|
||||
SUPPORTED_ROLES = ["user", "model", "system"]
|
||||
DEFAULT_MODEL = "gemini-1.5-flash"
|
||||
DEFAULT_MODEL = "gemini-2.5-flash-lite"
|
||||
DEFAULT_TEMPERATURE = 0.7
|
||||
DEFAULT_MAX_TOKENS = 8192
|
||||
DEFAULT_TOP_P = 0.9
|
||||
DEFAULT_TOP_K = 40
|
||||
DEFAULT_FILTER_MODELS = [
|
||||
"gemini-1.0-pro-vision-latest",
|
||||
"gemini-pro-vision",
|
||||
"chat-bison-001",
|
||||
"text-bison-001",
|
||||
"embedding-gecko-001"
|
||||
]
|
||||
"gemini-1.0-pro-vision-latest",
|
||||
"gemini-pro-vision",
|
||||
"chat-bison-001",
|
||||
"text-bison-001",
|
||||
"embedding-gecko-001",
|
||||
]
|
||||
DEFAULT_CREATE_IMAGE_MODEL = "imagen-3.0-generate-002"
|
||||
|
||||
# 图像生成相关常量
|
||||
VALID_IMAGE_RATIOS = ["1:1", "3:4", "4:3", "9:16", "16:9"]
|
||||
|
||||
# 上传提供商
|
||||
UPLOAD_PROVIDERS = ["smms", "picgo", "cloudflare_imgbed"]
|
||||
UPLOAD_PROVIDERS = ["smms", "picgo", "cloudflare_imgbed", "aliyun_oss"]
|
||||
DEFAULT_UPLOAD_PROVIDER = "smms"
|
||||
|
||||
# 流式输出相关常量
|
||||
@@ -38,14 +38,14 @@ DEFAULT_STREAM_LONG_TEXT_THRESHOLD = 50
|
||||
DEFAULT_STREAM_CHUNK_SIZE = 5
|
||||
|
||||
# 正则表达式模式
|
||||
IMAGE_URL_PATTERN = r'!\[(.*?)\]\((.*?)\)'
|
||||
DATA_URL_PATTERN = r'data:([^;]+);base64,(.+)'
|
||||
IMAGE_URL_PATTERN = r"!\[(.*?)\]\((.*?)\)"
|
||||
DATA_URL_PATTERN = r"data:([^;]+);base64,(.+)"
|
||||
|
||||
# Audio/Video Settings
|
||||
SUPPORTED_AUDIO_FORMATS = ["wav", "mp3", "flac", "ogg"]
|
||||
SUPPORTED_VIDEO_FORMATS = ["mp4", "mov", "avi", "webm"]
|
||||
MAX_AUDIO_SIZE_BYTES = 50 * 1024 * 1024 # Example: 50MB limit for Base64 payload
|
||||
MAX_VIDEO_SIZE_BYTES = 200 * 1024 * 1024 # Example: 200MB limit
|
||||
MAX_VIDEO_SIZE_BYTES = 200 * 1024 * 1024 # Example: 200MB limit
|
||||
|
||||
# Optional: Define MIME type mappings if needed, or handle directly in converter
|
||||
AUDIO_FORMAT_TO_MIMETYPE = {
|
||||
@@ -63,17 +63,50 @@ VIDEO_FORMAT_TO_MIMETYPE = {
|
||||
}
|
||||
|
||||
GEMINI_2_FLASH_EXP_SAFETY_SETTINGS = [
|
||||
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "OFF"},
|
||||
]
|
||||
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "OFF"},
|
||||
]
|
||||
|
||||
DEFAULT_SAFETY_SETTINGS = [
|
||||
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"},
|
||||
]
|
||||
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"},
|
||||
]
|
||||
|
||||
TTS_VOICE_NAMES = [
|
||||
"Zephyr",
|
||||
"Puck",
|
||||
"Charon",
|
||||
"Kore",
|
||||
"Fenrir",
|
||||
"Leda",
|
||||
"Orus",
|
||||
"Aoede",
|
||||
"Callirrhoe",
|
||||
"Autonoe",
|
||||
"Enceladus",
|
||||
"Iapetus",
|
||||
"Umbriel",
|
||||
"Algieba",
|
||||
"Despina",
|
||||
"Erinome",
|
||||
"Algenib",
|
||||
"Rasalgethi",
|
||||
"Laomedeia",
|
||||
"Achernar",
|
||||
"Alnilam",
|
||||
"Schedar",
|
||||
"Gacrux",
|
||||
"Pulcherrima",
|
||||
"Achird",
|
||||
"Zubenelgenubi",
|
||||
"Vindemiatrix",
|
||||
"Sadachbia",
|
||||
"Sadaltager",
|
||||
"Sulafat",
|
||||
]
|
||||
|
||||
@@ -2,7 +2,8 @@
|
||||
数据库模型模块
|
||||
"""
|
||||
import datetime
|
||||
from sqlalchemy import Column, Integer, String, Text, DateTime, JSON, Boolean
|
||||
from sqlalchemy import Column, Integer, String, Text, DateTime, JSON, Boolean, BigInteger, Enum
|
||||
import enum
|
||||
|
||||
from app.database.connection import Base
|
||||
|
||||
@@ -60,3 +61,69 @@ class RequestLog(Base):
|
||||
|
||||
def __repr__(self):
|
||||
return f"<RequestLog(id='{self.id}', key='{self.api_key[:4]}...', success='{self.is_success}')>"
|
||||
|
||||
|
||||
class FileState(enum.Enum):
|
||||
"""文件状态枚举"""
|
||||
PROCESSING = "PROCESSING"
|
||||
ACTIVE = "ACTIVE"
|
||||
FAILED = "FAILED"
|
||||
|
||||
|
||||
class FileRecord(Base):
|
||||
"""
|
||||
文件记录表,用于存储上传到 Gemini 的文件信息
|
||||
"""
|
||||
__tablename__ = "t_file_records"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
|
||||
# 文件基本信息
|
||||
name = Column(String(255), unique=True, nullable=False, comment="文件名称,格式: files/{file_id}")
|
||||
display_name = Column(String(255), nullable=True, comment="用户上传时的原始文件名")
|
||||
mime_type = Column(String(100), nullable=False, comment="MIME 类型")
|
||||
size_bytes = Column(BigInteger, nullable=False, comment="文件大小(字节)")
|
||||
sha256_hash = Column(String(255), nullable=True, comment="文件的 SHA256 哈希值")
|
||||
|
||||
# 状态信息
|
||||
state = Column(Enum(FileState), nullable=False, default=FileState.PROCESSING, comment="文件状态")
|
||||
|
||||
# 时间戳
|
||||
create_time = Column(DateTime, nullable=False, comment="创建时间")
|
||||
update_time = Column(DateTime, nullable=False, comment="更新时间")
|
||||
expiration_time = Column(DateTime, nullable=False, comment="过期时间")
|
||||
|
||||
# API 相关
|
||||
uri = Column(String(500), nullable=False, comment="文件访问 URI")
|
||||
api_key = Column(String(100), nullable=False, comment="上传时使用的 API Key")
|
||||
upload_url = Column(Text, nullable=True, comment="临时上传 URL(用于分块上传)")
|
||||
|
||||
# 额外信息
|
||||
user_token = Column(String(100), nullable=True, comment="上传用户的 token")
|
||||
upload_completed = Column(DateTime, nullable=True, comment="上传完成时间")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<FileRecord(name='{self.name}', state='{self.state.value if self.state else 'None'}', api_key='{self.api_key[:8]}...')>"
|
||||
|
||||
def to_dict(self):
|
||||
"""转换为字典格式,用于 API 响应"""
|
||||
return {
|
||||
"name": self.name,
|
||||
"displayName": self.display_name,
|
||||
"mimeType": self.mime_type,
|
||||
"sizeBytes": str(self.size_bytes),
|
||||
"createTime": self.create_time.isoformat() + "Z",
|
||||
"updateTime": self.update_time.isoformat() + "Z",
|
||||
"expirationTime": self.expiration_time.isoformat() + "Z",
|
||||
"sha256Hash": self.sha256_hash,
|
||||
"uri": self.uri,
|
||||
"state": self.state.value if self.state else "PROCESSING"
|
||||
}
|
||||
|
||||
def is_expired(self):
|
||||
"""检查文件是否已过期"""
|
||||
# 确保比较时都是 timezone-aware
|
||||
expiration_time = self.expiration_time
|
||||
if expiration_time.tzinfo is None:
|
||||
expiration_time = expiration_time.replace(tzinfo=datetime.timezone.utc)
|
||||
return datetime.datetime.now(datetime.timezone.utc) > expiration_time
|
||||
|
||||
@@ -1,13 +1,18 @@
|
||||
"""
|
||||
数据库服务模块
|
||||
"""
|
||||
from typing import List, Optional, Dict, Any, Union
|
||||
from datetime import datetime
|
||||
from sqlalchemy import func, desc, asc, select, insert, update, delete
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from sqlalchemy import asc, delete, desc, func, insert, select, update
|
||||
|
||||
from app.database.connection import database
|
||||
from app.database.models import Settings, ErrorLog, RequestLog
|
||||
from app.database.models import ErrorLog, FileRecord, FileState, RequestLog, Settings
|
||||
from app.log.logger import get_database_logger
|
||||
from app.utils.helpers import redact_key_for_logging
|
||||
|
||||
logger = get_database_logger()
|
||||
|
||||
@@ -15,7 +20,7 @@ logger = get_database_logger()
|
||||
async def get_all_settings() -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取所有设置
|
||||
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 设置列表
|
||||
"""
|
||||
@@ -31,10 +36,10 @@ async def get_all_settings() -> List[Dict[str, Any]]:
|
||||
async def get_setting(key: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
获取指定键的设置
|
||||
|
||||
|
||||
Args:
|
||||
key: 设置键名
|
||||
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: 设置信息,如果不存在则返回None
|
||||
"""
|
||||
@@ -47,22 +52,24 @@ async def get_setting(key: str) -> Optional[Dict[str, Any]]:
|
||||
raise
|
||||
|
||||
|
||||
async def update_setting(key: str, value: str, description: Optional[str] = None) -> bool:
|
||||
async def update_setting(
|
||||
key: str, value: str, description: Optional[str] = None
|
||||
) -> bool:
|
||||
"""
|
||||
更新设置
|
||||
|
||||
|
||||
Args:
|
||||
key: 设置键名
|
||||
value: 设置值
|
||||
description: 设置描述
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 是否更新成功
|
||||
"""
|
||||
try:
|
||||
# 检查设置是否存在
|
||||
setting = await get_setting(key)
|
||||
|
||||
|
||||
if setting:
|
||||
# 更新设置
|
||||
query = (
|
||||
@@ -71,7 +78,7 @@ async def update_setting(key: str, value: str, description: Optional[str] = None
|
||||
.values(
|
||||
value=value,
|
||||
description=description if description else setting["description"],
|
||||
updated_at=datetime.now()
|
||||
updated_at=datetime.now(),
|
||||
)
|
||||
)
|
||||
await database.execute(query)
|
||||
@@ -79,15 +86,12 @@ async def update_setting(key: str, value: str, description: Optional[str] = None
|
||||
return True
|
||||
else:
|
||||
# 插入设置
|
||||
query = (
|
||||
insert(Settings)
|
||||
.values(
|
||||
key=key,
|
||||
value=value,
|
||||
description=description,
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now()
|
||||
)
|
||||
query = insert(Settings).values(
|
||||
key=key,
|
||||
value=value,
|
||||
description=description,
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
)
|
||||
await database.execute(query)
|
||||
logger.info(f"Inserted setting: {key}")
|
||||
@@ -103,47 +107,48 @@ async def add_error_log(
|
||||
error_type: Optional[str] = None,
|
||||
error_log: Optional[str] = None,
|
||||
error_code: Optional[int] = None,
|
||||
request_msg: Optional[Union[Dict[str, Any], str]] = None
|
||||
request_msg: Optional[Union[Dict[str, Any], str]] = None,
|
||||
request_datetime: Optional[datetime] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
添加错误日志
|
||||
|
||||
|
||||
Args:
|
||||
gemini_key: Gemini API密钥
|
||||
error_log: 错误日志
|
||||
error_code: 错误代码 (例如 HTTP 状态码)
|
||||
request_msg: 请求消息
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 是否添加成功
|
||||
"""
|
||||
try:
|
||||
# 如果request_msg是字典,则转换为JSON字符串
|
||||
if isinstance(request_msg, dict):
|
||||
request_msg_json = request_msg
|
||||
elif isinstance(request_msg, str):
|
||||
try:
|
||||
request_msg_json = json.loads(request_msg)
|
||||
except json.JSONDecodeError:
|
||||
request_msg_json = {"message": request_msg}
|
||||
else:
|
||||
if request_msg is None:
|
||||
request_msg_json = None
|
||||
|
||||
else:
|
||||
# 如果request_msg是字典,则转换为JSON字符串
|
||||
if isinstance(request_msg, dict):
|
||||
request_msg_json = request_msg
|
||||
elif isinstance(request_msg, str):
|
||||
try:
|
||||
request_msg_json = json.loads(request_msg)
|
||||
except json.JSONDecodeError:
|
||||
request_msg_json = {"message": request_msg}
|
||||
else:
|
||||
request_msg_json = None
|
||||
|
||||
# 插入错误日志
|
||||
query = (
|
||||
insert(ErrorLog)
|
||||
.values(
|
||||
gemini_key=gemini_key,
|
||||
error_type=error_type,
|
||||
error_log=error_log,
|
||||
model_name=model_name,
|
||||
error_code=error_code,
|
||||
request_msg=request_msg_json,
|
||||
request_time=datetime.now()
|
||||
)
|
||||
query = insert(ErrorLog).values(
|
||||
gemini_key=gemini_key,
|
||||
error_type=error_type,
|
||||
error_log=error_log,
|
||||
model_name=model_name,
|
||||
error_code=error_code,
|
||||
request_msg=request_msg_json,
|
||||
request_time=(request_datetime if request_datetime else datetime.now()),
|
||||
)
|
||||
await database.execute(query)
|
||||
logger.info(f"Added error log for key: {gemini_key}")
|
||||
logger.info(f"Added error log for key: {redact_key_for_logging(gemini_key)}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add error log: {str(e)}")
|
||||
@@ -158,8 +163,8 @@ async def get_error_logs(
|
||||
error_code_search: Optional[str] = None,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
sort_by: str = 'id',
|
||||
sort_order: str = 'desc'
|
||||
sort_by: str = "id",
|
||||
sort_order: str = "desc",
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取错误日志,支持搜索、日期过滤和排序
|
||||
@@ -186,15 +191,15 @@ async def get_error_logs(
|
||||
ErrorLog.error_type,
|
||||
ErrorLog.error_log,
|
||||
ErrorLog.error_code,
|
||||
ErrorLog.request_time
|
||||
ErrorLog.request_time,
|
||||
)
|
||||
|
||||
|
||||
if key_search:
|
||||
query = query.where(ErrorLog.gemini_key.ilike(f"%{key_search}%"))
|
||||
if error_search:
|
||||
query = query.where(
|
||||
(ErrorLog.error_type.ilike(f"%{error_search}%")) |
|
||||
(ErrorLog.error_log.ilike(f"%{error_search}%"))
|
||||
(ErrorLog.error_type.ilike(f"%{error_search}%"))
|
||||
| (ErrorLog.error_log.ilike(f"%{error_search}%"))
|
||||
)
|
||||
if start_date:
|
||||
query = query.where(ErrorLog.request_time >= start_date)
|
||||
@@ -205,10 +210,12 @@ async def get_error_logs(
|
||||
error_code_int = int(error_code_search)
|
||||
query = query.where(ErrorLog.error_code == error_code_int)
|
||||
except ValueError:
|
||||
logger.warning(f"Invalid format for error_code_search: '{error_code_search}'. Expected an integer. Skipping error code filter.")
|
||||
logger.warning(
|
||||
f"Invalid format for error_code_search: '{error_code_search}'. Expected an integer. Skipping error code filter."
|
||||
)
|
||||
|
||||
sort_column = getattr(ErrorLog, sort_by, ErrorLog.id)
|
||||
if sort_order.lower() == 'asc':
|
||||
if sort_order.lower() == "asc":
|
||||
query = query.order_by(asc(sort_column))
|
||||
else:
|
||||
query = query.order_by(desc(sort_column))
|
||||
@@ -227,7 +234,7 @@ async def get_error_logs_count(
|
||||
error_search: Optional[str] = None,
|
||||
error_code_search: Optional[str] = None,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None
|
||||
end_date: Optional[datetime] = None,
|
||||
) -> int:
|
||||
"""
|
||||
获取符合条件的错误日志总数
|
||||
@@ -249,8 +256,8 @@ async def get_error_logs_count(
|
||||
query = query.where(ErrorLog.gemini_key.ilike(f"%{key_search}%"))
|
||||
if error_search:
|
||||
query = query.where(
|
||||
(ErrorLog.error_type.ilike(f"%{error_search}%")) |
|
||||
(ErrorLog.error_log.ilike(f"%{error_search}%"))
|
||||
(ErrorLog.error_type.ilike(f"%{error_search}%"))
|
||||
| (ErrorLog.error_log.ilike(f"%{error_search}%"))
|
||||
)
|
||||
if start_date:
|
||||
query = query.where(ErrorLog.request_time >= start_date)
|
||||
@@ -261,8 +268,9 @@ async def get_error_logs_count(
|
||||
error_code_int = int(error_code_search)
|
||||
query = query.where(ErrorLog.error_code == error_code_int)
|
||||
except ValueError:
|
||||
logger.warning(f"Invalid format for error_code_search in count: '{error_code_search}'. Expected an integer. Skipping error code filter.")
|
||||
|
||||
logger.warning(
|
||||
f"Invalid format for error_code_search in count: '{error_code_search}'. Expected an integer. Skipping error code filter."
|
||||
)
|
||||
|
||||
count_result = await database.fetch_one(query)
|
||||
return count_result[0] if count_result else 0
|
||||
@@ -288,12 +296,14 @@ async def get_error_log_details(log_id: int) -> Optional[Dict[str, Any]]:
|
||||
if result:
|
||||
# 将 request_msg (JSONB) 转换为字符串以便在 API 中返回
|
||||
log_dict = dict(result)
|
||||
if 'request_msg' in log_dict and log_dict['request_msg'] is not None:
|
||||
if "request_msg" in log_dict and log_dict["request_msg"] is not None:
|
||||
# 确保即使是 None 或非 JSON 数据也能处理
|
||||
try:
|
||||
log_dict['request_msg'] = json.dumps(log_dict['request_msg'], ensure_ascii=False, indent=2)
|
||||
log_dict["request_msg"] = json.dumps(
|
||||
log_dict["request_msg"], ensure_ascii=False, indent=2
|
||||
)
|
||||
except TypeError:
|
||||
log_dict['request_msg'] = str(log_dict['request_msg'])
|
||||
log_dict["request_msg"] = str(log_dict["request_msg"])
|
||||
return log_dict
|
||||
else:
|
||||
return None
|
||||
@@ -302,6 +312,78 @@ async def get_error_log_details(log_id: int) -> Optional[Dict[str, Any]]:
|
||||
raise
|
||||
|
||||
|
||||
# 新增函数:通过 gemini_key / error_code / 时间窗口 查找最接近的错误日志
|
||||
async def find_error_log_by_info(
|
||||
gemini_key: str,
|
||||
timestamp: datetime,
|
||||
status_code: Optional[int] = None,
|
||||
window_seconds: int = 1,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
在给定时间窗口内,根据 gemini_key(精确匹配)及可选的 status_code 查找最接近 timestamp 的错误日志。
|
||||
|
||||
假设错误日志的 error_code 存储的是 HTTP 状态码或等价错误码。
|
||||
|
||||
Args:
|
||||
gemini_key: 完整的 Gemini key 字符串。
|
||||
timestamp: 目标时间(UTC 或本地,与存储一致)。
|
||||
status_code: 可选的错误码,若提供则优先匹配该错误码。
|
||||
window_seconds: 允许的时间偏差窗口,单位秒,默认为 1 秒。
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: 最匹配的一条错误日志的完整详情(字段与 get_error_log_details 一致),若未找到则返回 None。
|
||||
"""
|
||||
try:
|
||||
start_time = timestamp - timedelta(seconds=window_seconds)
|
||||
end_time = timestamp + timedelta(seconds=window_seconds)
|
||||
|
||||
base_query = select(ErrorLog).where(
|
||||
ErrorLog.gemini_key == gemini_key,
|
||||
ErrorLog.request_time >= start_time,
|
||||
ErrorLog.request_time <= end_time,
|
||||
)
|
||||
|
||||
# 若提供了状态码,先尝试按状态码过滤
|
||||
if status_code is not None:
|
||||
query = base_query.where(ErrorLog.error_code == status_code).order_by(
|
||||
ErrorLog.request_time.desc()
|
||||
)
|
||||
candidates = await database.fetch_all(query)
|
||||
if not candidates:
|
||||
# 回退:不按状态码,仅按时间窗口
|
||||
query2 = base_query.order_by(ErrorLog.request_time.desc())
|
||||
candidates = await database.fetch_all(query2)
|
||||
else:
|
||||
query = base_query.order_by(ErrorLog.request_time.desc())
|
||||
candidates = await database.fetch_all(query)
|
||||
|
||||
if not candidates:
|
||||
return None
|
||||
|
||||
# 在 Python 中选择与 timestamp 最接近的一条
|
||||
def _to_dict(row: Any) -> Dict[str, Any]:
|
||||
d = dict(row)
|
||||
if "request_msg" in d and d["request_msg"] is not None:
|
||||
try:
|
||||
d["request_msg"] = json.dumps(
|
||||
d["request_msg"], ensure_ascii=False, indent=2
|
||||
)
|
||||
except TypeError:
|
||||
d["request_msg"] = str(d["request_msg"])
|
||||
return d
|
||||
|
||||
best = min(
|
||||
candidates,
|
||||
key=lambda r: abs((r["request_time"] - timestamp).total_seconds()),
|
||||
)
|
||||
return _to_dict(best)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to find error log by info (key=***{gemini_key[-4:] if gemini_key else ''}, code={status_code}, ts={timestamp}, window={window_seconds}s): {str(e)}"
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
async def delete_error_logs_by_ids(log_ids: List[int]) -> int:
|
||||
"""
|
||||
根据提供的 ID 列表批量删除错误日志 (异步)。
|
||||
@@ -326,12 +408,15 @@ async def delete_error_logs_by_ids(log_ids: List[int]) -> int:
|
||||
# 注意:databases 的 execute 不返回 rowcount,所以我们不能直接返回删除的数量
|
||||
# 返回 log_ids 的长度作为尝试删除的数量,或者返回 0/1 表示操作尝试
|
||||
logger.info(f"Attempted bulk deletion for error logs with IDs: {log_ids}")
|
||||
return len(log_ids) # 返回尝试删除的数量
|
||||
return len(log_ids) # 返回尝试删除的数量
|
||||
except Exception as e:
|
||||
# 数据库连接或执行错误
|
||||
logger.error(f"Error during bulk deletion of error logs {log_ids}: {e}", exc_info=True)
|
||||
logger.error(
|
||||
f"Error during bulk deletion of error logs {log_ids}: {e}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
async def delete_error_log_by_id(log_id: int) -> bool:
|
||||
"""
|
||||
根据 ID 删除单个错误日志 (异步)。
|
||||
@@ -348,7 +433,9 @@ async def delete_error_log_by_id(log_id: int) -> bool:
|
||||
exists = await database.fetch_one(check_query)
|
||||
|
||||
if not exists:
|
||||
logger.warning(f"Attempted to delete non-existent error log with ID: {log_id}")
|
||||
logger.warning(
|
||||
f"Attempted to delete non-existent error log with ID: {log_id}"
|
||||
)
|
||||
return False
|
||||
|
||||
# 执行删除
|
||||
@@ -359,35 +446,57 @@ async def delete_error_log_by_id(log_id: int) -> bool:
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting error log with ID {log_id}: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
|
||||
|
||||
async def delete_all_error_logs() -> int:
|
||||
"""
|
||||
删除所有错误日志条目。
|
||||
|
||||
分批删除所有错误日志,以避免大数据量下的超时和性能问题。
|
||||
|
||||
Returns:
|
||||
int: 被删除的错误日志数量。
|
||||
int: 被删除的错误日志总数。
|
||||
"""
|
||||
total_deleted_count = 0
|
||||
# SQLite 对 SQL 参数数量有上限(常见为 999),IN 子句中过多参数会报错
|
||||
# 统一使用 500,兼容 SQLite/MySQL,必要时可在配置中暴露该值
|
||||
batch_size = 200
|
||||
|
||||
try:
|
||||
# 1. 获取删除前的总数
|
||||
count_query = select(func.count()).select_from(ErrorLog)
|
||||
total_to_delete = await database.fetch_val(count_query)
|
||||
|
||||
if total_to_delete == 0:
|
||||
logger.info("No error logs found to delete.")
|
||||
return 0
|
||||
|
||||
# 2. 执行删除操作
|
||||
delete_query = delete(ErrorLog)
|
||||
await database.execute(delete_query)
|
||||
|
||||
logger.info(f"Successfully deleted all {total_to_delete} error logs.")
|
||||
return total_to_delete
|
||||
while True:
|
||||
# 1) 读取一批待删除的ID,仅选择ID列以提升效率
|
||||
id_query = select(ErrorLog.id).order_by(ErrorLog.id).limit(batch_size)
|
||||
rows = await database.fetch_all(id_query)
|
||||
if not rows:
|
||||
break
|
||||
|
||||
ids = [row["id"] for row in rows]
|
||||
|
||||
# 2) 按ID批量删除
|
||||
delete_query = delete(ErrorLog).where(ErrorLog.id.in_(ids))
|
||||
await database.execute(delete_query)
|
||||
|
||||
deleted_in_batch = len(ids)
|
||||
total_deleted_count += deleted_in_batch
|
||||
|
||||
logger.debug(f"Deleted a batch of {deleted_in_batch} error logs.")
|
||||
|
||||
# 若不足一个批次,说明已删除完成
|
||||
if deleted_in_batch < batch_size:
|
||||
break
|
||||
|
||||
# 3) 将控制权交还事件循环,缓解长时间占用
|
||||
await asyncio.sleep(0)
|
||||
|
||||
logger.info(
|
||||
f"Successfully deleted all error logs in batches. Total deleted: {total_deleted_count}"
|
||||
)
|
||||
return total_deleted_count
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete all error logs: {str(e)}", exc_info=True)
|
||||
logger.error(
|
||||
f"Failed to delete all error logs in batches: {str(e)}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
|
||||
|
||||
# 新增函数:添加请求日志
|
||||
async def add_request_log(
|
||||
model_name: Optional[str],
|
||||
@@ -395,7 +504,7 @@ async def add_request_log(
|
||||
is_success: bool,
|
||||
status_code: Optional[int] = None,
|
||||
latency_ms: Optional[int] = None,
|
||||
request_time: Optional[datetime] = None
|
||||
request_time: Optional[datetime] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
添加 API 请求日志
|
||||
@@ -420,10 +529,277 @@ async def add_request_log(
|
||||
api_key=api_key,
|
||||
is_success=is_success,
|
||||
status_code=status_code,
|
||||
latency_ms=latency_ms
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
await database.execute(query)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add request log: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
# ==================== 文件记录相关函数 ====================
|
||||
|
||||
|
||||
async def create_file_record(
|
||||
name: str,
|
||||
mime_type: str,
|
||||
size_bytes: int,
|
||||
api_key: str,
|
||||
uri: str,
|
||||
create_time: datetime,
|
||||
update_time: datetime,
|
||||
expiration_time: datetime,
|
||||
state: FileState = FileState.PROCESSING,
|
||||
display_name: Optional[str] = None,
|
||||
sha256_hash: Optional[str] = None,
|
||||
upload_url: Optional[str] = None,
|
||||
user_token: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
创建文件记录
|
||||
|
||||
Args:
|
||||
name: 文件名称(格式: files/{file_id})
|
||||
mime_type: MIME 类型
|
||||
size_bytes: 文件大小(字节)
|
||||
api_key: 上传时使用的 API Key
|
||||
uri: 文件访问 URI
|
||||
create_time: 创建时间
|
||||
update_time: 更新时间
|
||||
expiration_time: 过期时间
|
||||
display_name: 显示名称
|
||||
sha256_hash: SHA256 哈希值
|
||||
upload_url: 临时上传 URL
|
||||
user_token: 上传用户的 token
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 创建的文件记录
|
||||
"""
|
||||
try:
|
||||
query = insert(FileRecord).values(
|
||||
name=name,
|
||||
display_name=display_name,
|
||||
mime_type=mime_type,
|
||||
size_bytes=size_bytes,
|
||||
sha256_hash=sha256_hash,
|
||||
state=state,
|
||||
create_time=create_time,
|
||||
update_time=update_time,
|
||||
expiration_time=expiration_time,
|
||||
uri=uri,
|
||||
api_key=api_key,
|
||||
upload_url=upload_url,
|
||||
user_token=user_token,
|
||||
)
|
||||
await database.execute(query)
|
||||
|
||||
# 返回创建的记录
|
||||
return await get_file_record_by_name(name)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create file record: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
async def get_file_record_by_name(name: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
根据文件名获取文件记录
|
||||
|
||||
Args:
|
||||
name: 文件名称(格式: files/{file_id})
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: 文件记录,如果不存在则返回 None
|
||||
"""
|
||||
try:
|
||||
query = select(FileRecord).where(FileRecord.name == name)
|
||||
result = await database.fetch_one(query)
|
||||
return dict(result) if result else None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get file record by name {name}: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
async def update_file_record_state(
|
||||
file_name: str,
|
||||
state: FileState,
|
||||
update_time: Optional[datetime] = None,
|
||||
upload_completed: Optional[datetime] = None,
|
||||
sha256_hash: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
更新文件记录状态
|
||||
|
||||
Args:
|
||||
file_name: 文件名
|
||||
state: 新状态
|
||||
update_time: 更新时间
|
||||
upload_completed: 上传完成时间
|
||||
sha256_hash: SHA256 哈希值
|
||||
|
||||
Returns:
|
||||
bool: 是否更新成功
|
||||
"""
|
||||
try:
|
||||
values = {"state": state}
|
||||
if update_time:
|
||||
values["update_time"] = update_time
|
||||
if upload_completed:
|
||||
values["upload_completed"] = upload_completed
|
||||
if sha256_hash:
|
||||
values["sha256_hash"] = sha256_hash
|
||||
|
||||
query = update(FileRecord).where(FileRecord.name == file_name).values(**values)
|
||||
result = await database.execute(query)
|
||||
|
||||
if result:
|
||||
logger.info(f"Updated file record state for {file_name} to {state}")
|
||||
return True
|
||||
|
||||
logger.warning(f"File record not found for update: {file_name}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update file record state: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
async def list_file_records(
|
||||
user_token: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
page_size: int = 10,
|
||||
page_token: Optional[str] = None,
|
||||
) -> tuple[List[Dict[str, Any]], Optional[str]]:
|
||||
"""
|
||||
列出文件记录
|
||||
|
||||
Args:
|
||||
user_token: 用户 token(如果提供,只返回该用户的文件)
|
||||
api_key: API Key(如果提供,只返回使用该 key 的文件)
|
||||
page_size: 每页大小
|
||||
page_token: 分页标记(偏移量)
|
||||
|
||||
Returns:
|
||||
tuple[List[Dict[str, Any]], Optional[str]]: (文件列表, 下一页标记)
|
||||
"""
|
||||
try:
|
||||
logger.debug(
|
||||
f"list_file_records called with page_size={page_size}, page_token={page_token}"
|
||||
)
|
||||
query = select(FileRecord).where(
|
||||
FileRecord.expiration_time > datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
if user_token:
|
||||
query = query.where(FileRecord.user_token == user_token)
|
||||
if api_key:
|
||||
query = query.where(FileRecord.api_key == api_key)
|
||||
|
||||
# 使用偏移量进行分页
|
||||
offset = 0
|
||||
if page_token:
|
||||
try:
|
||||
offset = int(page_token)
|
||||
except ValueError:
|
||||
logger.warning(f"Invalid page token: {page_token}")
|
||||
offset = 0
|
||||
|
||||
# 按ID升序排列,使用 OFFSET 和 LIMIT
|
||||
query = query.order_by(FileRecord.id).offset(offset).limit(page_size + 1)
|
||||
|
||||
results = await database.fetch_all(query)
|
||||
|
||||
logger.debug(f"Query returned {len(results)} records")
|
||||
if results:
|
||||
logger.debug(
|
||||
f"First record ID: {results[0]['id']}, Last record ID: {results[-1]['id']}"
|
||||
)
|
||||
|
||||
# 处理分页
|
||||
has_next = len(results) > page_size
|
||||
if has_next:
|
||||
results = results[:page_size]
|
||||
# 下一页的偏移量是当前偏移量加上本页返回的记录数
|
||||
next_offset = offset + page_size
|
||||
next_page_token = str(next_offset)
|
||||
logger.debug(
|
||||
f"Has next page, offset={offset}, page_size={page_size}, next_page_token={next_page_token}"
|
||||
)
|
||||
else:
|
||||
next_page_token = None
|
||||
logger.debug(f"No next page, returning {len(results)} results")
|
||||
|
||||
return [dict(row) for row in results], next_page_token
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list file records: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
async def delete_file_record(name: str) -> bool:
|
||||
"""
|
||||
删除文件记录
|
||||
|
||||
Args:
|
||||
name: 文件名称
|
||||
|
||||
Returns:
|
||||
bool: 是否删除成功
|
||||
"""
|
||||
try:
|
||||
query = delete(FileRecord).where(FileRecord.name == name)
|
||||
await database.execute(query)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete file record: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
async def delete_expired_file_records() -> List[Dict[str, Any]]:
|
||||
"""
|
||||
删除已过期的文件记录
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 删除的记录列表
|
||||
"""
|
||||
try:
|
||||
# 先获取要删除的记录
|
||||
query = select(FileRecord).where(
|
||||
FileRecord.expiration_time <= datetime.now(timezone.utc)
|
||||
)
|
||||
expired_records = await database.fetch_all(query)
|
||||
|
||||
if not expired_records:
|
||||
return []
|
||||
|
||||
# 执行删除
|
||||
delete_query = delete(FileRecord).where(
|
||||
FileRecord.expiration_time <= datetime.now(timezone.utc)
|
||||
)
|
||||
await database.execute(delete_query)
|
||||
|
||||
logger.info(f"Deleted {len(expired_records)} expired file records")
|
||||
return [dict(record) for record in expired_records]
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete expired file records: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
async def get_file_api_key(name: str) -> Optional[str]:
|
||||
"""
|
||||
获取文件对应的 API Key
|
||||
|
||||
Args:
|
||||
name: 文件名称
|
||||
|
||||
Returns:
|
||||
Optional[str]: API Key,如果文件不存在或已过期则返回 None
|
||||
"""
|
||||
try:
|
||||
query = select(FileRecord.api_key).where(
|
||||
(FileRecord.name == name)
|
||||
& (FileRecord.expiration_time > datetime.now(timezone.utc))
|
||||
)
|
||||
result = await database.fetch_one(query)
|
||||
return result["api_key"] if result else None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get file API key: {str(e)}")
|
||||
raise
|
||||
|
||||
69
app/domain/file_models.py
Normal file
69
app/domain/file_models.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""
|
||||
Files API 相关的领域模型
|
||||
"""
|
||||
from typing import Optional, Dict, Any, List
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class FileUploadConfig(BaseModel):
|
||||
"""文件上传配置"""
|
||||
mime_type: Optional[str] = Field(None, description="MIME 类型")
|
||||
display_name: Optional[str] = Field(None, description="显示名称,最多40个字符")
|
||||
|
||||
|
||||
class CreateFileRequest(BaseModel):
|
||||
"""创建文件请求(用于初始化上传)"""
|
||||
file: Optional[Dict[str, Any]] = Field(None, description="文件元数据")
|
||||
|
||||
|
||||
class FileMetadata(BaseModel):
|
||||
"""文件元数据响应"""
|
||||
name: str = Field(..., description="文件名称,格式: files/{file_id}")
|
||||
displayName: Optional[str] = Field(None, description="显示名称")
|
||||
mimeType: str = Field(..., description="MIME 类型")
|
||||
sizeBytes: str = Field(..., description="文件大小(字节)")
|
||||
createTime: str = Field(..., description="创建时间 (RFC3339)")
|
||||
updateTime: str = Field(..., description="更新时间 (RFC3339)")
|
||||
expirationTime: str = Field(..., description="过期时间 (RFC3339)")
|
||||
sha256Hash: Optional[str] = Field(None, description="SHA256 哈希值")
|
||||
uri: str = Field(..., description="文件访问 URI")
|
||||
state: str = Field(..., description="文件状态")
|
||||
|
||||
class Config:
|
||||
json_encoders = {
|
||||
datetime: lambda v: v.isoformat() + "Z"
|
||||
}
|
||||
|
||||
|
||||
class ListFilesRequest(BaseModel):
|
||||
"""列出文件请求参数"""
|
||||
pageSize: Optional[int] = Field(10, ge=1, le=100, description="每页大小")
|
||||
pageToken: Optional[str] = Field(None, description="分页标记")
|
||||
|
||||
|
||||
class ListFilesResponse(BaseModel):
|
||||
"""列出文件响应"""
|
||||
files: List[FileMetadata] = Field(default_factory=list, description="文件列表")
|
||||
nextPageToken: Optional[str] = Field(None, description="下一页标记")
|
||||
|
||||
|
||||
class UploadInitResponse(BaseModel):
|
||||
"""上传初始化响应(内部使用)"""
|
||||
file_metadata: FileMetadata
|
||||
upload_url: str
|
||||
|
||||
|
||||
class FileKeyMapping(BaseModel):
|
||||
"""文件与 API Key 的映射关系(内部使用)"""
|
||||
file_name: str
|
||||
api_key: str
|
||||
user_token: str
|
||||
created_at: datetime
|
||||
expires_at: datetime
|
||||
|
||||
|
||||
class DeleteFileResponse(BaseModel):
|
||||
"""删除文件响应"""
|
||||
success: bool = Field(..., description="是否删除成功")
|
||||
message: Optional[str] = Field(None, description="消息")
|
||||
@@ -41,6 +41,9 @@ class GenerationConfig(BaseModel):
|
||||
responseLogprobs: Optional[bool] = None
|
||||
logprobs: Optional[int] = None
|
||||
thinkingConfig: Optional[Dict[str, Any]] = None
|
||||
# TTS相关字段
|
||||
responseModalities: Optional[List[str]] = None
|
||||
speechConfig: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class SystemInstruction(BaseModel):
|
||||
@@ -77,3 +80,36 @@ class ResetSelectedKeysRequest(BaseModel):
|
||||
|
||||
class VerifySelectedKeysRequest(BaseModel):
|
||||
keys: List[str]
|
||||
|
||||
|
||||
class GeminiEmbedContent(BaseModel):
|
||||
"""嵌入内容模型"""
|
||||
|
||||
parts: List[Dict[str, str]]
|
||||
|
||||
|
||||
class GeminiEmbedRequest(BaseModel):
|
||||
"""单一嵌入请求模型"""
|
||||
|
||||
content: GeminiEmbedContent
|
||||
taskType: Optional[
|
||||
Literal[
|
||||
"TASK_TYPE_UNSPECIFIED",
|
||||
"RETRIEVAL_QUERY",
|
||||
"RETRIEVAL_DOCUMENT",
|
||||
"SEMANTIC_SIMILARITY",
|
||||
"CLASSIFICATION",
|
||||
"CLUSTERING",
|
||||
"QUESTION_ANSWERING",
|
||||
"FACT_VERIFICATION",
|
||||
"CODE_RETRIEVAL_QUERY",
|
||||
]
|
||||
] = None
|
||||
title: Optional[str] = None
|
||||
outputDimensionality: Optional[int] = None
|
||||
|
||||
|
||||
class GeminiBatchEmbedRequest(BaseModel):
|
||||
"""批量嵌入请求模型"""
|
||||
|
||||
requests: List[GeminiEmbedRequest]
|
||||
|
||||
@@ -12,6 +12,7 @@ class ChatRequest(BaseModel):
|
||||
max_tokens: Optional[int] = None
|
||||
top_p: Optional[float] = DEFAULT_TOP_P
|
||||
top_k: Optional[int] = DEFAULT_TOP_K
|
||||
n: Optional[int] = 1
|
||||
stop: Optional[Union[List[str],str]] = None
|
||||
reasoning_effort: Optional[str] = None
|
||||
tools: Optional[Union[List[Dict[str, Any]], Dict[str, Any]]] = []
|
||||
|
||||
@@ -131,10 +131,5 @@ def setup_exception_handlers(app: FastAPI) -> None:
|
||||
logger.exception(f"Unhandled Exception: {str(exc)}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"error": {
|
||||
"code": "internal_server_error",
|
||||
"message": "An unexpected error occurred",
|
||||
}
|
||||
},
|
||||
content=str(exc),
|
||||
)
|
||||
|
||||
@@ -27,7 +27,7 @@ class MessageConverter(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def convert(
|
||||
self, messages: List[Dict[str, Any]]
|
||||
self, messages: List[Dict[str, Any]], model: str
|
||||
) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
|
||||
pass
|
||||
|
||||
@@ -84,7 +84,7 @@ def _convert_image_to_base64(url: str) -> str:
|
||||
raise Exception(f"Failed to fetch image: {response.status_code}")
|
||||
|
||||
|
||||
def _process_text_with_image(text: str) -> List[Dict[str, Any]]:
|
||||
def _process_text_with_image(text: str, model: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
处理可能包含图片URL的文本,提取图片并转换为base64
|
||||
|
||||
@@ -94,17 +94,31 @@ def _process_text_with_image(text: str) -> List[Dict[str, Any]]:
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 包含文本和图片的部分列表
|
||||
"""
|
||||
# 如果模型名中没有包含image,当作普通文本处理
|
||||
if "image" not in model:
|
||||
return [{"text": text}]
|
||||
parts = []
|
||||
img_url_match = re.search(IMAGE_URL_PATTERN, text)
|
||||
if img_url_match:
|
||||
# 提取URL
|
||||
img_url = img_url_match.group(2)
|
||||
# 将URL对应的图片转换为base64
|
||||
# 先判断是否是base64url如果是,直接用,不过不是,再将URL对应的图片转换为base64
|
||||
try:
|
||||
base64_data = _convert_image_to_base64(img_url)
|
||||
parts.append(
|
||||
{"inline_data": {"mimeType": "image/png", "data": base64_data}}
|
||||
)
|
||||
base64_url_match = re.search(DATA_URL_PATTERN, img_url)
|
||||
if base64_url_match:
|
||||
parts.append(
|
||||
{
|
||||
"inline_data": {
|
||||
"mimeType": base64_url_match.group(1),
|
||||
"data": base64_url_match.group(2),
|
||||
}
|
||||
}
|
||||
)
|
||||
else:
|
||||
base64_data = _convert_image_to_base64(img_url)
|
||||
parts.append(
|
||||
{"inline_data": {"mimeType": "image/png", "data": base64_data}}
|
||||
)
|
||||
except Exception:
|
||||
# 如果转换失败,回退到文本模式
|
||||
parts.append({"text": text})
|
||||
@@ -145,7 +159,7 @@ class OpenAIMessageConverter(MessageConverter):
|
||||
raise
|
||||
|
||||
def convert(
|
||||
self, messages: List[Dict[str, Any]]
|
||||
self, messages: List[Dict[str, Any]], model: str
|
||||
) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
|
||||
converted_messages = []
|
||||
system_instruction_parts = []
|
||||
@@ -296,7 +310,7 @@ class OpenAIMessageConverter(MessageConverter):
|
||||
elif (
|
||||
"content" in msg and isinstance(msg["content"], str) and msg["content"]
|
||||
):
|
||||
parts.extend(_process_text_with_image(msg["content"]))
|
||||
parts.extend(_process_text_with_image(msg["content"], model))
|
||||
elif "tool_calls" in msg and isinstance(msg["tool_calls"], list):
|
||||
# Keep existing tool call processing
|
||||
for tool_call in msg["tool_calls"]:
|
||||
|
||||
@@ -8,8 +8,12 @@ from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.config.config import settings
|
||||
from app.log.logger import get_openai_logger
|
||||
from app.utils.helpers import is_image_upload_configured
|
||||
from app.utils.uploader import ImageUploaderFactory
|
||||
|
||||
logger = get_openai_logger()
|
||||
|
||||
|
||||
class ResponseHandler(ABC):
|
||||
"""响应处理器基类"""
|
||||
@@ -29,7 +33,11 @@ class GeminiResponseHandler(ResponseHandler):
|
||||
self.thinking_status = False
|
||||
|
||||
def handle_response(
|
||||
self, response: Dict[str, Any], model: str, stream: bool = False, usage_metadata: Optional[Dict[str, Any]] = None
|
||||
self,
|
||||
response: Dict[str, Any],
|
||||
model: str,
|
||||
stream: bool = False,
|
||||
usage_metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
if stream:
|
||||
return _handle_gemini_stream_response(response, model, stream)
|
||||
@@ -37,52 +45,86 @@ class GeminiResponseHandler(ResponseHandler):
|
||||
|
||||
|
||||
def _handle_openai_stream_response(
|
||||
response: Dict[str, Any], model: str, finish_reason: str, usage_metadata: Optional[Dict[str, Any]]
|
||||
response: Dict[str, Any],
|
||||
model: str,
|
||||
finish_reason: str,
|
||||
usage_metadata: Optional[Dict[str, Any]],
|
||||
) -> Dict[str, Any]:
|
||||
text, tool_calls, _ = _extract_result(
|
||||
response, model, stream=True, gemini_format=False
|
||||
)
|
||||
if not text and not tool_calls:
|
||||
delta = {}
|
||||
else:
|
||||
delta = {"content": text, "role": "assistant"}
|
||||
if tool_calls:
|
||||
delta["tool_calls"] = tool_calls
|
||||
choices = []
|
||||
candidates = response.get("candidates", [])
|
||||
|
||||
for candidate in candidates:
|
||||
index = candidate.get("index", 0)
|
||||
text, reasoning_content, tool_calls, _ = _extract_result(
|
||||
{"candidates": [candidate]}, model, stream=True, gemini_format=False
|
||||
)
|
||||
|
||||
if not text and not tool_calls and not reasoning_content:
|
||||
delta = {}
|
||||
else:
|
||||
delta = {
|
||||
"content": text,
|
||||
"reasoning_content": reasoning_content,
|
||||
"role": "assistant",
|
||||
}
|
||||
if tool_calls:
|
||||
delta["tool_calls"] = tool_calls
|
||||
|
||||
choice = {"index": index, "delta": delta, "finish_reason": finish_reason}
|
||||
choices.append(choice)
|
||||
|
||||
template_chunk = {
|
||||
"id": f"chatcmpl-{uuid.uuid4()}",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [{"index": 0, "delta": delta, "finish_reason": finish_reason}],
|
||||
"choices": choices,
|
||||
}
|
||||
if usage_metadata:
|
||||
template_chunk["usage"] = {"prompt_tokens": usage_metadata.get("promptTokenCount", 0), "completion_tokens": usage_metadata.get("candidatesTokenCount",0), "total_tokens": usage_metadata.get("totalTokenCount", 0)}
|
||||
template_chunk["usage"] = {
|
||||
"prompt_tokens": usage_metadata.get("promptTokenCount", 0),
|
||||
"completion_tokens": usage_metadata.get("candidatesTokenCount", 0),
|
||||
"total_tokens": usage_metadata.get("totalTokenCount", 0),
|
||||
}
|
||||
return template_chunk
|
||||
|
||||
|
||||
def _handle_openai_normal_response(
|
||||
response: Dict[str, Any], model: str, finish_reason: str, usage_metadata: Optional[Dict[str, Any]]
|
||||
response: Dict[str, Any],
|
||||
model: str,
|
||||
finish_reason: str,
|
||||
usage_metadata: Optional[Dict[str, Any]],
|
||||
) -> Dict[str, Any]:
|
||||
text, tool_calls, _ = _extract_result(
|
||||
response, model, stream=False, gemini_format=False
|
||||
)
|
||||
choices = []
|
||||
candidates = response.get("candidates", [])
|
||||
|
||||
for i, candidate in enumerate(candidates):
|
||||
text, reasoning_content, tool_calls, _ = _extract_result(
|
||||
{"candidates": [candidate]}, model, stream=False, gemini_format=False
|
||||
)
|
||||
choice = {
|
||||
"index": i,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": text,
|
||||
"reasoning_content": reasoning_content,
|
||||
"tool_calls": tool_calls,
|
||||
},
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
choices.append(choice)
|
||||
|
||||
return {
|
||||
"id": f"chatcmpl-{uuid.uuid4()}",
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": text,
|
||||
"tool_calls": tool_calls,
|
||||
},
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": usage_metadata.get("promptTokenCount", 0), "completion_tokens": usage_metadata.get("candidatesTokenCount",0), "total_tokens": usage_metadata.get("totalTokenCount", 0)},
|
||||
"choices": choices,
|
||||
"usage": {
|
||||
"prompt_tokens": usage_metadata.get("promptTokenCount", 0),
|
||||
"completion_tokens": usage_metadata.get("candidatesTokenCount", 0),
|
||||
"total_tokens": usage_metadata.get("totalTokenCount", 0),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -103,8 +145,12 @@ class OpenAIResponseHandler(ResponseHandler):
|
||||
usage_metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
if stream:
|
||||
return _handle_openai_stream_response(response, model, finish_reason, usage_metadata)
|
||||
return _handle_openai_normal_response(response, model, finish_reason, usage_metadata)
|
||||
return _handle_openai_stream_response(
|
||||
response, model, finish_reason, usage_metadata
|
||||
)
|
||||
return _handle_openai_normal_response(
|
||||
response, model, finish_reason, usage_metadata
|
||||
)
|
||||
|
||||
def handle_image_chat_response(
|
||||
self, image_str: str, model: str, stream=False, finish_reason="stop"
|
||||
@@ -156,19 +202,24 @@ def _extract_result(
|
||||
model: str,
|
||||
stream: bool = False,
|
||||
gemini_format: bool = False,
|
||||
) -> tuple[str, List[Dict[str, Any]], Optional[bool]]:
|
||||
text, tool_calls = "", []
|
||||
thought = None
|
||||
) -> tuple[str, Optional[str], List[Dict[str, Any]], Optional[bool]]:
|
||||
text, reasoning_content, tool_calls, thought = "", "", [], None
|
||||
|
||||
if stream:
|
||||
if response.get("candidates"):
|
||||
candidate = response["candidates"][0]
|
||||
content = candidate.get("content", {})
|
||||
parts = content.get("parts", [])
|
||||
if not parts:
|
||||
return "", [], None
|
||||
logger.warning("No parts found in stream response")
|
||||
return "", None, [], None
|
||||
|
||||
if "text" in parts[0]:
|
||||
text = parts[0].get("text")
|
||||
if "thought" in parts[0]:
|
||||
if not gemini_format and settings.SHOW_THINKING_PROCESS:
|
||||
reasoning_content = text
|
||||
text = ""
|
||||
thought = parts[0].get("thought")
|
||||
elif "executableCode" in parts[0]:
|
||||
text = _format_code_block(parts[0]["executableCode"])
|
||||
@@ -187,40 +238,51 @@ def _extract_result(
|
||||
else:
|
||||
if response.get("candidates"):
|
||||
candidate = response["candidates"][0]
|
||||
if "thinking" in model:
|
||||
if settings.SHOW_THINKING_PROCESS:
|
||||
if len(candidate["content"]["parts"]) == 2:
|
||||
text = (
|
||||
"> thinking\n\n"
|
||||
+ candidate["content"]["parts"][0]["text"]
|
||||
+ "\n\n---\n> output\n\n"
|
||||
+ candidate["content"]["parts"][1]["text"]
|
||||
)
|
||||
else:
|
||||
text = candidate["content"]["parts"][0]["text"]
|
||||
else:
|
||||
if len(candidate["content"]["parts"]) == 2:
|
||||
text = candidate["content"]["parts"][1]["text"]
|
||||
else:
|
||||
text = candidate["content"]["parts"][0]["text"]
|
||||
else:
|
||||
text = ""
|
||||
if "parts" in candidate["content"]:
|
||||
for part in candidate["content"]["parts"]:
|
||||
text, reasoning_content = "", ""
|
||||
|
||||
# 使用安全的访问方式
|
||||
content = candidate.get("content", {})
|
||||
|
||||
if content and isinstance(content, dict):
|
||||
parts = content.get("parts", [])
|
||||
|
||||
if parts:
|
||||
for part in parts:
|
||||
if "text" in part:
|
||||
text += part["text"]
|
||||
if "thought" in part and settings.SHOW_THINKING_PROCESS:
|
||||
reasoning_content += part["text"]
|
||||
else:
|
||||
text += part["text"]
|
||||
if "thought" in part and thought is None:
|
||||
thought = part.get("thought")
|
||||
elif "inlineData" in part:
|
||||
text += _extract_image_data(part)
|
||||
else:
|
||||
logger.warning(f"No parts found in content for model: {model}")
|
||||
else:
|
||||
logger.error(f"Invalid content structure for model: {model}")
|
||||
|
||||
text = _add_search_link_text(model, candidate, text)
|
||||
tool_calls = _extract_tool_calls(
|
||||
candidate["content"]["parts"], gemini_format
|
||||
)
|
||||
|
||||
# 安全地获取 parts 用于工具调用提取
|
||||
parts = candidate.get("content", {}).get("parts", [])
|
||||
tool_calls = _extract_tool_calls(parts, gemini_format)
|
||||
else:
|
||||
logger.warning(f"No candidates found in response for model: {model}")
|
||||
text = "暂无返回"
|
||||
return text, tool_calls, thought
|
||||
|
||||
return text, reasoning_content, tool_calls, thought
|
||||
|
||||
|
||||
def _has_inline_image_part(response: Dict[str, Any]) -> bool:
|
||||
try:
|
||||
for c in response.get("candidates", []):
|
||||
for p in c.get("content", {}).get("parts", []):
|
||||
if isinstance(p, dict) and ("inlineData" in p):
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
return False
|
||||
|
||||
|
||||
def _extract_image_data(part: dict) -> str:
|
||||
@@ -231,24 +293,41 @@ def _extract_image_data(part: dict) -> str:
|
||||
)
|
||||
elif settings.UPLOAD_PROVIDER == "picgo":
|
||||
image_uploader = ImageUploaderFactory.create(
|
||||
provider=settings.UPLOAD_PROVIDER, api_key=settings.PICGO_API_KEY
|
||||
provider=settings.UPLOAD_PROVIDER,
|
||||
api_key=settings.PICGO_API_KEY,
|
||||
api_url=settings.PICGO_API_URL
|
||||
)
|
||||
elif settings.UPLOAD_PROVIDER == "cloudflare_imgbed":
|
||||
image_uploader = ImageUploaderFactory.create(
|
||||
provider=settings.UPLOAD_PROVIDER,
|
||||
base_url=settings.CLOUDFLARE_IMGBED_URL,
|
||||
auth_code=settings.CLOUDFLARE_IMGBED_AUTH_CODE,
|
||||
upload_folder=settings.CLOUDFLARE_IMGBED_UPLOAD_FOLDER,
|
||||
)
|
||||
elif settings.UPLOAD_PROVIDER == "aliyun_oss":
|
||||
image_uploader = ImageUploaderFactory.create(
|
||||
provider=settings.UPLOAD_PROVIDER,
|
||||
access_key=settings.OSS_ACCESS_KEY,
|
||||
access_key_secret=settings.OSS_ACCESS_KEY_SECRET,
|
||||
bucket_name=settings.OSS_BUCKET_NAME,
|
||||
endpoint=settings.OSS_ENDPOINT,
|
||||
region=settings.OSS_REGION,
|
||||
use_internal=False
|
||||
)
|
||||
current_date = time.strftime("%Y/%m/%d")
|
||||
filename = f"{current_date}/{uuid.uuid4().hex[:8]}.png"
|
||||
base64_data = part["inlineData"]["data"]
|
||||
mime_type = part["inlineData"]["mimeType"]
|
||||
# 将base64_data转成bytes数组
|
||||
# Return empty string if no uploader is configured
|
||||
if not is_image_upload_configured(settings):
|
||||
return f"\n\n\n\n"
|
||||
bytes_data = base64.b64decode(base64_data)
|
||||
upload_response = image_uploader.upload(bytes_data, filename)
|
||||
if upload_response.success:
|
||||
text = f"\n\n\n\n"
|
||||
else:
|
||||
text = ""
|
||||
text = f"\n\n\n\n"
|
||||
return text
|
||||
|
||||
|
||||
@@ -260,8 +339,8 @@ def _extract_tool_calls(
|
||||
return []
|
||||
|
||||
letters = string.ascii_lowercase + string.digits
|
||||
|
||||
tool_calls = list()
|
||||
|
||||
for i in range(len(parts)):
|
||||
part = parts[i]
|
||||
if not part or not isinstance(part, dict):
|
||||
@@ -293,7 +372,11 @@ def _extract_tool_calls(
|
||||
def _handle_gemini_stream_response(
|
||||
response: Dict[str, Any], model: str, stream: bool
|
||||
) -> Dict[str, Any]:
|
||||
text, tool_calls, thought = _extract_result(
|
||||
# Early return raw Gemini response if no uploader configured and contains inline images
|
||||
if not is_image_upload_configured(settings) and _has_inline_image_part(response):
|
||||
return response
|
||||
|
||||
text, reasoning_content, tool_calls, thought = _extract_result(
|
||||
response, model, stream=stream, gemini_format=True
|
||||
)
|
||||
if tool_calls:
|
||||
@@ -310,16 +393,22 @@ def _handle_gemini_stream_response(
|
||||
def _handle_gemini_normal_response(
|
||||
response: Dict[str, Any], model: str, stream: bool
|
||||
) -> Dict[str, Any]:
|
||||
text, tool_calls, thought = _extract_result(
|
||||
# Early return raw Gemini response if no uploader configured and contains inline images
|
||||
if not is_image_upload_configured(settings) and _has_inline_image_part(response):
|
||||
return response
|
||||
|
||||
text, reasoning_content, tool_calls, thought = _extract_result(
|
||||
response, model, stream=stream, gemini_format=True
|
||||
)
|
||||
parts = []
|
||||
if tool_calls:
|
||||
content = {"parts": tool_calls, "role": "model"}
|
||||
parts = tool_calls
|
||||
else:
|
||||
part = {"text": text}
|
||||
if thought is not None:
|
||||
part["thought"] = thought
|
||||
content = {"parts": [part], "role": "model"}
|
||||
parts.append({"text": reasoning_content, "thought": thought})
|
||||
part = {"text": text}
|
||||
parts.append(part)
|
||||
content = {"parts": parts, "role": "model"}
|
||||
response["candidates"][0]["content"] = content
|
||||
return response
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Callable, TypeVar
|
||||
|
||||
from app.config.config import settings
|
||||
from app.log.logger import get_retry_logger
|
||||
from app.utils.helpers import redact_key_for_logging
|
||||
|
||||
T = TypeVar("T")
|
||||
logger = get_retry_logger()
|
||||
@@ -37,7 +38,7 @@ class RetryHandler:
|
||||
new_key = await key_manager.handle_api_failure(old_key, retries)
|
||||
if new_key:
|
||||
kwargs[self.key_arg] = new_key
|
||||
logger.info(f"Switched to new API key: {new_key}")
|
||||
logger.info(f"Switched to new API key: {redact_key_for_logging(new_key)}")
|
||||
else:
|
||||
logger.error(f"No valid API key available after {retries} retries.")
|
||||
break
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import logging
|
||||
import platform
|
||||
import re
|
||||
import sys
|
||||
from typing import Dict, Optional
|
||||
|
||||
@@ -12,6 +13,7 @@ COLORS = {
|
||||
"CRITICAL": "\033[1;31m", # 红色加粗
|
||||
}
|
||||
|
||||
|
||||
# Windows系统启用ANSI支持
|
||||
if platform.system() == "Windows":
|
||||
import ctypes
|
||||
@@ -35,6 +37,73 @@ class ColoredFormatter(logging.Formatter):
|
||||
return super().format(record)
|
||||
|
||||
|
||||
class AccessLogFormatter(logging.Formatter):
|
||||
"""
|
||||
Custom access log formatter that redacts API keys in URLs
|
||||
"""
|
||||
|
||||
# API key patterns to match in URLs
|
||||
API_KEY_PATTERNS = [
|
||||
r"\bAIza[0-9A-Za-z_-]{35}", # Google API keys (like Gemini)
|
||||
r"\bsk-[0-9A-Za-z_-]{20,}", # OpenAI and general sk- prefixed keys
|
||||
]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# Compile regex patterns for better performance
|
||||
self.compiled_patterns = [
|
||||
re.compile(pattern) for pattern in self.API_KEY_PATTERNS
|
||||
]
|
||||
|
||||
def format(self, record):
|
||||
# Format the record normally first
|
||||
formatted_msg = super().format(record)
|
||||
|
||||
# Redact API keys in the formatted message
|
||||
return self._redact_api_keys_in_message(formatted_msg)
|
||||
|
||||
def _redact_api_keys_in_message(self, message: str) -> str:
|
||||
"""
|
||||
Replace API keys in log message with redacted versions
|
||||
"""
|
||||
try:
|
||||
for pattern in self.compiled_patterns:
|
||||
|
||||
def replace_key(match):
|
||||
key = match.group(0)
|
||||
return redact_key_for_logging(key)
|
||||
|
||||
message = pattern.sub(replace_key, message)
|
||||
|
||||
return message
|
||||
except Exception as e:
|
||||
# Log the error but don't expose the original message in case it contains keys
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.error(f"Error redacting API keys in access log: {e}")
|
||||
return "[LOG_REDACTION_ERROR]"
|
||||
|
||||
|
||||
def redact_key_for_logging(key: str) -> str:
|
||||
"""
|
||||
Redacts API key for secure logging by showing only first and last 6 characters.
|
||||
|
||||
Args:
|
||||
key: API key to redact
|
||||
|
||||
Returns:
|
||||
str: Redacted key in format "first6...last6" or descriptive placeholder for edge cases
|
||||
"""
|
||||
if not key:
|
||||
return key
|
||||
|
||||
if len(key) <= 12:
|
||||
return f"{key[:3]}...{key[-3:]}"
|
||||
else:
|
||||
return f"{key[:6]}...{key[-6:]}"
|
||||
|
||||
|
||||
# 日志格式 - 使用 fileloc 并设置固定宽度 (例如 30)
|
||||
FORMATTER = ColoredFormatter(
|
||||
"%(asctime)s | %(levelname)-17s | %(fileloc)-30s | %(message)s"
|
||||
@@ -228,6 +297,53 @@ def get_request_log_logger():
|
||||
return Logger.setup_logger("request_log")
|
||||
|
||||
|
||||
def get_files_logger():
|
||||
return Logger.setup_logger("files")
|
||||
|
||||
|
||||
def get_vertex_express_logger():
|
||||
return Logger.setup_logger("vertex_express")
|
||||
|
||||
|
||||
def get_gemini_embedding_logger():
|
||||
return Logger.setup_logger("gemini_embedding")
|
||||
|
||||
|
||||
def setup_access_logging():
|
||||
"""
|
||||
Configure uvicorn access logging with API key redaction
|
||||
|
||||
This function sets up a custom access log formatter that automatically
|
||||
redacts API keys in HTTP access logs. It works by:
|
||||
|
||||
1. Intercepting uvicorn's access log messages
|
||||
2. Using regex patterns to find API keys in URLs
|
||||
3. Replacing them with redacted versions (first6...last6)
|
||||
|
||||
Supported API key formats:
|
||||
- Google/Gemini API keys: AIza[35 chars]
|
||||
- OpenAI API keys: sk-[48 chars]
|
||||
- General sk- prefixed keys: sk-[20+ chars]
|
||||
|
||||
Usage:
|
||||
- Automatically called in main.py when running with uvicorn
|
||||
- For production deployment with gunicorn, ensure this is called in startup
|
||||
"""
|
||||
# Get the uvicorn access logger
|
||||
access_logger = logging.getLogger("uvicorn.access")
|
||||
|
||||
# Remove existing handlers to avoid duplicate logs
|
||||
for handler in access_logger.handlers[:]:
|
||||
access_logger.removeHandler(handler)
|
||||
|
||||
# Create new handler with our custom formatter that includes timestamp and log level
|
||||
handler = logging.StreamHandler(sys.stdout)
|
||||
access_formatter = AccessLogFormatter("%(asctime)s | %(levelname)-8s | %(message)s")
|
||||
handler.setFormatter(access_formatter)
|
||||
|
||||
# Add the handler to uvicorn access logger
|
||||
access_logger.addHandler(handler)
|
||||
access_logger.setLevel(logging.INFO)
|
||||
access_logger.propagate = False
|
||||
|
||||
return access_logger
|
||||
|
||||
@@ -34,6 +34,7 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
and not request.url.path.startswith("/openai")
|
||||
and not request.url.path.startswith("/api/version/check")
|
||||
and not request.url.path.startswith("/vertex-express")
|
||||
and not request.url.path.startswith("/upload")
|
||||
):
|
||||
|
||||
auth_token = request.cookies.get("auth_token")
|
||||
|
||||
@@ -11,6 +11,8 @@ from pydantic import BaseModel, Field
|
||||
from app.core.security import verify_auth_token
|
||||
from app.log.logger import Logger, get_config_routes_logger
|
||||
from app.service.config.config_service import ConfigService
|
||||
from app.service.proxy.proxy_check_service import get_proxy_check_service, ProxyCheckResult
|
||||
from app.utils.helpers import redact_key_for_logging
|
||||
|
||||
router = APIRouter(prefix="/api/config", tags=["config"])
|
||||
|
||||
@@ -63,10 +65,10 @@ class DeleteKeysRequest(BaseModel):
|
||||
async def delete_single_key(key_to_delete: str, request: Request):
|
||||
auth_token = request.cookies.get("auth_token")
|
||||
if not auth_token or not verify_auth_token(auth_token):
|
||||
logger.warning(f"Unauthorized attempt to delete key: {key_to_delete}")
|
||||
logger.warning(f"Unauthorized attempt to delete key: {redact_key_for_logging(key_to_delete)}")
|
||||
return RedirectResponse(url="/", status_code=302)
|
||||
try:
|
||||
logger.info(f"Attempting to delete key: {key_to_delete}")
|
||||
logger.info(f"Attempting to delete key: {redact_key_for_logging(key_to_delete)}")
|
||||
result = await ConfigService.delete_key(key_to_delete)
|
||||
if not result.get("success"):
|
||||
raise HTTPException(
|
||||
@@ -79,7 +81,7 @@ async def delete_single_key(key_to_delete: str, request: Request):
|
||||
except HTTPException as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting key '{key_to_delete}': {e}", exc_info=True)
|
||||
logger.error(f"Error deleting key '{redact_key_for_logging(key_to_delete)}': {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Error deleting key: {str(e)}")
|
||||
|
||||
|
||||
@@ -131,3 +133,93 @@ async def get_ui_models(request: Request):
|
||||
status_code=500,
|
||||
detail=f"An unexpected error occurred while fetching UI models: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
class ProxyCheckRequest(BaseModel):
|
||||
"""Proxy check request"""
|
||||
proxy: str = Field(..., description="Proxy address to check")
|
||||
use_cache: bool = Field(True, description="Whether to use cached results")
|
||||
|
||||
|
||||
class ProxyBatchCheckRequest(BaseModel):
|
||||
"""Batch proxy check request"""
|
||||
proxies: List[str] = Field(..., description="List of proxy addresses to check")
|
||||
use_cache: bool = Field(True, description="Whether to use cached results")
|
||||
max_concurrent: int = Field(5, description="Maximum concurrent check count", ge=1, le=10)
|
||||
|
||||
|
||||
@router.post("/proxy/check", response_model=ProxyCheckResult)
|
||||
async def check_single_proxy(proxy_request: ProxyCheckRequest, request: Request):
|
||||
"""Check if a single proxy is available"""
|
||||
auth_token = request.cookies.get("auth_token")
|
||||
if not auth_token or not verify_auth_token(auth_token):
|
||||
logger.warning("Unauthorized access attempt to proxy check")
|
||||
return RedirectResponse(url="/", status_code=302)
|
||||
|
||||
try:
|
||||
logger.info(f"Checking single proxy: {proxy_request.proxy}")
|
||||
proxy_service = get_proxy_check_service()
|
||||
result = await proxy_service.check_single_proxy(
|
||||
proxy_request.proxy,
|
||||
proxy_request.use_cache
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Proxy check failed: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Proxy check failed: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/proxy/check-all", response_model=List[ProxyCheckResult])
|
||||
async def check_all_proxies(batch_request: ProxyBatchCheckRequest, request: Request):
|
||||
"""Check multiple proxies availability"""
|
||||
auth_token = request.cookies.get("auth_token")
|
||||
if not auth_token or not verify_auth_token(auth_token):
|
||||
logger.warning("Unauthorized access attempt to batch proxy check")
|
||||
return RedirectResponse(url="/", status_code=302)
|
||||
|
||||
try:
|
||||
logger.info(f"Batch checking {len(batch_request.proxies)} proxies")
|
||||
proxy_service = get_proxy_check_service()
|
||||
results = await proxy_service.check_multiple_proxies(
|
||||
batch_request.proxies,
|
||||
batch_request.use_cache,
|
||||
batch_request.max_concurrent
|
||||
)
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.error(f"Batch proxy check failed: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Batch proxy check failed: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/proxy/cache-stats")
|
||||
async def get_proxy_cache_stats(request: Request):
|
||||
"""Get proxy check cache statistics"""
|
||||
auth_token = request.cookies.get("auth_token")
|
||||
if not auth_token or not verify_auth_token(auth_token):
|
||||
logger.warning("Unauthorized access attempt to proxy cache stats")
|
||||
return RedirectResponse(url="/", status_code=302)
|
||||
|
||||
try:
|
||||
proxy_service = get_proxy_check_service()
|
||||
stats = proxy_service.get_cache_stats()
|
||||
return stats
|
||||
except Exception as e:
|
||||
logger.error(f"Get proxy cache stats failed: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Get cache stats failed: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/proxy/clear-cache")
|
||||
async def clear_proxy_cache(request: Request):
|
||||
"""Clear proxy check cache"""
|
||||
auth_token = request.cookies.get("auth_token")
|
||||
if not auth_token or not verify_auth_token(auth_token):
|
||||
logger.warning("Unauthorized access attempt to clear proxy cache")
|
||||
return RedirectResponse(url="/", status_code=302)
|
||||
|
||||
try:
|
||||
proxy_service = get_proxy_check_service()
|
||||
proxy_service.clear_cache()
|
||||
return {"success": True, "message": "Proxy check cache cleared"}
|
||||
except Exception as e:
|
||||
logger.error(f"Clear proxy cache failed: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Clear cache failed: {str(e)}")
|
||||
|
||||
@@ -120,6 +120,7 @@ class ErrorLogDetailResponse(BaseModel):
|
||||
request_msg: Optional[str] = None
|
||||
model_name: Optional[str] = None
|
||||
request_time: Optional[datetime] = None
|
||||
error_code: Optional[int] = None
|
||||
|
||||
|
||||
@router.get("/errors/{log_id}/details", response_model=ErrorLogDetailResponse)
|
||||
@@ -151,6 +152,43 @@ async def get_error_log_detail_api(request: Request, log_id: int = Path(..., ge=
|
||||
)
|
||||
|
||||
|
||||
@router.get("/errors/lookup", response_model=ErrorLogDetailResponse)
|
||||
async def lookup_error_log_by_info(
|
||||
request: Request,
|
||||
gemini_key: str = Query(..., description="完整的 Gemini key"),
|
||||
timestamp: datetime = Query(..., description="请求时间 (ISO8601)"),
|
||||
status_code: Optional[int] = Query(None, description="错误码 (可选)"),
|
||||
window_seconds: int = Query(
|
||||
100, ge=1, le=300, description="时间窗口(秒), 默认100秒"
|
||||
),
|
||||
):
|
||||
"""
|
||||
通过 key / 错误码 / 时间窗口 查找最匹配的一条错误日志详情。
|
||||
"""
|
||||
auth_token = request.cookies.get("auth_token")
|
||||
if not auth_token or not verify_auth_token(auth_token):
|
||||
logger.warning("Unauthorized access attempt to lookup error log by info")
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
|
||||
try:
|
||||
detail = await error_log_service.process_find_error_log_by_info(
|
||||
gemini_key=gemini_key,
|
||||
timestamp=timestamp,
|
||||
status_code=status_code,
|
||||
window_seconds=window_seconds,
|
||||
)
|
||||
if not detail:
|
||||
raise HTTPException(status_code=404, detail="No matching error log found")
|
||||
return ErrorLogDetailResponse(**detail)
|
||||
except HTTPException as http_exc:
|
||||
raise http_exc
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to lookup error log by info for key=***{gemini_key[-4:] if gemini_key else ''}: {str(e)}"
|
||||
)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.delete("/errors", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_error_logs_bulk_api(
|
||||
request: Request, payload: Dict[str, List[int]] = Body(...)
|
||||
@@ -192,10 +230,10 @@ async def delete_all_error_logs_api(request: Request):
|
||||
if not auth_token or not verify_auth_token(auth_token):
|
||||
logger.warning("Unauthorized access attempt to delete all error logs")
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
|
||||
|
||||
try:
|
||||
deleted_count = await error_log_service.process_delete_all_error_logs()
|
||||
logger.info(f"Successfully deleted all {deleted_count} error logs.")
|
||||
await error_log_service.process_delete_all_error_logs()
|
||||
logger.info("Successfully deleted all error logs.")
|
||||
# No body needed for 204 response
|
||||
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
||||
except Exception as e:
|
||||
@@ -203,8 +241,8 @@ async def delete_all_error_logs_api(request: Request):
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Internal server error during deletion of all logs"
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
@router.delete("/errors/{log_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_error_log_api(request: Request, log_id: int = Path(..., ge=1)):
|
||||
"""
|
||||
@@ -214,7 +252,7 @@ async def delete_error_log_api(request: Request, log_id: int = Path(..., ge=1)):
|
||||
if not auth_token or not verify_auth_token(auth_token):
|
||||
logger.warning(f"Unauthorized access attempt to delete error log ID: {log_id}")
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
|
||||
|
||||
try:
|
||||
success = await error_log_service.process_delete_error_log_by_id(log_id)
|
||||
if not success:
|
||||
|
||||
296
app/router/files_routes.py
Normal file
296
app/router/files_routes.py
Normal file
@@ -0,0 +1,296 @@
|
||||
"""
|
||||
Files API 路由
|
||||
"""
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, Request, Query, Depends, Header, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from app.config.config import settings
|
||||
from app.domain.file_models import (
|
||||
FileMetadata,
|
||||
ListFilesResponse,
|
||||
DeleteFileResponse
|
||||
)
|
||||
from app.log.logger import get_files_logger
|
||||
from app.core.security import SecurityService
|
||||
from app.service.files.files_service import get_files_service
|
||||
from app.service.files.file_upload_handler import get_upload_handler
|
||||
from app.utils.helpers import redact_key_for_logging
|
||||
|
||||
logger = get_files_logger()
|
||||
|
||||
router = APIRouter()
|
||||
security_service = SecurityService()
|
||||
|
||||
|
||||
@router.post("/upload/v1beta/files")
|
||||
async def upload_file_init(
|
||||
request: Request,
|
||||
auth_token: str = Depends(security_service.verify_key_or_goog_api_key),
|
||||
x_goog_upload_protocol: Optional[str] = Header(None),
|
||||
x_goog_upload_command: Optional[str] = Header(None),
|
||||
x_goog_upload_header_content_length: Optional[str] = Header(None),
|
||||
x_goog_upload_header_content_type: Optional[str] = Header(None),
|
||||
):
|
||||
"""初始化文件上传"""
|
||||
logger.debug(f"Upload file request: {request.method=}, {request.url=}, {auth_token=}, {x_goog_upload_protocol=}, {x_goog_upload_command=}, {x_goog_upload_header_content_length=}, {x_goog_upload_header_content_type=}")
|
||||
|
||||
# 檢查是否是實際的上傳請求(有 upload_id)
|
||||
if request.query_params.get("upload_id") and x_goog_upload_command in ["upload", "upload, finalize"]:
|
||||
logger.debug("This is an upload request, not initialization. Redirecting to handle_upload.")
|
||||
return await handle_upload(
|
||||
upload_path="v1beta/files",
|
||||
request=request,
|
||||
key=request.query_params.get("key"),
|
||||
auth_token=auth_token
|
||||
)
|
||||
|
||||
try:
|
||||
# 使用认证 token 作为 user_token
|
||||
user_token = auth_token
|
||||
# 获取请求体
|
||||
body = await request.body()
|
||||
|
||||
# 构建请求主机 URL
|
||||
request_host = f"{request.url.scheme}://{request.url.netloc}"
|
||||
logger.info(f"Request host: {request_host}")
|
||||
|
||||
# 准备请求头
|
||||
headers = {
|
||||
"x-goog-upload-protocol": x_goog_upload_protocol or "resumable",
|
||||
"x-goog-upload-command": x_goog_upload_command or "start",
|
||||
}
|
||||
|
||||
if x_goog_upload_header_content_length:
|
||||
headers["x-goog-upload-header-content-length"] = x_goog_upload_header_content_length
|
||||
if x_goog_upload_header_content_type:
|
||||
headers["x-goog-upload-header-content-type"] = x_goog_upload_header_content_type
|
||||
|
||||
# 调用服务
|
||||
files_service = await get_files_service()
|
||||
response_data, response_headers = await files_service.initialize_upload(
|
||||
headers=headers,
|
||||
body=body,
|
||||
user_token=user_token,
|
||||
request_host=request_host # 傳遞請求主機
|
||||
)
|
||||
|
||||
logger.info(f"Upload initialization response: {response_data}")
|
||||
logger.info(f"Upload initialization response headers: {response_headers}")
|
||||
|
||||
logger.info(f"Upload initialization response headers: {response_data}")
|
||||
# 返回响应
|
||||
return JSONResponse(
|
||||
content=response_data,
|
||||
headers=response_headers
|
||||
)
|
||||
|
||||
except HTTPException as e:
|
||||
logger.error(f"Upload initialization failed: {e.detail}")
|
||||
return JSONResponse(
|
||||
content={"error": {"message": e.detail}},
|
||||
status_code=e.status_code
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in upload initialization: {str(e)}")
|
||||
return JSONResponse(
|
||||
content={"error": {"message": "Internal server error"}},
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.get("/v1beta/files")
|
||||
async def list_files(
|
||||
page_size: int = Query(10, ge=1, le=100, description="每页大小", alias="pageSize"),
|
||||
page_token: Optional[str] = Query(None, description="分页标记", alias="pageToken"),
|
||||
auth_token: str = Depends(security_service.verify_key_or_goog_api_key)
|
||||
) -> ListFilesResponse:
|
||||
"""列出文件"""
|
||||
logger.debug(f"List files: {page_size=}, {page_token=}, {auth_token=}")
|
||||
try:
|
||||
# 使用认证 token 作为 user_token(如果启用用户隔离)
|
||||
user_token = auth_token if settings.FILES_USER_ISOLATION_ENABLED else None
|
||||
# 调用服务
|
||||
files_service = await get_files_service()
|
||||
return await files_service.list_files(
|
||||
page_size=page_size,
|
||||
page_token=page_token,
|
||||
user_token=user_token
|
||||
)
|
||||
|
||||
except HTTPException as e:
|
||||
logger.error(f"List files failed: {e.detail}")
|
||||
return JSONResponse(
|
||||
content={"error": {"message": e.detail}},
|
||||
status_code=e.status_code
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in list files: {str(e)}")
|
||||
return JSONResponse(
|
||||
content={"error": {"message": "Internal server error"}},
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.get("/v1beta/files/{file_id:path}")
|
||||
async def get_file(
|
||||
file_id: str,
|
||||
auth_token: str = Depends(security_service.verify_key_or_goog_api_key)
|
||||
) -> FileMetadata:
|
||||
"""获取文件信息"""
|
||||
logger.debug(f"Get file request: {file_id=}, {auth_token=}")
|
||||
try:
|
||||
# 使用认证 token 作为 user_token
|
||||
user_token = auth_token
|
||||
# 调用服务
|
||||
files_service = await get_files_service()
|
||||
return await files_service.get_file(f"files/{file_id}", user_token)
|
||||
|
||||
except HTTPException as e:
|
||||
logger.error(f"Get file failed: {e.detail}")
|
||||
return JSONResponse(
|
||||
content={"error": {"message": e.detail}},
|
||||
status_code=e.status_code
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in get file: {str(e)}")
|
||||
return JSONResponse(
|
||||
content={"error": {"message": "Internal server error"}},
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/v1beta/files/{file_id:path}")
|
||||
async def delete_file(
|
||||
file_id: str,
|
||||
auth_token: str = Depends(security_service.verify_key_or_goog_api_key)
|
||||
) -> DeleteFileResponse:
|
||||
"""删除文件"""
|
||||
logger.info(f"Delete file: {file_id=}, {auth_token=}")
|
||||
try:
|
||||
# 使用认证 token 作为 user_token
|
||||
user_token = auth_token
|
||||
# 调用服务
|
||||
files_service = await get_files_service()
|
||||
success = await files_service.delete_file(f"files/{file_id}", user_token)
|
||||
|
||||
return DeleteFileResponse(
|
||||
success=success,
|
||||
message="File deleted successfully" if success else "Failed to delete file"
|
||||
)
|
||||
|
||||
except HTTPException as e:
|
||||
logger.error(f"Delete file failed: {e.detail}")
|
||||
return JSONResponse(
|
||||
content={"error": {"message": e.detail}},
|
||||
status_code=e.status_code
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in delete file: {str(e)}")
|
||||
return JSONResponse(
|
||||
content={"error": {"message": "Internal server error"}},
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
# 处理上传请求的通配符路由
|
||||
@router.api_route("/upload/{upload_path:path}", methods=["GET", "POST", "PUT"])
|
||||
async def handle_upload(
|
||||
upload_path: str,
|
||||
request: Request,
|
||||
key: Optional[str] = Query(None), # 從查詢參數獲取 key
|
||||
auth_token: str = Depends(security_service.verify_key_or_goog_api_key)
|
||||
):
|
||||
"""处理文件上传请求"""
|
||||
try:
|
||||
logger.info(f"Handling upload request: {request.method} {upload_path}, key={redact_key_for_logging(key)}")
|
||||
|
||||
# 從查詢參數獲取 upload_id
|
||||
upload_id = request.query_params.get("upload_id")
|
||||
if not upload_id:
|
||||
raise HTTPException(status_code=400, detail="Missing upload_id")
|
||||
|
||||
# 從 session 獲取真實的 API key
|
||||
files_service = await get_files_service()
|
||||
session_info = await files_service.get_upload_session(upload_id)
|
||||
if not session_info:
|
||||
logger.error(f"No session found for upload_id: {upload_id}")
|
||||
raise HTTPException(status_code=404, detail="Upload session not found")
|
||||
|
||||
real_api_key = session_info["api_key"]
|
||||
original_upload_url = session_info["upload_url"]
|
||||
|
||||
# 使用真實的 API key 構建完整的 Google 上傳 URL
|
||||
# 保留原始 URL 的所有參數,但使用真實的 API key
|
||||
upload_url = original_upload_url
|
||||
logger.info(f"Using real API key for upload: {redact_key_for_logging(real_api_key)}")
|
||||
|
||||
# 代理上传请求
|
||||
upload_handler = get_upload_handler()
|
||||
return await upload_handler.proxy_upload_request(
|
||||
request=request,
|
||||
upload_url=upload_url,
|
||||
files_service=files_service
|
||||
)
|
||||
|
||||
except HTTPException as e:
|
||||
logger.error(f"Upload handling failed: {e.detail}")
|
||||
return JSONResponse(
|
||||
content={"error": {"message": e.detail}},
|
||||
status_code=e.status_code
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in upload handling: {str(e)}")
|
||||
return JSONResponse(
|
||||
content={"error": {"message": "Internal server error"}},
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
# 为兼容性添加 /gemini 前缀的路由
|
||||
@router.post("/gemini/upload/v1beta/files")
|
||||
async def gemini_upload_file_init(
|
||||
request: Request,
|
||||
auth_token: str = Depends(security_service.verify_key_or_goog_api_key),
|
||||
x_goog_upload_protocol: Optional[str] = Header(None),
|
||||
x_goog_upload_command: Optional[str] = Header(None),
|
||||
x_goog_upload_header_content_length: Optional[str] = Header(None),
|
||||
x_goog_upload_header_content_type: Optional[str] = Header(None),
|
||||
):
|
||||
"""初始化文件上传(Gemini 前缀)"""
|
||||
return await upload_file_init(
|
||||
request,
|
||||
auth_token,
|
||||
x_goog_upload_protocol,
|
||||
x_goog_upload_command,
|
||||
x_goog_upload_header_content_length,
|
||||
x_goog_upload_header_content_type
|
||||
)
|
||||
|
||||
|
||||
@router.get("/gemini/v1beta/files")
|
||||
async def gemini_list_files(
|
||||
page_size: int = Query(10, ge=1, le=100, alias="pageSize"),
|
||||
page_token: Optional[str] = Query(None, alias="pageToken"),
|
||||
auth_token: str = Depends(security_service.verify_key_or_goog_api_key)
|
||||
) -> ListFilesResponse:
|
||||
"""列出文件(Gemini 前缀)"""
|
||||
return await list_files(page_size, page_token, auth_token)
|
||||
|
||||
|
||||
@router.get("/gemini/v1beta/files/{file_id:path}")
|
||||
async def gemini_get_file(
|
||||
file_id: str,
|
||||
auth_token: str = Depends(security_service.verify_key_or_goog_api_key)
|
||||
) -> FileMetadata:
|
||||
"""获取文件信息(Gemini 前缀)"""
|
||||
return await get_file(file_id, auth_token)
|
||||
|
||||
|
||||
@router.delete("/gemini/v1beta/files/{file_id:path}")
|
||||
async def gemini_delete_file(
|
||||
file_id: str,
|
||||
auth_token: str = Depends(security_service.verify_key_or_goog_api_key)
|
||||
) -> DeleteFileResponse:
|
||||
"""删除文件(Gemini 前缀)"""
|
||||
return await delete_file(file_id, auth_token)
|
||||
@@ -1,17 +1,29 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse, JSONResponse
|
||||
from copy import deepcopy
|
||||
import asyncio
|
||||
from copy import deepcopy
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from app.config.config import settings
|
||||
from app.log.logger import get_gemini_logger
|
||||
from app.core.constants import API_VERSION
|
||||
from app.core.security import SecurityService
|
||||
from app.domain.gemini_models import GeminiContent, GeminiRequest, ResetSelectedKeysRequest, VerifySelectedKeysRequest
|
||||
from app.domain.gemini_models import (
|
||||
GeminiBatchEmbedRequest,
|
||||
GeminiContent,
|
||||
GeminiEmbedRequest,
|
||||
GeminiRequest,
|
||||
ResetSelectedKeysRequest,
|
||||
VerifySelectedKeysRequest,
|
||||
)
|
||||
from app.handler.error_handler import handle_route_errors
|
||||
from app.handler.retry_handler import RetryHandler
|
||||
from app.log.logger import get_gemini_logger
|
||||
from app.service.chat.gemini_chat_service import GeminiChatService
|
||||
from app.service.embedding.gemini_embedding_service import GeminiEmbeddingService
|
||||
from app.service.key.key_manager import KeyManager, get_key_manager_instance
|
||||
from app.service.model.model_service import ModelService
|
||||
from app.handler.retry_handler import RetryHandler
|
||||
from app.handler.error_handler import handle_route_errors
|
||||
from app.core.constants import API_VERSION
|
||||
from app.service.tts.native.tts_routes import get_tts_chat_service
|
||||
from app.utils.helpers import redact_key_for_logging
|
||||
|
||||
router = APIRouter(prefix=f"/gemini/{API_VERSION}")
|
||||
router_v1beta = APIRouter(prefix=f"/{API_VERSION}")
|
||||
@@ -36,11 +48,16 @@ async def get_chat_service(key_manager: KeyManager = Depends(get_key_manager)):
|
||||
return GeminiChatService(settings.BASE_URL, key_manager)
|
||||
|
||||
|
||||
async def get_embedding_service(key_manager: KeyManager = Depends(get_key_manager)):
|
||||
"""获取Gemini嵌入服务实例"""
|
||||
return GeminiEmbeddingService(settings.BASE_URL, key_manager)
|
||||
|
||||
|
||||
@router.get("/models")
|
||||
@router_v1beta.get("/models")
|
||||
async def list_models(
|
||||
_=Depends(security_service.verify_key_or_goog_api_key),
|
||||
key_manager: KeyManager = Depends(get_key_manager)
|
||||
allowed_token=Depends(security_service.verify_key_or_goog_api_key),
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
):
|
||||
"""获取可用的 Gemini 模型列表,并根据配置添加衍生模型(搜索、图像、非思考)。"""
|
||||
operation_name = "list_gemini_models"
|
||||
@@ -48,22 +65,32 @@ async def list_models(
|
||||
logger.info("Handling Gemini models list request")
|
||||
|
||||
try:
|
||||
api_key = await key_manager.get_first_valid_key()
|
||||
api_key = await key_manager.get_random_valid_key()
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=503, detail="No valid API keys available to fetch models.")
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
raise HTTPException(
|
||||
status_code=503, detail="No valid API keys available to fetch models."
|
||||
)
|
||||
logger.info(f"Using allowed token: {allowed_token}")
|
||||
logger.info(f"Using API key: {redact_key_for_logging(api_key)}")
|
||||
|
||||
models_data = await model_service.get_gemini_models(api_key)
|
||||
if not models_data or "models" not in models_data:
|
||||
raise HTTPException(status_code=500, detail="Failed to fetch base models list.")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to fetch base models list."
|
||||
)
|
||||
|
||||
models_json = deepcopy(models_data)
|
||||
model_mapping = {x.get("name", "").split("/", maxsplit=1)[-1]: x for x in models_json.get("models", [])}
|
||||
model_mapping = {
|
||||
x.get("name", "").split("/", maxsplit=1)[-1]: x
|
||||
for x in models_json.get("models", [])
|
||||
}
|
||||
|
||||
def add_derived_model(base_name, suffix, display_suffix):
|
||||
model = model_mapping.get(base_name)
|
||||
if not model:
|
||||
logger.warning(f"Base model '{base_name}' not found for derived model '{suffix}'.")
|
||||
logger.warning(
|
||||
f"Base model '{base_name}' not found for derived model '{suffix}'."
|
||||
)
|
||||
return
|
||||
item = deepcopy(model)
|
||||
item["name"] = f"models/{base_name}{suffix}"
|
||||
@@ -77,7 +104,7 @@ async def list_models(
|
||||
add_derived_model(name, "-search", " For Search")
|
||||
if settings.IMAGE_MODELS:
|
||||
for name in settings.IMAGE_MODELS:
|
||||
add_derived_model(name, "-image", " For Image")
|
||||
add_derived_model(name, "-image", " For Image")
|
||||
if settings.THINKING_MODELS:
|
||||
for name in settings.THINKING_MODELS:
|
||||
add_derived_model(name, "-non-thinking", " Non Thinking")
|
||||
@@ -89,7 +116,8 @@ async def list_models(
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting Gemini models list: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Internal server error while fetching Gemini models list"
|
||||
status_code=500,
|
||||
detail="Internal server error while fetching Gemini models list",
|
||||
) from e
|
||||
|
||||
|
||||
@@ -99,25 +127,60 @@ async def list_models(
|
||||
async def generate_content(
|
||||
model_name: str,
|
||||
request: GeminiRequest,
|
||||
_=Depends(security_service.verify_key_or_goog_api_key),
|
||||
allowed_token=Depends(security_service.verify_key_or_goog_api_key),
|
||||
api_key: str = Depends(get_next_working_key),
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
chat_service: GeminiChatService = Depends(get_chat_service)
|
||||
chat_service: GeminiChatService = Depends(get_chat_service),
|
||||
):
|
||||
"""处理 Gemini 非流式内容生成请求。"""
|
||||
operation_name = "gemini_generate_content"
|
||||
async with handle_route_errors(logger, operation_name, failure_message="Content generation failed"):
|
||||
logger.info(f"Handling Gemini content generation request for model: {model_name}")
|
||||
async with handle_route_errors(
|
||||
logger, operation_name, failure_message="Content generation failed"
|
||||
):
|
||||
logger.info(
|
||||
f"Handling Gemini content generation request for model: {model_name}"
|
||||
)
|
||||
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
|
||||
# 检测是否为原生Gemini TTS请求
|
||||
is_native_tts = False
|
||||
if "tts" in model_name.lower() and request.generationConfig:
|
||||
# 直接从解析后的request对象获取TTS配置
|
||||
response_modalities = request.generationConfig.responseModalities or []
|
||||
speech_config = request.generationConfig.speechConfig or {}
|
||||
|
||||
# 如果包含AUDIO模态和语音配置,则认为是原生TTS请求
|
||||
if "AUDIO" in response_modalities and speech_config:
|
||||
is_native_tts = True
|
||||
logger.info("Detected native Gemini TTS request")
|
||||
logger.info(f"TTS responseModalities: {response_modalities}")
|
||||
logger.info(f"TTS speechConfig: {speech_config}")
|
||||
|
||||
logger.info(f"Using allowed token: {allowed_token}")
|
||||
logger.info(f"Using API key: {redact_key_for_logging(api_key)}")
|
||||
|
||||
if not await model_service.check_model_support(model_name):
|
||||
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Model {model_name} is not supported"
|
||||
)
|
||||
|
||||
# 所有原生TTS请求都使用TTS增强服务
|
||||
if is_native_tts:
|
||||
try:
|
||||
logger.info("Using native TTS enhanced service")
|
||||
tts_service = await get_tts_chat_service(key_manager)
|
||||
response = await tts_service.generate_content(
|
||||
model=model_name, request=request, api_key=api_key
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Native TTS processing failed, falling back to standard service: {e}"
|
||||
)
|
||||
|
||||
# 使用标准服务处理所有其他请求(非TTS)
|
||||
response = await chat_service.generate_content(
|
||||
model=model_name,
|
||||
request=request,
|
||||
api_key=api_key
|
||||
model=model_name, request=request, api_key=api_key
|
||||
)
|
||||
return response
|
||||
|
||||
@@ -128,41 +191,167 @@ async def generate_content(
|
||||
async def stream_generate_content(
|
||||
model_name: str,
|
||||
request: GeminiRequest,
|
||||
_=Depends(security_service.verify_key_or_goog_api_key),
|
||||
allowed_token=Depends(security_service.verify_key_or_goog_api_key),
|
||||
api_key: str = Depends(get_next_working_key),
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
chat_service: GeminiChatService = Depends(get_chat_service)
|
||||
chat_service: GeminiChatService = Depends(get_chat_service),
|
||||
):
|
||||
"""处理 Gemini 流式内容生成请求。"""
|
||||
operation_name = "gemini_stream_generate_content"
|
||||
async with handle_route_errors(logger, operation_name, failure_message="Streaming request initiation failed"):
|
||||
logger.info(f"Handling Gemini streaming content generation for model: {model_name}")
|
||||
async with handle_route_errors(
|
||||
logger, operation_name, failure_message="Streaming request initiation failed"
|
||||
):
|
||||
logger.info(
|
||||
f"Handling Gemini streaming content generation for model: {model_name}"
|
||||
)
|
||||
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
logger.info(f"Using allowed token: {allowed_token}")
|
||||
logger.info(f"Using API key: {redact_key_for_logging(api_key)}")
|
||||
|
||||
if not await model_service.check_model_support(model_name):
|
||||
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Model {model_name} is not supported"
|
||||
)
|
||||
|
||||
response_stream = chat_service.stream_generate_content(
|
||||
model=model_name,
|
||||
request=request,
|
||||
api_key=api_key
|
||||
raw_stream = chat_service.stream_generate_content(
|
||||
model=model_name, request=request, api_key=api_key
|
||||
)
|
||||
return StreamingResponse(response_stream, media_type="text/event-stream")
|
||||
try:
|
||||
# 尝试获取第一条数据,判断是正常 SSE(data: 前缀)还是错误 JSON
|
||||
first_chunk = await raw_stream.__anext__()
|
||||
except StopAsyncIteration:
|
||||
# 如果流直接结束,退回标准 SSE 输出
|
||||
return StreamingResponse(raw_stream, media_type="text/event-stream")
|
||||
except Exception as e:
|
||||
# 初始化流异常,直接返回 500 错误
|
||||
return JSONResponse(
|
||||
content={"error": {"code": e.args[0], "message": e.args[1]}},
|
||||
status_code=e.args[0],
|
||||
)
|
||||
|
||||
# 如果以 "data:" 开头,代表正常 SSE,将首块和后续块一起发送
|
||||
if isinstance(first_chunk, str) and first_chunk.startswith("data:"):
|
||||
|
||||
async def combined():
|
||||
yield first_chunk
|
||||
async for chunk in raw_stream:
|
||||
yield chunk
|
||||
|
||||
return StreamingResponse(combined(), media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.post("/models/{model_name}:countTokens")
|
||||
@router_v1beta.post("/models/{model_name}:countTokens")
|
||||
@RetryHandler(key_arg="api_key")
|
||||
async def count_tokens(
|
||||
model_name: str,
|
||||
request: GeminiRequest,
|
||||
allowed_token=Depends(security_service.verify_key_or_goog_api_key),
|
||||
api_key: str = Depends(get_next_working_key),
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
chat_service: GeminiChatService = Depends(get_chat_service),
|
||||
):
|
||||
"""处理 Gemini token 计数请求。"""
|
||||
operation_name = "gemini_count_tokens"
|
||||
async with handle_route_errors(
|
||||
logger, operation_name, failure_message="Token counting failed"
|
||||
):
|
||||
logger.info(f"Handling Gemini token count request for model: {model_name}")
|
||||
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||
logger.info(f"Using allowed token: {allowed_token}")
|
||||
logger.info(f"Using API key: {redact_key_for_logging(api_key)}")
|
||||
|
||||
if not await model_service.check_model_support(model_name):
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Model {model_name} is not supported"
|
||||
)
|
||||
|
||||
response = await chat_service.count_tokens(
|
||||
model=model_name, request=request, api_key=api_key
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@router.post("/models/{model_name}:embedContent")
|
||||
@router_v1beta.post("/models/{model_name}:embedContent")
|
||||
@RetryHandler(key_arg="api_key")
|
||||
async def embed_content(
|
||||
model_name: str,
|
||||
request: GeminiEmbedRequest,
|
||||
allowed_token=Depends(security_service.verify_key_or_goog_api_key),
|
||||
api_key: str = Depends(get_next_working_key),
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
embedding_service: GeminiEmbeddingService = Depends(get_embedding_service),
|
||||
):
|
||||
"""处理 Gemini 单一嵌入请求"""
|
||||
operation_name = "gemini_embed_content"
|
||||
async with handle_route_errors(
|
||||
logger, operation_name, failure_message="Embedding content generation failed"
|
||||
):
|
||||
logger.info(f"Handling Gemini embedding request for model: {model_name}")
|
||||
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||
logger.info(f"Using allowed token: {allowed_token}")
|
||||
logger.info(f"Using API key: {redact_key_for_logging(api_key)}")
|
||||
|
||||
if not await model_service.check_model_support(model_name):
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Model {model_name} is not supported"
|
||||
)
|
||||
|
||||
response = await embedding_service.embed_content(
|
||||
model=model_name, request=request, api_key=api_key
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@router.post("/models/{model_name}:batchEmbedContents")
|
||||
@router_v1beta.post("/models/{model_name}:batchEmbedContents")
|
||||
@RetryHandler(key_arg="api_key")
|
||||
async def batch_embed_contents(
|
||||
model_name: str,
|
||||
request: GeminiBatchEmbedRequest,
|
||||
allowed_token=Depends(security_service.verify_key_or_goog_api_key),
|
||||
api_key: str = Depends(get_next_working_key),
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
embedding_service: GeminiEmbeddingService = Depends(get_embedding_service),
|
||||
):
|
||||
"""处理 Gemini 批量嵌入请求"""
|
||||
operation_name = "gemini_batch_embed_contents"
|
||||
async with handle_route_errors(
|
||||
logger,
|
||||
operation_name,
|
||||
failure_message="Batch embedding content generation failed",
|
||||
):
|
||||
logger.info(f"Handling Gemini batch embedding request for model: {model_name}")
|
||||
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||
logger.info(f"Using allowed token: {allowed_token}")
|
||||
logger.info(f"Using API key: {redact_key_for_logging(api_key)}")
|
||||
|
||||
if not await model_service.check_model_support(model_name):
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Model {model_name} is not supported"
|
||||
)
|
||||
|
||||
response = await embedding_service.batch_embed_contents(
|
||||
model=model_name, request=request, api_key=api_key
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@router.post("/reset-all-fail-counts")
|
||||
async def reset_all_key_fail_counts(key_type: str = None, key_manager: KeyManager = Depends(get_key_manager)):
|
||||
async def reset_all_key_fail_counts(
|
||||
key_type: str = None, key_manager: KeyManager = Depends(get_key_manager)
|
||||
):
|
||||
"""批量重置Gemini API密钥的失败计数,可选择性地仅重置有效或无效密钥"""
|
||||
logger.info("-" * 50 + "reset_all_gemini_key_fail_counts" + "-" * 50)
|
||||
logger.info(f"Received reset request with key_type: {key_type}")
|
||||
|
||||
|
||||
try:
|
||||
# 获取分类后的密钥
|
||||
keys_by_status = await key_manager.get_keys_by_status()
|
||||
valid_keys = keys_by_status.get("valid_keys", {})
|
||||
invalid_keys = keys_by_status.get("invalid_keys", {})
|
||||
|
||||
|
||||
# 根据类型选择要重置的密钥
|
||||
keys_to_reset = []
|
||||
if key_type == "valid":
|
||||
@@ -174,35 +363,45 @@ async def reset_all_key_fail_counts(key_type: str = None, key_manager: KeyManage
|
||||
else:
|
||||
# 重置所有密钥
|
||||
await key_manager.reset_failure_counts()
|
||||
return JSONResponse({"success": True, "message": "所有密钥的失败计数已重置"})
|
||||
|
||||
return JSONResponse(
|
||||
{"success": True, "message": "所有密钥的失败计数已重置"}
|
||||
)
|
||||
|
||||
# 批量重置指定类型的密钥
|
||||
for key in keys_to_reset:
|
||||
await key_manager.reset_key_failure_count(key)
|
||||
|
||||
return JSONResponse({
|
||||
"success": True,
|
||||
"message": f"{key_type}密钥的失败计数已重置",
|
||||
"reset_count": len(keys_to_reset)
|
||||
})
|
||||
|
||||
return JSONResponse(
|
||||
{
|
||||
"success": True,
|
||||
"message": f"{key_type}密钥的失败计数已重置",
|
||||
"reset_count": len(keys_to_reset),
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to reset key failure counts: {str(e)}")
|
||||
return JSONResponse({"success": False, "message": f"批量重置失败: {str(e)}"}, status_code=500)
|
||||
|
||||
|
||||
return JSONResponse(
|
||||
{"success": False, "message": f"批量重置失败: {str(e)}"}, status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/reset-selected-fail-counts")
|
||||
async def reset_selected_key_fail_counts(
|
||||
request: ResetSelectedKeysRequest,
|
||||
key_manager: KeyManager = Depends(get_key_manager)
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
):
|
||||
"""批量重置选定Gemini API密钥的失败计数"""
|
||||
logger.info("-" * 50 + "reset_selected_gemini_key_fail_counts" + "-" * 50)
|
||||
keys_to_reset = request.keys
|
||||
key_type = request.key_type
|
||||
logger.info(f"Received reset request for {len(keys_to_reset)} selected {key_type} keys.")
|
||||
logger.info(
|
||||
f"Received reset request for {len(keys_to_reset)} selected {key_type} keys."
|
||||
)
|
||||
|
||||
if not keys_to_reset:
|
||||
return JSONResponse({"success": False, "message": "没有提供需要重置的密钥"}, status_code=400)
|
||||
return JSONResponse(
|
||||
{"success": False, "message": "没有提供需要重置的密钥"}, status_code=400
|
||||
)
|
||||
|
||||
reset_count = 0
|
||||
errors = []
|
||||
@@ -214,53 +413,79 @@ async def reset_selected_key_fail_counts(
|
||||
if result:
|
||||
reset_count += 1
|
||||
else:
|
||||
logger.warning(f"Key not found during selective reset: {key}")
|
||||
logger.warning(
|
||||
f"Key not found during selective reset: {redact_key_for_logging(key)}"
|
||||
)
|
||||
except Exception as key_error:
|
||||
logger.error(f"Error resetting key {key}: {str(key_error)}")
|
||||
logger.error(
|
||||
f"Error resetting key {redact_key_for_logging(key)}: {str(key_error)}"
|
||||
)
|
||||
errors.append(f"Key {key}: {str(key_error)}")
|
||||
|
||||
if errors:
|
||||
error_message = f"批量重置完成,但出现错误: {'; '.join(errors)}"
|
||||
final_success = reset_count > 0
|
||||
status_code = 207 if final_success and errors else 500
|
||||
return JSONResponse({
|
||||
"success": final_success,
|
||||
"message": error_message,
|
||||
"reset_count": reset_count
|
||||
}, status_code=status_code)
|
||||
error_message = f"批量重置完成,但出现错误: {'; '.join(errors)}"
|
||||
final_success = reset_count > 0
|
||||
status_code = 207 if final_success and errors else 500
|
||||
return JSONResponse(
|
||||
{
|
||||
"success": final_success,
|
||||
"message": error_message,
|
||||
"reset_count": reset_count,
|
||||
},
|
||||
status_code=status_code,
|
||||
)
|
||||
|
||||
return JSONResponse({
|
||||
"success": True,
|
||||
"message": f"成功重置 {reset_count} 个选定 {key_type} 密钥的失败计数",
|
||||
"reset_count": reset_count
|
||||
})
|
||||
return JSONResponse(
|
||||
{
|
||||
"success": True,
|
||||
"message": f"成功重置 {reset_count} 个选定 {key_type} 密钥的失败计数",
|
||||
"reset_count": reset_count,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process reset selected key failure counts request: {str(e)}")
|
||||
return JSONResponse({"success": False, "message": f"批量重置处理失败: {str(e)}"}, status_code=500)
|
||||
logger.error(
|
||||
f"Failed to process reset selected key failure counts request: {str(e)}"
|
||||
)
|
||||
return JSONResponse(
|
||||
{"success": False, "message": f"批量重置处理失败: {str(e)}"},
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/reset-fail-count/{api_key}")
|
||||
async def reset_key_fail_count(api_key: str, key_manager: KeyManager = Depends(get_key_manager)):
|
||||
async def reset_key_fail_count(
|
||||
api_key: str, key_manager: KeyManager = Depends(get_key_manager)
|
||||
):
|
||||
"""重置指定Gemini API密钥的失败计数"""
|
||||
logger.info("-" * 50 + "reset_gemini_key_fail_count" + "-" * 50)
|
||||
logger.info(f"Resetting failure count for API key: {api_key}")
|
||||
|
||||
logger.info(
|
||||
f"Resetting failure count for API key: {redact_key_for_logging(api_key)}"
|
||||
)
|
||||
|
||||
try:
|
||||
result = await key_manager.reset_key_failure_count(api_key)
|
||||
if result:
|
||||
return JSONResponse({"success": True, "message": "失败计数已重置"})
|
||||
return JSONResponse({"success": False, "message": "未找到指定密钥"}, status_code=404)
|
||||
return JSONResponse(
|
||||
{"success": False, "message": "未找到指定密钥"}, status_code=404
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to reset key failure count: {str(e)}")
|
||||
return JSONResponse({"success": False, "message": f"重置失败: {str(e)}"}, status_code=500)
|
||||
return JSONResponse(
|
||||
{"success": False, "message": f"重置失败: {str(e)}"}, status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/verify-key/{api_key}")
|
||||
async def verify_key(api_key: str, chat_service: GeminiChatService = Depends(get_chat_service), key_manager: KeyManager = Depends(get_key_manager)):
|
||||
async def verify_key(
|
||||
api_key: str,
|
||||
chat_service: GeminiChatService = Depends(get_chat_service),
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
):
|
||||
"""验证Gemini API密钥的有效性"""
|
||||
logger.info("-" * 50 + "verify_gemini_key" + "-" * 50)
|
||||
logger.info("Verifying API key validity")
|
||||
|
||||
|
||||
try:
|
||||
gemini_request = GeminiRequest(
|
||||
contents=[
|
||||
@@ -269,41 +494,47 @@ async def verify_key(api_key: str, chat_service: GeminiChatService = Depends(get
|
||||
parts=[{"text": "hi"}],
|
||||
)
|
||||
],
|
||||
generation_config={"temperature": 0.7, "top_p": 1.0, "max_output_tokens": 10}
|
||||
generation_config={"temperature": 0.7, "topP": 1.0, "maxOutputTokens": 10},
|
||||
)
|
||||
|
||||
|
||||
response = await chat_service.generate_content(
|
||||
settings.TEST_MODEL,
|
||||
gemini_request,
|
||||
api_key
|
||||
settings.TEST_MODEL, gemini_request, api_key
|
||||
)
|
||||
|
||||
|
||||
if response:
|
||||
return JSONResponse({"status": "valid"})
|
||||
# 如果密钥验证成功,则重置其失败计数
|
||||
await key_manager.reset_key_failure_count(api_key)
|
||||
return JSONResponse({"status": "valid"})
|
||||
except Exception as e:
|
||||
logger.error(f"Key verification failed: {str(e)}")
|
||||
|
||||
|
||||
async with key_manager.failure_count_lock:
|
||||
if api_key in key_manager.key_failure_counts:
|
||||
key_manager.key_failure_counts[api_key] += 1
|
||||
logger.warning(f"Verification exception for key: {api_key}, incrementing failure count")
|
||||
|
||||
return JSONResponse({"status": "invalid", "error": str(e)})
|
||||
logger.warning(
|
||||
f"Verification exception for key: {redact_key_for_logging(api_key)}, incrementing failure count"
|
||||
)
|
||||
|
||||
return JSONResponse({"status": "invalid", "error": e.args[1]})
|
||||
|
||||
|
||||
@router.post("/verify-selected-keys")
|
||||
async def verify_selected_keys(
|
||||
request: VerifySelectedKeysRequest,
|
||||
chat_service: GeminiChatService = Depends(get_chat_service),
|
||||
key_manager: KeyManager = Depends(get_key_manager)
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
):
|
||||
"""批量验证选定Gemini API密钥的有效性"""
|
||||
logger.info("-" * 50 + "verify_selected_gemini_keys" + "-" * 50)
|
||||
keys_to_verify = request.keys
|
||||
logger.info(f"Received verification request for {len(keys_to_verify)} selected keys.")
|
||||
logger.info(
|
||||
f"Received verification request for {len(keys_to_verify)} selected keys."
|
||||
)
|
||||
|
||||
if not keys_to_verify:
|
||||
return JSONResponse({"success": False, "message": "没有提供需要验证的密钥"}, status_code=400)
|
||||
return JSONResponse(
|
||||
{"success": False, "message": "没有提供需要验证的密钥"}, status_code=400
|
||||
)
|
||||
|
||||
successful_keys = []
|
||||
failed_keys = {}
|
||||
@@ -314,26 +545,36 @@ async def verify_selected_keys(
|
||||
try:
|
||||
gemini_request = GeminiRequest(
|
||||
contents=[GeminiContent(role="user", parts=[{"text": "hi"}])],
|
||||
generation_config={"temperature": 0.7, "top_p": 1.0, "max_output_tokens": 10}
|
||||
generation_config={
|
||||
"temperature": 0.7,
|
||||
"topP": 1.0,
|
||||
"maxOutputTokens": 10,
|
||||
},
|
||||
)
|
||||
await chat_service.generate_content(
|
||||
settings.TEST_MODEL,
|
||||
gemini_request,
|
||||
api_key
|
||||
settings.TEST_MODEL, gemini_request, api_key
|
||||
)
|
||||
successful_keys.append(api_key)
|
||||
# 如果密钥验证成功,则重置其失败计数
|
||||
await key_manager.reset_key_failure_count(api_key)
|
||||
return api_key, "valid", None
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
logger.warning(f"Key verification failed for {api_key}: {error_message}")
|
||||
error_message = e.args[1]
|
||||
logger.warning(
|
||||
f"Key verification failed for {redact_key_for_logging(api_key)}: {error_message}"
|
||||
)
|
||||
async with key_manager.failure_count_lock:
|
||||
if api_key in key_manager.key_failure_counts:
|
||||
key_manager.key_failure_counts[api_key] += 1
|
||||
logger.warning(f"Bulk verification exception for key: {api_key}, incrementing failure count")
|
||||
logger.warning(
|
||||
f"Bulk verification exception for key: {redact_key_for_logging(api_key)}, incrementing failure count"
|
||||
)
|
||||
else:
|
||||
key_manager.key_failure_counts[api_key] = 1
|
||||
logger.warning(f"Bulk verification exception for key: {api_key}, initializing failure count to 1")
|
||||
failed_keys[api_key] = error_message
|
||||
key_manager.key_failure_counts[api_key] = 1
|
||||
logger.warning(
|
||||
f"Bulk verification exception for key: {redact_key_for_logging(api_key)}, initializing failure count to 1"
|
||||
)
|
||||
failed_keys[api_key] = {"error_message": e.args[1], "error_code": e.args[0]}
|
||||
return api_key, "invalid", error_message
|
||||
|
||||
tasks = [_verify_single_key(key) for key in keys_to_verify]
|
||||
@@ -341,34 +582,37 @@ async def verify_selected_keys(
|
||||
|
||||
for result in results:
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"An unexpected error occurred during bulk verification task: {result}")
|
||||
elif result:
|
||||
if not isinstance(result, Exception) and result:
|
||||
key, status, error = result
|
||||
elif isinstance(result, Exception):
|
||||
logger.error(f"Task execution error during bulk verification: {result}")
|
||||
logger.error(
|
||||
f"An unexpected error occurred during bulk verification task: {result}"
|
||||
)
|
||||
|
||||
valid_count = len(successful_keys)
|
||||
invalid_count = len(failed_keys)
|
||||
logger.info(f"Bulk verification finished. Valid: {valid_count}, Invalid: {invalid_count}")
|
||||
logger.info(
|
||||
f"Bulk verification finished. Valid: {valid_count}, Invalid: {invalid_count}"
|
||||
)
|
||||
|
||||
if failed_keys:
|
||||
message = f"批量验证完成。成功: {valid_count}, 失败: {invalid_count}。"
|
||||
return JSONResponse({
|
||||
"success": True,
|
||||
"message": message,
|
||||
"successful_keys": successful_keys,
|
||||
"failed_keys": failed_keys,
|
||||
"valid_count": valid_count,
|
||||
"invalid_count": invalid_count
|
||||
})
|
||||
return JSONResponse(
|
||||
{
|
||||
"success": True,
|
||||
"message": message,
|
||||
"successful_keys": successful_keys,
|
||||
"failed_keys": failed_keys,
|
||||
"valid_count": valid_count,
|
||||
"invalid_count": invalid_count,
|
||||
}
|
||||
)
|
||||
else:
|
||||
message = f"批量验证成功完成。所有 {valid_count} 个密钥均有效。"
|
||||
return JSONResponse({
|
||||
"success": True,
|
||||
"message": message,
|
||||
"successful_keys": successful_keys,
|
||||
"failed_keys": {},
|
||||
"valid_count": valid_count,
|
||||
"invalid_count": 0
|
||||
})
|
||||
return JSONResponse(
|
||||
{
|
||||
"success": True,
|
||||
"message": message,
|
||||
"successful_keys": successful_keys,
|
||||
"failed_keys": {},
|
||||
"valid_count": valid_count,
|
||||
"invalid_count": 0,
|
||||
}
|
||||
)
|
||||
|
||||
83
app/router/key_routes.py
Normal file
83
app/router/key_routes.py
Normal file
@@ -0,0 +1,83 @@
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from app.service.key.key_manager import KeyManager, get_key_manager_instance
|
||||
from app.core.security import verify_auth_token
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/api/keys")
|
||||
async def get_keys_paginated(
|
||||
request: Request,
|
||||
page: int = 1,
|
||||
limit: int = 10,
|
||||
search: str = None,
|
||||
fail_count_threshold: int = None,
|
||||
status: str = "all", # 'valid', 'invalid', 'all'
|
||||
key_manager: KeyManager = Depends(get_key_manager_instance),
|
||||
):
|
||||
"""
|
||||
Get paginated, filtered, and searched keys.
|
||||
"""
|
||||
auth_token = request.cookies.get("auth_token")
|
||||
if not auth_token or not verify_auth_token(auth_token):
|
||||
return JSONResponse(status_code=401, content={"detail": "Unauthorized"})
|
||||
|
||||
all_keys_with_status = await key_manager.get_all_keys_with_fail_count()
|
||||
|
||||
# Filter by status
|
||||
if status == "valid":
|
||||
keys_to_filter = all_keys_with_status["valid_keys"]
|
||||
elif status == "invalid":
|
||||
keys_to_filter = all_keys_with_status["invalid_keys"]
|
||||
else:
|
||||
# Combine both for 'all' status, which might be useful for a unified view if ever needed
|
||||
keys_to_filter = {**all_keys_with_status["valid_keys"], **all_keys_with_status["invalid_keys"]}
|
||||
|
||||
|
||||
# Further filtering (search and fail_count_threshold)
|
||||
filtered_keys = {}
|
||||
for key, fail_count in keys_to_filter.items():
|
||||
search_match = True
|
||||
if search:
|
||||
search_match = search.lower() in key.lower()
|
||||
|
||||
fail_count_match = True
|
||||
if fail_count_threshold is not None:
|
||||
fail_count_match = fail_count >= fail_count_threshold
|
||||
|
||||
if search_match and fail_count_match:
|
||||
filtered_keys[key] = fail_count
|
||||
|
||||
# Pagination
|
||||
keys_list = list(filtered_keys.items())
|
||||
total_items = len(keys_list)
|
||||
start_index = (page - 1) * limit
|
||||
end_index = start_index + limit
|
||||
paginated_keys = dict(keys_list[start_index:end_index])
|
||||
|
||||
return {
|
||||
"keys": paginated_keys,
|
||||
"total_items": total_items,
|
||||
"total_pages": (total_items + limit - 1) // limit,
|
||||
"current_page": page,
|
||||
}
|
||||
|
||||
@router.get("/api/keys/all")
|
||||
async def get_all_keys(
|
||||
request: Request,
|
||||
key_manager: KeyManager = Depends(get_key_manager_instance),
|
||||
):
|
||||
"""
|
||||
Get all keys (both valid and invalid) for bulk operations.
|
||||
"""
|
||||
auth_token = request.cookies.get("auth_token")
|
||||
if not auth_token or not verify_auth_token(auth_token):
|
||||
return JSONResponse(status_code=401, content={"detail": "Unauthorized"})
|
||||
|
||||
all_keys_with_status = await key_manager.get_all_keys_with_fail_count()
|
||||
|
||||
return {
|
||||
"valid_keys": list(all_keys_with_status["valid_keys"].keys()),
|
||||
"invalid_keys": list(all_keys_with_status["invalid_keys"].keys()),
|
||||
"total_count": len(all_keys_with_status["valid_keys"]) + len(all_keys_with_status["invalid_keys"])
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from app.config.config import settings
|
||||
from app.core.security import SecurityService
|
||||
@@ -8,18 +8,21 @@ from app.domain.openai_models import (
|
||||
EmbeddingRequest,
|
||||
ImageGenerationRequest,
|
||||
)
|
||||
from app.handler.retry_handler import RetryHandler
|
||||
from app.handler.error_handler import handle_route_errors
|
||||
from app.handler.retry_handler import RetryHandler
|
||||
from app.log.logger import get_openai_compatible_logger
|
||||
from app.service.key.key_manager import KeyManager, get_key_manager_instance
|
||||
from app.service.openai_compatiable.openai_compatiable_service import OpenAICompatiableService
|
||||
|
||||
from app.service.openai_compatiable.openai_compatiable_service import (
|
||||
OpenAICompatiableService,
|
||||
)
|
||||
from app.utils.helpers import redact_key_for_logging
|
||||
|
||||
router = APIRouter()
|
||||
logger = get_openai_compatible_logger()
|
||||
|
||||
security_service = SecurityService()
|
||||
|
||||
|
||||
async def get_key_manager():
|
||||
return await get_key_manager_instance()
|
||||
|
||||
@@ -37,7 +40,7 @@ async def get_openai_service(key_manager: KeyManager = Depends(get_key_manager))
|
||||
|
||||
@router.get("/openai/v1/models")
|
||||
async def list_models(
|
||||
_=Depends(security_service.verify_authorization),
|
||||
allowed_token=Depends(security_service.verify_authorization),
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
openai_service: OpenAICompatiableService = Depends(get_openai_service),
|
||||
):
|
||||
@@ -45,8 +48,9 @@ async def list_models(
|
||||
operation_name = "list_models"
|
||||
async with handle_route_errors(logger, operation_name):
|
||||
logger.info("Handling models list request")
|
||||
api_key = await key_manager.get_first_valid_key()
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
api_key = await key_manager.get_random_valid_key()
|
||||
logger.info(f"Using allowed token: {allowed_token}")
|
||||
logger.info(f"Using API key: {redact_key_for_logging(api_key)}")
|
||||
return await openai_service.get_models(api_key)
|
||||
|
||||
|
||||
@@ -54,7 +58,7 @@ async def list_models(
|
||||
@RetryHandler(key_arg="api_key")
|
||||
async def chat_completion(
|
||||
request: ChatRequest,
|
||||
_=Depends(security_service.verify_authorization),
|
||||
allowed_token=Depends(security_service.verify_authorization),
|
||||
api_key: str = Depends(get_next_working_key_wrapper),
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
openai_service: OpenAICompatiableService = Depends(get_openai_service),
|
||||
@@ -69,28 +73,56 @@ async def chat_completion(
|
||||
async with handle_route_errors(logger, operation_name):
|
||||
logger.info(f"Handling chat completion request for model: {request.model}")
|
||||
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||
logger.info(f"Using API key: {current_api_key}")
|
||||
logger.info(f"Using allowed token: {allowed_token}")
|
||||
logger.info(f"Using API key: {redact_key_for_logging(current_api_key)}")
|
||||
|
||||
raw_response = None
|
||||
if is_image_chat:
|
||||
response = await openai_service.create_image_chat_completion(request, current_api_key)
|
||||
return response
|
||||
raw_response = await openai_service.create_image_chat_completion(
|
||||
request, current_api_key
|
||||
)
|
||||
else:
|
||||
response = await openai_service.create_chat_completion(request, current_api_key)
|
||||
if request.stream:
|
||||
return StreamingResponse(response, media_type="text/event-stream")
|
||||
return response
|
||||
raw_response = await openai_service.create_chat_completion(
|
||||
request, current_api_key
|
||||
)
|
||||
if request.stream:
|
||||
try:
|
||||
# 尝试获取第一条数据,判断是正常 SSE(data: 前缀)还是错误 JSON
|
||||
first_chunk = await raw_response.__anext__()
|
||||
except StopAsyncIteration:
|
||||
# 如果流直接结束,退回标准 SSE 输出
|
||||
return StreamingResponse(raw_response, media_type="text/event-stream")
|
||||
except Exception as e:
|
||||
# 初始化流异常,直接返回 500 错误
|
||||
return JSONResponse(
|
||||
content={"error": {"code": e.args[0], "message": e.args[1]}},
|
||||
status_code=e.args[0],
|
||||
)
|
||||
|
||||
# 如果以 "data:" 开头,代表正常 SSE,将首块和后续块一起发送
|
||||
if isinstance(first_chunk, str) and first_chunk.startswith("data:"):
|
||||
|
||||
async def combined():
|
||||
yield first_chunk
|
||||
async for chunk in raw_response:
|
||||
yield chunk
|
||||
|
||||
return StreamingResponse(combined(), media_type="text/event-stream")
|
||||
else:
|
||||
return raw_response
|
||||
|
||||
|
||||
@router.post("/openai/v1/images/generations")
|
||||
async def generate_image(
|
||||
request: ImageGenerationRequest,
|
||||
_=Depends(security_service.verify_authorization),
|
||||
allowed_token=Depends(security_service.verify_authorization),
|
||||
openai_service: OpenAICompatiableService = Depends(get_openai_service),
|
||||
):
|
||||
"""处理图像生成请求。"""
|
||||
operation_name = "generate_image"
|
||||
async with handle_route_errors(logger, operation_name):
|
||||
logger.info(f"Handling image generation request for prompt: {request.prompt}")
|
||||
logger.info(f"Using allowed token: {allowed_token}")
|
||||
request.model = settings.CREATE_IMAGE_MODEL
|
||||
return await openai_service.generate_images(request)
|
||||
|
||||
@@ -98,7 +130,7 @@ async def generate_image(
|
||||
@router.post("/openai/v1/embeddings")
|
||||
async def embedding(
|
||||
request: EmbeddingRequest,
|
||||
_=Depends(security_service.verify_authorization),
|
||||
allowed_token=Depends(security_service.verify_authorization),
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
openai_service: OpenAICompatiableService = Depends(get_openai_service),
|
||||
):
|
||||
@@ -107,7 +139,8 @@ async def embedding(
|
||||
async with handle_route_errors(logger, operation_name):
|
||||
logger.info(f"Handling embedding request for model: {request.model}")
|
||||
api_key = await key_manager.get_next_working_key()
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
logger.info(f"Using allowed token: {allowed_token}")
|
||||
logger.info(f"Using API key: {redact_key_for_logging(api_key)}")
|
||||
return await openai_service.create_embeddings(
|
||||
input_text=request.input, model=request.model, api_key=api_key
|
||||
)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Response
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from app.config.config import settings
|
||||
from app.core.security import SecurityService
|
||||
@@ -9,15 +9,16 @@ from app.domain.openai_models import (
|
||||
ImageGenerationRequest,
|
||||
TTSRequest,
|
||||
)
|
||||
from app.handler.retry_handler import RetryHandler
|
||||
from app.handler.error_handler import handle_route_errors
|
||||
from app.handler.retry_handler import RetryHandler
|
||||
from app.log.logger import get_openai_logger
|
||||
from app.service.chat.openai_chat_service import OpenAIChatService
|
||||
from app.service.embedding.embedding_service import EmbeddingService
|
||||
from app.service.image.image_create_service import ImageCreateService
|
||||
from app.service.tts.tts_service import TTSService
|
||||
from app.service.key.key_manager import KeyManager, get_key_manager_instance
|
||||
from app.service.model.model_service import ModelService
|
||||
from app.service.tts.tts_service import TTSService
|
||||
from app.utils.helpers import redact_key_for_logging
|
||||
|
||||
router = APIRouter()
|
||||
logger = get_openai_logger()
|
||||
@@ -52,15 +53,16 @@ async def get_tts_service():
|
||||
@router.get("/v1/models")
|
||||
@router.get("/hf/v1/models")
|
||||
async def list_models(
|
||||
_=Depends(security_service.verify_authorization),
|
||||
allowed_token=Depends(security_service.verify_authorization),
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
):
|
||||
"""获取可用的 OpenAI 模型列表 (兼容 Gemini 和 OpenAI)。"""
|
||||
operation_name = "list_models"
|
||||
async with handle_route_errors(logger, operation_name):
|
||||
logger.info("Handling models list request")
|
||||
api_key = await key_manager.get_first_valid_key()
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
api_key = await key_manager.get_random_valid_key()
|
||||
logger.info(f"Using allowed token: {allowed_token}")
|
||||
logger.info(f"Using API key: {redact_key_for_logging(api_key)}")
|
||||
return await model_service.get_gemini_openai_models(api_key)
|
||||
|
||||
|
||||
@@ -69,7 +71,7 @@ async def list_models(
|
||||
@RetryHandler(key_arg="api_key")
|
||||
async def chat_completion(
|
||||
request: ChatRequest,
|
||||
_=Depends(security_service.verify_authorization),
|
||||
allowed_token=Depends(security_service.verify_authorization),
|
||||
api_key: str = Depends(get_next_working_key_wrapper),
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
chat_service: OpenAIChatService = Depends(get_openai_chat_service),
|
||||
@@ -84,35 +86,62 @@ async def chat_completion(
|
||||
async with handle_route_errors(logger, operation_name):
|
||||
logger.info(f"Handling chat completion request for model: {request.model}")
|
||||
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||
logger.info(f"Using API key: {current_api_key}")
|
||||
logger.info(f"Using allowed token: {allowed_token}")
|
||||
logger.info(f"Using API key: {redact_key_for_logging(current_api_key)}")
|
||||
|
||||
if not await model_service.check_model_support(request.model):
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Model {request.model} is not supported"
|
||||
)
|
||||
|
||||
raw_response = None
|
||||
if is_image_chat:
|
||||
response = await chat_service.create_image_chat_completion(request, current_api_key)
|
||||
if request.stream:
|
||||
return StreamingResponse(response, media_type="text/event-stream")
|
||||
return response
|
||||
raw_response = await chat_service.create_image_chat_completion(
|
||||
request, current_api_key
|
||||
)
|
||||
else:
|
||||
response = await chat_service.create_chat_completion(request, current_api_key)
|
||||
if request.stream:
|
||||
return StreamingResponse(response, media_type="text/event-stream")
|
||||
return response
|
||||
raw_response = await chat_service.create_chat_completion(
|
||||
request, current_api_key
|
||||
)
|
||||
|
||||
if request.stream:
|
||||
try:
|
||||
# 尝试获取第一条数据,判断是正常 SSE(data: 前缀)还是错误 JSON
|
||||
first_chunk = await raw_response.__anext__()
|
||||
except StopAsyncIteration:
|
||||
# 如果流直接结束,退回标准 SSE 输出
|
||||
return StreamingResponse(raw_response, media_type="text/event-stream")
|
||||
except Exception as e:
|
||||
# 初始化流异常,直接返回 500 错误
|
||||
return JSONResponse(
|
||||
content={"error": {"code": e.args[0], "message": e.args[1]}},
|
||||
status_code=e.args[0],
|
||||
)
|
||||
|
||||
# 如果以 "data:" 开头,代表正常 SSE,将首块和后续块一起发送
|
||||
if isinstance(first_chunk, str) and first_chunk.startswith("data:"):
|
||||
|
||||
async def combined():
|
||||
yield first_chunk
|
||||
async for chunk in raw_response:
|
||||
yield chunk
|
||||
|
||||
return StreamingResponse(combined(), media_type="text/event-stream")
|
||||
else:
|
||||
return raw_response
|
||||
|
||||
|
||||
@router.post("/v1/images/generations")
|
||||
@router.post("/hf/v1/images/generations")
|
||||
async def generate_image(
|
||||
request: ImageGenerationRequest,
|
||||
_=Depends(security_service.verify_authorization),
|
||||
allowed_token=Depends(security_service.verify_authorization),
|
||||
):
|
||||
"""处理 OpenAI 图像生成请求。"""
|
||||
operation_name = "generate_image"
|
||||
async with handle_route_errors(logger, operation_name):
|
||||
logger.info(f"Handling image generation request for prompt: {request.prompt}")
|
||||
logger.info(f"Using allowed token: {allowed_token}")
|
||||
response = image_create_service.generate_images(request)
|
||||
return response
|
||||
|
||||
@@ -121,7 +150,7 @@ async def generate_image(
|
||||
@router.post("/hf/v1/embeddings")
|
||||
async def embedding(
|
||||
request: EmbeddingRequest,
|
||||
_=Depends(security_service.verify_authorization),
|
||||
allowed_token=Depends(security_service.verify_authorization),
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
):
|
||||
"""处理 OpenAI 文本嵌入请求。"""
|
||||
@@ -129,7 +158,8 @@ async def embedding(
|
||||
async with handle_route_errors(logger, operation_name):
|
||||
logger.info(f"Handling embedding request for model: {request.model}")
|
||||
api_key = await key_manager.get_next_working_key()
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
logger.info(f"Using allowed token: {allowed_token}")
|
||||
logger.info(f"Using API key: {redact_key_for_logging(api_key)}")
|
||||
response = await embedding_service.create_embedding(
|
||||
input_text=request.input, model=request.model, api_key=api_key
|
||||
)
|
||||
@@ -161,7 +191,7 @@ async def get_keys_list(
|
||||
@router.post("/hf/v1/audio/speech")
|
||||
async def text_to_speech(
|
||||
request: TTSRequest,
|
||||
_=Depends(security_service.verify_authorization),
|
||||
allowed_token=Depends(security_service.verify_authorization),
|
||||
api_key: str = Depends(get_next_working_key_wrapper),
|
||||
tts_service: TTSService = Depends(get_tts_service),
|
||||
):
|
||||
@@ -170,6 +200,7 @@ async def text_to_speech(
|
||||
async with handle_route_errors(logger, operation_name):
|
||||
logger.info(f"Handling TTS request for model: {request.model}")
|
||||
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
logger.info(f"Using allowed token: {allowed_token}")
|
||||
logger.info(f"Using API key: {redact_key_for_logging(api_key)}")
|
||||
audio_data = await tts_service.create_tts(request, api_key)
|
||||
return Response(content=audio_data, media_type="audio/wav")
|
||||
|
||||
@@ -6,15 +6,31 @@ from fastapi import FastAPI, Request
|
||||
from fastapi.responses import HTMLResponse, RedirectResponse
|
||||
from fastapi.templating import Jinja2Templates
|
||||
|
||||
from app.config.config import settings
|
||||
from app.core.security import verify_auth_token
|
||||
from app.log.logger import get_routes_logger
|
||||
from app.router import error_log_routes, gemini_routes, openai_routes, config_routes, scheduler_routes, stats_routes, version_routes, openai_compatiable_routes, vertex_express_routes
|
||||
from app.router import (
|
||||
config_routes,
|
||||
error_log_routes,
|
||||
files_routes,
|
||||
gemini_routes,
|
||||
key_routes,
|
||||
openai_compatiable_routes,
|
||||
openai_routes,
|
||||
scheduler_routes,
|
||||
stats_routes,
|
||||
version_routes,
|
||||
vertex_express_routes,
|
||||
)
|
||||
from app.service.key.key_manager import get_key_manager_instance
|
||||
from app.service.stats.stats_service import StatsService
|
||||
from app.utils.static_version import get_static_url
|
||||
|
||||
logger = get_routes_logger()
|
||||
|
||||
templates = Jinja2Templates(directory="app/templates")
|
||||
# 设置模板全局变量
|
||||
templates.env.globals["static_url"] = get_static_url
|
||||
|
||||
|
||||
def setup_routers(app: FastAPI) -> None:
|
||||
@@ -34,6 +50,8 @@ def setup_routers(app: FastAPI) -> None:
|
||||
app.include_router(version_routes.router)
|
||||
app.include_router(openai_compatiable_routes.router)
|
||||
app.include_router(vertex_express_routes.router)
|
||||
app.include_router(files_routes.router)
|
||||
app.include_router(key_routes.router)
|
||||
|
||||
setup_page_routes(app)
|
||||
|
||||
@@ -66,9 +84,12 @@ def setup_page_routes(app: FastAPI) -> None:
|
||||
|
||||
if verify_auth_token(auth_token):
|
||||
logger.info("Successful authentication")
|
||||
response = RedirectResponse(url="/config", status_code=302)
|
||||
response = RedirectResponse(url="/keys", status_code=302)
|
||||
response.set_cookie(
|
||||
key="auth_token", value=auth_token, httponly=True, max_age=3600
|
||||
key="auth_token",
|
||||
value=auth_token,
|
||||
httponly=True,
|
||||
max_age=settings.ADMIN_SESSION_EXPIRE,
|
||||
)
|
||||
return response
|
||||
logger.warning("Failed authentication attempt with invalid token")
|
||||
@@ -88,7 +109,9 @@ def setup_page_routes(app: FastAPI) -> None:
|
||||
|
||||
key_manager = await get_key_manager_instance()
|
||||
keys_status = await key_manager.get_keys_by_status()
|
||||
total_keys = len(keys_status["valid_keys"]) + len(keys_status["invalid_keys"])
|
||||
total_keys = len(keys_status["valid_keys"]) + len(
|
||||
keys_status["invalid_keys"]
|
||||
)
|
||||
valid_key_count = len(keys_status["valid_keys"])
|
||||
invalid_key_count = len(keys_status["invalid_keys"])
|
||||
|
||||
@@ -101,8 +124,8 @@ def setup_page_routes(app: FastAPI) -> None:
|
||||
"keys_status.html",
|
||||
{
|
||||
"request": request,
|
||||
"valid_keys": keys_status["valid_keys"],
|
||||
"invalid_keys": keys_status["invalid_keys"],
|
||||
"valid_keys": {},
|
||||
"invalid_keys": {},
|
||||
"total_keys": total_keys,
|
||||
"valid_key_count": valid_key_count,
|
||||
"invalid_key_count": invalid_key_count,
|
||||
@@ -111,8 +134,26 @@ def setup_page_routes(app: FastAPI) -> None:
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving keys status or API stats: {str(e)}")
|
||||
raise
|
||||
|
||||
# Even if there's an error, render the page with whatever data is available
|
||||
# or with empty/default values, so the frontend can still load.
|
||||
return templates.TemplateResponse(
|
||||
"keys_status.html",
|
||||
{
|
||||
"request": request,
|
||||
"valid_keys": {},
|
||||
"invalid_keys": {},
|
||||
"total_keys": 0,
|
||||
"valid_key_count": 0,
|
||||
"invalid_key_count": 0,
|
||||
"api_stats": { # Provide a default structure for api_stats
|
||||
"calls_1m": {"total": 0, "success": 0, "failure": 0},
|
||||
"calls_1h": {"total": 0, "success": 0, "failure": 0},
|
||||
"calls_24h": {"total": 0, "success": 0, "failure": 0},
|
||||
"calls_month": {"total": 0, "success": 0, "failure": 0},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
@app.get("/config", response_class=HTMLResponse)
|
||||
async def config_page(request: Request):
|
||||
"""配置编辑页面"""
|
||||
@@ -121,13 +162,15 @@ def setup_page_routes(app: FastAPI) -> None:
|
||||
if not auth_token or not verify_auth_token(auth_token):
|
||||
logger.warning("Unauthorized access attempt to config page")
|
||||
return RedirectResponse(url="/", status_code=302)
|
||||
|
||||
|
||||
logger.info("Config page accessed successfully")
|
||||
return templates.TemplateResponse("config_editor.html", {"request": request})
|
||||
return templates.TemplateResponse(
|
||||
"config_editor.html", {"request": request}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error accessing config page: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@app.get("/logs", response_class=HTMLResponse)
|
||||
async def logs_page(request: Request):
|
||||
"""错误日志页面"""
|
||||
@@ -136,7 +179,7 @@ def setup_page_routes(app: FastAPI) -> None:
|
||||
if not auth_token or not verify_auth_token(auth_token):
|
||||
logger.warning("Unauthorized access attempt to logs page")
|
||||
return RedirectResponse(url="/", status_code=302)
|
||||
|
||||
|
||||
logger.info("Logs page accessed successfully")
|
||||
return templates.TemplateResponse("error_logs.html", {"request": request})
|
||||
except Exception as e:
|
||||
@@ -166,6 +209,7 @@ def setup_api_stats_routes(app: FastAPI) -> None:
|
||||
Args:
|
||||
app: FastAPI应用程序实例
|
||||
"""
|
||||
|
||||
@app.get("/api/stats/details")
|
||||
async def api_stats_details(request: Request, period: str):
|
||||
"""获取指定时间段内的 API 调用详情"""
|
||||
@@ -180,8 +224,67 @@ def setup_api_stats_routes(app: FastAPI) -> None:
|
||||
details = await stats_service.get_api_call_details(period)
|
||||
return details
|
||||
except ValueError as e:
|
||||
logger.warning(f"Invalid period requested for API stats details: {period} - {str(e)}")
|
||||
logger.warning(
|
||||
f"Invalid period requested for API stats details: {period} - {str(e)}"
|
||||
)
|
||||
return {"error": str(e)}, 400
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching API stats details for period {period}: {str(e)}")
|
||||
logger.error(
|
||||
f"Error fetching API stats details for period {period}: {str(e)}"
|
||||
)
|
||||
return {"error": "Internal server error"}, 500
|
||||
|
||||
@app.get("/api/stats/attention-keys")
|
||||
async def api_stats_attention_keys(
|
||||
request: Request, limit: int = 20, status_code: int = 429
|
||||
):
|
||||
"""返回最近24小时指定错误码次数最多的Key(仅包含内存Key列表中的)。默认错误码429。"""
|
||||
try:
|
||||
auth_token = request.cookies.get("auth_token")
|
||||
if not auth_token or not verify_auth_token(auth_token):
|
||||
logger.warning("Unauthorized access attempt to attention-keys")
|
||||
return {"error": "Unauthorized"}, 401
|
||||
|
||||
# 支持所有标准HTTP状态码范围
|
||||
# if not isinstance(status_code, int) or status_code < 100 or status_code > 599:
|
||||
# return {"error": f"Unsupported status_code: {status_code}"}, 400
|
||||
|
||||
key_manager = await get_key_manager_instance()
|
||||
keys_status = await key_manager.get_keys_by_status()
|
||||
in_memory_keys = set(keys_status.get("valid_keys", [])) | set(
|
||||
keys_status.get("invalid_keys", [])
|
||||
)
|
||||
stats_service = StatsService()
|
||||
data = await stats_service.get_attention_keys_last_24h(
|
||||
in_memory_keys, limit, status_code
|
||||
)
|
||||
return data
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching attention keys: {e}")
|
||||
return {"error": "Internal server error"}, 500
|
||||
|
||||
@app.get("/api/stats/key-details")
|
||||
async def api_stats_key_details(request: Request, key: str, period: str):
|
||||
"""获取指定密钥在指定时间段内的调用详情"""
|
||||
try:
|
||||
auth_token = request.cookies.get("auth_token")
|
||||
if not auth_token or not verify_auth_token(auth_token):
|
||||
logger.warning("Unauthorized access attempt to API key stats details")
|
||||
return {"error": "Unauthorized"}, 401
|
||||
|
||||
logger.info(
|
||||
f"Fetching key call details for key=...{key[-4:] if key else ''}, period: {period}"
|
||||
)
|
||||
stats_service = StatsService()
|
||||
details = await stats_service.get_key_call_details(key, period)
|
||||
return details
|
||||
except ValueError as e:
|
||||
logger.warning(
|
||||
f"Invalid period requested for key stats details: {period} - {str(e)}"
|
||||
)
|
||||
return {"error": str(e)}, 400
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error fetching key stats details for period {period}: {str(e)}"
|
||||
)
|
||||
return {"error": "Internal server error"}, 500
|
||||
|
||||
@@ -3,6 +3,7 @@ from starlette import status
|
||||
from app.core.security import verify_auth_token
|
||||
from app.service.stats.stats_service import StatsService
|
||||
from app.log.logger import get_stats_logger
|
||||
from app.utils.helpers import redact_key_for_logging
|
||||
|
||||
logger = get_stats_logger()
|
||||
|
||||
@@ -48,7 +49,7 @@ async def get_key_usage_details(key: str):
|
||||
return {}
|
||||
return usage_details
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching key usage details for key {key[:4]}...: {e}")
|
||||
logger.error(f"Error fetching key usage details for key {redact_key_for_logging(key)}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"获取密钥使用详情时出错: {e}"
|
||||
|
||||
@@ -1,16 +1,19 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from copy import deepcopy
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from app.config.config import settings
|
||||
from app.log.logger import get_vertex_express_logger
|
||||
from app.core.constants import API_VERSION
|
||||
from app.core.security import SecurityService
|
||||
from app.domain.gemini_models import GeminiRequest
|
||||
from app.handler.error_handler import handle_route_errors
|
||||
from app.handler.retry_handler import RetryHandler
|
||||
from app.log.logger import get_vertex_express_logger
|
||||
from app.service.chat.vertex_express_chat_service import GeminiChatService
|
||||
from app.service.key.key_manager import KeyManager, get_key_manager_instance
|
||||
from app.service.model.model_service import ModelService
|
||||
from app.handler.retry_handler import RetryHandler
|
||||
from app.handler.error_handler import handle_route_errors
|
||||
from app.core.constants import API_VERSION
|
||||
from app.utils.helpers import redact_key_for_logging
|
||||
|
||||
router = APIRouter(prefix=f"/vertex-express/{API_VERSION}")
|
||||
logger = get_vertex_express_logger()
|
||||
@@ -36,8 +39,8 @@ async def get_chat_service(key_manager: KeyManager = Depends(get_key_manager)):
|
||||
|
||||
@router.get("/models")
|
||||
async def list_models(
|
||||
_=Depends(security_service.verify_key_or_goog_api_key),
|
||||
key_manager: KeyManager = Depends(get_key_manager)
|
||||
allowed_token=Depends(security_service.verify_key_or_goog_api_key),
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
):
|
||||
"""获取可用的 Gemini 模型列表,并根据配置添加衍生模型(搜索、图像、非思考)。"""
|
||||
operation_name = "list_gemini_models"
|
||||
@@ -45,22 +48,32 @@ async def list_models(
|
||||
logger.info("Handling Gemini models list request")
|
||||
|
||||
try:
|
||||
api_key = await key_manager.get_first_valid_key()
|
||||
api_key = await key_manager.get_random_valid_key()
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=503, detail="No valid API keys available to fetch models.")
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
raise HTTPException(
|
||||
status_code=503, detail="No valid API keys available to fetch models."
|
||||
)
|
||||
logger.info(f"Using allowed token: {allowed_token}")
|
||||
logger.info(f"Using API key: {redact_key_for_logging(api_key)}")
|
||||
|
||||
models_data = await model_service.get_gemini_models(api_key)
|
||||
if not models_data or "models" not in models_data:
|
||||
raise HTTPException(status_code=500, detail="Failed to fetch base models list.")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to fetch base models list."
|
||||
)
|
||||
|
||||
models_json = deepcopy(models_data)
|
||||
model_mapping = {x.get("name", "").split("/", maxsplit=1)[-1]: x for x in models_json.get("models", [])}
|
||||
model_mapping = {
|
||||
x.get("name", "").split("/", maxsplit=1)[-1]: x
|
||||
for x in models_json.get("models", [])
|
||||
}
|
||||
|
||||
def add_derived_model(base_name, suffix, display_suffix):
|
||||
model = model_mapping.get(base_name)
|
||||
if not model:
|
||||
logger.warning(f"Base model '{base_name}' not found for derived model '{suffix}'.")
|
||||
logger.warning(
|
||||
f"Base model '{base_name}' not found for derived model '{suffix}'."
|
||||
)
|
||||
return
|
||||
item = deepcopy(model)
|
||||
item["name"] = f"models/{base_name}{suffix}"
|
||||
@@ -74,7 +87,7 @@ async def list_models(
|
||||
add_derived_model(name, "-search", " For Search")
|
||||
if settings.IMAGE_MODELS:
|
||||
for name in settings.IMAGE_MODELS:
|
||||
add_derived_model(name, "-image", " For Image")
|
||||
add_derived_model(name, "-image", " For Image")
|
||||
if settings.THINKING_MODELS:
|
||||
for name in settings.THINKING_MODELS:
|
||||
add_derived_model(name, "-non-thinking", " Non Thinking")
|
||||
@@ -86,7 +99,8 @@ async def list_models(
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting Gemini models list: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Internal server error while fetching Gemini models list"
|
||||
status_code=500,
|
||||
detail="Internal server error while fetching Gemini models list",
|
||||
) from e
|
||||
|
||||
|
||||
@@ -95,25 +109,30 @@ async def list_models(
|
||||
async def generate_content(
|
||||
model_name: str,
|
||||
request: GeminiRequest,
|
||||
_=Depends(security_service.verify_key_or_goog_api_key),
|
||||
allowed_token=Depends(security_service.verify_key_or_goog_api_key),
|
||||
api_key: str = Depends(get_next_working_key),
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
chat_service: GeminiChatService = Depends(get_chat_service)
|
||||
chat_service: GeminiChatService = Depends(get_chat_service),
|
||||
):
|
||||
"""处理 Gemini 非流式内容生成请求。"""
|
||||
operation_name = "gemini_generate_content"
|
||||
async with handle_route_errors(logger, operation_name, failure_message="Content generation failed"):
|
||||
logger.info(f"Handling Gemini content generation request for model: {model_name}")
|
||||
async with handle_route_errors(
|
||||
logger, operation_name, failure_message="Content generation failed"
|
||||
):
|
||||
logger.info(
|
||||
f"Handling Gemini content generation request for model: {model_name}"
|
||||
)
|
||||
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
logger.info(f"Using allowed token: {allowed_token}")
|
||||
logger.info(f"Using API key: {redact_key_for_logging(api_key)}")
|
||||
|
||||
if not await model_service.check_model_support(model_name):
|
||||
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Model {model_name} is not supported"
|
||||
)
|
||||
|
||||
response = await chat_service.generate_content(
|
||||
model=model_name,
|
||||
request=request,
|
||||
api_key=api_key
|
||||
model=model_name, request=request, api_key=api_key
|
||||
)
|
||||
return response
|
||||
|
||||
@@ -123,24 +142,50 @@ async def generate_content(
|
||||
async def stream_generate_content(
|
||||
model_name: str,
|
||||
request: GeminiRequest,
|
||||
_=Depends(security_service.verify_key_or_goog_api_key),
|
||||
allowed_token=Depends(security_service.verify_key_or_goog_api_key),
|
||||
api_key: str = Depends(get_next_working_key),
|
||||
key_manager: KeyManager = Depends(get_key_manager),
|
||||
chat_service: GeminiChatService = Depends(get_chat_service)
|
||||
chat_service: GeminiChatService = Depends(get_chat_service),
|
||||
):
|
||||
"""处理 Gemini 流式内容生成请求。"""
|
||||
operation_name = "gemini_stream_generate_content"
|
||||
async with handle_route_errors(logger, operation_name, failure_message="Streaming request initiation failed"):
|
||||
logger.info(f"Handling Gemini streaming content generation for model: {model_name}")
|
||||
async with handle_route_errors(
|
||||
logger, operation_name, failure_message="Streaming request initiation failed"
|
||||
):
|
||||
logger.info(
|
||||
f"Handling Gemini streaming content generation for model: {model_name}"
|
||||
)
|
||||
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
||||
logger.info(f"Using API key: {api_key}")
|
||||
logger.info(f"Using allowed token: {allowed_token}")
|
||||
logger.info(f"Using API key: {redact_key_for_logging(api_key)}")
|
||||
|
||||
if not await model_service.check_model_support(model_name):
|
||||
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Model {model_name} is not supported"
|
||||
)
|
||||
|
||||
response_stream = chat_service.stream_generate_content(
|
||||
model=model_name,
|
||||
request=request,
|
||||
api_key=api_key
|
||||
raw_stream = chat_service.stream_generate_content(
|
||||
model=model_name, request=request, api_key=api_key
|
||||
)
|
||||
return StreamingResponse(response_stream, media_type="text/event-stream")
|
||||
try:
|
||||
# 尝试获取第一条数据,判断是正常 SSE(data: 前缀)还是错误 JSON
|
||||
first_chunk = await raw_stream.__anext__()
|
||||
except StopAsyncIteration:
|
||||
# 如果流直接结束,退回标准 SSE 输出
|
||||
return StreamingResponse(raw_stream, media_type="text/event-stream")
|
||||
except Exception as e:
|
||||
# 初始化流异常,直接返回 500 错误
|
||||
return JSONResponse(
|
||||
content={"error": {"code": e.args[0], "message": e.args[1]}},
|
||||
status_code=e.args[0],
|
||||
)
|
||||
|
||||
# 如果以 "data:" 开头,代表正常 SSE,将首块和后续块一起发送
|
||||
if isinstance(first_chunk, str) and first_chunk.startswith("data:"):
|
||||
|
||||
async def combined():
|
||||
yield first_chunk
|
||||
async for chunk in raw_stream:
|
||||
yield chunk
|
||||
|
||||
return StreamingResponse(combined(), media_type="text/event-stream")
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
|
||||
from app.config.config import settings
|
||||
@@ -6,8 +5,10 @@ from app.domain.gemini_models import GeminiContent, GeminiRequest
|
||||
from app.log.logger import Logger
|
||||
from app.service.chat.gemini_chat_service import GeminiChatService
|
||||
from app.service.error_log.error_log_service import delete_old_error_logs
|
||||
from app.service.files.files_service import get_files_service
|
||||
from app.service.key.key_manager import get_key_manager_instance
|
||||
from app.service.request_log.request_log_service import delete_old_request_logs_task
|
||||
from app.utils.helpers import redact_key_for_logging
|
||||
|
||||
logger = Logger.setup_logger("scheduler")
|
||||
|
||||
@@ -50,7 +51,7 @@ async def check_failed_keys():
|
||||
|
||||
for key in keys_to_check:
|
||||
# 隐藏部分 key 用于日志记录
|
||||
log_key = f"{key[:4]}...{key[-4:]}" if len(key) > 8 else key
|
||||
log_key = redact_key_for_logging(key)
|
||||
logger.info(f"Verifying key: {log_key}...")
|
||||
try:
|
||||
# 构造测试请求
|
||||
@@ -96,38 +97,60 @@ async def check_failed_keys():
|
||||
)
|
||||
|
||||
|
||||
async def cleanup_expired_files():
|
||||
"""
|
||||
定时清理过期的文件记录
|
||||
"""
|
||||
logger.info("Starting scheduled cleanup for expired files...")
|
||||
try:
|
||||
files_service = await get_files_service()
|
||||
deleted_count = await files_service.cleanup_expired_files()
|
||||
|
||||
if deleted_count > 0:
|
||||
logger.info(f"Successfully cleaned up {deleted_count} expired files.")
|
||||
else:
|
||||
logger.info("No expired files to clean up.")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"An error occurred during the scheduled file cleanup: {str(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
def setup_scheduler():
|
||||
"""设置并启动 APScheduler"""
|
||||
scheduler = AsyncIOScheduler(timezone=str(settings.TIMEZONE)) # 从配置读取时区
|
||||
# 添加检查失败密钥的定时任务
|
||||
scheduler.add_job(
|
||||
check_failed_keys,
|
||||
"interval",
|
||||
hours=settings.CHECK_INTERVAL_HOURS,
|
||||
id="check_failed_keys_job",
|
||||
name="Check Failed API Keys",
|
||||
)
|
||||
logger.info(
|
||||
f"Key check job scheduled to run every {settings.CHECK_INTERVAL_HOURS} hour(s)."
|
||||
)
|
||||
if settings.CHECK_INTERVAL_HOURS != 0:
|
||||
scheduler.add_job(
|
||||
check_failed_keys,
|
||||
"interval",
|
||||
hours=settings.CHECK_INTERVAL_HOURS,
|
||||
id="check_failed_keys_job",
|
||||
name="Check Failed API Keys",
|
||||
)
|
||||
logger.info(
|
||||
f"Key check job scheduled to run every {settings.CHECK_INTERVAL_HOURS} hour(s)."
|
||||
)
|
||||
|
||||
# 新增:添加自动删除错误日志的定时任务,每天凌晨3点执行
|
||||
# 新增:添加自动删除错误日志的定时任务,每天凌晨0点执行
|
||||
scheduler.add_job(
|
||||
delete_old_error_logs,
|
||||
"cron",
|
||||
hour=3,
|
||||
hour=0,
|
||||
minute=0,
|
||||
id="delete_old_error_logs_job",
|
||||
name="Delete Old Error Logs",
|
||||
)
|
||||
logger.info("Auto-delete error logs job scheduled to run daily at 3:00 AM.")
|
||||
|
||||
# 新增:添加自动删除请求日志的定时任务,每天凌晨3点05分执行
|
||||
# 新增:添加自动删除请求日志的定时任务,每天凌晨0点执行
|
||||
scheduler.add_job(
|
||||
delete_old_request_logs_task,
|
||||
"cron",
|
||||
hour=3,
|
||||
minute=5,
|
||||
hour=0,
|
||||
minute=0,
|
||||
id="delete_old_request_logs_job",
|
||||
name="Delete Old Request Logs",
|
||||
)
|
||||
@@ -135,6 +158,20 @@ def setup_scheduler():
|
||||
f"Auto-delete request logs job scheduled to run daily at 3:05 AM, if enabled and AUTO_DELETE_REQUEST_LOGS_DAYS is set to {settings.AUTO_DELETE_REQUEST_LOGS_DAYS} days."
|
||||
)
|
||||
|
||||
# 新增:添加文件过期清理的定时任务,每小时执行一次
|
||||
if getattr(settings, "FILES_CLEANUP_ENABLED", True):
|
||||
cleanup_interval = getattr(settings, "FILES_CLEANUP_INTERVAL_HOURS", 1)
|
||||
scheduler.add_job(
|
||||
cleanup_expired_files,
|
||||
"interval",
|
||||
hours=cleanup_interval,
|
||||
id="cleanup_expired_files_job",
|
||||
name="Cleanup Expired Files",
|
||||
)
|
||||
logger.info(
|
||||
f"File cleanup job scheduled to run every {cleanup_interval} hour(s)."
|
||||
)
|
||||
|
||||
scheduler.start()
|
||||
logger.info("Scheduler started with all jobs.")
|
||||
return scheduler
|
||||
|
||||
@@ -1,19 +1,21 @@
|
||||
# app/services/chat_service.py
|
||||
|
||||
import datetime
|
||||
import json
|
||||
import re
|
||||
import datetime
|
||||
import time
|
||||
from typing import Any, AsyncGenerator, Dict, List
|
||||
|
||||
from app.config.config import settings
|
||||
from app.core.constants import GEMINI_2_FLASH_EXP_SAFETY_SETTINGS
|
||||
from app.database.services import add_error_log, add_request_log, get_file_api_key
|
||||
from app.domain.gemini_models import GeminiRequest
|
||||
from app.handler.response_handler import GeminiResponseHandler
|
||||
from app.handler.stream_optimizer import gemini_optimizer
|
||||
from app.log.logger import get_gemini_logger
|
||||
from app.service.client.api_client import GeminiApiClient
|
||||
from app.service.key.key_manager import KeyManager
|
||||
from app.database.services import add_error_log, add_request_log
|
||||
from app.utils.helpers import redact_key_for_logging
|
||||
|
||||
logger = get_gemini_logger()
|
||||
|
||||
@@ -28,9 +30,93 @@ def _has_image_parts(contents: List[Dict[str, Any]]) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _extract_file_references(contents: List[Dict[str, Any]]) -> List[str]:
|
||||
"""從內容中提取文件引用"""
|
||||
file_names = []
|
||||
for content in contents:
|
||||
if "parts" in content:
|
||||
for part in content["parts"]:
|
||||
if not isinstance(part, dict) or "fileData" not in part:
|
||||
continue
|
||||
file_data = part["fileData"]
|
||||
if "fileUri" not in file_data:
|
||||
continue
|
||||
file_uri = file_data["fileUri"]
|
||||
# 從 URI 中提取文件名
|
||||
# 1. https://generativelanguage.googleapis.com/v1beta/files/{file_id}
|
||||
match = re.match(
|
||||
rf"{re.escape(settings.BASE_URL)}/(files/.*)", file_uri
|
||||
)
|
||||
if not match:
|
||||
logger.warning(f"Invalid file URI: {file_uri}")
|
||||
continue
|
||||
file_id = match.group(1)
|
||||
file_names.append(file_id)
|
||||
logger.info(f"Found file reference: {file_id}")
|
||||
return file_names
|
||||
|
||||
|
||||
def _clean_json_schema_properties(obj: Any) -> Any:
|
||||
"""清理JSON Schema中Gemini API不支持的字段"""
|
||||
if not isinstance(obj, dict):
|
||||
return obj
|
||||
|
||||
# Gemini API不支持的JSON Schema字段
|
||||
unsupported_fields = {
|
||||
"exclusiveMaximum",
|
||||
"exclusiveMinimum",
|
||||
"const",
|
||||
"examples",
|
||||
"contentEncoding",
|
||||
"contentMediaType",
|
||||
"if",
|
||||
"then",
|
||||
"else",
|
||||
"allOf",
|
||||
"anyOf",
|
||||
"oneOf",
|
||||
"not",
|
||||
"definitions",
|
||||
"$schema",
|
||||
"$id",
|
||||
"$ref",
|
||||
"$comment",
|
||||
"readOnly",
|
||||
"writeOnly",
|
||||
}
|
||||
|
||||
cleaned = {}
|
||||
for key, value in obj.items():
|
||||
if key in unsupported_fields:
|
||||
continue
|
||||
if isinstance(value, dict):
|
||||
cleaned[key] = _clean_json_schema_properties(value)
|
||||
elif isinstance(value, list):
|
||||
cleaned[key] = [_clean_json_schema_properties(item) for item in value]
|
||||
else:
|
||||
cleaned[key] = value
|
||||
|
||||
return cleaned
|
||||
|
||||
|
||||
def _build_tools(model: str, payload: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""构建工具"""
|
||||
|
||||
|
||||
def _has_function_call(contents: List[Dict[str, Any]]) -> bool:
|
||||
"""检查内容中是否包含 functionCall"""
|
||||
if not contents or not isinstance(contents, list):
|
||||
return False
|
||||
for content in contents:
|
||||
if not content or not isinstance(content, dict) or "parts" not in content:
|
||||
continue
|
||||
parts = content.get("parts", [])
|
||||
if not parts or not isinstance(parts, list):
|
||||
continue
|
||||
for part in parts:
|
||||
if isinstance(part, dict) and "functionCall" in part:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _merge_tools(tools: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
record = dict()
|
||||
for item in tools:
|
||||
@@ -40,12 +126,28 @@ def _build_tools(model: str, payload: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
for k, v in item.items():
|
||||
if k == "functionDeclarations" and v and isinstance(v, list):
|
||||
functions = record.get("functionDeclarations", [])
|
||||
functions.extend(v)
|
||||
# 清理每个函数声明中的不支持字段
|
||||
cleaned_functions = []
|
||||
for func in v:
|
||||
if isinstance(func, dict):
|
||||
cleaned_func = _clean_json_schema_properties(func)
|
||||
cleaned_functions.append(cleaned_func)
|
||||
else:
|
||||
cleaned_functions.append(func)
|
||||
functions.extend(cleaned_functions)
|
||||
record["functionDeclarations"] = functions
|
||||
else:
|
||||
record[k] = v
|
||||
return record
|
||||
|
||||
def _is_structured_output_request(payload: Dict[str, Any]) -> bool:
|
||||
"""检查请求是否要求结构化JSON输出"""
|
||||
try:
|
||||
generation_config = payload.get("generationConfig", {})
|
||||
return generation_config.get("responseMimeType") == "application/json"
|
||||
except (AttributeError, TypeError):
|
||||
return False
|
||||
|
||||
tool = dict()
|
||||
if payload and isinstance(payload, dict) and "tools" in payload:
|
||||
if payload.get("tools") and isinstance(payload.get("tools"), dict):
|
||||
@@ -54,23 +156,48 @@ def _build_tools(model: str, payload: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
if items and isinstance(items, list):
|
||||
tool.update(_merge_tools(items))
|
||||
|
||||
if (
|
||||
settings.TOOLS_CODE_EXECUTION_ENABLED
|
||||
and not (model.endswith("-search") or "-thinking" in model)
|
||||
and not _has_image_parts(payload.get("contents", []))
|
||||
):
|
||||
tool["codeExecution"] = {}
|
||||
if model.endswith("-search"):
|
||||
tool["googleSearch"] = {}
|
||||
# "Tool use with a response mime type: 'application/json' is unsupported"
|
||||
# Gemini API限制:不支持同时使用tools和结构化输出(response_mime_type='application/json')
|
||||
# 当请求指定了JSON响应格式时,跳过所有工具的添加以避免API错误
|
||||
has_structured_output = _is_structured_output_request(payload)
|
||||
if not has_structured_output:
|
||||
if (
|
||||
settings.TOOLS_CODE_EXECUTION_ENABLED
|
||||
and not (model.endswith("-search") or "-thinking" in model)
|
||||
and not _has_image_parts(payload.get("contents", []))
|
||||
):
|
||||
tool["codeExecution"] = {}
|
||||
|
||||
if model.endswith("-search"):
|
||||
tool["googleSearch"] = {}
|
||||
|
||||
real_model = _get_real_model(model)
|
||||
if real_model in settings.URL_CONTEXT_MODELS and settings.URL_CONTEXT_ENABLED:
|
||||
tool["urlContext"] = {}
|
||||
|
||||
# 解决 "Tool use with function calling is unsupported" 问题
|
||||
if tool.get("functionDeclarations"):
|
||||
if tool.get("functionDeclarations") or _has_function_call(
|
||||
payload.get("contents", [])
|
||||
):
|
||||
tool.pop("googleSearch", None)
|
||||
tool.pop("codeExecution", None)
|
||||
tool.pop("urlContext", None)
|
||||
|
||||
return [tool] if tool else []
|
||||
|
||||
|
||||
def _get_real_model(model: str) -> str:
|
||||
if model.endswith("-search"):
|
||||
model = model[:-7]
|
||||
if model.endswith("-image"):
|
||||
model = model[:-6]
|
||||
if model.endswith("-non-thinking"):
|
||||
model = model[:-13]
|
||||
if "-search" in model and "-non-thinking" in model:
|
||||
model = model[:-20]
|
||||
return model
|
||||
|
||||
|
||||
def _get_safety_settings(model: str) -> List[Dict[str, str]]:
|
||||
"""获取安全设置"""
|
||||
if model == "gemini-2.0-flash-exp":
|
||||
@@ -78,40 +205,97 @@ def _get_safety_settings(model: str) -> List[Dict[str, str]]:
|
||||
return settings.SAFETY_SETTINGS
|
||||
|
||||
|
||||
def _filter_empty_parts(contents: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Filters out contents with empty or invalid parts."""
|
||||
if not contents:
|
||||
return []
|
||||
|
||||
filtered_contents = []
|
||||
for content in contents:
|
||||
if (
|
||||
not content
|
||||
or "parts" not in content
|
||||
or not isinstance(content.get("parts"), list)
|
||||
):
|
||||
continue
|
||||
|
||||
valid_parts = [
|
||||
part for part in content["parts"] if isinstance(part, dict) and part
|
||||
]
|
||||
|
||||
if valid_parts:
|
||||
new_content = content.copy()
|
||||
new_content["parts"] = valid_parts
|
||||
filtered_contents.append(new_content)
|
||||
|
||||
return filtered_contents
|
||||
|
||||
|
||||
def _build_payload(model: str, request: GeminiRequest) -> Dict[str, Any]:
|
||||
"""构建请求payload"""
|
||||
request_dict = request.model_dump()
|
||||
request_dict = request.model_dump(exclude_none=False)
|
||||
if request.generationConfig:
|
||||
if request.generationConfig.maxOutputTokens is None:
|
||||
# 如果未指定最大输出长度,则不传递该字段,解决截断的问题
|
||||
request_dict["generationConfig"].pop("maxOutputTokens")
|
||||
|
||||
payload = {
|
||||
"contents": request_dict.get("contents", []),
|
||||
"tools": _build_tools(model, request_dict),
|
||||
"safetySettings": _get_safety_settings(model),
|
||||
"generationConfig": request_dict.get("generationConfig"),
|
||||
"systemInstruction": request_dict.get("systemInstruction"),
|
||||
}
|
||||
if "maxOutputTokens" in request_dict["generationConfig"]:
|
||||
request_dict["generationConfig"].pop("maxOutputTokens")
|
||||
|
||||
# 检查是否为TTS模型
|
||||
is_tts_model = "tts" in model.lower()
|
||||
|
||||
if is_tts_model:
|
||||
# TTS模型使用简化的payload,不包含tools和safetySettings
|
||||
payload = {
|
||||
"contents": _filter_empty_parts(request_dict.get("contents", [])),
|
||||
"generationConfig": request_dict.get("generationConfig"),
|
||||
}
|
||||
|
||||
# 只在有systemInstruction时才添加
|
||||
if request_dict.get("systemInstruction"):
|
||||
payload["systemInstruction"] = request_dict.get("systemInstruction")
|
||||
else:
|
||||
# 非TTS模型使用完整的payload
|
||||
payload = {
|
||||
"contents": _filter_empty_parts(request_dict.get("contents", [])),
|
||||
"tools": _build_tools(model, request_dict),
|
||||
"safetySettings": _get_safety_settings(model),
|
||||
"generationConfig": request_dict.get("generationConfig"),
|
||||
"systemInstruction": request_dict.get("systemInstruction"),
|
||||
}
|
||||
|
||||
# 确保 generationConfig 不为 None
|
||||
if payload["generationConfig"] is None:
|
||||
payload["generationConfig"] = {}
|
||||
|
||||
if model.endswith("-image") or model.endswith("-image-generation"):
|
||||
payload.pop("systemInstruction")
|
||||
payload["generationConfig"]["responseModalities"] = ["Text", "Image"]
|
||||
|
||||
|
||||
# 处理思考配置:优先使用客户端提供的配置,否则使用默认配置
|
||||
client_thinking_config = None
|
||||
if request.generationConfig and request.generationConfig.thinkingConfig:
|
||||
client_thinking_config = request.generationConfig.thinkingConfig
|
||||
|
||||
|
||||
if client_thinking_config is not None:
|
||||
# 客户端提供了思考配置,直接使用
|
||||
payload["generationConfig"]["thinkingConfig"] = client_thinking_config
|
||||
else:
|
||||
# 客户端没有提供思考配置,使用默认配置
|
||||
# 客户端没有提供思考配置,使用默认配置
|
||||
if model.endswith("-non-thinking"):
|
||||
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": 0}
|
||||
elif model in settings.THINKING_BUDGET_MAP:
|
||||
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": settings.THINKING_BUDGET_MAP.get(model,1000)}
|
||||
if "gemini-2.5-pro" in model:
|
||||
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": 128}
|
||||
else:
|
||||
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": 0}
|
||||
elif _get_real_model(model) in settings.THINKING_BUDGET_MAP:
|
||||
if settings.SHOW_THINKING_PROCESS:
|
||||
payload["generationConfig"]["thinkingConfig"] = {
|
||||
"thinkingBudget": settings.THINKING_BUDGET_MAP.get(model, 1000),
|
||||
"includeThoughts": True,
|
||||
}
|
||||
else:
|
||||
payload["generationConfig"]["thinkingConfig"] = {
|
||||
"thinkingBudget": settings.THINKING_BUDGET_MAP.get(model, 1000)
|
||||
}
|
||||
|
||||
return payload
|
||||
|
||||
@@ -152,6 +336,21 @@ class GeminiChatService:
|
||||
self, model: str, request: GeminiRequest, api_key: str
|
||||
) -> Dict[str, Any]:
|
||||
"""生成内容"""
|
||||
# 檢查並獲取文件專用的 API key(如果有文件)
|
||||
file_names = _extract_file_references(request.model_dump().get("contents", []))
|
||||
if file_names:
|
||||
logger.info(f"Request contains file references: {file_names}")
|
||||
file_api_key = await get_file_api_key(file_names[0])
|
||||
if file_api_key:
|
||||
logger.info(
|
||||
f"Found API key for file {file_names[0]}: {redact_key_for_logging(file_api_key)}"
|
||||
)
|
||||
api_key = file_api_key # 使用文件的 API key
|
||||
else:
|
||||
logger.warning(
|
||||
f"No API key found for file {file_names[0]}, using default key: {redact_key_for_logging(api_key)}"
|
||||
)
|
||||
|
||||
payload = _build_payload(model, request)
|
||||
start_time = time.perf_counter()
|
||||
request_datetime = datetime.datetime.now()
|
||||
@@ -166,13 +365,9 @@ class GeminiChatService:
|
||||
return self.response_handler.handle_response(response, model, stream=False)
|
||||
except Exception as e:
|
||||
is_success = False
|
||||
error_log_msg = str(e)
|
||||
status_code = e.args[0]
|
||||
error_log_msg = e.args[1]
|
||||
logger.error(f"Normal API call failed with error: {error_log_msg}")
|
||||
match = re.search(r"status code (\d+)", error_log_msg)
|
||||
if match:
|
||||
status_code = int(match.group(1))
|
||||
else:
|
||||
status_code = 500
|
||||
|
||||
await add_error_log(
|
||||
gemini_key=api_key,
|
||||
@@ -180,7 +375,8 @@ class GeminiChatService:
|
||||
error_type="gemini-chat-non-stream",
|
||||
error_log=error_log_msg,
|
||||
error_code=status_code,
|
||||
request_msg=payload
|
||||
request_msg=payload if settings.ERROR_LOG_RECORD_REQUEST_BODY else None,
|
||||
request_datetime=request_datetime,
|
||||
)
|
||||
raise e
|
||||
finally:
|
||||
@@ -192,13 +388,74 @@ class GeminiChatService:
|
||||
is_success=is_success,
|
||||
status_code=status_code,
|
||||
latency_ms=latency_ms,
|
||||
request_time=request_datetime
|
||||
request_time=request_datetime,
|
||||
)
|
||||
|
||||
async def count_tokens(
|
||||
self, model: str, request: GeminiRequest, api_key: str
|
||||
) -> Dict[str, Any]:
|
||||
"""计算token数量"""
|
||||
# countTokens API只需要contents
|
||||
payload = {
|
||||
"contents": _filter_empty_parts(request.model_dump().get("contents", []))
|
||||
}
|
||||
start_time = time.perf_counter()
|
||||
request_datetime = datetime.datetime.now()
|
||||
is_success = False
|
||||
status_code = None
|
||||
response = None
|
||||
|
||||
try:
|
||||
response = await self.api_client.count_tokens(payload, model, api_key)
|
||||
is_success = True
|
||||
status_code = 200
|
||||
return response
|
||||
except Exception as e:
|
||||
is_success = False
|
||||
status_code = e.args[0]
|
||||
error_log_msg = e.args[1]
|
||||
logger.error(f"Count tokens API call failed with error: {error_log_msg}")
|
||||
|
||||
await add_error_log(
|
||||
gemini_key=api_key,
|
||||
model_name=model,
|
||||
error_type="gemini-count-tokens",
|
||||
error_log=error_log_msg,
|
||||
error_code=status_code,
|
||||
request_msg=payload if settings.ERROR_LOG_RECORD_REQUEST_BODY else None,
|
||||
)
|
||||
raise e
|
||||
finally:
|
||||
end_time = time.perf_counter()
|
||||
latency_ms = int((end_time - start_time) * 1000)
|
||||
await add_request_log(
|
||||
model_name=model,
|
||||
api_key=api_key,
|
||||
is_success=is_success,
|
||||
status_code=status_code,
|
||||
latency_ms=latency_ms,
|
||||
request_time=request_datetime,
|
||||
)
|
||||
|
||||
async def stream_generate_content(
|
||||
self, model: str, request: GeminiRequest, api_key: str
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""流式生成内容"""
|
||||
# 檢查並獲取文件專用的 API key(如果有文件)
|
||||
file_names = _extract_file_references(request.model_dump().get("contents", []))
|
||||
if file_names:
|
||||
logger.info(f"Request contains file references: {file_names}")
|
||||
file_api_key = await get_file_api_key(file_names[0])
|
||||
if file_api_key:
|
||||
logger.info(
|
||||
f"Found API key for file {file_names[0]}: {redact_key_for_logging(file_api_key)}"
|
||||
)
|
||||
api_key = file_api_key # 使用文件的 API key
|
||||
else:
|
||||
logger.warning(
|
||||
f"No API key found for file {file_names[0]}, using default key: {redact_key_for_logging(api_key)}"
|
||||
)
|
||||
|
||||
retries = 0
|
||||
max_retries = settings.MAX_RETRIES
|
||||
payload = _build_payload(model, request)
|
||||
@@ -243,15 +500,11 @@ class GeminiChatService:
|
||||
except Exception as e:
|
||||
retries += 1
|
||||
is_success = False
|
||||
error_log_msg = str(e)
|
||||
status_code = e.args[0]
|
||||
error_log_msg = e.args[1]
|
||||
logger.warning(
|
||||
f"Streaming API call failed with error: {error_log_msg}. Attempt {retries} of {max_retries}"
|
||||
)
|
||||
match = re.search(r"status code (\d+)", error_log_msg)
|
||||
if match:
|
||||
status_code = int(match.group(1))
|
||||
else:
|
||||
status_code = 500
|
||||
|
||||
await add_error_log(
|
||||
gemini_key=current_attempt_key,
|
||||
@@ -259,21 +512,26 @@ class GeminiChatService:
|
||||
error_type="gemini-chat-stream",
|
||||
error_log=error_log_msg,
|
||||
error_code=status_code,
|
||||
request_msg=payload
|
||||
request_msg=(
|
||||
payload if settings.ERROR_LOG_RECORD_REQUEST_BODY else None
|
||||
),
|
||||
request_datetime=request_datetime,
|
||||
)
|
||||
|
||||
api_key = await self.key_manager.handle_api_failure(current_attempt_key, retries)
|
||||
api_key = await self.key_manager.handle_api_failure(
|
||||
current_attempt_key, retries
|
||||
)
|
||||
if api_key:
|
||||
logger.info(f"Switched to new API key: {api_key}")
|
||||
logger.info(
|
||||
f"Switched to new API key: {redact_key_for_logging(api_key)}"
|
||||
)
|
||||
else:
|
||||
logger.error(f"No valid API key available after {retries} retries.")
|
||||
break
|
||||
raise
|
||||
|
||||
if retries >= max_retries:
|
||||
logger.error(
|
||||
f"Max retries ({max_retries}) reached for streaming."
|
||||
)
|
||||
break
|
||||
logger.error(f"Max retries ({max_retries}) reached for streaming.")
|
||||
raise
|
||||
finally:
|
||||
end_time = time.perf_counter()
|
||||
latency_ms = int((end_time - start_time) * 1000)
|
||||
@@ -283,5 +541,5 @@ class GeminiChatService:
|
||||
is_success=is_success,
|
||||
status_code=status_code,
|
||||
latency_ms=latency_ms,
|
||||
request_time=request_datetime
|
||||
request_time=request_datetime,
|
||||
)
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
import asyncio
|
||||
import datetime
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from copy import deepcopy
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
||||
@@ -26,16 +25,59 @@ from app.service.key.key_manager import KeyManager
|
||||
logger = get_openai_logger()
|
||||
|
||||
|
||||
def _has_media_parts(contents: List[Dict[str, Any]]) -> bool:
|
||||
"""判断消息是否包含图片、音频或视频部分 (inline_data)"""
|
||||
for content in contents:
|
||||
if content and "parts" in content and isinstance(content["parts"], list):
|
||||
for part in content["parts"]:
|
||||
if isinstance(part, dict) and "inline_data" in part:
|
||||
def _has_media_parts(messages: List[Dict[str, Any]]) -> bool:
|
||||
"""判断消息是否包含多媒体部分"""
|
||||
for message in messages:
|
||||
if "parts" in message:
|
||||
for part in message["parts"]:
|
||||
if "image_url" in part or "inline_data" in part:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _clean_json_schema_properties(obj: Any) -> Any:
|
||||
"""清理JSON Schema中Gemini API不支持的字段"""
|
||||
if not isinstance(obj, dict):
|
||||
return obj
|
||||
|
||||
# Gemini API不支持的JSON Schema字段
|
||||
unsupported_fields = {
|
||||
"exclusiveMaximum",
|
||||
"exclusiveMinimum",
|
||||
"const",
|
||||
"examples",
|
||||
"contentEncoding",
|
||||
"contentMediaType",
|
||||
"if",
|
||||
"then",
|
||||
"else",
|
||||
"allOf",
|
||||
"anyOf",
|
||||
"oneOf",
|
||||
"not",
|
||||
"definitions",
|
||||
"$schema",
|
||||
"$id",
|
||||
"$ref",
|
||||
"$comment",
|
||||
"readOnly",
|
||||
"writeOnly",
|
||||
}
|
||||
|
||||
cleaned = {}
|
||||
for key, value in obj.items():
|
||||
if key in unsupported_fields:
|
||||
continue
|
||||
if isinstance(value, dict):
|
||||
cleaned[key] = _clean_json_schema_properties(value)
|
||||
elif isinstance(value, list):
|
||||
cleaned[key] = [_clean_json_schema_properties(item) for item in value]
|
||||
else:
|
||||
cleaned[key] = value
|
||||
|
||||
return cleaned
|
||||
|
||||
|
||||
def _build_tools(
|
||||
request: ChatRequest, messages: List[Dict[str, Any]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
@@ -61,6 +103,10 @@ def _build_tools(
|
||||
if model.endswith("-search"):
|
||||
tool["googleSearch"] = {}
|
||||
|
||||
real_model = _get_real_model(model)
|
||||
if real_model in settings.URL_CONTEXT_MODELS and settings.URL_CONTEXT_ENABLED:
|
||||
tool["urlContext"] = {}
|
||||
|
||||
# 将 request 中的 tools 合并到 tools 中
|
||||
if request.tools:
|
||||
function_declarations = []
|
||||
@@ -76,6 +122,8 @@ def _build_tools(
|
||||
):
|
||||
function.pop("parameters", None)
|
||||
|
||||
# 清理函数中的不支持字段
|
||||
function = _clean_json_schema_properties(function)
|
||||
function_declarations.append(function)
|
||||
|
||||
if function_declarations:
|
||||
@@ -83,7 +131,7 @@ def _build_tools(
|
||||
names, functions = set(), []
|
||||
for fc in function_declarations:
|
||||
if fc.get("name") not in names:
|
||||
if fc.get("name")=="googleSearch":
|
||||
if fc.get("name") == "googleSearch":
|
||||
# cherry开启内置搜索时,添加googleSearch工具
|
||||
tool["googleSearch"] = {}
|
||||
else:
|
||||
@@ -97,10 +145,23 @@ def _build_tools(
|
||||
if tool.get("functionDeclarations"):
|
||||
tool.pop("googleSearch", None)
|
||||
tool.pop("codeExecution", None)
|
||||
tool.pop("urlContext", None)
|
||||
|
||||
return [tool] if tool else []
|
||||
|
||||
|
||||
def _get_real_model(model: str) -> str:
|
||||
if model.endswith("-search"):
|
||||
model = model[:-7]
|
||||
if model.endswith("-image"):
|
||||
model = model[:-6]
|
||||
if model.endswith("-non-thinking"):
|
||||
model = model[:-13]
|
||||
if "-search" in model and "-non-thinking" in model:
|
||||
model = model[:-20]
|
||||
return model
|
||||
|
||||
|
||||
def _get_safety_settings(model: str) -> List[Dict[str, str]]:
|
||||
"""获取安全设置"""
|
||||
# if (
|
||||
@@ -113,6 +174,23 @@ def _get_safety_settings(model: str) -> List[Dict[str, str]]:
|
||||
return settings.SAFETY_SETTINGS
|
||||
|
||||
|
||||
def _validate_and_set_max_tokens(
|
||||
payload: Dict[str, Any], max_tokens: Optional[int], logger_instance
|
||||
) -> None:
|
||||
"""验证并设置 max_tokens 参数"""
|
||||
if max_tokens is None:
|
||||
return
|
||||
|
||||
# 参数验证和处理
|
||||
if max_tokens <= 0:
|
||||
logger_instance.warning(
|
||||
f"Invalid max_tokens value: {max_tokens}, will not set maxOutputTokens"
|
||||
)
|
||||
# 不设置 maxOutputTokens,让 Gemini API 使用默认值
|
||||
else:
|
||||
payload["generationConfig"]["maxOutputTokens"] = max_tokens
|
||||
|
||||
|
||||
def _build_payload(
|
||||
request: ChatRequest,
|
||||
messages: List[Dict[str, Any]],
|
||||
@@ -130,16 +208,33 @@ def _build_payload(
|
||||
"tools": _build_tools(request, messages),
|
||||
"safetySettings": _get_safety_settings(request.model),
|
||||
}
|
||||
if request.max_tokens is not None:
|
||||
payload["generationConfig"]["maxOutputTokens"] = request.max_tokens
|
||||
|
||||
# 处理 max_tokens 参数
|
||||
_validate_and_set_max_tokens(payload, request.max_tokens, logger)
|
||||
|
||||
# 处理 n 参数
|
||||
if request.n is not None and request.n > 0:
|
||||
payload["generationConfig"]["candidateCount"] = request.n
|
||||
|
||||
if request.model.endswith("-image") or request.model.endswith("-image-generation"):
|
||||
payload["generationConfig"]["responseModalities"] = ["Text", "Image"]
|
||||
|
||||
if request.model.endswith("-non-thinking"):
|
||||
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": 0}
|
||||
if request.model in settings.THINKING_BUDGET_MAP:
|
||||
payload["generationConfig"]["thinkingConfig"] = {
|
||||
"thinkingBudget": settings.THINKING_BUDGET_MAP.get(request.model, 1000)
|
||||
}
|
||||
if "gemini-2.5-pro" in request.model:
|
||||
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": 128}
|
||||
else:
|
||||
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": 0}
|
||||
|
||||
elif _get_real_model(request.model) in settings.THINKING_BUDGET_MAP:
|
||||
if settings.SHOW_THINKING_PROCESS:
|
||||
payload["generationConfig"]["thinkingConfig"] = {
|
||||
"thinkingBudget": settings.THINKING_BUDGET_MAP.get(request.model, 1000),
|
||||
"includeThoughts": True,
|
||||
}
|
||||
else:
|
||||
payload["generationConfig"]["thinkingConfig"] = {
|
||||
"thinkingBudget": settings.THINKING_BUDGET_MAP.get(request.model, 1000)
|
||||
}
|
||||
|
||||
if (
|
||||
instruction
|
||||
@@ -189,7 +284,9 @@ class OpenAIChatService:
|
||||
api_key: str,
|
||||
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
|
||||
"""创建聊天完成"""
|
||||
messages, instruction = self.message_converter.convert(request.messages)
|
||||
messages, instruction = self.message_converter.convert(
|
||||
request.messages, request.model
|
||||
)
|
||||
|
||||
payload = _build_payload(request, messages, instruction)
|
||||
|
||||
@@ -206,27 +303,55 @@ class OpenAIChatService:
|
||||
is_success = False
|
||||
status_code = None
|
||||
response = None
|
||||
|
||||
try:
|
||||
response = await self.api_client.generate_content(payload, model, api_key)
|
||||
usage_metadata = response.get("usageMetadata", {})
|
||||
is_success = True
|
||||
status_code = 200
|
||||
return self.response_handler.handle_response(
|
||||
response,
|
||||
model,
|
||||
stream=False,
|
||||
finish_reason="stop",
|
||||
usage_metadata=usage_metadata,
|
||||
)
|
||||
|
||||
# 尝试处理响应,捕获可能的响应处理异常
|
||||
try:
|
||||
result = self.response_handler.handle_response(
|
||||
response,
|
||||
model,
|
||||
stream=False,
|
||||
finish_reason="stop",
|
||||
usage_metadata=usage_metadata,
|
||||
)
|
||||
return result
|
||||
except Exception as response_error:
|
||||
logger.error(
|
||||
f"Response processing failed for model {model}: {str(response_error)}"
|
||||
)
|
||||
|
||||
# 记录详细的错误信息
|
||||
if "parts" in str(response_error):
|
||||
logger.error("Response structure issue - missing or invalid parts")
|
||||
if response.get("candidates"):
|
||||
candidate = response["candidates"][0]
|
||||
content = candidate.get("content", {})
|
||||
logger.error(f"Content structure: {content}")
|
||||
|
||||
# 重新抛出异常
|
||||
raise response_error
|
||||
|
||||
except Exception as e:
|
||||
is_success = False
|
||||
error_log_msg = str(e)
|
||||
logger.error(f"Normal API call failed with error: {error_log_msg}")
|
||||
match = re.search(r"status code (\d+)", error_log_msg)
|
||||
if match:
|
||||
status_code = int(match.group(1))
|
||||
else:
|
||||
status_code = 500
|
||||
status_code = e.args[0]
|
||||
error_log_msg = e.args[1]
|
||||
logger.error(f"API call failed for model {model}: {error_log_msg}")
|
||||
|
||||
# 特别记录 max_tokens 相关的错误
|
||||
gen_config = payload.get("generationConfig", {})
|
||||
if "maxOutputTokens" in gen_config:
|
||||
logger.error(
|
||||
f"Request had maxOutputTokens: {gen_config['maxOutputTokens']}"
|
||||
)
|
||||
|
||||
# 如果是响应处理错误,记录更多信息
|
||||
if "parts" in error_log_msg:
|
||||
logger.error("This is likely a response processing error")
|
||||
|
||||
await add_error_log(
|
||||
gemini_key=api_key,
|
||||
@@ -234,12 +359,17 @@ class OpenAIChatService:
|
||||
error_type="openai-chat-non-stream",
|
||||
error_log=error_log_msg,
|
||||
error_code=status_code,
|
||||
request_msg=payload,
|
||||
request_msg=payload if settings.ERROR_LOG_RECORD_REQUEST_BODY else None,
|
||||
request_datetime=request_datetime,
|
||||
)
|
||||
raise e
|
||||
finally:
|
||||
end_time = time.perf_counter()
|
||||
latency_ms = int((end_time - start_time) * 1000)
|
||||
logger.info(
|
||||
f"Normal completion finished - Success: {is_success}, Latency: {latency_ms}ms"
|
||||
)
|
||||
|
||||
await add_request_log(
|
||||
model_name=model,
|
||||
api_key=api_key,
|
||||
@@ -256,49 +386,44 @@ class OpenAIChatService:
|
||||
logger.info(
|
||||
f"Fake streaming enabled for model: {model}. Calling non-streaming endpoint."
|
||||
)
|
||||
keep_sending_empty_data = True
|
||||
|
||||
async def send_empty_data_locally() -> AsyncGenerator[str, None]:
|
||||
"""定期发送空数据以保持连接"""
|
||||
while keep_sending_empty_data:
|
||||
await asyncio.sleep(settings.FAKE_STREAM_EMPTY_DATA_INTERVAL_SECONDS)
|
||||
if keep_sending_empty_data:
|
||||
empty_chunk = self.response_handler.handle_response({}, model, stream=True, finish_reason='stop', usage_metadata=None)
|
||||
yield f"data: {json.dumps(empty_chunk)}\n\n"
|
||||
logger.debug("Sent empty data chunk for fake stream heartbeat.")
|
||||
|
||||
empty_data_generator = send_empty_data_locally()
|
||||
api_response_task = asyncio.create_task(
|
||||
self.api_client.generate_content(payload, model, api_key)
|
||||
)
|
||||
|
||||
i = 0
|
||||
try:
|
||||
while not api_response_task.done():
|
||||
try:
|
||||
next_empty_chunk = await asyncio.wait_for(
|
||||
empty_data_generator.__anext__(), timeout=0.1
|
||||
i = i + 1
|
||||
"""定期发送空数据以保持连接"""
|
||||
if i >= settings.FAKE_STREAM_EMPTY_DATA_INTERVAL_SECONDS:
|
||||
i = 0
|
||||
empty_chunk = self.response_handler.handle_response(
|
||||
{},
|
||||
model,
|
||||
stream=True,
|
||||
finish_reason="stop",
|
||||
usage_metadata=None,
|
||||
)
|
||||
yield next_empty_chunk
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
except (
|
||||
StopAsyncIteration
|
||||
):
|
||||
break
|
||||
|
||||
response = await api_response_task
|
||||
yield f"data: {json.dumps(empty_chunk)}\n\n"
|
||||
logger.debug("Sent empty data chunk for fake stream heartbeat.")
|
||||
await asyncio.sleep(1)
|
||||
finally:
|
||||
keep_sending_empty_data = False
|
||||
response = await api_response_task
|
||||
|
||||
if response and response.get("candidates"):
|
||||
response = self.response_handler.handle_response(response, model, stream=True, finish_reason='stop', usage_metadata=response.get("usageMetadata", {}))
|
||||
response = self.response_handler.handle_response(
|
||||
response,
|
||||
model,
|
||||
stream=True,
|
||||
finish_reason="stop",
|
||||
usage_metadata=response.get("usageMetadata", {}),
|
||||
)
|
||||
yield f"data: {json.dumps(response)}\n\n"
|
||||
logger.info(f"Sent full response content for fake stream: {model}")
|
||||
else:
|
||||
error_message = "Failed to get response from model"
|
||||
if (
|
||||
response and isinstance(response, dict) and response.get("error")
|
||||
):
|
||||
if response and isinstance(response, dict) and response.get("error"):
|
||||
error_details = response.get("error")
|
||||
if isinstance(error_details, dict):
|
||||
error_message = error_details.get("message", error_message)
|
||||
@@ -306,7 +431,9 @@ class OpenAIChatService:
|
||||
logger.error(
|
||||
f"No candidates or error in response for fake stream model {model}: {response}"
|
||||
)
|
||||
error_chunk = self.response_handler.handle_response({}, model, stream=True, finish_reason='stop', usage_metadata=None)
|
||||
error_chunk = self.response_handler.handle_response(
|
||||
{}, model, stream=True, finish_reason="stop", usage_metadata=None
|
||||
)
|
||||
yield f"data: {json.dumps(error_chunk)}\n\n"
|
||||
|
||||
async def _real_stream_logic_impl(
|
||||
@@ -334,7 +461,11 @@ class OpenAIChatService:
|
||||
)
|
||||
continue
|
||||
openai_chunk = self.response_handler.handle_response(
|
||||
chunk, model, stream=True, finish_reason=None, usage_metadata=usage_metadata
|
||||
chunk,
|
||||
model,
|
||||
stream=True,
|
||||
finish_reason=None,
|
||||
usage_metadata=usage_metadata,
|
||||
)
|
||||
if openai_chunk:
|
||||
text = self._extract_text_from_openai_chunk(openai_chunk)
|
||||
@@ -348,7 +479,9 @@ class OpenAIChatService:
|
||||
):
|
||||
yield optimized_chunk_data
|
||||
else:
|
||||
if openai_chunk.get("choices") and openai_chunk["choices"][0].get("delta", {}).get("tool_calls"):
|
||||
if openai_chunk.get("choices") and openai_chunk["choices"][
|
||||
0
|
||||
].get("delta", {}).get("tool_calls"):
|
||||
tool_call_flag = True
|
||||
|
||||
yield f"data: {json.dumps(openai_chunk)}\n\n"
|
||||
@@ -404,27 +537,22 @@ class OpenAIChatService:
|
||||
except Exception as e:
|
||||
retries += 1
|
||||
is_success = False
|
||||
error_log_msg = str(e)
|
||||
status_code = e.args[0]
|
||||
error_log_msg = e.args[1]
|
||||
logger.warning(
|
||||
f"Streaming API call failed with error: {error_log_msg}. Attempt {retries} of {max_retries} with key {current_attempt_key}"
|
||||
)
|
||||
|
||||
match = re.search(r"status code (\\d+)", error_log_msg)
|
||||
if match:
|
||||
status_code = int(match.group(1))
|
||||
else:
|
||||
if isinstance(e, asyncio.TimeoutError):
|
||||
status_code = 408
|
||||
else:
|
||||
status_code = 500
|
||||
|
||||
await add_error_log(
|
||||
gemini_key=current_attempt_key,
|
||||
model_name=model,
|
||||
error_type="openai-chat-stream",
|
||||
error_log=error_log_msg,
|
||||
error_code=status_code,
|
||||
request_msg=payload,
|
||||
request_msg=(
|
||||
payload if settings.ERROR_LOG_RECORD_REQUEST_BODY else None
|
||||
),
|
||||
request_datetime=request_datetime,
|
||||
)
|
||||
|
||||
if self.key_manager:
|
||||
@@ -440,7 +568,7 @@ class OpenAIChatService:
|
||||
logger.error(
|
||||
f"No valid API key available after {retries} retries, ceasing attempts for this request."
|
||||
)
|
||||
break
|
||||
raise
|
||||
else:
|
||||
logger.error(
|
||||
"KeyManager not available, cannot switch API key. Ceasing attempts for this request."
|
||||
@@ -451,6 +579,7 @@ class OpenAIChatService:
|
||||
logger.error(
|
||||
f"Max retries ({max_retries}) reached for streaming model {model}."
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
end_time = time.perf_counter()
|
||||
latency_ms = int((end_time - start_time) * 1000)
|
||||
@@ -463,13 +592,6 @@ class OpenAIChatService:
|
||||
request_time=request_datetime,
|
||||
)
|
||||
|
||||
if not is_success:
|
||||
logger.error(
|
||||
f"Streaming failed permanently for model {model} after {retries} attempts."
|
||||
)
|
||||
yield f"data: {json.dumps({'error': f'Streaming failed after {retries} retries.'})}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
async def create_image_chat_completion(
|
||||
self, request: ChatRequest, api_key: str
|
||||
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
|
||||
@@ -528,19 +650,23 @@ class OpenAIChatService:
|
||||
yield "data: [DONE]\n\n"
|
||||
except Exception as e:
|
||||
is_success = False
|
||||
error_log_msg = f"Stream image completion failed for model {model}: {e}"
|
||||
status_code = e.args[0]
|
||||
error_log_msg = e.args[1]
|
||||
logger.error(error_log_msg)
|
||||
status_code = 500
|
||||
await add_error_log(
|
||||
gemini_key=api_key,
|
||||
model_name=model,
|
||||
error_type="openai-image-stream",
|
||||
error_log=error_log_msg,
|
||||
error_code=status_code,
|
||||
request_msg={"image_data_truncated": image_data[:1000]},
|
||||
request_msg=(
|
||||
{"image_data_truncated": image_data[:1000]}
|
||||
if settings.ERROR_LOG_RECORD_REQUEST_BODY
|
||||
else None
|
||||
),
|
||||
request_datetime=request_datetime,
|
||||
)
|
||||
yield f"data: {json.dumps({'error': error_log_msg})}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
raise
|
||||
finally:
|
||||
end_time = time.perf_counter()
|
||||
latency_ms = int((end_time - start_time) * 1000)
|
||||
@@ -578,18 +704,23 @@ class OpenAIChatService:
|
||||
return result
|
||||
except Exception as e:
|
||||
is_success = False
|
||||
error_log_msg = f"Normal image completion failed for model {model}: {e}"
|
||||
status_code = e.args[0]
|
||||
error_log_msg = e.args[1]
|
||||
logger.error(error_log_msg)
|
||||
status_code = 500
|
||||
await add_error_log(
|
||||
gemini_key=api_key,
|
||||
model_name=model,
|
||||
error_type="openai-image-non-stream",
|
||||
error_log=error_log_msg,
|
||||
error_code=status_code,
|
||||
request_msg={"image_data_truncated": image_data[:1000]},
|
||||
request_msg=(
|
||||
{"image_data_truncated": image_data[:1000]}
|
||||
if settings.ERROR_LOG_RECORD_REQUEST_BODY
|
||||
else None
|
||||
),
|
||||
request_datetime=request_datetime,
|
||||
)
|
||||
raise e
|
||||
raise
|
||||
finally:
|
||||
end_time = time.perf_counter()
|
||||
latency_ms = int((end_time - start_time) * 1000)
|
||||
|
||||
@@ -1,19 +1,20 @@
|
||||
# app/services/chat_service.py
|
||||
|
||||
import json
|
||||
import re
|
||||
import datetime
|
||||
import json
|
||||
import time
|
||||
from typing import Any, AsyncGenerator, Dict, List
|
||||
|
||||
from app.config.config import settings
|
||||
from app.core.constants import GEMINI_2_FLASH_EXP_SAFETY_SETTINGS
|
||||
from app.database.services import add_error_log, add_request_log
|
||||
from app.domain.gemini_models import GeminiRequest
|
||||
from app.handler.response_handler import GeminiResponseHandler
|
||||
from app.handler.stream_optimizer import gemini_optimizer
|
||||
from app.log.logger import get_gemini_logger
|
||||
from app.service.client.api_client import GeminiApiClient
|
||||
from app.service.key.key_manager import KeyManager
|
||||
from app.database.services import add_error_log, add_request_log
|
||||
from app.utils.helpers import redact_key_for_logging
|
||||
|
||||
logger = get_gemini_logger()
|
||||
|
||||
@@ -28,9 +29,67 @@ def _has_image_parts(contents: List[Dict[str, Any]]) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _clean_json_schema_properties(obj: Any) -> Any:
|
||||
"""清理JSON Schema中Gemini API不支持的字段"""
|
||||
if not isinstance(obj, dict):
|
||||
return obj
|
||||
|
||||
# Gemini API不支持的JSON Schema字段
|
||||
unsupported_fields = {
|
||||
"exclusiveMaximum",
|
||||
"exclusiveMinimum",
|
||||
"const",
|
||||
"examples",
|
||||
"contentEncoding",
|
||||
"contentMediaType",
|
||||
"if",
|
||||
"then",
|
||||
"else",
|
||||
"allOf",
|
||||
"anyOf",
|
||||
"oneOf",
|
||||
"not",
|
||||
"definitions",
|
||||
"$schema",
|
||||
"$id",
|
||||
"$ref",
|
||||
"$comment",
|
||||
"readOnly",
|
||||
"writeOnly",
|
||||
}
|
||||
|
||||
cleaned = {}
|
||||
for key, value in obj.items():
|
||||
if key in unsupported_fields:
|
||||
continue
|
||||
if isinstance(value, dict):
|
||||
cleaned[key] = _clean_json_schema_properties(value)
|
||||
elif isinstance(value, list):
|
||||
cleaned[key] = [_clean_json_schema_properties(item) for item in value]
|
||||
else:
|
||||
cleaned[key] = value
|
||||
|
||||
return cleaned
|
||||
|
||||
|
||||
def _build_tools(model: str, payload: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""构建工具"""
|
||||
|
||||
|
||||
def _has_function_call(contents: List[Dict[str, Any]]) -> bool:
|
||||
"""检查内容中是否包含 functionCall"""
|
||||
if not contents or not isinstance(contents, list):
|
||||
return False
|
||||
for content in contents:
|
||||
if not content or not isinstance(content, dict) or "parts" not in content:
|
||||
continue
|
||||
parts = content.get("parts", [])
|
||||
if not parts or not isinstance(parts, list):
|
||||
continue
|
||||
for part in parts:
|
||||
if isinstance(part, dict) and "functionCall" in part:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _merge_tools(tools: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
record = dict()
|
||||
for item in tools:
|
||||
@@ -40,12 +99,28 @@ def _build_tools(model: str, payload: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
for k, v in item.items():
|
||||
if k == "functionDeclarations" and v and isinstance(v, list):
|
||||
functions = record.get("functionDeclarations", [])
|
||||
functions.extend(v)
|
||||
# 清理每个函数声明中的不支持字段
|
||||
cleaned_functions = []
|
||||
for func in v:
|
||||
if isinstance(func, dict):
|
||||
cleaned_func = _clean_json_schema_properties(func)
|
||||
cleaned_functions.append(cleaned_func)
|
||||
else:
|
||||
cleaned_functions.append(func)
|
||||
functions.extend(cleaned_functions)
|
||||
record["functionDeclarations"] = functions
|
||||
else:
|
||||
record[k] = v
|
||||
return record
|
||||
|
||||
def _is_structured_output_request(payload: Dict[str, Any]) -> bool:
|
||||
"""检查请求是否要求结构化JSON输出"""
|
||||
try:
|
||||
generation_config = payload.get("generationConfig", {})
|
||||
return generation_config.get("responseMimeType") == "application/json"
|
||||
except (AttributeError, TypeError):
|
||||
return False
|
||||
|
||||
tool = dict()
|
||||
if payload and isinstance(payload, dict) and "tools" in payload:
|
||||
if payload.get("tools") and isinstance(payload.get("tools"), dict):
|
||||
@@ -54,23 +129,48 @@ def _build_tools(model: str, payload: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
if items and isinstance(items, list):
|
||||
tool.update(_merge_tools(items))
|
||||
|
||||
if (
|
||||
settings.TOOLS_CODE_EXECUTION_ENABLED
|
||||
and not (model.endswith("-search") or "-thinking" in model)
|
||||
and not _has_image_parts(payload.get("contents", []))
|
||||
):
|
||||
tool["codeExecution"] = {}
|
||||
if model.endswith("-search"):
|
||||
tool["googleSearch"] = {}
|
||||
# "Tool use with a response mime type: 'application/json' is unsupported"
|
||||
# Gemini API限制:不支持同时使用tools和结构化输出(response_mime_type='application/json')
|
||||
# 当请求指定了JSON响应格式时,跳过所有工具的添加以避免API错误
|
||||
has_structured_output = _is_structured_output_request(payload)
|
||||
if not has_structured_output:
|
||||
if (
|
||||
settings.TOOLS_CODE_EXECUTION_ENABLED
|
||||
and not (model.endswith("-search") or "-thinking" in model)
|
||||
and not _has_image_parts(payload.get("contents", []))
|
||||
):
|
||||
tool["codeExecution"] = {}
|
||||
|
||||
if model.endswith("-search"):
|
||||
tool["googleSearch"] = {}
|
||||
|
||||
real_model = _get_real_model(model)
|
||||
if real_model in settings.URL_CONTEXT_MODELS and settings.URL_CONTEXT_ENABLED:
|
||||
tool["urlContext"] = {}
|
||||
|
||||
# 解决 "Tool use with function calling is unsupported" 问题
|
||||
if tool.get("functionDeclarations"):
|
||||
if tool.get("functionDeclarations") or _has_function_call(
|
||||
payload.get("contents", [])
|
||||
):
|
||||
tool.pop("googleSearch", None)
|
||||
tool.pop("codeExecution", None)
|
||||
tool.pop("urlContext", None)
|
||||
|
||||
return [tool] if tool else []
|
||||
|
||||
|
||||
def _get_real_model(model: str) -> str:
|
||||
if model.endswith("-search"):
|
||||
model = model[:-7]
|
||||
if model.endswith("-image"):
|
||||
model = model[:-6]
|
||||
if model.endswith("-non-thinking"):
|
||||
model = model[:-13]
|
||||
if "-search" in model and "-non-thinking" in model:
|
||||
model = model[:-20]
|
||||
return model
|
||||
|
||||
|
||||
def _get_safety_settings(model: str) -> List[Dict[str, str]]:
|
||||
"""获取安全设置"""
|
||||
if model == "gemini-2.0-flash-exp":
|
||||
@@ -80,12 +180,12 @@ def _get_safety_settings(model: str) -> List[Dict[str, str]]:
|
||||
|
||||
def _build_payload(model: str, request: GeminiRequest) -> Dict[str, Any]:
|
||||
"""构建请求payload"""
|
||||
request_dict = request.model_dump()
|
||||
request_dict = request.model_dump(exclude_none=False)
|
||||
if request.generationConfig:
|
||||
if request.generationConfig.maxOutputTokens is None:
|
||||
# 如果未指定最大输出长度,则不传递该字段,解决截断的问题
|
||||
request_dict["generationConfig"].pop("maxOutputTokens")
|
||||
|
||||
|
||||
payload = {
|
||||
"contents": request_dict.get("contents", []),
|
||||
"tools": _build_tools(model, request_dict),
|
||||
@@ -97,11 +197,32 @@ def _build_payload(model: str, request: GeminiRequest) -> Dict[str, Any]:
|
||||
if model.endswith("-image") or model.endswith("-image-generation"):
|
||||
payload.pop("systemInstruction")
|
||||
payload["generationConfig"]["responseModalities"] = ["Text", "Image"]
|
||||
|
||||
if model.endswith("-non-thinking"):
|
||||
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": 0}
|
||||
if model in settings.THINKING_BUDGET_MAP:
|
||||
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": settings.THINKING_BUDGET_MAP.get(model,1000)}
|
||||
|
||||
# 处理思考配置:优先使用客户端提供的配置,否则使用默认配置
|
||||
client_thinking_config = None
|
||||
if request.generationConfig and request.generationConfig.thinkingConfig:
|
||||
client_thinking_config = request.generationConfig.thinkingConfig
|
||||
|
||||
if client_thinking_config is not None:
|
||||
# 客户端提供了思考配置,直接使用
|
||||
payload["generationConfig"]["thinkingConfig"] = client_thinking_config
|
||||
else:
|
||||
# 客户端没有提供思考配置,使用默认配置
|
||||
if model.endswith("-non-thinking"):
|
||||
if "gemini-2.5-pro" in model:
|
||||
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": 128}
|
||||
else:
|
||||
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": 0}
|
||||
elif _get_real_model(model) in settings.THINKING_BUDGET_MAP:
|
||||
if settings.SHOW_THINKING_PROCESS:
|
||||
payload["generationConfig"]["thinkingConfig"] = {
|
||||
"thinkingBudget": settings.THINKING_BUDGET_MAP.get(model, 1000),
|
||||
"includeThoughts": True,
|
||||
}
|
||||
else:
|
||||
payload["generationConfig"]["thinkingConfig"] = {
|
||||
"thinkingBudget": settings.THINKING_BUDGET_MAP.get(model, 1000)
|
||||
}
|
||||
|
||||
return payload
|
||||
|
||||
@@ -156,13 +277,9 @@ class GeminiChatService:
|
||||
return self.response_handler.handle_response(response, model, stream=False)
|
||||
except Exception as e:
|
||||
is_success = False
|
||||
error_log_msg = str(e)
|
||||
status_code = e.args[0]
|
||||
error_log_msg = e.args[1]
|
||||
logger.error(f"Normal API call failed with error: {error_log_msg}")
|
||||
match = re.search(r"status code (\d+)", error_log_msg)
|
||||
if match:
|
||||
status_code = int(match.group(1))
|
||||
else:
|
||||
status_code = 500
|
||||
|
||||
await add_error_log(
|
||||
gemini_key=api_key,
|
||||
@@ -170,7 +287,8 @@ class GeminiChatService:
|
||||
error_type="gemini-chat-non-stream",
|
||||
error_log=error_log_msg,
|
||||
error_code=status_code,
|
||||
request_msg=payload
|
||||
request_msg=payload if settings.ERROR_LOG_RECORD_REQUEST_BODY else None,
|
||||
request_datetime=request_datetime,
|
||||
)
|
||||
raise e
|
||||
finally:
|
||||
@@ -182,7 +300,7 @@ class GeminiChatService:
|
||||
is_success=is_success,
|
||||
status_code=status_code,
|
||||
latency_ms=latency_ms,
|
||||
request_time=request_datetime
|
||||
request_time=request_datetime,
|
||||
)
|
||||
|
||||
async def stream_generate_content(
|
||||
@@ -200,7 +318,7 @@ class GeminiChatService:
|
||||
request_datetime = datetime.datetime.now()
|
||||
start_time = time.perf_counter()
|
||||
current_attempt_key = api_key
|
||||
final_api_key = current_attempt_key # Update final key used
|
||||
final_api_key = current_attempt_key # Update final key used
|
||||
try:
|
||||
async for line in self.api_client.stream_generate_content(
|
||||
payload, model, current_attempt_key
|
||||
@@ -233,15 +351,11 @@ class GeminiChatService:
|
||||
except Exception as e:
|
||||
retries += 1
|
||||
is_success = False
|
||||
error_log_msg = str(e)
|
||||
status_code = e.args[0]
|
||||
error_log_msg = e.args[1]
|
||||
logger.warning(
|
||||
f"Streaming API call failed with error: {error_log_msg}. Attempt {retries} of {max_retries}"
|
||||
)
|
||||
match = re.search(r"status code (\d+)", error_log_msg)
|
||||
if match:
|
||||
status_code = int(match.group(1))
|
||||
else:
|
||||
status_code = 500
|
||||
|
||||
await add_error_log(
|
||||
gemini_key=current_attempt_key,
|
||||
@@ -249,21 +363,26 @@ class GeminiChatService:
|
||||
error_type="gemini-chat-stream",
|
||||
error_log=error_log_msg,
|
||||
error_code=status_code,
|
||||
request_msg=payload
|
||||
request_msg=(
|
||||
payload if settings.ERROR_LOG_RECORD_REQUEST_BODY else None
|
||||
),
|
||||
request_datetime=request_datetime,
|
||||
)
|
||||
|
||||
api_key = await self.key_manager.handle_api_failure(current_attempt_key, retries)
|
||||
api_key = await self.key_manager.handle_api_failure(
|
||||
current_attempt_key, retries
|
||||
)
|
||||
if api_key:
|
||||
logger.info(f"Switched to new API key: {api_key}")
|
||||
logger.info(
|
||||
f"Switched to new API key: {redact_key_for_logging(api_key)}"
|
||||
)
|
||||
else:
|
||||
logger.error(f"No valid API key available after {retries} retries.")
|
||||
break
|
||||
raise
|
||||
|
||||
if retries >= max_retries:
|
||||
logger.error(
|
||||
f"Max retries ({max_retries}) reached for streaming."
|
||||
)
|
||||
break
|
||||
logger.error(f"Max retries ({max_retries}) reached for streaming.")
|
||||
raise
|
||||
finally:
|
||||
end_time = time.perf_counter()
|
||||
latency_ms = int((end_time - start_time) * 1000)
|
||||
@@ -273,5 +392,5 @@ class GeminiChatService:
|
||||
is_success=is_success,
|
||||
status_code=status_code,
|
||||
latency_ms=latency_ms,
|
||||
request_time=request_datetime
|
||||
request_time=request_datetime,
|
||||
)
|
||||
|
||||
@@ -1,24 +1,31 @@
|
||||
# app/services/chat/api_client.py
|
||||
|
||||
from typing import Dict, Any, AsyncGenerator, Optional
|
||||
import httpx
|
||||
import random
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, AsyncGenerator, Dict, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from app.config.config import settings
|
||||
from app.log.logger import get_api_client_logger
|
||||
from app.core.constants import DEFAULT_TIMEOUT
|
||||
from app.log.logger import get_api_client_logger
|
||||
|
||||
logger = get_api_client_logger()
|
||||
|
||||
|
||||
class ApiClient(ABC):
|
||||
"""API客户端基类"""
|
||||
|
||||
@abstractmethod
|
||||
async def generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]:
|
||||
async def generate_content(
|
||||
self, payload: Dict[str, Any], model: str, api_key: str
|
||||
) -> Dict[str, Any]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def stream_generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> AsyncGenerator[str, None]:
|
||||
async def stream_generate_content(
|
||||
self, payload: Dict[str, Any], model: str, api_key: str
|
||||
) -> AsyncGenerator[str, None]:
|
||||
pass
|
||||
|
||||
|
||||
@@ -40,10 +47,17 @@ class GeminiApiClient(ApiClient):
|
||||
model = model[:-20]
|
||||
return model
|
||||
|
||||
def _prepare_headers(self) -> Dict[str, str]:
|
||||
headers = {}
|
||||
if settings.CUSTOM_HEADERS:
|
||||
headers.update(settings.CUSTOM_HEADERS)
|
||||
logger.info(f"Using custom headers: {settings.CUSTOM_HEADERS}")
|
||||
return headers
|
||||
|
||||
async def get_models(self, api_key: str) -> Optional[Dict[str, Any]]:
|
||||
"""获取可用的 Gemini 模型列表"""
|
||||
timeout = httpx.Timeout(timeout=5)
|
||||
|
||||
|
||||
proxy_to_use = None
|
||||
if settings.PROXIES:
|
||||
if settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY:
|
||||
@@ -52,10 +66,11 @@ class GeminiApiClient(ApiClient):
|
||||
proxy_to_use = random.choice(settings.PROXIES)
|
||||
logger.info(f"Using proxy for getting models: {proxy_to_use}")
|
||||
|
||||
headers = self._prepare_headers()
|
||||
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
|
||||
url = f"{self.base_url}/models?key={api_key}&pageSize=1000"
|
||||
try:
|
||||
response = await client.get(url)
|
||||
response = await client.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except httpx.HTTPStatusError as e:
|
||||
@@ -65,8 +80,10 @@ class GeminiApiClient(ApiClient):
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"请求模型列表失败: {e}")
|
||||
return None
|
||||
|
||||
async def generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]:
|
||||
|
||||
async def generate_content(
|
||||
self, payload: Dict[str, Any], model: str, api_key: str
|
||||
) -> Dict[str, Any]:
|
||||
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
||||
model = self._get_real_model(model)
|
||||
|
||||
@@ -77,19 +94,33 @@ class GeminiApiClient(ApiClient):
|
||||
else:
|
||||
proxy_to_use = random.choice(settings.PROXIES)
|
||||
logger.info(f"Using proxy for getting models: {proxy_to_use}")
|
||||
|
||||
|
||||
headers = self._prepare_headers()
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
|
||||
url = f"{self.base_url}/models/{model}:generateContent?key={api_key}"
|
||||
response = await client.post(url, json=payload)
|
||||
response = await client.post(url, json=payload, headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_content = response.text
|
||||
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
|
||||
return response.json()
|
||||
logger.error(
|
||||
f"API call failed - Status: {response.status_code}, Content: {error_content}"
|
||||
)
|
||||
raise Exception(response.status_code, error_content)
|
||||
response_data = response.json()
|
||||
|
||||
async def stream_generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> AsyncGenerator[str, None]:
|
||||
# 检查响应结构的基本信息
|
||||
if not response_data.get("candidates"):
|
||||
logger.warning("No candidates found in API response")
|
||||
|
||||
return response_data
|
||||
|
||||
async def stream_generate_content(
|
||||
self, payload: Dict[str, Any], model: str, api_key: str
|
||||
) -> AsyncGenerator[str, None]:
|
||||
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
||||
model = self._get_real_model(model)
|
||||
|
||||
|
||||
proxy_to_use = None
|
||||
if settings.PROXIES:
|
||||
if settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY:
|
||||
@@ -98,16 +129,96 @@ class GeminiApiClient(ApiClient):
|
||||
proxy_to_use = random.choice(settings.PROXIES)
|
||||
logger.info(f"Using proxy for getting models: {proxy_to_use}")
|
||||
|
||||
headers = self._prepare_headers()
|
||||
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
|
||||
url = f"{self.base_url}/models/{model}:streamGenerateContent?alt=sse&key={api_key}"
|
||||
async with client.stream(method="POST", url=url, json=payload) as response:
|
||||
async with client.stream(
|
||||
method="POST", url=url, json=payload, headers=headers
|
||||
) as response:
|
||||
if response.status_code != 200:
|
||||
error_content = await response.aread()
|
||||
error_msg = error_content.decode("utf-8")
|
||||
raise Exception(f"API call failed with status code {response.status_code}, {error_msg}")
|
||||
raise Exception(response.status_code, error_msg)
|
||||
async for line in response.aiter_lines():
|
||||
yield line
|
||||
|
||||
async def count_tokens(
|
||||
self, payload: Dict[str, Any], model: str, api_key: str
|
||||
) -> Dict[str, Any]:
|
||||
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
||||
model = self._get_real_model(model)
|
||||
|
||||
proxy_to_use = None
|
||||
if settings.PROXIES:
|
||||
if settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY:
|
||||
proxy_to_use = settings.PROXIES[hash(api_key) % len(settings.PROXIES)]
|
||||
else:
|
||||
proxy_to_use = random.choice(settings.PROXIES)
|
||||
logger.info(f"Using proxy for counting tokens: {proxy_to_use}")
|
||||
|
||||
headers = self._prepare_headers()
|
||||
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
|
||||
url = f"{self.base_url}/models/{model}:countTokens?key={api_key}"
|
||||
response = await client.post(url, json=payload, headers=headers)
|
||||
if response.status_code != 200:
|
||||
error_content = response.text
|
||||
raise Exception(response.status_code, error_content)
|
||||
return response.json()
|
||||
|
||||
async def embed_content(
|
||||
self, payload: Dict[str, Any], model: str, api_key: str
|
||||
) -> Dict[str, Any]:
|
||||
"""单一嵌入内容生成"""
|
||||
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
||||
model = self._get_real_model(model)
|
||||
|
||||
proxy_to_use = None
|
||||
if settings.PROXIES:
|
||||
if settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY:
|
||||
proxy_to_use = settings.PROXIES[hash(api_key) % len(settings.PROXIES)]
|
||||
else:
|
||||
proxy_to_use = random.choice(settings.PROXIES)
|
||||
logger.info(f"Using proxy for embedding: {proxy_to_use}")
|
||||
|
||||
headers = self._prepare_headers()
|
||||
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
|
||||
url = f"{self.base_url}/models/{model}:embedContent?key={api_key}"
|
||||
response = await client.post(url, json=payload, headers=headers)
|
||||
if response.status_code != 200:
|
||||
error_content = response.text
|
||||
logger.error(
|
||||
f"Embedding API call failed - Status: {response.status_code}, Content: {error_content}"
|
||||
)
|
||||
raise Exception(response.status_code, error_content)
|
||||
return response.json()
|
||||
|
||||
async def batch_embed_contents(
|
||||
self, payload: Dict[str, Any], model: str, api_key: str
|
||||
) -> Dict[str, Any]:
|
||||
"""批量嵌入内容生成"""
|
||||
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
||||
model = self._get_real_model(model)
|
||||
|
||||
proxy_to_use = None
|
||||
if settings.PROXIES:
|
||||
if settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY:
|
||||
proxy_to_use = settings.PROXIES[hash(api_key) % len(settings.PROXIES)]
|
||||
else:
|
||||
proxy_to_use = random.choice(settings.PROXIES)
|
||||
logger.info(f"Using proxy for batch embedding: {proxy_to_use}")
|
||||
|
||||
headers = self._prepare_headers()
|
||||
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
|
||||
url = f"{self.base_url}/models/{model}:batchEmbedContents?key={api_key}"
|
||||
response = await client.post(url, json=payload, headers=headers)
|
||||
if response.status_code != 200:
|
||||
error_content = response.text
|
||||
logger.error(
|
||||
f"Batch embedding API call failed - Status: {response.status_code}, Content: {error_content}"
|
||||
)
|
||||
raise Exception(response.status_code, error_content)
|
||||
return response.json()
|
||||
|
||||
|
||||
class OpenaiApiClient(ApiClient):
|
||||
"""OpenAI API客户端"""
|
||||
@@ -115,7 +226,14 @@ class OpenaiApiClient(ApiClient):
|
||||
def __init__(self, base_url: str, timeout: int = DEFAULT_TIMEOUT):
|
||||
self.base_url = base_url
|
||||
self.timeout = timeout
|
||||
|
||||
|
||||
def _prepare_headers(self, api_key: str) -> Dict[str, str]:
|
||||
headers = {"Authorization": f"Bearer {api_key}"}
|
||||
if settings.CUSTOM_HEADERS:
|
||||
headers.update(settings.CUSTOM_HEADERS)
|
||||
logger.info(f"Using custom headers: {settings.CUSTOM_HEADERS}")
|
||||
return headers
|
||||
|
||||
async def get_models(self, api_key: str) -> Dict[str, Any]:
|
||||
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
||||
|
||||
@@ -127,18 +245,22 @@ class OpenaiApiClient(ApiClient):
|
||||
proxy_to_use = random.choice(settings.PROXIES)
|
||||
logger.info(f"Using proxy for getting models: {proxy_to_use}")
|
||||
|
||||
headers = self._prepare_headers(api_key)
|
||||
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
|
||||
url = f"{self.base_url}/openai/models"
|
||||
headers = {"Authorization": f"Bearer {api_key}"}
|
||||
response = await client.get(url, headers=headers)
|
||||
if response.status_code != 200:
|
||||
error_content = response.text
|
||||
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
|
||||
raise Exception(response.status_code, error_content)
|
||||
return response.json()
|
||||
|
||||
async def generate_content(self, payload: Dict[str, Any], api_key: str) -> Dict[str, Any]:
|
||||
async def generate_content(
|
||||
self, payload: Dict[str, Any], api_key: str
|
||||
) -> Dict[str, Any]:
|
||||
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
||||
logger.info(f"settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY: {settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY}")
|
||||
logger.info(
|
||||
f"settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY: {settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY}"
|
||||
)
|
||||
proxy_to_use = None
|
||||
if settings.PROXIES:
|
||||
if settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY:
|
||||
@@ -147,16 +269,18 @@ class OpenaiApiClient(ApiClient):
|
||||
proxy_to_use = random.choice(settings.PROXIES)
|
||||
logger.info(f"Using proxy for getting models: {proxy_to_use}")
|
||||
|
||||
headers = self._prepare_headers(api_key)
|
||||
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
|
||||
url = f"{self.base_url}/openai/chat/completions"
|
||||
headers = {"Authorization": f"Bearer {api_key}"}
|
||||
response = await client.post(url, json=payload, headers=headers)
|
||||
if response.status_code != 200:
|
||||
error_content = response.text
|
||||
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
|
||||
raise Exception(response.status_code, error_content)
|
||||
return response.json()
|
||||
|
||||
async def stream_generate_content(self, payload: Dict[str, Any], api_key: str) -> AsyncGenerator[str, None]:
|
||||
async def stream_generate_content(
|
||||
self, payload: Dict[str, Any], api_key: str
|
||||
) -> AsyncGenerator[str, None]:
|
||||
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
||||
proxy_to_use = None
|
||||
if settings.PROXIES:
|
||||
@@ -166,20 +290,24 @@ class OpenaiApiClient(ApiClient):
|
||||
proxy_to_use = random.choice(settings.PROXIES)
|
||||
logger.info(f"Using proxy for getting models: {proxy_to_use}")
|
||||
|
||||
headers = self._prepare_headers(api_key)
|
||||
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
|
||||
url = f"{self.base_url}/openai/chat/completions"
|
||||
headers = {"Authorization": f"Bearer {api_key}"}
|
||||
async with client.stream(method="POST", url=url, json=payload, headers=headers) as response:
|
||||
async with client.stream(
|
||||
method="POST", url=url, json=payload, headers=headers
|
||||
) as response:
|
||||
if response.status_code != 200:
|
||||
error_content = await response.aread()
|
||||
error_msg = error_content.decode("utf-8")
|
||||
raise Exception(f"API call failed with status code {response.status_code}, {error_msg}")
|
||||
raise Exception(response.status_code, error_msg)
|
||||
async for line in response.aiter_lines():
|
||||
yield line
|
||||
|
||||
async def create_embeddings(self, input: str, model: str, api_key: str) -> Dict[str, Any]:
|
||||
|
||||
async def create_embeddings(
|
||||
self, input: str, model: str, api_key: str
|
||||
) -> Dict[str, Any]:
|
||||
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
||||
|
||||
|
||||
proxy_to_use = None
|
||||
if settings.PROXIES:
|
||||
if settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY:
|
||||
@@ -188,9 +316,9 @@ class OpenaiApiClient(ApiClient):
|
||||
proxy_to_use = random.choice(settings.PROXIES)
|
||||
logger.info(f"Using proxy for getting models: {proxy_to_use}")
|
||||
|
||||
headers = self._prepare_headers(api_key)
|
||||
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
|
||||
url = f"{self.base_url}/openai/embeddings"
|
||||
headers = {"Authorization": f"Bearer {api_key}"}
|
||||
payload = {
|
||||
"input": input,
|
||||
"model": model,
|
||||
@@ -198,10 +326,12 @@ class OpenaiApiClient(ApiClient):
|
||||
response = await client.post(url, json=payload, headers=headers)
|
||||
if response.status_code != 200:
|
||||
error_content = response.text
|
||||
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
|
||||
raise Exception(response.status_code, error_content)
|
||||
return response.json()
|
||||
|
||||
async def generate_images(self, payload: Dict[str, Any], api_key: str) -> Dict[str, Any]:
|
||||
|
||||
async def generate_images(
|
||||
self, payload: Dict[str, Any], api_key: str
|
||||
) -> Dict[str, Any]:
|
||||
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
||||
|
||||
proxy_to_use = None
|
||||
@@ -212,11 +342,11 @@ class OpenaiApiClient(ApiClient):
|
||||
proxy_to_use = random.choice(settings.PROXIES)
|
||||
logger.info(f"Using proxy for getting models: {proxy_to_use}")
|
||||
|
||||
headers = self._prepare_headers(api_key)
|
||||
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
|
||||
url = f"{self.base_url}/openai/images/generations"
|
||||
headers = {"Authorization": f"Bearer {api_key}"}
|
||||
response = await client.post(url, json=payload, headers=headers)
|
||||
if response.status_code != 200:
|
||||
error_content = response.text
|
||||
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
|
||||
return response.json()
|
||||
raise Exception(response.status_code, error_content)
|
||||
return response.json()
|
||||
|
||||
@@ -230,7 +230,7 @@ class ConfigService:
|
||||
key_manager = await get_key_manager_instance()
|
||||
model_service = ModelService()
|
||||
|
||||
api_key = await key_manager.get_first_valid_key()
|
||||
api_key = await key_manager.get_random_valid_key()
|
||||
if not api_key:
|
||||
logger.error("No valid API keys available to fetch model list for UI.")
|
||||
raise HTTPException(
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import datetime
|
||||
import time
|
||||
import re
|
||||
from typing import List, Union
|
||||
|
||||
import openai
|
||||
@@ -8,8 +7,8 @@ from openai import APIStatusError
|
||||
from openai.types import CreateEmbeddingResponse
|
||||
|
||||
from app.config.config import settings
|
||||
from app.log.logger import get_embeddings_logger
|
||||
from app.database.services import add_error_log, add_request_log
|
||||
from app.log.logger import get_embeddings_logger
|
||||
|
||||
logger = get_embeddings_logger()
|
||||
|
||||
@@ -27,12 +26,20 @@ class EmbeddingService:
|
||||
response = None
|
||||
error_log_msg = ""
|
||||
if isinstance(input_text, list):
|
||||
request_msg_log = {"input_truncated": [str(item)[:100] + "..." if len(str(item)) > 100 else str(item) for item in input_text[:5]]}
|
||||
request_msg_log = {
|
||||
"input_truncated": [
|
||||
str(item)[:100] + "..." if len(str(item)) > 100 else str(item)
|
||||
for item in input_text[:5]
|
||||
]
|
||||
}
|
||||
if len(input_text) > 5:
|
||||
request_msg_log["input_truncated"].append("...")
|
||||
request_msg_log["input_truncated"].append("...")
|
||||
else:
|
||||
request_msg_log = {"input_truncated": input_text[:1000] + "..." if len(input_text) > 1000 else input_text}
|
||||
|
||||
request_msg_log = {
|
||||
"input_truncated": (
|
||||
input_text[:1000] + "..." if len(input_text) > 1000 else input_text
|
||||
)
|
||||
}
|
||||
|
||||
try:
|
||||
client = openai.OpenAI(api_key=api_key, base_url=settings.BASE_URL)
|
||||
@@ -48,13 +55,9 @@ class EmbeddingService:
|
||||
raise e
|
||||
except Exception as e:
|
||||
is_success = False
|
||||
status_code = 500
|
||||
error_log_msg = f"Generic error: {e}"
|
||||
logger.error(f"Error creating embedding (Exception): {error_log_msg}")
|
||||
match = re.search(r"status code (\d+)", str(e))
|
||||
if match:
|
||||
status_code = int(match.group(1))
|
||||
else:
|
||||
status_code = 500
|
||||
raise e
|
||||
finally:
|
||||
end_time = time.perf_counter()
|
||||
@@ -66,13 +69,18 @@ class EmbeddingService:
|
||||
error_type="openai-embedding",
|
||||
error_log=error_log_msg,
|
||||
error_code=status_code,
|
||||
request_msg=request_msg_log
|
||||
)
|
||||
request_msg=(
|
||||
request_msg_log
|
||||
if settings.ERROR_LOG_RECORD_REQUEST_BODY
|
||||
else None
|
||||
),
|
||||
request_datetime=request_datetime,
|
||||
)
|
||||
await add_request_log(
|
||||
model_name=model,
|
||||
api_key=api_key,
|
||||
is_success=is_success,
|
||||
status_code=status_code,
|
||||
latency_ms=latency_ms,
|
||||
request_time=request_datetime
|
||||
request_time=request_datetime,
|
||||
)
|
||||
|
||||
141
app/service/embedding/gemini_embedding_service.py
Normal file
141
app/service/embedding/gemini_embedding_service.py
Normal file
@@ -0,0 +1,141 @@
|
||||
# app/service/embedding/gemini_embedding_service.py
|
||||
|
||||
import datetime
|
||||
import time
|
||||
from typing import Any, Dict
|
||||
|
||||
from app.config.config import settings
|
||||
from app.database.services import add_error_log, add_request_log
|
||||
from app.domain.gemini_models import GeminiBatchEmbedRequest, GeminiEmbedRequest
|
||||
from app.log.logger import get_gemini_embedding_logger
|
||||
from app.service.client.api_client import GeminiApiClient
|
||||
from app.service.key.key_manager import KeyManager
|
||||
|
||||
logger = get_gemini_embedding_logger()
|
||||
|
||||
|
||||
def _build_embed_payload(request: GeminiEmbedRequest) -> Dict[str, Any]:
|
||||
"""构建嵌入请求payload"""
|
||||
payload = {"content": request.content.model_dump()}
|
||||
|
||||
if request.taskType:
|
||||
payload["taskType"] = request.taskType
|
||||
if request.title:
|
||||
payload["title"] = request.title
|
||||
if request.outputDimensionality:
|
||||
payload["outputDimensionality"] = request.outputDimensionality
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
def _build_batch_embed_payload(
|
||||
request: GeminiBatchEmbedRequest, model: str
|
||||
) -> Dict[str, Any]:
|
||||
"""构建批量嵌入请求payload"""
|
||||
requests = []
|
||||
for embed_request in request.requests:
|
||||
embed_payload = _build_embed_payload(embed_request)
|
||||
embed_payload["model"] = (
|
||||
f"models/{model}" # Gemini API要求每个请求包含model字段
|
||||
)
|
||||
requests.append(embed_payload)
|
||||
|
||||
return {"requests": requests}
|
||||
|
||||
|
||||
class GeminiEmbeddingService:
|
||||
"""Gemini嵌入服务"""
|
||||
|
||||
def __init__(self, base_url: str, key_manager: KeyManager):
|
||||
self.api_client = GeminiApiClient(base_url, settings.TIME_OUT)
|
||||
self.key_manager = key_manager
|
||||
|
||||
async def embed_content(
|
||||
self, model: str, request: GeminiEmbedRequest, api_key: str
|
||||
) -> Dict[str, Any]:
|
||||
"""生成单一嵌入内容"""
|
||||
payload = _build_embed_payload(request)
|
||||
start_time = time.perf_counter()
|
||||
request_datetime = datetime.datetime.now()
|
||||
is_success = False
|
||||
status_code = None
|
||||
response = None
|
||||
|
||||
try:
|
||||
response = await self.api_client.embed_content(payload, model, api_key)
|
||||
is_success = True
|
||||
status_code = 200
|
||||
return response
|
||||
except Exception as e:
|
||||
is_success = False
|
||||
status_code = e.args[0]
|
||||
error_log_msg = e.args[1]
|
||||
logger.error(f"Single embedding API call failed: {error_log_msg}")
|
||||
|
||||
await add_error_log(
|
||||
gemini_key=api_key,
|
||||
model_name=model,
|
||||
error_type="gemini-embed-single",
|
||||
error_log=error_log_msg,
|
||||
error_code=status_code,
|
||||
request_msg=payload if settings.ERROR_LOG_RECORD_REQUEST_BODY else None,
|
||||
request_datetime=request_datetime,
|
||||
)
|
||||
raise e
|
||||
finally:
|
||||
end_time = time.perf_counter()
|
||||
latency_ms = int((end_time - start_time) * 1000)
|
||||
await add_request_log(
|
||||
model_name=model,
|
||||
api_key=api_key,
|
||||
is_success=is_success,
|
||||
status_code=status_code,
|
||||
latency_ms=latency_ms,
|
||||
request_time=request_datetime,
|
||||
)
|
||||
|
||||
async def batch_embed_contents(
|
||||
self, model: str, request: GeminiBatchEmbedRequest, api_key: str
|
||||
) -> Dict[str, Any]:
|
||||
"""生成批量嵌入内容"""
|
||||
payload = _build_batch_embed_payload(request, model)
|
||||
start_time = time.perf_counter()
|
||||
request_datetime = datetime.datetime.now()
|
||||
is_success = False
|
||||
status_code = None
|
||||
response = None
|
||||
|
||||
try:
|
||||
response = await self.api_client.batch_embed_contents(
|
||||
payload, model, api_key
|
||||
)
|
||||
is_success = True
|
||||
status_code = 200
|
||||
return response
|
||||
except Exception as e:
|
||||
is_success = False
|
||||
status_code = e.args[0]
|
||||
error_log_msg = e.args[1]
|
||||
logger.error(f"Batch embedding API call failed: {error_log_msg}")
|
||||
|
||||
await add_error_log(
|
||||
gemini_key=api_key,
|
||||
model_name=model,
|
||||
error_type="gemini-embed-batch",
|
||||
error_log=error_log_msg,
|
||||
error_code=status_code,
|
||||
request_msg=payload if settings.ERROR_LOG_RECORD_REQUEST_BODY else None,
|
||||
request_datetime=request_datetime,
|
||||
)
|
||||
raise e
|
||||
finally:
|
||||
end_time = time.perf_counter()
|
||||
latency_ms = int((end_time - start_time) * 1000)
|
||||
await add_request_log(
|
||||
model_name=model,
|
||||
api_key=api_key,
|
||||
is_success=is_success,
|
||||
status_code=status_code,
|
||||
latency_ms=latency_ms,
|
||||
request_time=request_datetime,
|
||||
)
|
||||
@@ -1,4 +1,4 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from sqlalchemy import delete, func, select
|
||||
@@ -28,7 +28,7 @@ async def delete_old_error_logs():
|
||||
)
|
||||
return
|
||||
|
||||
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_to_keep)
|
||||
cutoff_date = datetime.now() - timedelta(days=days_to_keep)
|
||||
|
||||
logger.info(
|
||||
f"Attempting to delete error logs older than {days_to_keep} days (before {cutoff_date.strftime('%Y-%m-%d %H:%M:%S %Z')})."
|
||||
@@ -121,6 +121,30 @@ async def process_get_error_log_details(log_id: int) -> Optional[Dict[str, Any]]
|
||||
raise
|
||||
|
||||
|
||||
async def process_find_error_log_by_info(
|
||||
gemini_key: str,
|
||||
timestamp: datetime,
|
||||
status_code: Optional[int] = None,
|
||||
window_seconds: int = 100,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
根据 key/状态码/时间窗口 查询最匹配的一条错误日志,未找到则返回 None。
|
||||
"""
|
||||
try:
|
||||
return await db_services.find_error_log_by_info(
|
||||
gemini_key=gemini_key,
|
||||
timestamp=timestamp,
|
||||
status_code=status_code,
|
||||
window_seconds=window_seconds,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Service error in process_find_error_log_by_info: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
async def process_delete_error_logs_by_ids(log_ids: List[int]) -> int:
|
||||
"""
|
||||
按 ID 批量删除错误日志。
|
||||
|
||||
1
app/service/files/__init__.py
Normal file
1
app/service/files/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Intentionally empty __init__.py file
|
||||
248
app/service/files/file_upload_handler.py
Normal file
248
app/service/files/file_upload_handler.py
Normal file
@@ -0,0 +1,248 @@
|
||||
"""
|
||||
文件上传处理器
|
||||
处理 Google 的可恢复上传协议
|
||||
"""
|
||||
from typing import Optional
|
||||
from datetime import datetime, timezone, timedelta
|
||||
|
||||
from httpx import AsyncClient
|
||||
from fastapi import Request, Response, HTTPException
|
||||
|
||||
from app.config.config import settings
|
||||
from app.database import services as db_services
|
||||
from app.database.models import FileState
|
||||
from app.log.logger import get_files_logger
|
||||
from app.utils.helpers import redact_key_for_logging
|
||||
|
||||
logger = get_files_logger()
|
||||
|
||||
|
||||
class FileUploadHandler:
|
||||
"""处理文件分块上传"""
|
||||
|
||||
def __init__(self):
|
||||
self.chunk_size = 8 * 1024 * 1024 # 8MB
|
||||
|
||||
async def handle_upload_chunk(
|
||||
self,
|
||||
upload_url: str,
|
||||
request: Request,
|
||||
files_service=None # 添加 files_service 參數
|
||||
) -> Response:
|
||||
"""
|
||||
处理上传分块
|
||||
|
||||
Args:
|
||||
upload_url: 上传 URL
|
||||
request: FastAPI 请求对象
|
||||
files_service: 文件服務實例
|
||||
|
||||
Returns:
|
||||
Response: 响应对象
|
||||
"""
|
||||
try:
|
||||
# 获取请求头
|
||||
headers = {}
|
||||
|
||||
# 复制必要的上传头
|
||||
upload_headers = [
|
||||
"x-goog-upload-command",
|
||||
"x-goog-upload-offset",
|
||||
"content-type",
|
||||
"content-length"
|
||||
]
|
||||
|
||||
for header in upload_headers:
|
||||
if header in request.headers:
|
||||
# 转换为正确的格式
|
||||
key = "-".join(word.capitalize() for word in header.split("-"))
|
||||
headers[key] = request.headers[header]
|
||||
|
||||
# 读取请求体
|
||||
body = await request.body()
|
||||
|
||||
# 检查是否是最后一块
|
||||
is_final = "finalize" in headers.get("X-Goog-Upload-Command", "")
|
||||
logger.debug(f"Upload command: {headers.get('X-Goog-Upload-Command', '')}, is_final: {is_final}")
|
||||
|
||||
# 转发到真实的上传 URL
|
||||
async with AsyncClient() as client:
|
||||
response = await client.post(
|
||||
upload_url,
|
||||
headers=headers,
|
||||
content=body,
|
||||
timeout=300.0 # 5分钟超时
|
||||
)
|
||||
|
||||
if response.status_code not in [200, 201, 308]:
|
||||
logger.error(f"Upload chunk failed: {response.status_code} - {response.text}")
|
||||
raise HTTPException(status_code=response.status_code, detail="Upload failed")
|
||||
|
||||
# 如果是最后一块,更新文件状态
|
||||
if is_final and response.status_code in [200, 201]:
|
||||
logger.debug(f"Upload finalized with status {response.status_code}")
|
||||
try:
|
||||
# 解析響應獲取文件信息
|
||||
response_data = response.json()
|
||||
logger.debug(f"Upload complete response data: {response_data}")
|
||||
file_data = response_data.get("file", {})
|
||||
|
||||
# 獲取真實的文件名
|
||||
real_file_name = file_data.get("name")
|
||||
logger.debug(f"Upload response: {response_data}")
|
||||
if real_file_name and files_service:
|
||||
logger.info(f"Upload completed, file name: {real_file_name}")
|
||||
|
||||
# 從會話中獲取信息
|
||||
session_info = await files_service.get_upload_session(upload_url)
|
||||
logger.debug(f"Retrieved session info for {upload_url}: {session_info}")
|
||||
if session_info:
|
||||
# 創建文件記錄
|
||||
now = datetime.now(timezone.utc)
|
||||
expiration_time = now + timedelta(hours=48)
|
||||
|
||||
# 處理過期時間格式(Google 可能返回納秒級精度)
|
||||
expiration_time_str = file_data.get("expirationTime", expiration_time.isoformat() + "Z")
|
||||
# 處理納秒格式:2025-07-11T02:02:52.531916141Z -> 2025-07-11T02:02:52.531916Z
|
||||
if expiration_time_str.endswith("Z"):
|
||||
# 移除 Z
|
||||
expiration_time_str = expiration_time_str[:-1]
|
||||
# 如果有納秒(超過6位小數),截斷到微秒
|
||||
if "." in expiration_time_str:
|
||||
date_part, frac_part = expiration_time_str.rsplit(".", 1)
|
||||
if len(frac_part) > 6:
|
||||
frac_part = frac_part[:6]
|
||||
expiration_time_str = f"{date_part}.{frac_part}"
|
||||
# 添加時區
|
||||
expiration_time_str += "+00:00"
|
||||
|
||||
# 獲取文件狀態(Google 可能返回 PROCESSING)
|
||||
file_state = file_data.get("state", "PROCESSING")
|
||||
logger.debug(f"File state from Google: {file_state}")
|
||||
|
||||
# 將字符串狀態轉換為枚舉
|
||||
if file_state == "ACTIVE":
|
||||
state_enum = FileState.ACTIVE
|
||||
elif file_state == "PROCESSING":
|
||||
state_enum = FileState.PROCESSING
|
||||
elif file_state == "FAILED":
|
||||
state_enum = FileState.FAILED
|
||||
else:
|
||||
logger.warning(f"Unknown file state: {file_state}, defaulting to PROCESSING")
|
||||
state_enum = FileState.PROCESSING
|
||||
|
||||
await db_services.create_file_record(
|
||||
name=real_file_name,
|
||||
mime_type=file_data.get("mimeType", session_info["mime_type"]),
|
||||
size_bytes=int(file_data.get("sizeBytes", session_info["size_bytes"])),
|
||||
api_key=session_info["api_key"],
|
||||
uri=file_data.get("uri", f"{settings.BASE_URL}/{real_file_name}"),
|
||||
create_time=now,
|
||||
update_time=now,
|
||||
expiration_time=datetime.fromisoformat(expiration_time_str),
|
||||
state=state_enum,
|
||||
display_name=file_data.get("displayName", session_info.get("display_name", "")),
|
||||
sha256_hash=file_data.get("sha256Hash"),
|
||||
user_token=session_info["user_token"]
|
||||
)
|
||||
logger.info(f"Created file record: name={real_file_name}, api_key={redact_key_for_logging(session_info['api_key'])}")
|
||||
else:
|
||||
logger.warning(f"No upload session found for URL: {upload_url}")
|
||||
else:
|
||||
logger.warning(f"Missing real_file_name or files_service: real_file_name={real_file_name}, files_service={files_service}")
|
||||
|
||||
# 返回完整的文件信息
|
||||
return Response(
|
||||
content=response.content,
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create file record: {str(e)}", exc_info=True)
|
||||
else:
|
||||
logger.debug(f"Upload chunk processed: is_final={is_final}, status={response.status_code}")
|
||||
|
||||
# 返回响应
|
||||
response_headers = dict(response.headers)
|
||||
|
||||
# 确保包含必要的头
|
||||
if response.status_code == 308: # Resume Incomplete
|
||||
if "x-goog-upload-status" not in response_headers:
|
||||
response_headers["x-goog-upload-status"] = "active"
|
||||
|
||||
return Response(
|
||||
content=response.content,
|
||||
status_code=response.status_code,
|
||||
headers=response_headers
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to handle upload chunk: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}")
|
||||
|
||||
async def proxy_upload_request(
|
||||
self,
|
||||
request: Request,
|
||||
upload_url: str,
|
||||
files_service=None
|
||||
) -> Response:
|
||||
"""
|
||||
代理上传请求
|
||||
|
||||
Args:
|
||||
request: FastAPI 请求对象
|
||||
upload_url: 目标上传 URL
|
||||
files_service: 文件服務實例
|
||||
|
||||
Returns:
|
||||
Response: 代理响应
|
||||
"""
|
||||
logger.debug(f"Proxy upload request: {request.method}, {upload_url}")
|
||||
try:
|
||||
# 如果是 GET 请求,返回上传状态
|
||||
if request.method == "GET":
|
||||
return await self._get_upload_status(upload_url)
|
||||
|
||||
# 处理 POST/PUT 请求
|
||||
return await self.handle_upload_chunk(upload_url, request, files_service)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to proxy upload request: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}")
|
||||
|
||||
async def _get_upload_status(self, upload_url: str) -> Response:
|
||||
"""
|
||||
获取上传状态
|
||||
|
||||
Args:
|
||||
upload_url: 上传 URL
|
||||
|
||||
Returns:
|
||||
Response: 状态响应
|
||||
"""
|
||||
try:
|
||||
async with AsyncClient() as client:
|
||||
response = await client.get(upload_url)
|
||||
|
||||
return Response(
|
||||
content=response.content,
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get upload status: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}")
|
||||
|
||||
|
||||
# 单例实例
|
||||
_upload_handler_instance: Optional[FileUploadHandler] = None
|
||||
|
||||
|
||||
def get_upload_handler() -> FileUploadHandler:
|
||||
"""获取上传处理器单例实例"""
|
||||
global _upload_handler_instance
|
||||
if _upload_handler_instance is None:
|
||||
_upload_handler_instance = FileUploadHandler()
|
||||
return _upload_handler_instance
|
||||
499
app/service/files/files_service.py
Normal file
499
app/service/files/files_service.py
Normal file
@@ -0,0 +1,499 @@
|
||||
"""
|
||||
文件管理服务
|
||||
"""
|
||||
import json
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional, Dict, Any, Tuple
|
||||
from httpx import AsyncClient
|
||||
import asyncio
|
||||
|
||||
from app.config.config import settings
|
||||
from app.database import services as db_services
|
||||
from app.database.models import FileState
|
||||
from app.domain.file_models import FileMetadata, ListFilesResponse
|
||||
from fastapi import HTTPException
|
||||
from app.log.logger import get_files_logger
|
||||
from app.utils.helpers import redact_key_for_logging
|
||||
from app.service.client.api_client import GeminiApiClient
|
||||
from app.service.key.key_manager import get_key_manager_instance
|
||||
|
||||
logger = get_files_logger()
|
||||
|
||||
# 全局上傳會話存儲
|
||||
_upload_sessions: Dict[str, Dict[str, Any]] = {}
|
||||
_upload_sessions_lock = asyncio.Lock()
|
||||
|
||||
|
||||
class FilesService:
|
||||
"""文件管理服务类"""
|
||||
|
||||
def __init__(self):
|
||||
self.api_client = GeminiApiClient(base_url=settings.BASE_URL)
|
||||
self.key_manager = None
|
||||
|
||||
async def _get_key_manager(self):
|
||||
"""获取 KeyManager 实例"""
|
||||
if not self.key_manager:
|
||||
self.key_manager = await get_key_manager_instance(
|
||||
settings.API_KEYS,
|
||||
settings.VERTEX_API_KEYS
|
||||
)
|
||||
return self.key_manager
|
||||
|
||||
async def initialize_upload(
|
||||
self,
|
||||
headers: Dict[str, str],
|
||||
body: Optional[bytes],
|
||||
user_token: str,
|
||||
request_host: str = None # 添加請求主機參數
|
||||
) -> Tuple[Dict[str, Any], Dict[str, str]]:
|
||||
"""
|
||||
初始化文件上传
|
||||
|
||||
Args:
|
||||
headers: 请求头
|
||||
body: 请求体
|
||||
user_token: 用户令牌
|
||||
|
||||
Returns:
|
||||
Tuple[Dict[str, Any], Dict[str, str]]: (响应体, 响应头)
|
||||
"""
|
||||
try:
|
||||
# 获取可用的 API key
|
||||
key_manager = await self._get_key_manager()
|
||||
api_key = await key_manager.get_next_key()
|
||||
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=503, detail="No available API keys")
|
||||
|
||||
# 转发请求到真实的 Gemini API
|
||||
async with AsyncClient() as client:
|
||||
# 准备请求头
|
||||
forward_headers = {
|
||||
"X-Goog-Upload-Protocol": headers.get("x-goog-upload-protocol", "resumable"),
|
||||
"X-Goog-Upload-Command": headers.get("x-goog-upload-command", "start"),
|
||||
"Content-Type": headers.get("content-type", "application/json"),
|
||||
}
|
||||
|
||||
# 添加其他必要的头
|
||||
if "x-goog-upload-header-content-length" in headers:
|
||||
forward_headers["X-Goog-Upload-Header-Content-Length"] = headers["x-goog-upload-header-content-length"]
|
||||
if "x-goog-upload-header-content-type" in headers:
|
||||
forward_headers["X-Goog-Upload-Header-Content-Type"] = headers["x-goog-upload-header-content-type"]
|
||||
|
||||
# 发送请求
|
||||
response = await client.post(
|
||||
"https://generativelanguage.googleapis.com/upload/v1beta/files",
|
||||
headers=forward_headers,
|
||||
content=body,
|
||||
params={"key": api_key}
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"Upload initialization failed: {response.status_code} - {response.text}")
|
||||
raise HTTPException(status_code=response.status_code, detail="Upload initialization failed")
|
||||
|
||||
# 获取上传 URL
|
||||
upload_url = response.headers.get("x-goog-upload-url")
|
||||
if not upload_url:
|
||||
raise HTTPException(status_code=500, detail="No upload URL in response")
|
||||
|
||||
logger.info(f"Original upload URL from Google: {upload_url}")
|
||||
|
||||
|
||||
# 儲存上傳資訊到 headers 中,供後續使用
|
||||
# 不在這裡創建數據庫記錄,等到上傳完成後再創建
|
||||
logger.info(f"Upload initialized with API key: {redact_key_for_logging(api_key)}")
|
||||
|
||||
# 解析响应 - 初始化响应可能是空的
|
||||
response_data = {}
|
||||
|
||||
# 從請求體中解析文件信息(如果有)
|
||||
display_name = ""
|
||||
if body:
|
||||
try:
|
||||
request_data = json.loads(body)
|
||||
display_name = request_data.get("displayName", "")
|
||||
except Exception:
|
||||
pass
|
||||
# 從 upload URL 中提取 upload_id
|
||||
import urllib.parse
|
||||
parsed_url = urllib.parse.urlparse(upload_url)
|
||||
query_params = urllib.parse.parse_qs(parsed_url.query)
|
||||
upload_id = query_params.get('upload_id', [None])[0]
|
||||
|
||||
if upload_id:
|
||||
# 儲存上傳會話信息,使用 upload_id 作為 key
|
||||
async with _upload_sessions_lock:
|
||||
_upload_sessions[upload_id] = {
|
||||
"api_key": api_key,
|
||||
"user_token": user_token,
|
||||
"display_name": display_name,
|
||||
"mime_type": headers.get("x-goog-upload-header-content-type", "application/octet-stream"),
|
||||
"size_bytes": int(headers.get("x-goog-upload-header-content-length", "0")),
|
||||
"created_at": datetime.now(timezone.utc),
|
||||
"upload_url": upload_url
|
||||
}
|
||||
logger.info(f"Stored upload session for upload_id={upload_id}: api_key={redact_key_for_logging(api_key)}")
|
||||
logger.debug(f"Total active sessions: {len(_upload_sessions)}")
|
||||
else:
|
||||
logger.warning(f"No upload_id found in upload URL: {upload_url}")
|
||||
|
||||
# 定期清理過期的會話(超過1小時)
|
||||
asyncio.create_task(self._cleanup_expired_sessions())
|
||||
|
||||
# 替換 Google 的 URL 為我們的代理 URL
|
||||
proxy_upload_url = upload_url
|
||||
if request_host:
|
||||
# 原始: https://generativelanguage.googleapis.com/upload/v1beta/files?key=AIzaSyDc...&upload_id=xxx&upload_protocol=resumable
|
||||
# 替換為: http://request-host/upload/v1beta/files?key=sk-123456&upload_id=xxx&upload_protocol=resumable
|
||||
|
||||
# 先替換域名
|
||||
proxy_upload_url = upload_url.replace(
|
||||
"https://generativelanguage.googleapis.com",
|
||||
request_host.rstrip('/')
|
||||
)
|
||||
|
||||
# 再替換 key 參數
|
||||
import re
|
||||
# 匹配 key=xxx 參數
|
||||
key_pattern = r'(\?|&)key=([^&]+)'
|
||||
match = re.search(key_pattern, proxy_upload_url)
|
||||
if match:
|
||||
# 替換為我們的 token
|
||||
proxy_upload_url = proxy_upload_url.replace(
|
||||
f"{match.group(1)}key={match.group(2)}",
|
||||
f"{match.group(1)}key={user_token}"
|
||||
)
|
||||
|
||||
logger.info(f"Replaced upload URL: {upload_url} -> {proxy_upload_url}")
|
||||
|
||||
return response_data, {
|
||||
"X-Goog-Upload-URL": proxy_upload_url,
|
||||
"X-Goog-Upload-Status": "active"
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize upload: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}")
|
||||
|
||||
async def _cleanup_expired_sessions(self):
|
||||
"""清理過期的上傳會話"""
|
||||
try:
|
||||
async with _upload_sessions_lock:
|
||||
now = datetime.now(timezone.utc)
|
||||
expired_keys = []
|
||||
for key, session in _upload_sessions.items():
|
||||
if now - session["created_at"] > timedelta(hours=1):
|
||||
expired_keys.append(key)
|
||||
|
||||
for key in expired_keys:
|
||||
del _upload_sessions[key]
|
||||
|
||||
if expired_keys:
|
||||
logger.info(f"Cleaned up {len(expired_keys)} expired upload sessions")
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up upload sessions: {str(e)}")
|
||||
|
||||
async def get_upload_session(self, key: str) -> Optional[Dict[str, Any]]:
|
||||
"""獲取上傳會話信息(支持 upload_id 或完整 URL)"""
|
||||
async with _upload_sessions_lock:
|
||||
# 先嘗試直接查找
|
||||
session = _upload_sessions.get(key)
|
||||
if session:
|
||||
logger.debug(f"Found session by direct key {redact_key_for_logging(key)}")
|
||||
return session
|
||||
|
||||
# 如果是 URL,嘗試提取 upload_id
|
||||
if key.startswith("http"):
|
||||
import urllib.parse
|
||||
parsed_url = urllib.parse.urlparse(key)
|
||||
query_params = urllib.parse.parse_qs(parsed_url.query)
|
||||
upload_id = query_params.get('upload_id', [None])[0]
|
||||
if upload_id:
|
||||
session = _upload_sessions.get(upload_id)
|
||||
if session:
|
||||
logger.debug(f"Found session by upload_id {upload_id} from URL")
|
||||
return session
|
||||
|
||||
logger.debug(f"No session found for key: {redact_key_for_logging(key)}")
|
||||
return None
|
||||
|
||||
async def get_file(self, file_name: str, user_token: str) -> FileMetadata:
|
||||
"""
|
||||
获取文件信息
|
||||
|
||||
Args:
|
||||
file_name: 文件名称 (格式: files/{file_id})
|
||||
user_token: 用户令牌
|
||||
|
||||
Returns:
|
||||
FileMetadata: 文件元数据
|
||||
"""
|
||||
try:
|
||||
# 查询文件记录
|
||||
file_record = await db_services.get_file_record_by_name(file_name)
|
||||
|
||||
if not file_record:
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
# 检查是否过期
|
||||
expiration_time = datetime.fromisoformat(str(file_record["expiration_time"]))
|
||||
# 如果是 naive datetime,假设为 UTC
|
||||
if expiration_time.tzinfo is None:
|
||||
expiration_time = expiration_time.replace(tzinfo=timezone.utc)
|
||||
if expiration_time <= datetime.now(timezone.utc):
|
||||
raise HTTPException(status_code=404, detail="File has expired")
|
||||
|
||||
# 使用原始 API key 获取文件信息
|
||||
api_key = file_record["api_key"]
|
||||
|
||||
async with AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{settings.BASE_URL}/{file_name}",
|
||||
params={"key": api_key}
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"Failed to get file: {response.status_code} - {response.text}")
|
||||
raise HTTPException(status_code=response.status_code, detail="Failed to get file")
|
||||
|
||||
file_data = response.json()
|
||||
|
||||
# 檢查並更新文件狀態
|
||||
google_state = file_data.get("state", "PROCESSING")
|
||||
if google_state != file_record.get("state", "").value if file_record.get("state") else None:
|
||||
logger.info(f"File state changed from {file_record.get('state')} to {google_state}")
|
||||
# 更新數據庫中的狀態
|
||||
if google_state == "ACTIVE":
|
||||
await db_services.update_file_record_state(
|
||||
file_name=file_name,
|
||||
state=FileState.ACTIVE,
|
||||
update_time=datetime.now(timezone.utc)
|
||||
)
|
||||
elif google_state == "FAILED":
|
||||
await db_services.update_file_record_state(
|
||||
file_name=file_name,
|
||||
state=FileState.FAILED,
|
||||
update_time=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
# 构建响应
|
||||
return FileMetadata(
|
||||
name=file_data["name"],
|
||||
displayName=file_data.get("displayName"),
|
||||
mimeType=file_data["mimeType"],
|
||||
sizeBytes=str(file_data["sizeBytes"]),
|
||||
createTime=file_data["createTime"],
|
||||
updateTime=file_data["updateTime"],
|
||||
expirationTime=file_data["expirationTime"],
|
||||
sha256Hash=file_data.get("sha256Hash"),
|
||||
uri=file_data["uri"],
|
||||
state=google_state
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get file {file_name}: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}")
|
||||
|
||||
async def list_files(
|
||||
self,
|
||||
page_size: int = 10,
|
||||
page_token: Optional[str] = None,
|
||||
user_token: Optional[str] = None
|
||||
) -> ListFilesResponse:
|
||||
"""
|
||||
列出文件
|
||||
|
||||
Args:
|
||||
page_size: 每页大小
|
||||
page_token: 分页标记
|
||||
user_token: 用户令牌(可选,如果提供则只返回该用户的文件)
|
||||
|
||||
Returns:
|
||||
ListFilesResponse: 文件列表响应
|
||||
"""
|
||||
try:
|
||||
logger.debug(f"list_files called with page_size={page_size}, page_token={page_token}")
|
||||
|
||||
# 从数据库获取文件列表
|
||||
files, next_page_token = await db_services.list_file_records(
|
||||
user_token=user_token,
|
||||
page_size=page_size,
|
||||
page_token=page_token
|
||||
)
|
||||
|
||||
logger.debug(f"Database returned {len(files)} files, next_page_token={next_page_token}")
|
||||
|
||||
# 转换为响应格式
|
||||
file_list = []
|
||||
for file_record in files:
|
||||
file_list.append(FileMetadata(
|
||||
name=file_record["name"],
|
||||
displayName=file_record.get("display_name"),
|
||||
mimeType=file_record["mime_type"],
|
||||
sizeBytes=str(file_record["size_bytes"]),
|
||||
createTime=file_record["create_time"].isoformat() + "Z",
|
||||
updateTime=file_record["update_time"].isoformat() + "Z",
|
||||
expirationTime=file_record["expiration_time"].isoformat() + "Z",
|
||||
sha256Hash=file_record.get("sha256_hash"),
|
||||
uri=file_record["uri"],
|
||||
state=file_record["state"].value if file_record.get("state") else "ACTIVE"
|
||||
))
|
||||
|
||||
response = ListFilesResponse(
|
||||
files=file_list,
|
||||
nextPageToken=next_page_token
|
||||
)
|
||||
|
||||
logger.debug(f"Returning response with {len(response.files)} files, nextPageToken={response.nextPageToken}")
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list files: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}")
|
||||
|
||||
async def delete_file(self, file_name: str, user_token: str) -> bool:
|
||||
"""
|
||||
删除文件
|
||||
|
||||
Args:
|
||||
file_name: 文件名称
|
||||
user_token: 用户令牌
|
||||
|
||||
Returns:
|
||||
bool: 是否删除成功
|
||||
"""
|
||||
try:
|
||||
# 查询文件记录
|
||||
file_record = await db_services.get_file_record_by_name(file_name)
|
||||
|
||||
if not file_record:
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
# 使用原始 API key 删除文件
|
||||
api_key = file_record["api_key"]
|
||||
|
||||
async with AsyncClient() as client:
|
||||
response = await client.delete(
|
||||
f"{settings.BASE_URL}/{file_name}",
|
||||
params={"key": api_key}
|
||||
)
|
||||
|
||||
if response.status_code not in [200, 204]:
|
||||
logger.error(f"Failed to delete file: {response.status_code} - {response.text}")
|
||||
# 如果 API 删除失败,但文件已过期,仍然删除数据库记录
|
||||
expiration_time = datetime.fromisoformat(str(file_record["expiration_time"]))
|
||||
if expiration_time.tzinfo is None:
|
||||
expiration_time = expiration_time.replace(tzinfo=timezone.utc)
|
||||
if expiration_time <= datetime.now(timezone.utc):
|
||||
await db_services.delete_file_record(file_name)
|
||||
return True
|
||||
raise HTTPException(status_code=response.status_code, detail="Failed to delete file")
|
||||
|
||||
# 删除数据库记录
|
||||
await db_services.delete_file_record(file_name)
|
||||
return True
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete file {file_name}: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}")
|
||||
|
||||
async def check_file_state(self, file_name: str, api_key: str) -> str:
|
||||
"""
|
||||
檢查並更新文件狀態
|
||||
|
||||
Args:
|
||||
file_name: 文件名稱
|
||||
api_key: API密鑰
|
||||
|
||||
Returns:
|
||||
str: 當前狀態
|
||||
"""
|
||||
try:
|
||||
async with AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{settings.BASE_URL}/{file_name}",
|
||||
params={"key": api_key}
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"Failed to check file state: {response.status_code}")
|
||||
return "UNKNOWN"
|
||||
|
||||
file_data = response.json()
|
||||
google_state = file_data.get("state", "PROCESSING")
|
||||
|
||||
# 更新數據庫狀態
|
||||
if google_state == "ACTIVE":
|
||||
await db_services.update_file_record_state(
|
||||
file_name=file_name,
|
||||
state=FileState.ACTIVE,
|
||||
update_time=datetime.now(timezone.utc)
|
||||
)
|
||||
elif google_state == "FAILED":
|
||||
await db_services.update_file_record_state(
|
||||
file_name=file_name,
|
||||
state=FileState.FAILED,
|
||||
update_time=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
return google_state
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check file state: {str(e)}")
|
||||
return "UNKNOWN"
|
||||
|
||||
async def cleanup_expired_files(self) -> int:
|
||||
"""
|
||||
清理过期文件
|
||||
|
||||
Returns:
|
||||
int: 清理的文件数量
|
||||
"""
|
||||
try:
|
||||
# 获取过期文件
|
||||
expired_files = await db_services.delete_expired_file_records()
|
||||
|
||||
if not expired_files:
|
||||
return 0
|
||||
|
||||
# 尝试从 Gemini API 删除文件
|
||||
for file_record in expired_files:
|
||||
try:
|
||||
api_key = file_record["api_key"]
|
||||
file_name = file_record["name"]
|
||||
|
||||
async with AsyncClient() as client:
|
||||
await client.delete(
|
||||
f"{settings.BASE_URL}/{file_name}",
|
||||
params={"key": api_key}
|
||||
)
|
||||
except Exception as e:
|
||||
# 记录错误但继续处理其他文件
|
||||
logger.error(f"Failed to delete file {file_record['name']} from API: {str(e)}")
|
||||
|
||||
return len(expired_files)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cleanup expired files: {str(e)}")
|
||||
return 0
|
||||
|
||||
|
||||
# 单例实例
|
||||
_files_service_instance: Optional[FilesService] = None
|
||||
|
||||
|
||||
async def get_files_service() -> FilesService:
|
||||
"""获取文件服务单例实例"""
|
||||
global _files_service_instance
|
||||
if _files_service_instance is None:
|
||||
_files_service_instance = FilesService()
|
||||
return _files_service_instance
|
||||
@@ -9,6 +9,7 @@ from app.config.config import settings
|
||||
from app.core.constants import VALID_IMAGE_RATIOS
|
||||
from app.domain.openai_models import ImageGenerationRequest
|
||||
from app.log.logger import get_image_create_logger
|
||||
from app.utils.helpers import is_image_upload_configured
|
||||
from app.utils.uploader import ImageUploaderFactory
|
||||
|
||||
logger = get_image_create_logger()
|
||||
@@ -97,12 +98,18 @@ class ImageCreateService:
|
||||
image_data = generated_image.image.image_bytes
|
||||
image_uploader = None
|
||||
|
||||
if request.response_format == "b64_json":
|
||||
# Return base64 if explicitly requested or if no uploader is configured
|
||||
if (
|
||||
request.response_format == "b64_json"
|
||||
or not is_image_upload_configured(settings)
|
||||
):
|
||||
base64_image = base64.b64encode(image_data).decode("utf-8")
|
||||
images_data.append(
|
||||
{"b64_json": base64_image, "revised_prompt": request.prompt}
|
||||
)
|
||||
continue
|
||||
else:
|
||||
# Upload to configured provider
|
||||
current_date = time.strftime("%Y/%m/%d")
|
||||
filename = f"{current_date}/{uuid.uuid4().hex[:8]}.png"
|
||||
|
||||
@@ -115,12 +122,24 @@ class ImageCreateService:
|
||||
image_uploader = ImageUploaderFactory.create(
|
||||
provider=settings.UPLOAD_PROVIDER,
|
||||
api_key=settings.PICGO_API_KEY,
|
||||
api_url=settings.PICGO_API_URL,
|
||||
)
|
||||
elif settings.UPLOAD_PROVIDER == "cloudflare_imgbed":
|
||||
image_uploader = ImageUploaderFactory.create(
|
||||
provider=settings.UPLOAD_PROVIDER,
|
||||
base_url=settings.CLOUDFLARE_IMGBED_URL,
|
||||
auth_code=settings.CLOUDFLARE_IMGBED_AUTH_CODE,
|
||||
upload_folder=settings.CLOUDFLARE_IMGBED_UPLOAD_FOLDER,
|
||||
)
|
||||
elif settings.UPLOAD_PROVIDER == "aliyun_oss":
|
||||
image_uploader = ImageUploaderFactory.create(
|
||||
provider=settings.UPLOAD_PROVIDER,
|
||||
access_key=settings.OSS_ACCESS_KEY,
|
||||
access_key_secret=settings.OSS_ACCESS_KEY_SECRET,
|
||||
bucket_name=settings.OSS_BUCKET_NAME,
|
||||
endpoint=settings.OSS_ENDPOINT,
|
||||
region=settings.OSS_REGION,
|
||||
use_internal=False
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import asyncio
|
||||
import random
|
||||
from itertools import cycle
|
||||
from typing import Dict, Union
|
||||
|
||||
from app.config.config import settings
|
||||
from app.log.logger import get_key_manager_logger
|
||||
from app.utils.helpers import redact_key_for_logging
|
||||
|
||||
logger = get_key_manager_logger()
|
||||
|
||||
@@ -34,7 +36,7 @@ class KeyManager:
|
||||
return next(self.key_cycle)
|
||||
|
||||
async def get_next_vertex_key(self) -> str:
|
||||
"""获取下一个 Vertex API key"""
|
||||
"""获取下一个 Vertex Express API key"""
|
||||
async with self.vertex_key_cycle_lock:
|
||||
return next(self.vertex_key_cycle)
|
||||
|
||||
@@ -65,7 +67,7 @@ class KeyManager:
|
||||
async with self.failure_count_lock:
|
||||
if key in self.key_failure_counts:
|
||||
self.key_failure_counts[key] = 0
|
||||
logger.info(f"Reset failure count for key: {key}")
|
||||
logger.info(f"Reset failure count for key: {redact_key_for_logging(key)}")
|
||||
return True
|
||||
logger.warning(
|
||||
f"Attempt to reset failure count for non-existent key: {key}"
|
||||
@@ -77,7 +79,7 @@ class KeyManager:
|
||||
async with self.vertex_failure_count_lock:
|
||||
if key in self.vertex_key_failure_counts:
|
||||
self.vertex_key_failure_counts[key] = 0
|
||||
logger.info(f"Reset failure count for Vertex key: {key}")
|
||||
logger.info(f"Reset failure count for Vertex key: {redact_key_for_logging(key)}")
|
||||
return True
|
||||
logger.warning(
|
||||
f"Attempt to reset failure count for non-existent Vertex key: {key}"
|
||||
@@ -98,7 +100,7 @@ class KeyManager:
|
||||
return current_key
|
||||
|
||||
async def get_next_working_vertex_key(self) -> str:
|
||||
"""获取下一可用的 Vertex API key"""
|
||||
"""获取下一可用的 Vertex Express API key"""
|
||||
initial_key = await self.get_next_vertex_key()
|
||||
current_key = initial_key
|
||||
|
||||
@@ -116,7 +118,7 @@ class KeyManager:
|
||||
self.key_failure_counts[api_key] += 1
|
||||
if self.key_failure_counts[api_key] >= self.MAX_FAILURES:
|
||||
logger.warning(
|
||||
f"API key {api_key} has failed {self.MAX_FAILURES} times"
|
||||
f"API key {redact_key_for_logging(api_key)} has failed {self.MAX_FAILURES} times"
|
||||
)
|
||||
if retries < settings.MAX_RETRIES:
|
||||
return await self.get_next_working_key()
|
||||
@@ -124,12 +126,12 @@ class KeyManager:
|
||||
return ""
|
||||
|
||||
async def handle_vertex_api_failure(self, api_key: str, retries: int) -> str:
|
||||
"""处理 Vertex API 调用失败"""
|
||||
"""处理 Vertex Express API 调用失败"""
|
||||
async with self.vertex_failure_count_lock:
|
||||
self.vertex_key_failure_counts[api_key] += 1
|
||||
if self.vertex_key_failure_counts[api_key] >= self.MAX_FAILURES:
|
||||
logger.warning(
|
||||
f"Vertex API key {api_key} has failed {self.MAX_FAILURES} times"
|
||||
f"Vertex Express API key {redact_key_for_logging(api_key)} has failed {self.MAX_FAILURES} times"
|
||||
)
|
||||
|
||||
def get_fail_count(self, key: str) -> int:
|
||||
@@ -140,6 +142,18 @@ class KeyManager:
|
||||
"""获取指定 Vertex 密钥的失败次数"""
|
||||
return self.vertex_key_failure_counts.get(key, 0)
|
||||
|
||||
async def get_all_keys_with_fail_count(self) -> dict:
|
||||
"""获取所有API key及其失败次数"""
|
||||
all_keys = {}
|
||||
async with self.failure_count_lock:
|
||||
for key in self.api_keys:
|
||||
all_keys[key] = self.key_failure_counts.get(key, 0)
|
||||
|
||||
valid_keys = {k: v for k, v in all_keys.items() if v < self.MAX_FAILURES}
|
||||
invalid_keys = {k: v for k, v in all_keys.items() if v >= self.MAX_FAILURES}
|
||||
|
||||
return {"valid_keys": valid_keys, "invalid_keys": invalid_keys, "all_keys": all_keys}
|
||||
|
||||
async def get_keys_by_status(self) -> dict:
|
||||
"""获取分类后的API key列表,包括失败次数"""
|
||||
valid_keys = {}
|
||||
@@ -156,7 +170,7 @@ class KeyManager:
|
||||
return {"valid_keys": valid_keys, "invalid_keys": invalid_keys}
|
||||
|
||||
async def get_vertex_keys_by_status(self) -> dict:
|
||||
"""获取分类后的 Vertex API key 列表,包括失败次数"""
|
||||
"""获取分类后的 Vertex Express API key 列表,包括失败次数"""
|
||||
valid_keys = {}
|
||||
invalid_keys = {}
|
||||
|
||||
@@ -178,11 +192,29 @@ class KeyManager:
|
||||
if self.api_keys:
|
||||
return self.api_keys[0]
|
||||
if not self.api_keys:
|
||||
logger.warning(
|
||||
"API key list is empty, cannot get first valid key.")
|
||||
logger.warning("API key list is empty, cannot get first valid key.")
|
||||
return ""
|
||||
return self.api_keys[0]
|
||||
|
||||
async def get_random_valid_key(self) -> str:
|
||||
"""获取随机的有效API key"""
|
||||
valid_keys = []
|
||||
async with self.failure_count_lock:
|
||||
for key in self.key_failure_counts:
|
||||
if self.key_failure_counts[key] < self.MAX_FAILURES:
|
||||
valid_keys.append(key)
|
||||
|
||||
if valid_keys:
|
||||
return random.choice(valid_keys)
|
||||
|
||||
# 如果没有有效的key,返回第一个key作为fallback
|
||||
if self.api_keys:
|
||||
logger.warning("No valid keys available, returning first key as fallback.")
|
||||
return self.api_keys[0]
|
||||
|
||||
logger.warning("API key list is empty, cannot get random valid key.")
|
||||
return ""
|
||||
|
||||
|
||||
_singleton_instance = None
|
||||
_singleton_lock = asyncio.Lock()
|
||||
@@ -214,7 +246,7 @@ async def get_key_manager_instance(
|
||||
)
|
||||
if vertex_api_keys is None:
|
||||
raise ValueError(
|
||||
"Vertex API keys are required to initialize or re-initialize the KeyManager instance."
|
||||
"Vertex Express API keys are required to initialize or re-initialize the KeyManager instance."
|
||||
)
|
||||
|
||||
if not api_keys:
|
||||
@@ -223,12 +255,12 @@ async def get_key_manager_instance(
|
||||
)
|
||||
if not vertex_api_keys:
|
||||
logger.warning(
|
||||
"Initializing KeyManager with an empty list of Vertex API keys."
|
||||
"Initializing KeyManager with an empty list of Vertex Express API keys."
|
||||
)
|
||||
|
||||
_singleton_instance = KeyManager(api_keys, vertex_api_keys)
|
||||
logger.info(
|
||||
f"KeyManager instance created/re-created with {len(api_keys)} API keys and {len(vertex_api_keys)} Vertex API keys."
|
||||
f"KeyManager instance created/re-created with {len(api_keys)} API keys and {len(vertex_api_keys)} Vertex Express API keys."
|
||||
)
|
||||
|
||||
# 1. 恢复失败计数
|
||||
@@ -253,8 +285,7 @@ async def get_key_manager_instance(
|
||||
_singleton_instance.vertex_key_failure_counts = (
|
||||
current_vertex_failure_counts
|
||||
)
|
||||
logger.info(
|
||||
"Inherited failure counts for applicable Vertex keys.")
|
||||
logger.info("Inherited failure counts for applicable Vertex keys.")
|
||||
_preserved_vertex_failure_counts = None
|
||||
|
||||
# 2. 调整 key_cycle 的起始点
|
||||
@@ -351,7 +382,7 @@ async def get_key_manager_instance(
|
||||
break
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
f"Preserved next key '{_preserved_vertex_next_key_in_cycle}' not found in preserved old Vertex API keys. "
|
||||
f"Preserved next key '{_preserved_vertex_next_key_in_cycle}' not found in preserved old Vertex Express API keys. "
|
||||
"New cycle will start from the beginning of the new list."
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -372,12 +403,12 @@ async def get_key_manager_instance(
|
||||
)
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
f"Determined start key '{start_key_for_new_vertex_cycle}' not found in new Vertex API keys during cycle advancement. "
|
||||
f"Determined start key '{start_key_for_new_vertex_cycle}' not found in new Vertex Express API keys during cycle advancement. "
|
||||
"New cycle will start from the beginning."
|
||||
)
|
||||
except StopIteration:
|
||||
logger.error(
|
||||
"StopIteration while advancing Vertex key cycle, implies empty new Vertex API key list previously missed."
|
||||
"StopIteration while advancing Vertex key cycle, implies empty new Vertex Express API key list previously missed."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
@@ -386,11 +417,11 @@ async def get_key_manager_instance(
|
||||
else:
|
||||
if _singleton_instance.vertex_api_keys:
|
||||
logger.info(
|
||||
"New Vertex key cycle will start from the beginning of the new Vertex API key list (no specific start key determined or needed)."
|
||||
"New Vertex key cycle will start from the beginning of the new Vertex Express API key list (no specific start key determined or needed)."
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"New Vertex key cycle not applicable as the new Vertex API key list is empty."
|
||||
"New Vertex key cycle not applicable as the new Vertex Express API key list is empty."
|
||||
)
|
||||
|
||||
# 清理所有保存的状态
|
||||
@@ -411,11 +442,15 @@ async def reset_key_manager_instance():
|
||||
if _singleton_instance:
|
||||
# 1. 保存失败计数
|
||||
_preserved_failure_counts = _singleton_instance.key_failure_counts.copy()
|
||||
_preserved_vertex_failure_counts = _singleton_instance.vertex_key_failure_counts.copy()
|
||||
_preserved_vertex_failure_counts = (
|
||||
_singleton_instance.vertex_key_failure_counts.copy()
|
||||
)
|
||||
|
||||
# 2. 保存旧的 API keys 列表
|
||||
_preserved_old_api_keys_for_reset = _singleton_instance.api_keys.copy()
|
||||
_preserved_vertex_old_api_keys_for_reset = _singleton_instance.vertex_api_keys.copy()
|
||||
_preserved_vertex_old_api_keys_for_reset = (
|
||||
_singleton_instance.vertex_api_keys.copy()
|
||||
)
|
||||
|
||||
# 3. 保存 key_cycle 的下一个 key 提示
|
||||
try:
|
||||
@@ -431,8 +466,7 @@ async def reset_key_manager_instance():
|
||||
)
|
||||
_preserved_next_key_in_cycle = None
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error preserving next key hint during reset: {e}")
|
||||
logger.error(f"Error preserving next key hint during reset: {e}")
|
||||
_preserved_next_key_in_cycle = None
|
||||
|
||||
# 4. 保存 vertex_key_cycle 的下一个 key 提示
|
||||
@@ -449,8 +483,7 @@ async def reset_key_manager_instance():
|
||||
)
|
||||
_preserved_vertex_next_key_in_cycle = None
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error preserving next key hint during reset: {e}")
|
||||
logger.error(f"Error preserving next key hint during reset: {e}")
|
||||
_preserved_vertex_next_key_in_cycle = None
|
||||
|
||||
_singleton_instance = None
|
||||
|
||||
@@ -1,7 +1,4 @@
|
||||
|
||||
import datetime
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from typing import Any, AsyncGenerator, Dict, Union
|
||||
|
||||
@@ -11,19 +8,21 @@ from app.database.services import (
|
||||
add_request_log,
|
||||
)
|
||||
from app.domain.openai_models import ChatRequest, ImageGenerationRequest
|
||||
from app.log.logger import get_openai_compatible_logger
|
||||
from app.service.client.api_client import OpenaiApiClient
|
||||
from app.service.key.key_manager import KeyManager
|
||||
from app.log.logger import get_openai_compatible_logger
|
||||
from app.utils.helpers import redact_key_for_logging
|
||||
|
||||
logger = get_openai_compatible_logger()
|
||||
|
||||
|
||||
class OpenAICompatiableService:
|
||||
|
||||
def __init__(self, base_url: str, key_manager: KeyManager = None):
|
||||
self.key_manager = key_manager
|
||||
self.base_url = base_url
|
||||
self.api_client = OpenaiApiClient(base_url, settings.TIME_OUT)
|
||||
|
||||
|
||||
async def get_models(self, api_key: str) -> Dict[str, Any]:
|
||||
return await self.api_client.get_models(api_key)
|
||||
|
||||
@@ -36,10 +35,12 @@ class OpenAICompatiableService:
|
||||
request_dict = request.model_dump()
|
||||
# 移除值为null的
|
||||
request_dict = {k: v for k, v in request_dict.items() if v is not None}
|
||||
del request_dict["top_k"] # 删除top_k参数,目前不支持该参数
|
||||
del request_dict["top_k"] # 删除top_k参数,目前不支持该参数
|
||||
if request.stream:
|
||||
return self._handle_stream_completion(request.model, request_dict, api_key)
|
||||
return await self._handle_normal_completion(request.model, request_dict, api_key)
|
||||
return await self._handle_normal_completion(
|
||||
request.model, request_dict, api_key
|
||||
)
|
||||
|
||||
async def generate_images(
|
||||
self,
|
||||
@@ -77,13 +78,9 @@ class OpenAICompatiableService:
|
||||
return response
|
||||
except Exception as e:
|
||||
is_success = False
|
||||
error_log_msg = str(e)
|
||||
status_code = e.args[0]
|
||||
error_log_msg = e.args[1]
|
||||
logger.error(f"Normal API call failed with error: {error_log_msg}")
|
||||
match = re.search(r"status code (\d+)", error_log_msg)
|
||||
if match:
|
||||
status_code = int(match.group(1))
|
||||
else:
|
||||
status_code = 500
|
||||
|
||||
await add_error_log(
|
||||
gemini_key=api_key,
|
||||
@@ -91,7 +88,7 @@ class OpenAICompatiableService:
|
||||
error_type="openai-compatiable-non-stream",
|
||||
error_log=error_log_msg,
|
||||
error_code=status_code,
|
||||
request_msg=request,
|
||||
request_msg=request if settings.ERROR_LOG_RECORD_REQUEST_BODY else None,
|
||||
)
|
||||
raise e
|
||||
finally:
|
||||
@@ -135,15 +132,11 @@ class OpenAICompatiableService:
|
||||
except Exception as e:
|
||||
retries += 1
|
||||
is_success = False
|
||||
error_log_msg = str(e)
|
||||
status_code = e.args[0]
|
||||
error_log_msg = e.args[1]
|
||||
logger.warning(
|
||||
f"Streaming API call failed with error: {error_log_msg}. Attempt {retries} of {max_retries}"
|
||||
)
|
||||
match = re.search(r"status code (\d+)", error_log_msg)
|
||||
if match:
|
||||
status_code = int(match.group(1))
|
||||
else:
|
||||
status_code = 500
|
||||
|
||||
await add_error_log(
|
||||
gemini_key=current_attempt_key,
|
||||
@@ -151,7 +144,10 @@ class OpenAICompatiableService:
|
||||
error_type="openai-compatiable-stream",
|
||||
error_log=error_log_msg,
|
||||
error_code=status_code,
|
||||
request_msg=payload,
|
||||
request_msg=(
|
||||
payload if settings.ERROR_LOG_RECORD_REQUEST_BODY else None
|
||||
),
|
||||
request_datetime=request_datetime,
|
||||
)
|
||||
|
||||
if self.key_manager:
|
||||
@@ -159,19 +155,21 @@ class OpenAICompatiableService:
|
||||
current_attempt_key, retries
|
||||
)
|
||||
if api_key:
|
||||
logger.info(f"Switched to new API key: {api_key}")
|
||||
logger.info(
|
||||
f"Switched to new API key: {redact_key_for_logging(api_key)}"
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"No valid API key available after {retries} retries."
|
||||
)
|
||||
break
|
||||
raise
|
||||
else:
|
||||
logger.error("KeyManager not available for retry logic.")
|
||||
break
|
||||
break
|
||||
|
||||
if retries >= max_retries:
|
||||
logger.error(f"Max retries ({max_retries}) reached for streaming.")
|
||||
break
|
||||
raise
|
||||
finally:
|
||||
end_time = time.perf_counter()
|
||||
latency_ms = int((end_time - start_time) * 1000)
|
||||
@@ -183,8 +181,3 @@ class OpenAICompatiableService:
|
||||
latency_ms=latency_ms,
|
||||
request_time=request_datetime,
|
||||
)
|
||||
if not is_success and retries >= max_retries:
|
||||
yield f"data: {json.dumps({'error': 'Streaming failed after retries'})}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
|
||||
|
||||
7
app/service/proxy/__init__.py
Normal file
7
app/service/proxy/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
Proxy service module
|
||||
"""
|
||||
|
||||
from .proxy_check_service import ProxyCheckService
|
||||
|
||||
__all__ = ["ProxyCheckService"]
|
||||
219
app/service/proxy/proxy_check_service.py
Normal file
219
app/service/proxy/proxy_check_service.py
Normal file
@@ -0,0 +1,219 @@
|
||||
"""
|
||||
Proxy detection service module
|
||||
"""
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Dict, List, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.log.logger import get_config_routes_logger
|
||||
|
||||
logger = get_config_routes_logger()
|
||||
|
||||
|
||||
class ProxyCheckResult(BaseModel):
|
||||
"""Proxy check result model"""
|
||||
proxy: str
|
||||
is_available: bool
|
||||
response_time: Optional[float] = None
|
||||
error_message: Optional[str] = None
|
||||
checked_at: float
|
||||
|
||||
|
||||
class ProxyCheckService:
|
||||
"""Proxy detection service class"""
|
||||
|
||||
# Target URL for checking
|
||||
CHECK_URL = "https://www.google.com"
|
||||
# Timeout in seconds
|
||||
TIMEOUT_SECONDS = 10
|
||||
# Cache duration in seconds
|
||||
CACHE_DURATION = 10 # 10s
|
||||
|
||||
def __init__(self):
|
||||
self._cache: Dict[str, ProxyCheckResult] = {}
|
||||
|
||||
def _is_valid_proxy_format(self, proxy: str) -> bool:
|
||||
"""Validate proxy format"""
|
||||
try:
|
||||
parsed = urlparse(proxy)
|
||||
return parsed.scheme in ['http', 'https', 'socks5'] and parsed.hostname
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _get_cached_result(self, proxy: str) -> Optional[ProxyCheckResult]:
|
||||
"""Get cached check result"""
|
||||
if proxy in self._cache:
|
||||
result = self._cache[proxy]
|
||||
# Check if cache is expired
|
||||
if time.time() - result.checked_at < self.CACHE_DURATION:
|
||||
logger.debug(f"Using cached proxy check result: {proxy}")
|
||||
return result
|
||||
else:
|
||||
# Remove expired cache
|
||||
del self._cache[proxy]
|
||||
return None
|
||||
|
||||
def _cache_result(self, result: ProxyCheckResult) -> None:
|
||||
"""Cache check result"""
|
||||
self._cache[result.proxy] = result
|
||||
|
||||
async def check_single_proxy(self, proxy: str, use_cache: bool = True) -> ProxyCheckResult:
|
||||
"""
|
||||
Check if a single proxy is available
|
||||
|
||||
Args:
|
||||
proxy: Proxy address in format like http://host:port or socks5://host:port
|
||||
use_cache: Whether to use cached results
|
||||
|
||||
Returns:
|
||||
ProxyCheckResult: Check result
|
||||
"""
|
||||
# Check cache first
|
||||
if use_cache:
|
||||
cached = self._get_cached_result(proxy)
|
||||
if cached:
|
||||
return cached
|
||||
|
||||
# Validate proxy format
|
||||
if not self._is_valid_proxy_format(proxy):
|
||||
result = ProxyCheckResult(
|
||||
proxy=proxy,
|
||||
is_available=False,
|
||||
error_message="Invalid proxy format",
|
||||
checked_at=time.time()
|
||||
)
|
||||
self._cache_result(result)
|
||||
return result
|
||||
|
||||
# Perform check
|
||||
start_time = time.time()
|
||||
try:
|
||||
logger.info(f"Starting proxy check: {proxy}")
|
||||
|
||||
timeout = httpx.Timeout(self.TIMEOUT_SECONDS, read=self.TIMEOUT_SECONDS)
|
||||
async with httpx.AsyncClient(timeout=timeout, proxy=proxy) as client:
|
||||
response = await client.head(self.CHECK_URL)
|
||||
|
||||
response_time = time.time() - start_time
|
||||
|
||||
# Check response status
|
||||
is_available = response.status_code in [200, 204, 301, 302, 307, 308]
|
||||
|
||||
result = ProxyCheckResult(
|
||||
proxy=proxy,
|
||||
is_available=is_available,
|
||||
response_time=round(response_time, 3),
|
||||
error_message=None if is_available else f"HTTP {response.status_code}",
|
||||
checked_at=time.time()
|
||||
)
|
||||
|
||||
logger.info(f"Proxy check completed: {proxy}, available: {is_available}, response_time: {response_time:.3f}s")
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
result = ProxyCheckResult(
|
||||
proxy=proxy,
|
||||
is_available=False,
|
||||
error_message="Connection timeout",
|
||||
checked_at=time.time()
|
||||
)
|
||||
logger.warning(f"Proxy check timeout: {proxy}")
|
||||
|
||||
except Exception as e:
|
||||
result = ProxyCheckResult(
|
||||
proxy=proxy,
|
||||
is_available=False,
|
||||
error_message=str(e),
|
||||
checked_at=time.time()
|
||||
)
|
||||
logger.error(f"Proxy check failed: {proxy}, error: {str(e)}")
|
||||
|
||||
# Cache result
|
||||
self._cache_result(result)
|
||||
return result
|
||||
|
||||
async def check_multiple_proxies(
|
||||
self,
|
||||
proxies: List[str],
|
||||
use_cache: bool = True,
|
||||
max_concurrent: int = 5
|
||||
) -> List[ProxyCheckResult]:
|
||||
"""
|
||||
Check multiple proxies concurrently
|
||||
|
||||
Args:
|
||||
proxies: List of proxy addresses
|
||||
use_cache: Whether to use cached results
|
||||
max_concurrent: Maximum concurrent check count
|
||||
|
||||
Returns:
|
||||
List[ProxyCheckResult]: List of check results
|
||||
"""
|
||||
if not proxies:
|
||||
return []
|
||||
|
||||
logger.info(f"Starting batch proxy check for {len(proxies)} proxies")
|
||||
|
||||
# Use semaphore to limit concurrency
|
||||
semaphore = asyncio.Semaphore(max_concurrent)
|
||||
|
||||
async def check_with_semaphore(proxy: str) -> ProxyCheckResult:
|
||||
async with semaphore:
|
||||
return await self.check_single_proxy(proxy, use_cache)
|
||||
|
||||
# Execute checks concurrently
|
||||
tasks = [check_with_semaphore(proxy) for proxy in proxies]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Handle exception results
|
||||
final_results = []
|
||||
for i, result in enumerate(results):
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"Proxy check task exception: {proxies[i]}, error: {str(result)}")
|
||||
final_results.append(ProxyCheckResult(
|
||||
proxy=proxies[i],
|
||||
is_available=False,
|
||||
error_message=f"Check task exception: {str(result)}",
|
||||
checked_at=time.time()
|
||||
))
|
||||
else:
|
||||
final_results.append(result)
|
||||
|
||||
available_count = sum(1 for r in final_results if r.is_available)
|
||||
logger.info(f"Batch proxy check completed: {available_count}/{len(proxies)} proxies available")
|
||||
|
||||
return final_results
|
||||
|
||||
def get_cache_stats(self) -> Dict[str, int]:
|
||||
"""Get cache statistics"""
|
||||
current_time = time.time()
|
||||
valid_cache_count = sum(
|
||||
1 for result in self._cache.values()
|
||||
if current_time - result.checked_at < self.CACHE_DURATION
|
||||
)
|
||||
|
||||
return {
|
||||
"total_cached": len(self._cache),
|
||||
"valid_cached": valid_cache_count,
|
||||
"expired_cached": len(self._cache) - valid_cache_count
|
||||
}
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear all cache"""
|
||||
self._cache.clear()
|
||||
logger.info("Proxy check cache cleared")
|
||||
|
||||
|
||||
# Global instance
|
||||
_proxy_check_service: Optional[ProxyCheckService] = None
|
||||
|
||||
|
||||
def get_proxy_check_service() -> ProxyCheckService:
|
||||
"""Get proxy check service instance"""
|
||||
global _proxy_check_service
|
||||
if _proxy_check_service is None:
|
||||
_proxy_check_service = ProxyCheckService()
|
||||
return _proxy_check_service
|
||||
@@ -2,12 +2,12 @@
|
||||
Service for request log operations.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from sqlalchemy import delete
|
||||
|
||||
from app.database.connection import database
|
||||
from app.config.config import settings
|
||||
from app.database.connection import database
|
||||
from app.database.models import RequestLog
|
||||
from app.log.logger import get_request_log_logger
|
||||
|
||||
@@ -30,7 +30,7 @@ async def delete_old_request_logs_task():
|
||||
)
|
||||
|
||||
try:
|
||||
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_to_keep)
|
||||
cutoff_date = datetime.now() - timedelta(days=days_to_keep)
|
||||
|
||||
query = delete(RequestLog).where(RequestLog.request_time < cutoff_date)
|
||||
|
||||
@@ -40,7 +40,7 @@ async def delete_old_request_logs_task():
|
||||
|
||||
result = await database.execute(query)
|
||||
logger.info(
|
||||
f"Request logs older than {cutoff_date} potentially deleted. Rows affected: {result.rowcount if result else 'N/A'}"
|
||||
f"Request logs older than {cutoff_date} potentially deleted. Rows affected: {result}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -146,7 +146,7 @@ class StatsService:
|
||||
period: 时间段标识 ('1m', '1h', '24h')
|
||||
|
||||
Returns:
|
||||
包含调用详情的字典列表,每个字典包含 timestamp, key, model, status
|
||||
包含调用详情的字典列表,每个字典包含 timestamp, key, model, status, status_code, latency_ms, error_log_id(可选)
|
||||
|
||||
Raises:
|
||||
ValueError: 如果 period 无效
|
||||
@@ -156,6 +156,8 @@ class StatsService:
|
||||
start_time = now - datetime.timedelta(minutes=1)
|
||||
elif period == "1h":
|
||||
start_time = now - datetime.timedelta(hours=1)
|
||||
elif period == "8h":
|
||||
start_time = now - datetime.timedelta(hours=8)
|
||||
elif period == "24h":
|
||||
start_time = now - datetime.timedelta(hours=24)
|
||||
else:
|
||||
@@ -167,7 +169,8 @@ class StatsService:
|
||||
RequestLog.request_time.label("timestamp"),
|
||||
RequestLog.api_key.label("key"),
|
||||
RequestLog.model_name.label("model"),
|
||||
RequestLog.status_code,
|
||||
RequestLog.status_code.label("status_code"),
|
||||
RequestLog.latency_ms.label("latency_ms"),
|
||||
)
|
||||
.where(RequestLog.request_time >= start_time)
|
||||
.order_by(RequestLog.request_time.desc())
|
||||
@@ -175,31 +178,127 @@ class StatsService:
|
||||
|
||||
results = await database.fetch_all(query)
|
||||
|
||||
details = []
|
||||
details: list[dict] = []
|
||||
for row in results:
|
||||
status = "failure"
|
||||
if row["status_code"] is not None:
|
||||
status = "success" if 200 <= row["status_code"] < 300 else "failure"
|
||||
details.append(
|
||||
{
|
||||
"timestamp": row[
|
||||
"timestamp"
|
||||
].isoformat(),
|
||||
"key": row["key"],
|
||||
"model": row["model"],
|
||||
"status": status,
|
||||
}
|
||||
)
|
||||
|
||||
record = {
|
||||
"timestamp": row["timestamp"].isoformat(),
|
||||
"key": row["key"],
|
||||
"model": row["model"],
|
||||
"status": status,
|
||||
"status_code": row["status_code"],
|
||||
"latency_ms": row["latency_ms"],
|
||||
}
|
||||
|
||||
details.append(record)
|
||||
|
||||
logger.info(
|
||||
f"Retrieved {len(details)} API call details for period '{period}'"
|
||||
)
|
||||
return details
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to get API call details for period '{period}': {e}")
|
||||
logger.error(f"Failed to get API call details for period '{period}': {e}")
|
||||
raise
|
||||
|
||||
async def get_key_call_details(self, key: str, period: str) -> list[dict]:
|
||||
"""获取指定密钥在指定时间段内的调用详情 (与 get_api_call_details 结构一致)"""
|
||||
now = datetime.datetime.now()
|
||||
if period == "1m":
|
||||
start_time = now - datetime.timedelta(minutes=1)
|
||||
elif period == "1h":
|
||||
start_time = now - datetime.timedelta(hours=1)
|
||||
elif period == "8h":
|
||||
start_time = now - datetime.timedelta(hours=8)
|
||||
elif period == "24h":
|
||||
start_time = now - datetime.timedelta(hours=24)
|
||||
else:
|
||||
raise ValueError(f"无效的时间段标识: {period}")
|
||||
|
||||
try:
|
||||
query = (
|
||||
select(
|
||||
RequestLog.request_time.label("timestamp"),
|
||||
RequestLog.api_key.label("key"),
|
||||
RequestLog.model_name.label("model"),
|
||||
RequestLog.status_code.label("status_code"),
|
||||
RequestLog.latency_ms.label("latency_ms"),
|
||||
)
|
||||
.where(RequestLog.request_time >= start_time, RequestLog.api_key == key)
|
||||
.order_by(RequestLog.request_time.desc())
|
||||
)
|
||||
|
||||
results = await database.fetch_all(query)
|
||||
|
||||
details: list[dict] = []
|
||||
for row in results:
|
||||
status = "failure"
|
||||
if row["status_code"] is not None:
|
||||
status = "success" if 200 <= row["status_code"] < 300 else "failure"
|
||||
|
||||
record = {
|
||||
"timestamp": row["timestamp"].isoformat(),
|
||||
"key": row["key"],
|
||||
"model": row["model"],
|
||||
"status": status,
|
||||
"status_code": row["status_code"],
|
||||
"latency_ms": row["latency_ms"],
|
||||
}
|
||||
|
||||
details.append(record)
|
||||
|
||||
logger.info(
|
||||
f"Retrieved {len(details)} key call details for key=...{key[-4:] if key else ''} period '{period}'"
|
||||
)
|
||||
return details
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to get key call details for key=...{key[-4:] if key else ''} period '{period}': {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_attention_keys_last_24h(
|
||||
self, include_keys: set[str], limit: int = 20, status_code: int = 429
|
||||
) -> list[dict]:
|
||||
"""返回最近24小时内指定状态码(默认429)最多的Key列表,仅包含include_keys中的Key。
|
||||
|
||||
Returns: [{"key": str, "count": int, "status_code": int}, ...] 按次数降序
|
||||
"""
|
||||
try:
|
||||
now = datetime.datetime.now()
|
||||
start_time = now - datetime.timedelta(hours=24)
|
||||
if not include_keys:
|
||||
return []
|
||||
query = (
|
||||
select(
|
||||
RequestLog.api_key.label("key"),
|
||||
func.count(RequestLog.id).label("count"),
|
||||
)
|
||||
.where(
|
||||
RequestLog.request_time >= start_time,
|
||||
RequestLog.status_code == status_code,
|
||||
RequestLog.api_key.isnot(None),
|
||||
RequestLog.api_key.in_(list(include_keys)),
|
||||
)
|
||||
.group_by(RequestLog.api_key)
|
||||
.order_by(func.count(RequestLog.id).desc())
|
||||
.limit(limit)
|
||||
)
|
||||
rows = await database.fetch_all(query)
|
||||
return [
|
||||
{"key": row["key"], "count": row["count"], "status_code": status_code}
|
||||
for row in rows
|
||||
if row["key"]
|
||||
]
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to get attention keys ({status_code}) in last 24h: {e}"
|
||||
)
|
||||
return []
|
||||
|
||||
async def get_key_usage_details_last_24h(self, key: str) -> Union[dict, None]:
|
||||
"""
|
||||
获取指定 API 密钥在过去 24 小时内按模型统计的调用次数。
|
||||
@@ -220,8 +319,7 @@ class StatsService:
|
||||
try:
|
||||
query = (
|
||||
select(
|
||||
RequestLog.model_name, func.count(
|
||||
RequestLog.id).label("call_count")
|
||||
RequestLog.model_name, func.count(RequestLog.id).label("call_count")
|
||||
)
|
||||
.where(
|
||||
RequestLog.api_key == key,
|
||||
@@ -240,8 +338,7 @@ class StatsService:
|
||||
)
|
||||
return {}
|
||||
|
||||
usage_details = {row["model_name"]: row["call_count"]
|
||||
for row in results}
|
||||
usage_details = {row["model_name"]: row["call_count"] for row in results}
|
||||
logger.info(
|
||||
f"Successfully fetched usage details for key ending in ...{key[-4:]}: {usage_details}"
|
||||
)
|
||||
|
||||
363
app/service/tts/native/README.md
Normal file
363
app/service/tts/native/README.md
Normal file
@@ -0,0 +1,363 @@
|
||||
# 原生Gemini TTS功能
|
||||
|
||||
这个模块为Gemini Balance项目添加了原生Gemini TTS(Text-to-Speech)功能,支持单人和多人语音合成,采用智能检测和继承模式设计,保持与原始代码的完全兼容性。
|
||||
|
||||
## 🎯 设计原则
|
||||
|
||||
- **智能检测**:自动检测所有原生Gemini TTS格式的请求(包含responseModalities和speechConfig)
|
||||
- **继承而非修改**:所有扩展都继承自原始类,不修改源码
|
||||
- **完全兼容**:原有TTS功能(OpenAI兼容TTS)完全不受影响
|
||||
- **动态模型选择**:支持用户在请求URL中指定不同的TTS模型
|
||||
- **自动回退**:原生TTS处理失败时自动回退到标准服务
|
||||
- **完整日志记录**:包含请求日志、错误日志和性能监控
|
||||
- **易于维护**:更新原始代码时不会产生冲突
|
||||
|
||||
## 📁 文件结构
|
||||
|
||||
```
|
||||
app/service/tts/
|
||||
├── tts_service.py # 原有的OpenAI兼容TTS服务
|
||||
└── native/ # 原生Gemini TTS扩展
|
||||
├── __init__.py # 模块初始化
|
||||
├── README.md # 使用说明(本文件)
|
||||
├── tts_models.py # TTS数据模型(继承自原始模型)
|
||||
├── tts_response_handler.py # TTS响应处理器(继承自原始处理器)
|
||||
├── tts_chat_service.py # TTS聊天服务(继承自原始服务)
|
||||
└── tts_routes.py # TTS路由扩展和依赖注入
|
||||
```
|
||||
|
||||
## 🚀 原生Gemini TTS功能
|
||||
|
||||
### 智能检测机制(当前实现)
|
||||
|
||||
原生Gemini TTS功能通过智能检测自动启用,无需任何配置:
|
||||
|
||||
1. **自动启用**:
|
||||
```bash
|
||||
# 直接启动服务,原生TTS功能自动可用
|
||||
python -m uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload
|
||||
```
|
||||
|
||||
2. **无需配置**:
|
||||
- 不需要环境变量
|
||||
- 不需要修改配置文件
|
||||
- 完全基于请求内容智能判断
|
||||
|
||||
### 工作原理
|
||||
|
||||
系统会智能检测请求内容:
|
||||
- **原生TTS请求**:包含 `responseModalities: ["AUDIO"]` 和 `speechConfig` → 使用TTS增强服务
|
||||
- **单人TTS**:包含 `voiceConfig.prebuiltVoiceConfig`
|
||||
- **多人TTS**:包含 `multiSpeakerVoiceConfig`
|
||||
- **普通请求**:非TTS模型 → 使用原有Gemini聊天服务
|
||||
|
||||
```python
|
||||
# app/router/gemini_routes.py 中的智能检测逻辑
|
||||
if "tts" in model_name.lower() and request.generationConfig:
|
||||
# 直接从解析后的request对象获取TTS配置
|
||||
response_modalities = request.generationConfig.responseModalities or []
|
||||
speech_config = request.generationConfig.speechConfig or {}
|
||||
|
||||
# 如果包含AUDIO模态和语音配置,则认为是原生TTS请求
|
||||
if "AUDIO" in response_modalities and speech_config:
|
||||
# 使用TTS增强服务
|
||||
tts_service = await get_tts_chat_service(key_manager)
|
||||
return await tts_service.generate_content(...)
|
||||
# 否则使用原有服务
|
||||
```
|
||||
|
||||
## 📝 使用示例
|
||||
|
||||
### 1. 原生Gemini单人TTS请求(使用TTS增强服务)
|
||||
|
||||
包含 `voiceConfig.prebuiltVoiceConfig` 的原生Gemini格式请求会自动使用TTS增强服务:
|
||||
|
||||
```bash
|
||||
curl -X POST "https://your-domain.com/v1beta/models/gemini-2.5-flash-preview-tts:generateContent" \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "x-goog-api-key: your-token" \
|
||||
-d '{
|
||||
"contents": [{
|
||||
"parts": [{
|
||||
"text": "Hello, this is a single speaker test."
|
||||
}]
|
||||
}],
|
||||
"generationConfig": {
|
||||
"responseModalities": ["AUDIO"],
|
||||
"speechConfig": {
|
||||
"voiceConfig": {
|
||||
"prebuiltVoiceConfig": {
|
||||
"voiceName": "Kore"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
### 2. 原生Gemini多人TTS请求(使用TTS增强服务)
|
||||
|
||||
包含 `multiSpeakerVoiceConfig` 的原生Gemini格式请求会自动使用TTS增强服务:
|
||||
|
||||
```bash
|
||||
curl -X POST "https://your-domain.com/v1beta/models/gemini-2.5-flash-preview-tts:generateContent" \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "x-goog-api-key: your-token" \
|
||||
-d '{
|
||||
"contents": [{
|
||||
"parts": [{
|
||||
"text": "Alice: Hello everyone, welcome to our show today.\nBob: Hi Alice, and hello to all our listeners! Today we are talking about AI development."
|
||||
}]
|
||||
}],
|
||||
"generationConfig": {
|
||||
"responseModalities": ["AUDIO"],
|
||||
"speechConfig": {
|
||||
"multiSpeakerVoiceConfig": {
|
||||
"speakerVoiceConfigs": [
|
||||
{
|
||||
"speaker": "Alice",
|
||||
"voiceConfig": {
|
||||
"prebuiltVoiceConfig": {
|
||||
"voiceName": "Puck"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"speaker": "Bob",
|
||||
"voiceConfig": {
|
||||
"prebuiltVoiceConfig": {
|
||||
"voiceName": "Kore"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
### 3. OpenAI兼容TTS请求(使用原有服务)
|
||||
|
||||
OpenAI兼容格式的TTS请求使用不同的API路径,不受本模块影响:
|
||||
|
||||
```bash
|
||||
curl -X POST "https://your-domain.com/v1/audio/speech" \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer your-token" \
|
||||
-d '{
|
||||
"model": "tts-1",
|
||||
"input": "这是一个OpenAI兼容格式的TTS测试。",
|
||||
"voice": "alloy"
|
||||
}' \
|
||||
--output openai_tts_test.wav
|
||||
```
|
||||
|
||||
**注意**:OpenAI兼容TTS请求:
|
||||
- 使用路径:`/v1/audio/speech`
|
||||
- 使用Authorization头而不是x-goog-api-key
|
||||
- 返回音频文件而不是JSON响应
|
||||
- 不受本模块的TTS增强服务影响
|
||||
|
||||
### 普通文本生成(使用原有服务)
|
||||
|
||||
非TTS模型的请求会使用原有的Gemini聊天服务,完全不受影响:
|
||||
|
||||
```bash
|
||||
curl -X POST "https://your-domain.com/v1beta/models/gemini-2.5-flash:generateContent" \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "x-goog-api-key: your-token" \
|
||||
-d '{
|
||||
"contents": [{
|
||||
"parts": [{
|
||||
"text": "请简单介绍一下人工智能的发展历程。"
|
||||
}]
|
||||
}],
|
||||
"generationConfig": {
|
||||
"maxOutputTokens": 200,
|
||||
"temperature": 0.7
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
## 🔧 技术实现
|
||||
|
||||
### 继承关系
|
||||
|
||||
```
|
||||
GeminiChatService
|
||||
↓ (继承)
|
||||
TTSGeminiChatService
|
||||
├── 重写 generate_content() 方法
|
||||
├── 添加 _handle_tts_request() 方法
|
||||
└── 集成完整的日志记录功能
|
||||
|
||||
GeminiResponseHandler
|
||||
↓ (继承)
|
||||
TTSResponseHandler
|
||||
└── 重写 handle_response() 方法
|
||||
|
||||
GenerationConfig (Pydantic模型)
|
||||
↓ (扩展)
|
||||
TTSGenerationConfig
|
||||
├── responseModalities: List[str]
|
||||
└── speechConfig: Dict[str, Any]
|
||||
```
|
||||
|
||||
### 工作流程
|
||||
|
||||
1. **请求接收**:系统接收到API请求
|
||||
2. **智能检测**:
|
||||
- 检查模型名称是否包含 "tts"
|
||||
- 如果是TTS模型,从 `request.generationConfig` 检查是否包含 `responseModalities: ["AUDIO"]` 和 `speechConfig`
|
||||
3. **服务选择**:
|
||||
- **原生TTS请求**:使用 `TTSGeminiChatService` 增强服务
|
||||
- **普通请求**:使用原有 `GeminiChatService`
|
||||
4. **请求处理**:
|
||||
- **原生TTS**:使用 `_handle_tts_request()` 特殊处理
|
||||
- **其他请求**:使用标准 `generate_content()` 方法
|
||||
5. **字段处理**:从 `request.generationConfig` 直接获取TTS字段(`responseModalities`, `speechConfig`)
|
||||
6. **API调用**:构建优化的payload并调用Gemini API
|
||||
7. **自动回退**:如果原生TTS处理失败,自动回退到标准服务
|
||||
8. **响应处理**:
|
||||
- **TTS响应**:检测音频数据,直接返回原始响应
|
||||
- **普通响应**:使用标准处理方法
|
||||
9. **日志记录**:记录请求时间、成功状态、错误信息到数据库
|
||||
|
||||
## 📊 功能特性
|
||||
|
||||
### ✅ 已实现功能
|
||||
|
||||
- **智能原生TTS支持**:支持单人和多人语音合成
|
||||
- **单人TTS**:支持 `voiceConfig.prebuiltVoiceConfig` 配置
|
||||
- **多人TTS**:支持 `multiSpeakerVoiceConfig` 配置
|
||||
- **智能检测机制**:自动检测所有原生Gemini TTS格式的请求
|
||||
- **动态模型选择**:支持用户在URL中指定不同TTS模型
|
||||
- **完全向后兼容**:原有TTS功能(OpenAI兼容TTS)完全不受影响
|
||||
- **自动回退机制**:原生TTS处理失败时自动使用标准服务
|
||||
- **完整日志记录**:请求日志、错误日志、性能监控
|
||||
- **API配额管理**:自动重试和密钥轮换
|
||||
- **零配置启用**:无需环境变量或配置文件修改
|
||||
- **错误处理**:完整的异常捕获和错误记录
|
||||
|
||||
### 🎵 支持的语音配置
|
||||
|
||||
#### 单人语音配置
|
||||
|
||||
```json
|
||||
{
|
||||
"responseModalities": ["AUDIO"],
|
||||
"speechConfig": {
|
||||
"voiceConfig": {
|
||||
"prebuiltVoiceConfig": {
|
||||
"voiceName": "Kore|Puck|其他预设语音"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
#### 多人语音配置
|
||||
|
||||
```json
|
||||
{
|
||||
"responseModalities": ["AUDIO"],
|
||||
"speechConfig": {
|
||||
"multiSpeakerVoiceConfig": {
|
||||
"speakerVoiceConfigs": [
|
||||
{
|
||||
"speaker": "角色名称",
|
||||
"voiceConfig": {
|
||||
"prebuiltVoiceConfig": {
|
||||
"voiceName": "Kore|Puck|其他预设语音"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## ⚠️ 注意事项
|
||||
|
||||
### API要求
|
||||
- 确保API密钥有TTS权限
|
||||
- TTS功能需要 `gemini-2.5-flash-preview-tts` 模型
|
||||
- 注意API配额限制(免费版每天15次)
|
||||
|
||||
### 性能考虑
|
||||
- TTS响应通常比文本响应更大(音频数据)
|
||||
- 建议监控API调用频率和成功率
|
||||
- 扩展功能不影响原始功能的性能和稳定性
|
||||
|
||||
### 部署建议
|
||||
- 生产环境建议先测试普通功能
|
||||
- 逐步启用TTS功能并监控日志
|
||||
- 定期检查API配额使用情况
|
||||
|
||||
## 📈 监控和调试
|
||||
|
||||
### 日志查看
|
||||
- **服务器日志**:查看TTS请求处理过程
|
||||
- **管理界面**:在"API 调用详情"中查看请求记录
|
||||
- **错误日志**:查看失败请求的详细信息
|
||||
|
||||
### 调试技巧
|
||||
```bash
|
||||
# 启用详细日志
|
||||
export LOG_LEVEL=DEBUG
|
||||
|
||||
# 查看实时日志
|
||||
tail -f logs/app.log
|
||||
|
||||
# 多人TTS功能无需配置,自动启用
|
||||
# 可通过请求内容智能检测
|
||||
```
|
||||
|
||||
## 🔄 TTS系统对比
|
||||
|
||||
项目中现在有三套TTS系统,各自服务不同的用途:
|
||||
|
||||
| TTS类型 | 路径 | 模型选择 | 语音配置 | 使用场景 | 我们的影响 |
|
||||
|---------|------|----------|----------|----------|------------|
|
||||
| **OpenAI兼容TTS** | `/v1/audio/speech` | 固定配置文件 | 单人语音 | OpenAI API兼容 | ✅ 无影响 |
|
||||
| **Gemini单人TTS** | `/v1beta/models/{model}:generateContent` | 用户指定 | 单人语音 | 原生Gemini TTS | ✅ 我们的增强 |
|
||||
| **Gemini多人TTS** | `/v1beta/models/{model}:generateContent` | 用户指定 | 多人语音 | 对话场景 | ✅ 我们的增强 |
|
||||
|
||||
### 智能路由机制
|
||||
|
||||
```mermaid
|
||||
flowchart TD
|
||||
A[API请求] --> B{路径检查}
|
||||
B -->|/v1/audio/speech| C[OpenAI兼容TTS服务]
|
||||
B -->|/v1beta/models/{model}:generateContent| D{模型名包含'tts'?}
|
||||
D -->|否| E[标准Gemini聊天服务]
|
||||
D -->|是| F{包含responseModalities和speechConfig?}
|
||||
F -->|否| G[标准Gemini聊天服务]
|
||||
F -->|是| H[原生TTS增强服务]
|
||||
H --> I{处理成功?}
|
||||
I -->|是| J[返回原生TTS响应]
|
||||
I -->|否| K[自动回退到标准服务]
|
||||
C --> L[完成]
|
||||
E --> L
|
||||
G --> L
|
||||
J --> L
|
||||
K --> L
|
||||
```
|
||||
|
||||
## 🎉 成功案例
|
||||
|
||||
基于智能检测的原生Gemini TTS解决方案已经成功实现:
|
||||
|
||||
- ✅ **零配置启用**:无需任何环境变量或配置修改
|
||||
- ✅ **智能检测**:自动检测所有原生Gemini TTS格式的请求
|
||||
- ✅ **完全向后兼容**:所有原有TTS功能零影响
|
||||
- ✅ **动态模型选择**:支持用户指定不同TTS模型
|
||||
- ✅ **自动回退机制**:处理失败时自动使用标准服务
|
||||
- ✅ **单人和多人语音合成**:支持所有原生Gemini TTS场景
|
||||
- ✅ **完整日志记录**:可在管理界面查看所有请求
|
||||
- ✅ **错误处理完善**:API配额和重试机制
|
||||
- ✅ **易于维护**:更新原始代码无冲突
|
||||
|
||||
这个实现展示了如何在不修改原始代码的情况下,优雅地扩展复杂系统的功能,同时保持完美的向后兼容性。
|
||||
19
app/service/tts/native/__init__.py
Normal file
19
app/service/tts/native/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""
|
||||
原生Gemini TTS功能模块
|
||||
Native Gemini TTS functionality for both single and multi-speaker scenarios
|
||||
"""
|
||||
|
||||
from .tts_chat_service import TTSGeminiChatService
|
||||
from .tts_models import TTSGenerationConfig, MultiSpeakerVoiceConfig, SpeechConfig, TTSRequest
|
||||
from .tts_response_handler import TTSResponseHandler
|
||||
from .tts_routes import get_tts_chat_service
|
||||
|
||||
__all__ = [
|
||||
"TTSGeminiChatService",
|
||||
"TTSGenerationConfig",
|
||||
"MultiSpeakerVoiceConfig",
|
||||
"SpeechConfig",
|
||||
"TTSRequest",
|
||||
"TTSResponseHandler",
|
||||
"get_tts_chat_service"
|
||||
]
|
||||
172
app/service/tts/native/tts_chat_service.py
Normal file
172
app/service/tts/native/tts_chat_service.py
Normal file
@@ -0,0 +1,172 @@
|
||||
"""
|
||||
原生Gemini TTS聊天服务扩展
|
||||
继承自原始聊天服务,添加原生Gemini TTS支持(单人和多人),保持向后兼容
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import time
|
||||
from typing import Any, Dict
|
||||
|
||||
from app.config.config import settings
|
||||
from app.database.services import add_error_log, add_request_log
|
||||
from app.domain.gemini_models import GeminiRequest
|
||||
from app.log.logger import get_gemini_logger
|
||||
from app.service.chat.gemini_chat_service import GeminiChatService
|
||||
from app.service.tts.native.tts_response_handler import TTSResponseHandler
|
||||
|
||||
logger = get_gemini_logger()
|
||||
|
||||
|
||||
class TTSGeminiChatService(GeminiChatService):
|
||||
"""
|
||||
支持TTS的Gemini聊天服务
|
||||
继承自原始的GeminiChatService,添加TTS功能
|
||||
"""
|
||||
|
||||
def __init__(self, base_url: str, key_manager):
|
||||
"""
|
||||
初始化TTS聊天服务
|
||||
"""
|
||||
super().__init__(base_url, key_manager)
|
||||
# 使用TTS响应处理器替换原始处理器
|
||||
self.response_handler = TTSResponseHandler()
|
||||
logger.info(
|
||||
"TTS Gemini Chat Service initialized with multi-speaker TTS support"
|
||||
)
|
||||
|
||||
async def generate_content(
|
||||
self, model: str, request: GeminiRequest, api_key: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
生成内容,支持TTS
|
||||
"""
|
||||
try:
|
||||
# 添加调试日志
|
||||
logger.info(f"TTS request model: {model}")
|
||||
logger.info(f"TTS request generationConfig: {request.generationConfig}")
|
||||
|
||||
# 检查是否是TTS模型,如果是,需要特殊处理
|
||||
if "tts" in model.lower():
|
||||
logger.info("Detected TTS model, applying TTS-specific processing")
|
||||
# 对于TTS模型,我们需要确保正确的字段被传递
|
||||
response = await self._handle_tts_request(model, request, api_key)
|
||||
return response
|
||||
else:
|
||||
# 对于非TTS模型,使用父类的方法
|
||||
response = await super().generate_content(model, request, api_key)
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"TTS API call failed with error: {e}")
|
||||
raise
|
||||
|
||||
async def _handle_tts_request(
|
||||
self, model: str, request: GeminiRequest, api_key: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
处理TTS特定的请求,包含完整的日志记录功能
|
||||
"""
|
||||
# 记录开始时间和请求时间
|
||||
start_time = time.perf_counter()
|
||||
request_datetime = datetime.datetime.now()
|
||||
is_success = False
|
||||
status_code = None
|
||||
|
||||
try:
|
||||
# 构建TTS专用的payload - 不包含tools和safetySettings
|
||||
from app.service.chat.gemini_chat_service import _filter_empty_parts
|
||||
|
||||
request_dict = request.model_dump(exclude_none=False)
|
||||
|
||||
# 构建TTS专用的简化payload
|
||||
payload = {
|
||||
"contents": _filter_empty_parts(request_dict.get("contents", [])),
|
||||
"generationConfig": request_dict.get("generationConfig", {}),
|
||||
}
|
||||
|
||||
# 只在有systemInstruction时才添加
|
||||
if request_dict.get("systemInstruction"):
|
||||
payload["systemInstruction"] = request_dict.get("systemInstruction")
|
||||
|
||||
# 确保 generationConfig 不为 None
|
||||
if payload["generationConfig"] is None:
|
||||
payload["generationConfig"] = {}
|
||||
|
||||
# 从request.generationConfig直接获取TTS相关字段
|
||||
if request.generationConfig:
|
||||
# 添加TTS特定字段
|
||||
if request.generationConfig.responseModalities:
|
||||
payload["generationConfig"][
|
||||
"responseModalities"
|
||||
] = request.generationConfig.responseModalities
|
||||
logger.info(
|
||||
f"Added responseModalities: {request.generationConfig.responseModalities}"
|
||||
)
|
||||
|
||||
if request.generationConfig.speechConfig:
|
||||
payload["generationConfig"][
|
||||
"speechConfig"
|
||||
] = request.generationConfig.speechConfig
|
||||
logger.info(
|
||||
f"Added speechConfig: {request.generationConfig.speechConfig}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"No generationConfig found in request, TTS fields may be missing"
|
||||
)
|
||||
|
||||
logger.info(f"TTS payload before API call: {payload}")
|
||||
|
||||
# 调用API
|
||||
response = await self.api_client.generate_content(payload, model, api_key)
|
||||
|
||||
# 如果到达这里,说明API调用成功
|
||||
is_success = True
|
||||
status_code = 200
|
||||
|
||||
# 使用TTS响应处理器处理响应
|
||||
return self.response_handler.handle_response(response, model, False, None)
|
||||
|
||||
except Exception as e:
|
||||
# 记录错误
|
||||
is_success = False
|
||||
error_msg = str(e)
|
||||
|
||||
# 尝试从错误消息中提取状态码
|
||||
import re
|
||||
|
||||
match = re.search(r"status code (\d+)", error_msg)
|
||||
if match:
|
||||
status_code = int(match.group(1))
|
||||
else:
|
||||
status_code = 500
|
||||
|
||||
# 添加错误日志
|
||||
await add_error_log(
|
||||
gemini_key=api_key,
|
||||
model_name=model,
|
||||
error_type="tts-api-error",
|
||||
error_log=error_msg,
|
||||
error_code=status_code,
|
||||
request_msg=(
|
||||
request.model_dump(exclude_none=False)
|
||||
if settings.ERROR_LOG_RECORD_REQUEST_BODY
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
logger.error(f"TTS API call failed: {error_msg}")
|
||||
raise
|
||||
|
||||
finally:
|
||||
# 记录请求日志
|
||||
end_time = time.perf_counter()
|
||||
latency_ms = int((end_time - start_time) * 1000)
|
||||
|
||||
await add_request_log(
|
||||
model_name=model,
|
||||
api_key=api_key,
|
||||
is_success=is_success,
|
||||
status_code=status_code,
|
||||
latency_ms=latency_ms,
|
||||
request_time=request_datetime,
|
||||
)
|
||||
37
app/service/tts/native/tts_config.py
Normal file
37
app/service/tts/native/tts_config.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""
|
||||
TTS扩展配置
|
||||
控制是否启用TTS功能
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Union
|
||||
from app.service.chat.gemini_chat_service import GeminiChatService
|
||||
from app.service.tts.native.tts_chat_service import TTSGeminiChatService
|
||||
|
||||
|
||||
class TTSConfig:
|
||||
"""TTS配置管理"""
|
||||
|
||||
@staticmethod
|
||||
def is_tts_enabled() -> bool:
|
||||
"""
|
||||
检查是否启用TTS功能
|
||||
通过环境变量 ENABLE_TTS 控制,默认为 False
|
||||
"""
|
||||
return os.getenv("ENABLE_TTS", "false").lower() in ("true", "1", "yes", "on")
|
||||
|
||||
@staticmethod
|
||||
def get_chat_service(base_url: str, key_manager) -> Union[GeminiChatService, TTSGeminiChatService]:
|
||||
"""
|
||||
工厂方法:根据配置返回合适的聊天服务
|
||||
"""
|
||||
if TTSConfig.is_tts_enabled():
|
||||
return TTSGeminiChatService(base_url, key_manager)
|
||||
else:
|
||||
return GeminiChatService(base_url, key_manager)
|
||||
|
||||
|
||||
# 便捷函数
|
||||
def create_chat_service(base_url: str, key_manager) -> Union[GeminiChatService, TTSGeminiChatService]:
|
||||
"""创建聊天服务实例"""
|
||||
return TTSConfig.get_chat_service(base_url, key_manager)
|
||||
36
app/service/tts/native/tts_models.py
Normal file
36
app/service/tts/native/tts_models.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""
|
||||
原生Gemini TTS扩展数据模型
|
||||
继承自原始模型,添加原生Gemini TTS相关字段,保持向后兼容
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.domain.gemini_models import GenerationConfig as BaseGenerationConfig
|
||||
|
||||
|
||||
class TTSGenerationConfig(BaseGenerationConfig):
|
||||
"""
|
||||
支持TTS的生成配置类
|
||||
继承自原始的GenerationConfig,添加TTS相关字段
|
||||
"""
|
||||
# TTS 相关配置
|
||||
responseModalities: Optional[List[str]] = None
|
||||
speechConfig: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class MultiSpeakerVoiceConfig(BaseModel):
|
||||
"""多人语音配置"""
|
||||
speakerVoiceConfigs: List[Dict[str, Any]]
|
||||
|
||||
|
||||
class SpeechConfig(BaseModel):
|
||||
"""语音配置"""
|
||||
multiSpeakerVoiceConfig: Optional[MultiSpeakerVoiceConfig] = None
|
||||
voiceConfig: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class TTSRequest(BaseModel):
|
||||
"""TTS请求模型"""
|
||||
contents: List[Dict[str, Any]]
|
||||
generationConfig: TTSGenerationConfig
|
||||
53
app/service/tts/native/tts_response_handler.py
Normal file
53
app/service/tts/native/tts_response_handler.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""
|
||||
原生Gemini TTS响应处理器扩展
|
||||
继承自原始响应处理器,添加原生Gemini TTS支持,保持向后兼容
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
from app.handler.response_handler import GeminiResponseHandler
|
||||
from app.log.logger import get_gemini_logger
|
||||
|
||||
logger = get_gemini_logger()
|
||||
|
||||
|
||||
class TTSResponseHandler(GeminiResponseHandler):
|
||||
"""
|
||||
支持TTS的响应处理器
|
||||
继承自原始的GeminiResponseHandler,添加TTS响应处理
|
||||
"""
|
||||
|
||||
def handle_response(
|
||||
self, response: Dict[str, Any], model: str, stream: bool = False, usage_metadata: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
处理响应,支持TTS音频数据
|
||||
"""
|
||||
# 检查是否是TTS响应(包含音频数据)
|
||||
if self._is_tts_response(response):
|
||||
logger.info("Detected TTS response with audio data, returning original response")
|
||||
return response
|
||||
|
||||
# 对于非TTS响应,使用父类的处理方法
|
||||
return super().handle_response(response, model, stream, usage_metadata)
|
||||
|
||||
def _is_tts_response(self, response: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
检查是否是TTS响应
|
||||
"""
|
||||
try:
|
||||
if (response.get("candidates") and
|
||||
len(response["candidates"]) > 0 and
|
||||
response["candidates"][0].get("content") and
|
||||
response["candidates"][0]["content"].get("parts") and
|
||||
len(response["candidates"][0]["content"]["parts"]) > 0):
|
||||
|
||||
parts = response["candidates"][0]["content"]["parts"]
|
||||
for part in parts:
|
||||
if "inlineData" in part:
|
||||
mime_type = part["inlineData"].get("mimeType", "")
|
||||
if mime_type.startswith("audio/"):
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning(f"Error checking TTS response: {e}")
|
||||
return False
|
||||
24
app/service/tts/native/tts_routes.py
Normal file
24
app/service/tts/native/tts_routes.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""
|
||||
TTS路由扩展
|
||||
提供原生Gemini TTS增强服务,支持单人和多人语音
|
||||
"""
|
||||
|
||||
from fastapi import Depends
|
||||
|
||||
from app.config.config import settings
|
||||
from app.service.key.key_manager import KeyManager, get_key_manager_instance
|
||||
from app.service.tts.native.tts_chat_service import TTSGeminiChatService
|
||||
|
||||
|
||||
async def get_key_manager():
|
||||
"""获取密钥管理器实例"""
|
||||
return get_key_manager_instance()
|
||||
|
||||
|
||||
async def get_tts_chat_service(key_manager: KeyManager = Depends(get_key_manager)) -> TTSGeminiChatService:
|
||||
"""
|
||||
获取原生Gemini TTS增强聊天服务实例,支持单人和多人语音
|
||||
"""
|
||||
return TTSGeminiChatService(settings.BASE_URL, key_manager)
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing import Optional
|
||||
from google import genai
|
||||
|
||||
from app.config.config import settings
|
||||
from app.core.constants import TTS_VOICE_NAMES
|
||||
from app.database.services import add_error_log, add_request_log
|
||||
from app.domain.openai_models import TTSRequest
|
||||
from app.log.logger import get_openai_logger
|
||||
@@ -39,7 +40,7 @@ class TTSService:
|
||||
error_log_msg = ""
|
||||
try:
|
||||
client = genai.Client(api_key=api_key)
|
||||
response =await client.aio.models.generate_content(
|
||||
response = await client.aio.models.generate_content(
|
||||
model=settings.TTS_MODEL,
|
||||
contents=f"Speak in a {settings.TTS_SPEED} speed voice: {request.input}",
|
||||
config={
|
||||
@@ -47,7 +48,11 @@ class TTSService:
|
||||
"speech_config": {
|
||||
"voice_config": {
|
||||
"prebuilt_voice_config": {
|
||||
"voice_name": settings.TTS_VOICE_NAME
|
||||
"voice_name": (
|
||||
request.voice
|
||||
if request.voice in TTS_VOICE_NAMES
|
||||
else settings.TTS_VOICE_NAME
|
||||
)
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -58,7 +63,9 @@ class TTSService:
|
||||
and response.candidates[0].content.parts
|
||||
and response.candidates[0].content.parts[0].inline_data
|
||||
):
|
||||
raw_audio_data = response.candidates[0].content.parts[0].inline_data.data
|
||||
raw_audio_data = (
|
||||
response.candidates[0].content.parts[0].inline_data.data
|
||||
)
|
||||
is_success = True
|
||||
status_code = 200
|
||||
return _create_wav_file(raw_audio_data)
|
||||
@@ -82,13 +89,17 @@ class TTSService:
|
||||
error_type="google-tts",
|
||||
error_log=error_log_msg,
|
||||
error_code=status_code,
|
||||
request_msg=request.input
|
||||
)
|
||||
request_msg=(
|
||||
request.input
|
||||
if settings.ERROR_LOG_RECORD_REQUEST_BODY
|
||||
else None
|
||||
),
|
||||
)
|
||||
await add_request_log(
|
||||
model_name=settings.TTS_MODEL,
|
||||
api_key=api_key,
|
||||
is_success=is_success,
|
||||
status_code=status_code,
|
||||
latency_ms=latency_ms,
|
||||
request_time=request_datetime
|
||||
)
|
||||
request_time=request_datetime,
|
||||
)
|
||||
|
||||
315
app/static/css/fonts.css
Normal file
315
app/static/css/fonts.css
Normal file
@@ -0,0 +1,315 @@
|
||||
/* cyrillic-ext */
|
||||
@font-face {
|
||||
font-family: 'Inter';
|
||||
font-style: normal;
|
||||
font-weight: 300;
|
||||
font-display: swap;
|
||||
src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa2JL7SUc.woff2) format('woff2');
|
||||
unicode-range: U+0460-052F, U+1C80-1C8A, U+20B4, U+2DE0-2DFF, U+A640-A69F, U+FE2E-FE2F;
|
||||
}
|
||||
/* cyrillic */
|
||||
@font-face {
|
||||
font-family: 'Inter';
|
||||
font-style: normal;
|
||||
font-weight: 300;
|
||||
font-display: swap;
|
||||
src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa0ZL7SUc.woff2) format('woff2');
|
||||
unicode-range: U+0301, U+0400-045F, U+0490-0491, U+04B0-04B1, U+2116;
|
||||
}
|
||||
/* greek-ext */
|
||||
@font-face {
|
||||
font-family: 'Inter';
|
||||
font-style: normal;
|
||||
font-weight: 300;
|
||||
font-display: swap;
|
||||
src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa2ZL7SUc.woff2) format('woff2');
|
||||
unicode-range: U+1F00-1FFF;
|
||||
}
|
||||
/* greek */
|
||||
@font-face {
|
||||
font-family: 'Inter';
|
||||
font-style: normal;
|
||||
font-weight: 300;
|
||||
font-display: swap;
|
||||
src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa1pL7SUc.woff2) format('woff2');
|
||||
unicode-range: U+0370-0377, U+037A-037F, U+0384-038A, U+038C, U+038E-03A1, U+03A3-03FF;
|
||||
}
|
||||
/* vietnamese */
|
||||
@font-face {
|
||||
font-family: 'Inter';
|
||||
font-style: normal;
|
||||
font-weight: 300;
|
||||
font-display: swap;
|
||||
src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa2pL7SUc.woff2) format('woff2');
|
||||
unicode-range: U+0102-0103, U+0110-0111, U+0128-0129, U+0168-0169, U+01A0-01A1, U+01AF-01B0, U+0300-0301, U+0303-0304, U+0308-0309, U+0323, U+0329, U+1EA0-1EF9, U+20AB;
|
||||
}
|
||||
/* latin-ext */
|
||||
@font-face {
|
||||
font-family: 'Inter';
|
||||
font-style: normal;
|
||||
font-weight: 300;
|
||||
font-display: swap;
|
||||
src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa25L7SUc.woff2) format('woff2');
|
||||
unicode-range: U+0100-02BA, U+02BD-02C5, U+02C7-02CC, U+02CE-02D7, U+02DD-02FF, U+0304, U+0308, U+0329, U+1D00-1DBF, U+1E00-1E9F, U+1EF2-1EFF, U+2020, U+20A0-20AB, U+20AD-20C0, U+2113, U+2C60-2C7F, U+A720-A7FF;
|
||||
}
|
||||
/* latin */
|
||||
@font-face {
|
||||
font-family: 'Inter';
|
||||
font-style: normal;
|
||||
font-weight: 300;
|
||||
font-display: swap;
|
||||
src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa1ZL7.woff2) format('woff2');
|
||||
unicode-range: U+0000-00FF, U+0131, U+0152-0153, U+02BB-02BC, U+02C6, U+02DA, U+02DC, U+0304, U+0308, U+0329, U+2000-206F, U+20AC, U+2122, U+2191, U+2193, U+2212, U+2215, U+FEFF, U+FFFD;
|
||||
}
|
||||
/* cyrillic-ext */
|
||||
@font-face {
|
||||
font-family: 'Inter';
|
||||
font-style: normal;
|
||||
font-weight: 400;
|
||||
font-display: swap;
|
||||
src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa2JL7SUc.woff2) format('woff2');
|
||||
unicode-range: U+0460-052F, U+1C80-1C8A, U+20B4, U+2DE0-2DFF, U+A640-A69F, U+FE2E-FE2F;
|
||||
}
|
||||
/* cyrillic */
|
||||
@font-face {
|
||||
font-family: 'Inter';
|
||||
font-style: normal;
|
||||
font-weight: 400;
|
||||
font-display: swap;
|
||||
src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa0ZL7SUc.woff2) format('woff2');
|
||||
unicode-range: U+0301, U+0400-045F, U+0490-0491, U+04B0-04B1, U+2116;
|
||||
}
|
||||
/* greek-ext */
|
||||
@font-face {
|
||||
font-family: 'Inter';
|
||||
font-style: normal;
|
||||
font-weight: 400;
|
||||
font-display: swap;
|
||||
src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa2ZL7SUc.woff2) format('woff2');
|
||||
unicode-range: U+1F00-1FFF;
|
||||
}
|
||||
/* greek */
|
||||
@font-face {
|
||||
font-family: 'Inter';
|
||||
font-style: normal;
|
||||
font-weight: 400;
|
||||
font-display: swap;
|
||||
src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa1pL7SUc.woff2) format('woff2');
|
||||
unicode-range: U+0370-0377, U+037A-037F, U+0384-038A, U+038C, U+038E-03A1, U+03A3-03FF;
|
||||
}
|
||||
/* vietnamese */
|
||||
@font-face {
|
||||
font-family: 'Inter';
|
||||
font-style: normal;
|
||||
font-weight: 400;
|
||||
font-display: swap;
|
||||
src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa2pL7SUc.woff2) format('woff2');
|
||||
unicode-range: U+0102-0103, U+0110-0111, U+0128-0129, U+0168-0169, U+01A0-01A1, U+01AF-01B0, U+0300-0301, U+0303-0304, U+0308-0309, U+0323, U+0329, U+1EA0-1EF9, U+20AB;
|
||||
}
|
||||
/* latin-ext */
|
||||
@font-face {
|
||||
font-family: 'Inter';
|
||||
font-style: normal;
|
||||
font-weight: 400;
|
||||
font-display: swap;
|
||||
src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa25L7SUc.woff2) format('woff2');
|
||||
unicode-range: U+0100-02BA, U+02BD-02C5, U+02C7-02CC, U+02CE-02D7, U+02DD-02FF, U+0304, U+0308, U+0329, U+1D00-1DBF, U+1E00-1E9F, U+1EF2-1EFF, U+2020, U+20A0-20AB, U+20AD-20C0, U+2113, U+2C60-2C7F, U+A720-A7FF;
|
||||
}
|
||||
/* latin */
|
||||
@font-face {
|
||||
font-family: 'Inter';
|
||||
font-style: normal;
|
||||
font-weight: 400;
|
||||
font-display: swap;
|
||||
src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa1ZL7.woff2) format('woff2');
|
||||
unicode-range: U+0000-00FF, U+0131, U+0152-0153, U+02BB-02BC, U+02C6, U+02DA, U+02DC, U+0304, U+0308, U+0329, U+2000-206F, U+20AC, U+2122, U+2191, U+2193, U+2212, U+2215, U+FEFF, U+FFFD;
|
||||
}
|
||||
/* cyrillic-ext */
|
||||
@font-face {
|
||||
font-family: 'Inter';
|
||||
font-style: normal;
|
||||
font-weight: 500;
|
||||
font-display: swap;
|
||||
src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa2JL7SUc.woff2) format('woff2');
|
||||
unicode-range: U+0460-052F, U+1C80-1C8A, U+20B4, U+2DE0-2DFF, U+A640-A69F, U+FE2E-FE2F;
|
||||
}
|
||||
/* cyrillic */
|
||||
@font-face {
|
||||
font-family: 'Inter';
|
||||
font-style: normal;
|
||||
font-weight: 500;
|
||||
font-display: swap;
|
||||
src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa0ZL7SUc.woff2) format('woff2');
|
||||
unicode-range: U+0301, U+0400-045F, U+0490-0491, U+04B0-04B1, U+2116;
|
||||
}
|
||||
/* greek-ext */
|
||||
@font-face {
|
||||
font-family: 'Inter';
|
||||
font-style: normal;
|
||||
font-weight: 500;
|
||||
font-display: swap;
|
||||
src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa2ZL7SUc.woff2) format('woff2');
|
||||
unicode-range: U+1F00-1FFF;
|
||||
}
|
||||
/* greek */
|
||||
@font-face {
|
||||
font-family: 'Inter';
|
||||
font-style: normal;
|
||||
font-weight: 500;
|
||||
font-display: swap;
|
||||
src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa1pL7SUc.woff2) format('woff2');
|
||||
unicode-range: U+0370-0377, U+037A-037F, U+0384-038A, U+038C, U+038E-03A1, U+03A3-03FF;
|
||||
}
|
||||
/* vietnamese */
|
||||
@font-face {
|
||||
font-family: 'Inter';
|
||||
font-style: normal;
|
||||
font-weight: 500;
|
||||
font-display: swap;
|
||||
src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa2pL7SUc.woff2) format('woff2');
|
||||
unicode-range: U+0102-0103, U+0110-0111, U+0128-0129, U+0168-0169, U+01A0-01A1, U+01AF-01B0, U+0300-0301, U+0303-0304, U+0308-0309, U+0323, U+0329, U+1EA0-1EF9, U+20AB;
|
||||
}
|
||||
/* latin-ext */
|
||||
@font-face {
|
||||
font-family: 'Inter';
|
||||
font-style: normal;
|
||||
font-weight: 500;
|
||||
font-display: swap;
|
||||
src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa25L7SUc.woff2) format('woff2');
|
||||
unicode-range: U+0100-02BA, U+02BD-02C5, U+02C7-02CC, U+02CE-02D7, U+02DD-02FF, U+0304, U+0308, U+0329, U+1D00-1DBF, U+1E00-1E9F, U+1EF2-1EFF, U+2020, U+20A0-20AB, U+20AD-20C0, U+2113, U+2C60-2C7F, U+A720-A7FF;
|
||||
}
|
||||
/* latin */
|
||||
@font-face {
|
||||
font-family: 'Inter';
|
||||
font-style: normal;
|
||||
font-weight: 500;
|
||||
font-display: swap;
|
||||
src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa1ZL7.woff2) format('woff2');
|
||||
unicode-range: U+0000-00FF, U+0131, U+0152-0153, U+02BB-02BC, U+02C6, U+02DA, U+02DC, U+0304, U+0308, U+0329, U+2000-206F, U+20AC, U+2122, U+2191, U+2193, U+2212, U+2215, U+FEFF, U+FFFD;
|
||||
}
|
||||
/* cyrillic-ext */
|
||||
@font-face {
|
||||
font-family: 'Inter';
|
||||
font-style: normal;
|
||||
font-weight: 600;
|
||||
font-display: swap;
|
||||
src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa2JL7SUc.woff2) format('woff2');
|
||||
unicode-range: U+0460-052F, U+1C80-1C8A, U+20B4, U+2DE0-2DFF, U+A640-A69F, U+FE2E-FE2F;
|
||||
}
|
||||
/* cyrillic */
|
||||
@font-face {
|
||||
font-family: 'Inter';
|
||||
font-style: normal;
|
||||
font-weight: 600;
|
||||
font-display: swap;
|
||||
src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa0ZL7SUc.woff2) format('woff2');
|
||||
unicode-range: U+0301, U+0400-045F, U+0490-0491, U+04B0-04B1, U+2116;
|
||||
}
|
||||
/* greek-ext */
|
||||
@font-face {
|
||||
font-family: 'Inter';
|
||||
font-style: normal;
|
||||
font-weight: 600;
|
||||
font-display: swap;
|
||||
src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa2ZL7SUc.woff2) format('woff2');
|
||||
unicode-range: U+1F00-1FFF;
|
||||
}
|
||||
/* greek */
|
||||
@font-face {
|
||||
font-family: 'Inter';
|
||||
font-style: normal;
|
||||
font-weight: 600;
|
||||
font-display: swap;
|
||||
src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa1pL7SUc.woff2) format('woff2');
|
||||
unicode-range: U+0370-0377, U+037A-037F, U+0384-038A, U+038C, U+038E-03A1, U+03A3-03FF;
|
||||
}
|
||||
/* vietnamese */
|
||||
@font-face {
|
||||
font-family: 'Inter';
|
||||
font-style: normal;
|
||||
font-weight: 600;
|
||||
font-display: swap;
|
||||
src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa2pL7SUc.woff2) format('woff2');
|
||||
unicode-range: U+0102-0103, U+0110-0111, U+0128-0129, U+0168-0169, U+01A0-01A1, U+01AF-01B0, U+0300-0301, U+0303-0304, U+0308-0309, U+0323, U+0329, U+1EA0-1EF9, U+20AB;
|
||||
}
|
||||
/* latin-ext */
|
||||
@font-face {
|
||||
font-family: 'Inter';
|
||||
font-style: normal;
|
||||
font-weight: 600;
|
||||
font-display: swap;
|
||||
src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa25L7SUc.woff2) format('woff2');
|
||||
unicode-range: U+0100-02BA, U+02BD-02C5, U+02C7-02CC, U+02CE-02D7, U+02DD-02FF, U+0304, U+0308, U+0329, U+1D00-1DBF, U+1E00-1E9F, U+1EF2-1EFF, U+2020, U+20A0-20AB, U+20AD-20C0, U+2113, U+2C60-2C7F, U+A720-A7FF;
|
||||
}
|
||||
/* latin */
|
||||
@font-face {
|
||||
font-family: 'Inter';
|
||||
font-style: normal;
|
||||
font-weight: 600;
|
||||
font-display: swap;
|
||||
src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa1ZL7.woff2) format('woff2');
|
||||
unicode-range: U+0000-00FF, U+0131, U+0152-0153, U+02BB-02BC, U+02C6, U+02DA, U+02DC, U+0304, U+0308, U+0329, U+2000-206F, U+20AC, U+2122, U+2191, U+2193, U+2212, U+2215, U+FEFF, U+FFFD;
|
||||
}
|
||||
/* cyrillic-ext */
|
||||
@font-face {
|
||||
font-family: 'Inter';
|
||||
font-style: normal;
|
||||
font-weight: 700;
|
||||
font-display: swap;
|
||||
src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa2JL7SUc.woff2) format('woff2');
|
||||
unicode-range: U+0460-052F, U+1C80-1C8A, U+20B4, U+2DE0-2DFF, U+A640-A69F, U+FE2E-FE2F;
|
||||
}
|
||||
/* cyrillic */
|
||||
@font-face {
|
||||
font-family: 'Inter';
|
||||
font-style: normal;
|
||||
font-weight: 700;
|
||||
font-display: swap;
|
||||
src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa0ZL7SUc.woff2) format('woff2');
|
||||
unicode-range: U+0301, U+0400-045F, U+0490-0491, U+04B0-04B1, U+2116;
|
||||
}
|
||||
/* greek-ext */
|
||||
@font-face {
|
||||
font-family: 'Inter';
|
||||
font-style: normal;
|
||||
font-weight: 700;
|
||||
font-display: swap;
|
||||
src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa2ZL7SUc.woff2) format('woff2');
|
||||
unicode-range: U+1F00-1FFF;
|
||||
}
|
||||
/* greek */
|
||||
@font-face {
|
||||
font-family: 'Inter';
|
||||
font-style: normal;
|
||||
font-weight: 700;
|
||||
font-display: swap;
|
||||
src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa1pL7SUc.woff2) format('woff2');
|
||||
unicode-range: U+0370-0377, U+037A-037F, U+0384-038A, U+038C, U+038E-03A1, U+03A3-03FF;
|
||||
}
|
||||
/* vietnamese */
|
||||
@font-face {
|
||||
font-family: 'Inter';
|
||||
font-style: normal;
|
||||
font-weight: 700;
|
||||
font-display: swap;
|
||||
src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa2pL7SUc.woff2) format('woff2');
|
||||
unicode-range: U+0102-0103, U+0110-0111, U+0128-0129, U+0168-0169, U+01A0-01A1, U+01AF-01B0, U+0300-0301, U+0303-0304, U+0308-0309, U+0323, U+0329, U+1EA0-1EF9, U+20AB;
|
||||
}
|
||||
/* latin-ext */
|
||||
@font-face {
|
||||
font-family: 'Inter';
|
||||
font-style: normal;
|
||||
font-weight: 700;
|
||||
font-display: swap;
|
||||
src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa25L7SUc.woff2) format('woff2');
|
||||
unicode-range: U+0100-02BA, U+02BD-02C5, U+02C7-02CC, U+02CE-02D7, U+02DD-02FF, U+0304, U+0308, U+0329, U+1D00-1DBF, U+1E00-1E9F, U+1EF2-1EFF, U+2020, U+20A0-20AB, U+20AD-20C0, U+2113, U+2C60-2C7F, U+A720-A7FF;
|
||||
}
|
||||
/* latin */
|
||||
@font-face {
|
||||
font-family: 'Inter';
|
||||
font-style: normal;
|
||||
font-weight: 700;
|
||||
font-display: swap;
|
||||
src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa1ZL7.woff2) format('woff2');
|
||||
unicode-range: U+0000-00FF, U+0131, U+0152-0153, U+02BB-02BC, U+02C6, U+02DA, U+02DC, U+0304, U+0308, U+0329, U+2000-206F, U+20AC, U+2122, U+2191, U+2193, U+2212, U+2215, U+FEFF, U+FFFD;
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
83
app/static/js/tailwindcss.js
Normal file
83
app/static/js/tailwindcss.js
Normal file
File diff suppressed because one or more lines are too long
@@ -51,7 +51,7 @@
|
||||
</div>
|
||||
|
||||
<h2 class="text-3xl font-extrabold text-center text-gray-800 mb-8 animate-slide-down">
|
||||
<img src="/static/icons/logo.png" alt="Gemini Balance Logo" class="h-9 inline-block align-middle mr-2">
|
||||
<img src="{{ static_url('icons/logo.png') }}" alt="Gemini Balance Logo" class="h-9 inline-block align-middle mr-2">
|
||||
Gemini Balance
|
||||
</h2>
|
||||
|
||||
|
||||
@@ -4,21 +4,21 @@
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>{% block title %}Gemini Balance{% endblock %}</title>
|
||||
<link rel="manifest" href="/static/manifest.json" />
|
||||
<link rel="manifest" href="{{ static_url('manifest.json') }}" />
|
||||
<meta name="theme-color" content="#4F46E5" />
|
||||
<meta name="apple-mobile-web-app-capable" content="yes" />
|
||||
<meta name="apple-mobile-web-app-status-bar-style" content="black" />
|
||||
<meta name="apple-mobile-web-app-title" content="GBalance" />
|
||||
<link rel="icon" href="/static/icons/icon-192x192.png" />
|
||||
<link rel="icon" href="{{ static_url('icons/icon-192x192.png') }}" />
|
||||
<link
|
||||
href="https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap"
|
||||
href="{{ static_url('css/fonts.css') }}"
|
||||
rel="stylesheet"
|
||||
/>
|
||||
<link
|
||||
rel="stylesheet"
|
||||
href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.0/css/all.min.css"
|
||||
/>
|
||||
<script src="https://cdn.tailwindcss.com"></script>
|
||||
<script src="{{ static_url('js/tailwindcss.js') }}"></script>
|
||||
<script>
|
||||
tailwind.config = {
|
||||
theme: {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -45,6 +45,179 @@ endblock %} {% block head_extra_styles %}
|
||||
.search-container {
|
||||
grid-template-columns: 1fr;
|
||||
}
|
||||
|
||||
/* 移动端主容器布局 */
|
||||
.mobile-buttons-container {
|
||||
display: flex !important;
|
||||
flex-direction: column !important;
|
||||
gap: 1rem !important;
|
||||
align-items: stretch !important;
|
||||
width: 100% !important;
|
||||
padding: 0 !important;
|
||||
margin: 0 !important;
|
||||
}
|
||||
|
||||
/* 移动端搜索控件布局优化 */
|
||||
.mobile-search-controls {
|
||||
grid-template-columns: 1fr !important;
|
||||
gap: 0.75rem !important;
|
||||
width: 100% !important;
|
||||
margin-bottom: 0.5rem !important;
|
||||
}
|
||||
|
||||
/* 按钮容器在移动端的布局 */
|
||||
.buttons-container-responsive {
|
||||
display: flex !important;
|
||||
flex-direction: column !important;
|
||||
gap: 0.5rem !important;
|
||||
width: 100% !important;
|
||||
align-items: stretch !important;
|
||||
justify-content: stretch !important;
|
||||
}
|
||||
|
||||
/* 移动端所有按钮样式 */
|
||||
.buttons-container-responsive button {
|
||||
width: 100% !important;
|
||||
max-width: 100% !important;
|
||||
justify-content: center !important;
|
||||
text-align: center !important;
|
||||
min-width: 0 !important;
|
||||
flex-shrink: 0 !important;
|
||||
box-sizing: border-box !important;
|
||||
padding: 0.5rem 1rem !important;
|
||||
font-size: 0.875rem !important;
|
||||
white-space: nowrap !important;
|
||||
overflow: hidden !important;
|
||||
text-overflow: ellipsis !important;
|
||||
}
|
||||
}
|
||||
|
||||
/* 中等屏幕优化 */
|
||||
@media (max-width: 1024px) and (min-width: 769px) {
|
||||
.buttons-container-responsive {
|
||||
flex-wrap: wrap !important;
|
||||
justify-content: center !important;
|
||||
}
|
||||
|
||||
.buttons-container-responsive button {
|
||||
flex-shrink: 1 !important;
|
||||
min-width: 0 !important;
|
||||
padding-left: 0.75rem !important;
|
||||
padding-right: 0.75rem !important;
|
||||
}
|
||||
}
|
||||
|
||||
/* 小屏幕(手机)特殊优化 - 确保按钮在边框内 */
|
||||
@media (max-width: 640px) {
|
||||
/* 强制重写主容器布局 */
|
||||
.mobile-buttons-container {
|
||||
display: flex !important;
|
||||
flex-direction: column !important;
|
||||
width: 100% !important;
|
||||
padding: 0 !important;
|
||||
margin: 0 !important;
|
||||
gap: 1rem !important;
|
||||
overflow: visible !important;
|
||||
}
|
||||
|
||||
/* 搜索区域在移动端占满宽度 */
|
||||
.mobile-search-controls {
|
||||
width: 100% !important;
|
||||
box-sizing: border-box !important;
|
||||
}
|
||||
|
||||
/* 按钮区域完全重新布局 */
|
||||
.buttons-container-responsive {
|
||||
display: flex !important;
|
||||
flex-direction: column !important;
|
||||
width: 100% !important;
|
||||
max-width: 100% !important;
|
||||
gap: 0.5rem !important;
|
||||
padding: 0 !important;
|
||||
margin: 0 !important;
|
||||
box-sizing: border-box !important;
|
||||
overflow: visible !important;
|
||||
}
|
||||
|
||||
/* 所有按钮统一样式 */
|
||||
.buttons-container-responsive button {
|
||||
display: flex !important;
|
||||
align-items: center !important;
|
||||
justify-content: center !important;
|
||||
width: 100% !important;
|
||||
max-width: 100% !important;
|
||||
box-sizing: border-box !important;
|
||||
padding: 0.5rem 1rem !important;
|
||||
margin: 0 !important;
|
||||
font-size: 0.875rem !important;
|
||||
line-height: 1.25rem !important;
|
||||
border-radius: 0.5rem !important;
|
||||
white-space: nowrap !important;
|
||||
overflow: hidden !important;
|
||||
text-overflow: ellipsis !important;
|
||||
flex-shrink: 0 !important;
|
||||
}
|
||||
|
||||
/* 特别针对清空全部按钮 */
|
||||
#deleteAllLogsBtn {
|
||||
background-color: #f87171 !important;
|
||||
border: 1px solid #f87171 !important;
|
||||
}
|
||||
|
||||
#deleteAllLogsBtn:hover {
|
||||
background-color: #ef4444 !important;
|
||||
border: 1px solid #ef4444 !important;
|
||||
}
|
||||
|
||||
/* 确保容器不会溢出父级 */
|
||||
.mobile-buttons-container,
|
||||
.mobile-buttons-container > *,
|
||||
.buttons-container-responsive,
|
||||
.buttons-container-responsive > * {
|
||||
max-width: 100% !important;
|
||||
box-sizing: border-box !important;
|
||||
}
|
||||
|
||||
/* 额外的安全边距控制 */
|
||||
.mobile-buttons-container .grid {
|
||||
padding-left: 0 !important;
|
||||
padding-right: 0 !important;
|
||||
margin-left: 0 !important;
|
||||
margin-right: 0 !important;
|
||||
}
|
||||
|
||||
/* 确保主内容区域有适当的内边距 */
|
||||
.rounded-xl.p-6 {
|
||||
padding-left: 1rem !important;
|
||||
padding-right: 1rem !important;
|
||||
}
|
||||
}
|
||||
|
||||
/* 超小屏幕额外优化 */
|
||||
@media (max-width: 480px) {
|
||||
.mobile-buttons-container {
|
||||
gap: 0.75rem !important;
|
||||
}
|
||||
|
||||
.buttons-container-responsive {
|
||||
gap: 0.4rem !important;
|
||||
}
|
||||
|
||||
.buttons-container-responsive button {
|
||||
padding: 0.4rem 0.8rem !important;
|
||||
font-size: 0.8rem !important;
|
||||
}
|
||||
|
||||
/* 主容器内边距进一步缩小 */
|
||||
.rounded-xl.p-6 {
|
||||
padding-left: 0.75rem !important;
|
||||
padding-right: 0.75rem !important;
|
||||
}
|
||||
|
||||
/* 确保清空全部按钮文字不会太挤 */
|
||||
#deleteAllLogsBtn i {
|
||||
margin-right: 0.25rem !important;
|
||||
}
|
||||
}
|
||||
|
||||
input[type="text"],
|
||||
@@ -586,7 +759,7 @@ endblock %} {% block head_extra_styles %}
|
||||
class="text-3xl font-extrabold text-center text-gray-800 mb-4"
|
||||
>
|
||||
<img
|
||||
src="/static/icons/logo.png"
|
||||
src="{{ static_url('icons/logo.png') }}"
|
||||
alt="Gemini Balance Logo"
|
||||
class="h-9 inline-block align-middle mr-2"
|
||||
/>
|
||||
@@ -636,10 +809,10 @@ endblock %} {% block head_extra_styles %}
|
||||
|
||||
<!-- 搜索与操作控件 -->
|
||||
<div
|
||||
class="grid grid-cols-1 lg:grid-cols-[1fr_auto] items-center gap-4 mb-6"
|
||||
class="grid grid-cols-1 lg:grid-cols-[1fr_auto] items-center gap-4 mb-6 mobile-buttons-container"
|
||||
>
|
||||
<div
|
||||
class="grid grid-cols-1 sm:grid-cols-2 lg:grid-cols-3 gap-3 w-full"
|
||||
class="grid grid-cols-1 sm:grid-cols-2 lg:grid-cols-3 gap-3 w-full mobile-search-controls"
|
||||
>
|
||||
<input
|
||||
type="text"
|
||||
@@ -684,7 +857,7 @@ endblock %} {% block head_extra_styles %}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="flex items-center gap-3 flex-shrink-0">
|
||||
<div class="flex items-center gap-3 flex-shrink-0 buttons-container-responsive">
|
||||
<button
|
||||
id="searchBtn"
|
||||
class="flex items-center justify-center px-4 py-1.5 bg-blue-600 hover:bg-blue-700 text-white rounded-lg font-medium transition-all duration-200 shadow-sm hover:shadow-md whitespace-nowrap"
|
||||
@@ -1041,7 +1214,7 @@ endblock %} {% block head_extra_styles %}
|
||||
</div>
|
||||
</div>
|
||||
{% endblock %} {% block body_scripts %}
|
||||
<script src="/static/js/error_logs.js"></script>
|
||||
<script src="{{ static_url('js/error_logs.js') }}"></script>
|
||||
<script>
|
||||
// error_logs.html specific JS initialization (if any)
|
||||
// e.g., initialize date pickers or other elements if needed
|
||||
|
||||
@@ -38,6 +38,18 @@ endblock %} {% block head_extra_styles %}
|
||||
}
|
||||
}
|
||||
|
||||
/* 让图表卡片在网格中占满整行 */
|
||||
.stats-card.chart-wide {
|
||||
grid-column: 1 / -1;
|
||||
}
|
||||
/* 图表容器固定高度,配合 Chart.js maintainAspectRatio:false */
|
||||
.chart-container {
|
||||
height: 300px;
|
||||
}
|
||||
@media (max-width: 640px) {
|
||||
.chart-container { height: 220px; }
|
||||
}
|
||||
|
||||
/* 统计卡片样式 */
|
||||
.stats-card {
|
||||
background-color: rgba(255, 255, 255, 0.95);
|
||||
@@ -310,12 +322,13 @@ endblock %} {% block head_extra_styles %}
|
||||
border-color: rgba(59, 130, 246, 0.3);
|
||||
}
|
||||
|
||||
/* 隐藏原生复选框 */
|
||||
.key-checkbox {
|
||||
/* 隐藏原生复选框(仅隐藏有效/无效列表中的复选框,保留值得注意的Key列表中的复选框可见) */
|
||||
#validKeys .key-checkbox,
|
||||
#invalidKeys .key-checkbox {
|
||||
display: none;
|
||||
}
|
||||
|
||||
/* 自定义复选框样式 */
|
||||
/* 自定义复选框样式(仅针对有效/无效列表) */
|
||||
#validKeys li::before,
|
||||
#invalidKeys li::before {
|
||||
content: "";
|
||||
@@ -351,6 +364,31 @@ endblock %} {% block head_extra_styles %}
|
||||
font-size: 0.8rem;
|
||||
}
|
||||
|
||||
/* 值得注意的Key列表样式与选中态(保留原生复选框可见) */
|
||||
#attentionKeysList li {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
background-color: rgba(255, 255, 255, 0.9);
|
||||
border: 1px solid rgba(0, 0, 0, 0.08);
|
||||
border-radius: 0.5rem;
|
||||
padding: 0.5rem 0.75rem;
|
||||
transition: all 0.2s ease;
|
||||
cursor: pointer;
|
||||
}
|
||||
#attentionKeysList li:hover {
|
||||
border-color: rgba(0, 0, 0, 0.12);
|
||||
box-shadow: 0 4px 12px rgba(0,0,0,0.08);
|
||||
background-color: rgba(249, 250, 251, 0.95);
|
||||
}
|
||||
#attentionKeysList li.selected {
|
||||
background-color: rgba(239, 246, 255, 0.95); /* light blue */
|
||||
border-color: rgba(59, 130, 246, 0.35);
|
||||
}
|
||||
#attentionKeysList .key-checkbox {
|
||||
margin-right: 0.25rem;
|
||||
}
|
||||
|
||||
.key-text {
|
||||
color: #374151 !important; /* gray-700 for light theme */
|
||||
text-shadow: none;
|
||||
@@ -875,7 +913,8 @@ endblock %} {% block head_extra_styles %}
|
||||
}
|
||||
|
||||
/* Fix specific pagination elements by ID and class */
|
||||
#validKeysPageSize, #invalidKeysPageSize {
|
||||
#validKeysPageSize, #invalidKeysPageSize,
|
||||
#itemsPerPageSelect, #invalidItemsPerPageSelect {
|
||||
background-color: rgba(255, 255, 255, 0.95) !important;
|
||||
color: #374151 !important; /* gray-700 */
|
||||
border: 1px solid rgba(0, 0, 0, 0.12) !important;
|
||||
@@ -884,7 +923,8 @@ endblock %} {% block head_extra_styles %}
|
||||
font-size: 0.875rem !important; /* text-sm */
|
||||
}
|
||||
|
||||
#validKeysPageSize:focus, #invalidKeysPageSize:focus {
|
||||
#validKeysPageSize:focus, #invalidKeysPageSize:focus,
|
||||
#itemsPerPageSelect:focus, #invalidItemsPerPageSelect:focus {
|
||||
border-color: #3b82f6 !important; /* blue-500 */
|
||||
box-shadow: 0 0 0 3px rgba(59, 130, 246, 0.1) !important;
|
||||
outline: none !important;
|
||||
@@ -947,7 +987,10 @@ endblock %} {% block head_extra_styles %}
|
||||
label[for="selectAllInvalid"],
|
||||
label[for="failCountThreshold"],
|
||||
label[for="keySearchInput"],
|
||||
label[for="itemsPerPageSelect"] {
|
||||
label[for="itemsPerPageSelect"],
|
||||
label[for="invalidFailCountThreshold"],
|
||||
label[for="invalidKeySearchInput"],
|
||||
label[for="invalidItemsPerPageSelect"] {
|
||||
color: #1f2937 !important; /* gray-800 for maximum contrast */
|
||||
font-weight: 600 !important; /* font-semibold for better visibility */
|
||||
text-shadow: none !important;
|
||||
@@ -1026,33 +1069,80 @@ endblock %} {% block head_extra_styles %}
|
||||
color: #fca5b3 !important;
|
||||
}
|
||||
/* End of API Call Details Modal Specific Styling Adjustments */
|
||||
|
||||
/* 下拉菜单样式 */
|
||||
.dropdown-menu {
|
||||
position: absolute;
|
||||
top: 100%;
|
||||
right: 0;
|
||||
background-color: rgba(255, 255, 255, 0.98);
|
||||
border: 1px solid rgba(0, 0, 0, 0.08);
|
||||
border-radius: 0.5rem;
|
||||
box-shadow: 0 10px 15px -3px rgba(0, 0, 0, 0.1), 0 4px 6px -2px rgba(0, 0, 0, 0.05);
|
||||
min-width: 200px;
|
||||
z-index: 1000;
|
||||
opacity: 0;
|
||||
visibility: hidden;
|
||||
transform: translateY(-10px);
|
||||
transition: all 0.2s ease-in-out;
|
||||
}
|
||||
|
||||
.dropdown-menu.show {
|
||||
opacity: 1;
|
||||
visibility: visible;
|
||||
transform: translateY(0);
|
||||
}
|
||||
|
||||
.dropdown-item {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.75rem;
|
||||
padding: 0.75rem 1rem;
|
||||
color: #374151;
|
||||
text-decoration: none;
|
||||
transition: all 0.2s ease-in-out;
|
||||
cursor: pointer;
|
||||
border: none;
|
||||
background: none;
|
||||
width: 100%;
|
||||
text-align: left;
|
||||
font-size: 0.875rem;
|
||||
}
|
||||
|
||||
.dropdown-item:hover {
|
||||
background-color: rgba(59, 130, 246, 0.1);
|
||||
color: #3b82f6;
|
||||
}
|
||||
|
||||
.dropdown-item:first-child {
|
||||
border-top-left-radius: 0.5rem;
|
||||
border-top-right-radius: 0.5rem;
|
||||
}
|
||||
|
||||
.dropdown-item:last-child {
|
||||
border-bottom-left-radius: 0.5rem;
|
||||
border-bottom-right-radius: 0.5rem;
|
||||
}
|
||||
|
||||
.dropdown-item i {
|
||||
width: 1rem;
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
.dropdown-toggle {
|
||||
position: relative;
|
||||
}
|
||||
</style>
|
||||
{% endblock %} {% block head_extra_scripts %}
|
||||
<!-- keys_status.js needs to be loaded in head because it might be used by inline scripts -->
|
||||
<script src="/static/js/keys_status.js"></script>
|
||||
<!-- Chart.js for time-series chart -->
|
||||
<script src="https://cdn.jsdelivr.net/npm/chart.js@4.4.1/dist/chart.umd.min.js" defer></script>
|
||||
<!-- Load page script with defer to guarantee DOM is ready and keep execution order -->
|
||||
<script src="/static/js/keys_status.js" defer></script>
|
||||
{% endblock %} {% block content %}
|
||||
<div class="container max-w-6xl mx-auto px-4">
|
||||
<!-- Increased max-width -->
|
||||
<div class="glass-card rounded-2xl shadow-xl p-6 md:p-8">
|
||||
<div class="absolute top-6 right-6 flex items-center gap-3">
|
||||
<!-- 自动刷新开关 -->
|
||||
<div class="flex items-center text-sm select-none font-semibold" style="color: #1f2937 !important;">
|
||||
<span class="mr-2">自动刷新</span>
|
||||
<div
|
||||
class="relative inline-block w-10 mr-2 align-middle select-none transition duration-200 ease-in"
|
||||
>
|
||||
<input
|
||||
type="checkbox"
|
||||
name="autoRefreshToggle"
|
||||
id="autoRefreshToggle"
|
||||
class="toggle-checkbox absolute block w-6 h-6 rounded-full bg-white border-4 appearance-none cursor-pointer"
|
||||
/>
|
||||
<label
|
||||
for="autoRefreshToggle"
|
||||
class="toggle-label block overflow-hidden h-6 rounded-full bg-gray-300 cursor-pointer"
|
||||
></label>
|
||||
</div>
|
||||
</div>
|
||||
<!-- 手动刷新按钮 -->
|
||||
<button
|
||||
class="bg-white bg-opacity-20 hover:bg-opacity-30 rounded-full w-8 h-8 flex items-center justify-center text-primary-600 transition-all duration-300"
|
||||
@@ -1061,6 +1151,28 @@ endblock %} {% block head_extra_styles %}
|
||||
>
|
||||
<i class="fas fa-sync-alt"></i>
|
||||
</button>
|
||||
<!-- 下拉菜单按钮 -->
|
||||
<div class="dropdown-toggle relative">
|
||||
<button
|
||||
id="dropdownMenuButton"
|
||||
class="bg-white bg-opacity-20 hover:bg-opacity-30 rounded-full w-8 h-8 flex items-center justify-center text-primary-600 transition-all duration-300"
|
||||
onclick="toggleDropdownMenu()"
|
||||
title="更多操作"
|
||||
>
|
||||
<i class="fas fa-ellipsis-v"></i>
|
||||
</button>
|
||||
<!-- 下拉菜单 -->
|
||||
<div id="dropdownMenu" class="dropdown-menu">
|
||||
<button class="dropdown-item" onclick="copyAllKeys()">
|
||||
<i class="fas fa-copy"></i>
|
||||
<span>复制全部密钥</span>
|
||||
</button>
|
||||
<button class="dropdown-item" onclick="verifyAllKeys()">
|
||||
<i class="fas fa-check-double"></i>
|
||||
<span>验证所有密钥</span>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<h1
|
||||
@@ -1173,7 +1285,94 @@ endblock %} {% block head_extra_styles %}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 可切换时间区间的成功/失败图表卡片 -->
|
||||
<div class="stats-card chart-wide">
|
||||
<div class="stats-card-header">
|
||||
<h3 class="stats-card-title">
|
||||
<i class="fas fa-chart-bar"></i>
|
||||
<span>调用趋势图</span>
|
||||
</h3>
|
||||
<div class="flex items-center gap-2 text-xs">
|
||||
<button id="chartBtn1h" class="px-2 py-1 rounded bg-gray-200 hover:bg-gray-300 text-gray-700">1小时</button>
|
||||
<button id="chartBtn8h" class="px-2 py-1 rounded bg-gray-200 hover:bg-gray-300 text-gray-700">8小时</button>
|
||||
<button id="chartBtn24h" class="px-2 py-1 rounded bg-gray-200 hover:bg-gray-300 text-gray-700">24小时</button>
|
||||
</div>
|
||||
</div>
|
||||
<div class="p-4 chart-container">
|
||||
<canvas id="apiStatsChart"></canvas>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 值得注意的 Key 卡片(错误码统计,可切换) -->
|
||||
<div class="stats-card chart-wide">
|
||||
<div class="stats-card-header">
|
||||
<h3 class="stats-card-title">
|
||||
<i class="fas fa-exclamation-circle"></i>
|
||||
<span>值得注意的Key(24h内错误码最多)</span>
|
||||
</h3>
|
||||
<div class="flex items-center gap-2 text-xs">
|
||||
<button id="attentionErr429" class="px-2 py-1 rounded bg-gray-200 hover:bg-gray-300 text-gray-700" title="429 Too Many Requests">429</button>
|
||||
<button id="attentionErr403" class="px-2 py-1 rounded bg-gray-200 hover:bg-gray-300 text-gray-700" title="403 Forbidden">403</button>
|
||||
<button id="attentionErr400" class="px-2 py-1 rounded bg-gray-200 hover:bg-gray-300 text-gray-700" title="400 Bad Request">400</button>
|
||||
<div class="flex items-center gap-1 ml-2">
|
||||
<input id="attentionErrCustom" type="number" min="100" max="599" placeholder="自定义" class="form-input h-7 w-20 px-2 py-1 text-xs border rounded focus:ring-primary-500 focus:border-primary-500" />
|
||||
<button id="attentionErrGo" class="px-2 py-1 rounded bg-blue-500 hover:bg-blue-600 text-white" title="查询">查询</button>
|
||||
</div>
|
||||
<div class="flex items-center gap-2 ml-3">
|
||||
<label for="attentionLimitInput" class="text-xs text-gray-600">数量</label>
|
||||
<input id="attentionLimitInput" type="number" min="1" max="1000" value="10" class="form-input h-7 w-20 px-2 py-1 text-xs border rounded focus:ring-primary-500 focus:border-primary-500" />
|
||||
<!-- 全选移动到数量输入框右侧 -->
|
||||
<div class="flex items-center gap-1">
|
||||
<input
|
||||
type="checkbox"
|
||||
id="selectAllAttention"
|
||||
class="form-checkbox h-4 w-4 text-primary-600 border-gray-300 rounded focus:ring-primary-500"
|
||||
onchange="toggleSelectAll('attention', this.checked)"
|
||||
/>
|
||||
<label for="selectAllAttention" class="text-xs select-none whitespace-nowrap font-semibold" style="color: #1f2937 !important;">全选</label>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="p-4">
|
||||
<!-- 批量操作按钮组 (仅在选中时显示) -->
|
||||
<div
|
||||
id="attentionBatchActions"
|
||||
class="p-3 border mb-3 hidden flex items-center flex-wrap gap-3"
|
||||
style="background-color: rgba(249, 250, 251, 0.95); border-color: rgba(0, 0, 0, 0.08);"
|
||||
>
|
||||
<span class="text-sm font-semibold whitespace-nowrap" style="color: #1f2937 !important;">
|
||||
已选择 <span id="attentionSelectedCount">0</span> 项
|
||||
</span>
|
||||
<button
|
||||
class="flex items-center gap-1 bg-success-600 hover:bg-success-700 text-white px-2.5 py-1 rounded-lg text-xs font-medium transition-all duration-200 disabled:cursor-not-allowed"
|
||||
onclick="event.stopPropagation(); showVerifyModal('attention', event)"
|
||||
disabled
|
||||
>
|
||||
<i class="fas fa-check-double"></i> 批量验证
|
||||
</button>
|
||||
<button
|
||||
class="flex items-center gap-1 bg-blue-500 hover:bg-blue-600 text-white px-2.5 py-1 rounded-lg text-xs font-medium transition-all duration-200 disabled:cursor-not-allowed"
|
||||
onclick="event.stopPropagation(); copySelectedKeys('attention')"
|
||||
disabled
|
||||
>
|
||||
<i class="fas fa-copy"></i> 批量复制
|
||||
</button>
|
||||
<button
|
||||
class="flex items-center gap-1 bg-red-800 hover:bg-red-900 text-white px-2.5 py-1 rounded-lg text-xs font-medium transition-all duration-200 disabled:cursor-not-allowed"
|
||||
onclick="event.stopPropagation(); showDeleteConfirmationModal('attention', event)"
|
||||
disabled
|
||||
>
|
||||
<i class="fas fa-trash-alt"></i> 批量删除
|
||||
</button>
|
||||
</div>
|
||||
<ul id="attentionKeysList" class="space-y-2">
|
||||
<li class="text-center text-gray-500 py-2">加载中...</li>
|
||||
</ul>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 有效密钥区域 -->
|
||||
<div class="stats-card mb-6 animate-fade-in" style="animation-delay: 0.2s">
|
||||
@@ -1245,6 +1444,7 @@ endblock %} {% block head_extra_styles %}
|
||||
<option value="20">20</option>
|
||||
<option value="50">50</option>
|
||||
<option value="100">100</option>
|
||||
<option value="500">500</option>
|
||||
</select>
|
||||
<span class="text-sm select-none font-semibold" style="color: #1f2937 !important;">项</span>
|
||||
</div>
|
||||
@@ -1318,95 +1518,10 @@ endblock %} {% block head_extra_styles %}
|
||||
<div class="key-content p-4 bg-white bg-opacity-40">
|
||||
<!-- Key list will be populated by JS -->
|
||||
<ul id="validKeys" class="grid grid-cols-1 md:grid-cols-2 gap-3">
|
||||
{# Initial keys rendered by server-side for non-JS users or initial
|
||||
load #} {# JS will replace this content with paginated/filtered
|
||||
results #} {% if valid_keys %} {% for key, fail_count in
|
||||
valid_keys.items() %}
|
||||
<li
|
||||
class="bg-white rounded-lg p-3 shadow-sm hover:shadow-md transition-all duration-300 border border-gray-100 hover:border-success-300 transform hover:-translate-y-1"
|
||||
data-fail-count="{{ fail_count }}"
|
||||
data-key="{{ key }}"
|
||||
>
|
||||
<!-- Checkbox -->
|
||||
<input
|
||||
type="checkbox"
|
||||
class="form-checkbox h-5 w-5 text-primary-600 border-gray-300 rounded focus:ring-primary-500 mt-1 key-checkbox"
|
||||
data-key-type="valid"
|
||||
value="{{ key }}"
|
||||
/>
|
||||
<!-- Key Info -->
|
||||
<div class="flex-grow">
|
||||
<div class="flex flex-col justify-between h-full gap-3">
|
||||
<div class="flex flex-wrap items-center gap-2">
|
||||
<span
|
||||
class="inline-flex items-center px-2.5 py-0.5 rounded-full text-xs font-medium bg-success-50 text-success-600"
|
||||
>
|
||||
<i class="fas fa-check mr-1"></i> 有效
|
||||
</span>
|
||||
<div class="flex items-center gap-1">
|
||||
<span class="key-text font-mono" data-full-key="{{ key }}"
|
||||
>{{ key[:4] + '...' + key[-4:] }}</span
|
||||
>
|
||||
<button
|
||||
class="text-gray-500 hover:text-primary-600 transition-colors"
|
||||
onclick="toggleKeyVisibility(this)"
|
||||
title="显示/隐藏密钥"
|
||||
>
|
||||
<i class="fas fa-eye"></i>
|
||||
</button>
|
||||
</div>
|
||||
<span
|
||||
class="inline-flex items-center px-2.5 py-0.5 rounded-full text-xs font-medium bg-amber-50 text-amber-600"
|
||||
>
|
||||
<i class="fas fa-exclamation-triangle mr-1"></i>
|
||||
失败: {{ fail_count }}
|
||||
</span>
|
||||
</div>
|
||||
<div class="flex flex-wrap items-center gap-2">
|
||||
<button
|
||||
class="flex items-center gap-1 bg-success-600 hover:bg-success-700 text-white px-2.5 py-1 rounded-lg text-xs font-medium transition-all duration-200"
|
||||
onclick="verifyKey('{{ key }}', this)"
|
||||
>
|
||||
<i class="fas fa-check-circle"></i>
|
||||
验证
|
||||
</button>
|
||||
<button
|
||||
class="flex items-center gap-1 bg-gray-500 hover:bg-gray-600 text-white px-2.5 py-1 rounded-lg text-xs font-medium transition-all duration-200"
|
||||
onclick="resetKeyFailCount('{{ key }}', this)"
|
||||
>
|
||||
<i class="fas fa-redo-alt"></i>
|
||||
重置
|
||||
</button>
|
||||
<button
|
||||
class="flex items-center gap-1 bg-blue-500 hover:bg-blue-600 text-white px-2.5 py-1 rounded-lg text-xs font-medium transition-all duration-200"
|
||||
onclick="copyKey('{{ key }}')"
|
||||
>
|
||||
<i class="fas fa-copy"></i>
|
||||
复制
|
||||
</button>
|
||||
<button
|
||||
class="flex items-center gap-1 bg-blue-600 hover:bg-blue-700 text-white px-2.5 py-1 rounded-lg text-xs font-medium transition-all duration-200"
|
||||
onclick="showKeyUsageDetails('{{ key }}')"
|
||||
>
|
||||
<i class="fas fa-chart-pie"></i>
|
||||
详情
|
||||
</button>
|
||||
<button
|
||||
class="flex items-center gap-1 bg-red-800 hover:bg-red-900 text-white px-2.5 py-1 rounded-lg text-xs font-medium transition-all duration-200"
|
||||
onclick="showSingleKeyDeleteConfirmModal('{{ key }}', this)"
|
||||
>
|
||||
<i class="fas fa-trash-alt"></i>
|
||||
删除
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</li>
|
||||
{% endfor %} {% else %}
|
||||
{# This content is now loaded via JavaScript #}
|
||||
<li class="text-center text-gray-500 py-4 col-span-full">
|
||||
暂无有效密钥
|
||||
<i class="fas fa-spinner fa-spin"></i> Loading keys...
|
||||
</li>
|
||||
{% endif %}
|
||||
</ul>
|
||||
<!-- 有效密钥分页控件容器 -->
|
||||
<div
|
||||
@@ -1433,12 +1548,72 @@ endblock %} {% block head_extra_styles %}
|
||||
无效密钥列表 ({{ invalid_key_count }})
|
||||
</h2>
|
||||
</div>
|
||||
<!-- Middle: Filters and Search (Allow wrapping) -->
|
||||
<div
|
||||
class="flex items-center gap-x-4 gap-y-2 flex-grow flex-wrap justify-start md:justify-center"
|
||||
>
|
||||
<!-- Allow wrapping, center on medium+ -->
|
||||
<!-- 失败次数筛选 -->
|
||||
<div class="flex items-center gap-1">
|
||||
<label
|
||||
for="invalidFailCountThreshold"
|
||||
class="text-sm select-none whitespace-nowrap font-semibold"
|
||||
style="color: #1f2937 !important;"
|
||||
>失败次数≥</label
|
||||
>
|
||||
<input
|
||||
type="number"
|
||||
id="invalidFailCountThreshold"
|
||||
value="0"
|
||||
min="0"
|
||||
class="form-input h-7 w-16 px-2 py-1 text-sm border rounded focus:ring-primary-500 focus:border-primary-500"
|
||||
onclick="event.stopPropagation();"
|
||||
/>
|
||||
</div>
|
||||
<!-- 密钥搜索 -->
|
||||
<div class="flex items-center gap-1">
|
||||
<label
|
||||
for="invalidKeySearchInput"
|
||||
class="text-sm select-none whitespace-nowrap font-semibold"
|
||||
style="color: #1f2937 !important;"
|
||||
><i class="fas fa-search mr-1"></i>搜索</label
|
||||
>
|
||||
<input
|
||||
type="search"
|
||||
id="invalidKeySearchInput"
|
||||
placeholder="输入密钥..."
|
||||
class="form-input h-7 w-32 px-2 py-1 text-sm border rounded focus:ring-primary-500 focus:border-primary-500"
|
||||
onclick="event.stopPropagation();"
|
||||
/>
|
||||
</div>
|
||||
<!-- 每页显示数量 -->
|
||||
<div class="flex items-center gap-1">
|
||||
<label
|
||||
for="invalidItemsPerPageSelect"
|
||||
class="text-sm select-none whitespace-nowrap font-semibold"
|
||||
style="color: #1f2937 !important;"
|
||||
>每页</label
|
||||
>
|
||||
<select
|
||||
id="invalidItemsPerPageSelect"
|
||||
class="form-select h-7 px-2 py-1 text-sm border rounded focus:ring-primary-500 focus:border-primary-500"
|
||||
onclick="event.stopPropagation();"
|
||||
>
|
||||
<option value="10">10</option>
|
||||
<option value="20">20</option>
|
||||
<option value="50">50</option>
|
||||
<option value="100">100</option>
|
||||
<option value="500">500</option>
|
||||
</select>
|
||||
<span class="text-sm select-none font-semibold" style="color: #1f2937 !important;">项</span>
|
||||
</div>
|
||||
</div>
|
||||
<!-- Right side: Select All -->
|
||||
<div
|
||||
class="flex items-center gap-1 ml-auto flex-shrink-0"
|
||||
class="flex items-center gap-1 flex-shrink-0"
|
||||
onclick="event.stopPropagation();"
|
||||
>
|
||||
<!-- Use ml-auto, Prevent shrinking -->
|
||||
<!-- Prevent shrinking -->
|
||||
<input
|
||||
type="checkbox"
|
||||
id="selectAllInvalid"
|
||||
@@ -1502,93 +1677,10 @@ endblock %} {% block head_extra_styles %}
|
||||
<div class="key-content p-4 bg-white bg-opacity-40">
|
||||
<!-- Key list will be populated by JS -->
|
||||
<ul id="invalidKeys" class="grid grid-cols-1 md:grid-cols-2 gap-3">
|
||||
{# Initial keys rendered by server-side #} {# JS will replace this
|
||||
content with paginated results #} {% if invalid_keys %} {% for key,
|
||||
fail_count in invalid_keys.items() %}
|
||||
<li
|
||||
class="bg-white rounded-lg p-3 shadow-sm hover:shadow-md transition-all duration-300 border border-gray-100 hover:border-danger-300 transform hover:-translate-y-1"
|
||||
data-key="{{ key }}"
|
||||
>
|
||||
<!-- Checkbox -->
|
||||
<input
|
||||
type="checkbox"
|
||||
class="form-checkbox h-5 w-5 text-primary-600 border-gray-300 rounded focus:ring-primary-500 mt-1 key-checkbox"
|
||||
data-key-type="invalid"
|
||||
value="{{ key }}"
|
||||
/>
|
||||
<!-- Key Info -->
|
||||
<div class="flex-grow">
|
||||
<div class="flex flex-col justify-between h-full gap-3">
|
||||
<div class="flex flex-wrap items-center gap-2">
|
||||
<span
|
||||
class="inline-flex items-center px-2.5 py-0.5 rounded-full text-xs font-medium bg-danger-50 text-danger-600"
|
||||
>
|
||||
<i class="fas fa-times mr-1"></i> 无效
|
||||
</span>
|
||||
<div class="flex items-center gap-1">
|
||||
<span class="key-text font-mono" data-full-key="{{ key }}"
|
||||
>{{ key[:4] + '...' + key[-4:] }}</span
|
||||
>
|
||||
<button
|
||||
class="text-gray-500 hover:text-primary-600 transition-colors"
|
||||
onclick="toggleKeyVisibility(this)"
|
||||
title="显示/隐藏密钥"
|
||||
>
|
||||
<i class="fas fa-eye"></i>
|
||||
</button>
|
||||
</div>
|
||||
<span
|
||||
class="inline-flex items-center px-2.5 py-0.5 rounded-full text-xs font-medium bg-amber-50 text-amber-600"
|
||||
>
|
||||
<i class="fas fa-exclamation-triangle mr-1"></i>
|
||||
失败: {{ fail_count }}
|
||||
</span>
|
||||
</div>
|
||||
<div class="flex flex-wrap items-center gap-2">
|
||||
<button
|
||||
class="flex items-center gap-1 bg-success-600 hover:bg-success-700 text-white px-2.5 py-1 rounded-lg text-xs font-medium transition-all duration-200"
|
||||
onclick="verifyKey('{{ key }}', this)"
|
||||
>
|
||||
<i class="fas fa-check-circle"></i>
|
||||
验证
|
||||
</button>
|
||||
<button
|
||||
class="flex items-center gap-1 bg-gray-500 hover:bg-gray-600 text-white px-2.5 py-1 rounded-lg text-xs font-medium transition-all duration-200"
|
||||
onclick="resetKeyFailCount('{{ key }}', this)"
|
||||
>
|
||||
<i class="fas fa-redo-alt"></i>
|
||||
重置
|
||||
</button>
|
||||
<button
|
||||
class="flex items-center gap-1 bg-blue-500 hover:bg-blue-600 text-white px-2.5 py-1 rounded-lg text-xs font-medium transition-all duration-200"
|
||||
onclick="copyKey('{{ key }}')"
|
||||
>
|
||||
<i class="fas fa-copy"></i>
|
||||
复制
|
||||
</button>
|
||||
<button
|
||||
class="flex items-center gap-1 bg-blue-600 hover:bg-blue-700 text-white px-2.5 py-1 rounded-lg text-xs font-medium transition-all duration-200"
|
||||
onclick="showKeyUsageDetails('{{ key }}')"
|
||||
>
|
||||
<i class="fas fa-chart-pie"></i>
|
||||
详情
|
||||
</button>
|
||||
<button
|
||||
class="flex items-center gap-1 bg-red-800 hover:bg-red-900 text-white px-2.5 py-1 rounded-lg text-xs font-medium transition-all duration-200"
|
||||
onclick="showSingleKeyDeleteConfirmModal('{{ key }}', this)"
|
||||
>
|
||||
<i class="fas fa-trash-alt"></i>
|
||||
删除
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</li>
|
||||
{% endfor %} {% else %}
|
||||
{# This content is now loaded via JavaScript #}
|
||||
<li class="text-center text-gray-500 py-4 col-span-full">
|
||||
暂无无效密钥
|
||||
<i class="fas fa-spinner fa-spin"></i> Loading keys...
|
||||
</li>
|
||||
{% endif %}
|
||||
</ul>
|
||||
<!-- 无效密钥分页控件容器 -->
|
||||
<div
|
||||
@@ -1687,7 +1779,11 @@ endblock %} {% block head_extra_styles %}
|
||||
</button>
|
||||
</div>
|
||||
<div class="mb-6">
|
||||
<p style="color: #374151" id="verifyModalMessage"></p>
|
||||
<p style="color: #374151" id="verifyModalMessage" class="mb-4"></p>
|
||||
<div class="flex items-center gap-2">
|
||||
<label for="batchSize" class="text-sm font-medium" style="color: #374151;">每批次验证数量:</label>
|
||||
<input type="number" id="batchSize" value="10" min="1" class="form-input h-8 w-20 px-2 py-1 text-sm border rounded focus:ring-primary-500 focus:border-primary-500">
|
||||
</div>
|
||||
</div>
|
||||
<div class="flex justify-end gap-3">
|
||||
<button
|
||||
@@ -1811,10 +1907,83 @@ endblock %} {% block head_extra_styles %}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 批量操作进度模态框 -->
|
||||
<div
|
||||
id="progressModal"
|
||||
class="fixed inset-0 bg-black bg-opacity-50 flex items-center justify-center z-50 hidden"
|
||||
>
|
||||
<div
|
||||
class="bg-white rounded-lg p-6 shadow-xl max-w-2xl w-full animate-fade-in"
|
||||
style="
|
||||
background-color: rgba(255, 255, 255, 0.98);
|
||||
color: #374151;
|
||||
border-color: rgba(0, 0, 0, 0.08);
|
||||
"
|
||||
>
|
||||
<div class="flex items-center justify-between mb-4">
|
||||
<h3
|
||||
class="text-lg font-semibold text-gray-800"
|
||||
id="progressModalTitle"
|
||||
style="color: #1f2937; font-weight: 600"
|
||||
>
|
||||
批量操作进度
|
||||
</h3>
|
||||
<button
|
||||
onclick="closeProgressModal()"
|
||||
id="closeProgressModalBtn"
|
||||
class="text-gray-500 hover:text-gray-700 focus:outline-none"
|
||||
disabled
|
||||
>
|
||||
<i class="fas fa-times"></i>
|
||||
</button>
|
||||
</div>
|
||||
<div class="mb-4">
|
||||
<p id="progressStatusText" class="text-sm text-gray-600 mb-2">
|
||||
准备开始...
|
||||
</p>
|
||||
<div class="w-full bg-gray-200 rounded-full h-4 dark:bg-gray-700">
|
||||
<div
|
||||
id="progressBar"
|
||||
class="bg-primary-600 h-4 rounded-full transition-all duration-300"
|
||||
style="width: 0%"
|
||||
></div>
|
||||
</div>
|
||||
<p
|
||||
id="progressPercentage"
|
||||
class="text-center text-sm font-semibold mt-1"
|
||||
style="color: #1f2937"
|
||||
>
|
||||
0%
|
||||
</p>
|
||||
</div>
|
||||
<div
|
||||
id="progressLog"
|
||||
class="text-xs max-h-60 overflow-y-auto bg-gray-50 p-3 rounded border border-gray-200 space-y-1 font-mono"
|
||||
style="
|
||||
background-color: rgba(249, 250, 251, 0.95);
|
||||
border-color: rgba(0, 0, 0, 0.08);
|
||||
"
|
||||
>
|
||||
<!-- Log entries will be added here -->
|
||||
</div>
|
||||
<div class="flex justify-end gap-3 mt-6">
|
||||
<button
|
||||
id="progressModalCloseBtn"
|
||||
onclick="closeProgressModal(true)"
|
||||
class="px-4 py-1.5 text-sm font-medium bg-primary-700 hover:bg-primary-800 text-white rounded-lg transition-colors"
|
||||
disabled
|
||||
>
|
||||
完成并刷新
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 操作结果模态框 -->
|
||||
<div
|
||||
id="resultModal"
|
||||
class="fixed inset-0 bg-black bg-opacity-50 flex items-center justify-center z-50 hidden"
|
||||
class="fixed inset-0 bg-black bg-opacity-50 flex items-center justify-center hidden"
|
||||
style="z-index: 1001;"
|
||||
>
|
||||
<div
|
||||
class="bg-white rounded-2xl p-0 shadow-2xl max-w-lg w-full animate-fade-in border border-gray-200"
|
||||
|
||||
@@ -1,14 +1,17 @@
|
||||
"""
|
||||
通用工具函数模块
|
||||
"""
|
||||
import json
|
||||
import re
|
||||
import base64
|
||||
import requests
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from pathlib import Path
|
||||
import logging
|
||||
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import requests
|
||||
|
||||
from app.config.config import Settings
|
||||
from app.core.constants import DATA_URL_PATTERN, IMAGE_URL_PATTERN, VALID_IMAGE_RATIOS
|
||||
|
||||
helper_logger = logging.getLogger("app.utils")
|
||||
@@ -20,23 +23,25 @@ VERSION_FILE_PATH = PROJECT_ROOT / "VERSION"
|
||||
def extract_mime_type_and_data(base64_string: str) -> Tuple[Optional[str], str]:
|
||||
"""
|
||||
从 base64 字符串中提取 MIME 类型和数据
|
||||
|
||||
|
||||
Args:
|
||||
base64_string: 可能包含 MIME 类型信息的 base64 字符串
|
||||
|
||||
|
||||
Returns:
|
||||
tuple: (mime_type, encoded_data)
|
||||
"""
|
||||
# 检查字符串是否以 "data:" 格式开始
|
||||
if base64_string.startswith('data:'):
|
||||
if base64_string.startswith("data:"):
|
||||
# 提取 MIME 类型和数据
|
||||
pattern = DATA_URL_PATTERN
|
||||
match = re.match(pattern, base64_string)
|
||||
if match:
|
||||
mime_type = "image/jpeg" if match.group(1) == "image/jpg" else match.group(1)
|
||||
mime_type = (
|
||||
"image/jpeg" if match.group(1) == "image/jpg" else match.group(1)
|
||||
)
|
||||
encoded_data = match.group(2)
|
||||
return mime_type, encoded_data
|
||||
|
||||
|
||||
# 如果不是预期格式,假定它只是数据部分
|
||||
return None, base64_string
|
||||
|
||||
@@ -44,20 +49,20 @@ def extract_mime_type_and_data(base64_string: str) -> Tuple[Optional[str], str]:
|
||||
def convert_image_to_base64(url: str) -> str:
|
||||
"""
|
||||
将图片URL转换为base64编码
|
||||
|
||||
|
||||
Args:
|
||||
url: 图片URL
|
||||
|
||||
|
||||
Returns:
|
||||
str: base64编码的图片数据
|
||||
|
||||
|
||||
Raises:
|
||||
Exception: 如果获取图片失败
|
||||
"""
|
||||
response = requests.get(url)
|
||||
if response.status_code == 200:
|
||||
# 将图片内容转换为base64
|
||||
img_data = base64.b64encode(response.content).decode('utf-8')
|
||||
img_data = base64.b64encode(response.content).decode("utf-8")
|
||||
return img_data
|
||||
else:
|
||||
raise Exception(f"Failed to fetch image: {response.status_code}")
|
||||
@@ -66,64 +71,66 @@ def convert_image_to_base64(url: str) -> str:
|
||||
def format_json_response(data: Dict[str, Any], indent: int = 2) -> str:
|
||||
"""
|
||||
格式化JSON响应
|
||||
|
||||
|
||||
Args:
|
||||
data: 要格式化的数据
|
||||
indent: 缩进空格数
|
||||
|
||||
|
||||
Returns:
|
||||
str: 格式化后的JSON字符串
|
||||
"""
|
||||
return json.dumps(data, indent=indent, ensure_ascii=False)
|
||||
|
||||
|
||||
def parse_prompt_parameters(prompt: str, default_ratio: str = "1:1") -> Tuple[str, int, str]:
|
||||
def parse_prompt_parameters(
|
||||
prompt: str, default_ratio: str = "1:1"
|
||||
) -> Tuple[str, int, str]:
|
||||
"""
|
||||
从prompt中解析参数
|
||||
|
||||
|
||||
支持的格式:
|
||||
- {n:数量} 例如: {n:2} 生成2张图片
|
||||
- {ratio:比例} 例如: {ratio:16:9} 使用16:9比例
|
||||
|
||||
|
||||
Args:
|
||||
prompt: 提示文本
|
||||
default_ratio: 默认比例
|
||||
|
||||
|
||||
Returns:
|
||||
tuple: (清理后的提示文本, 图片数量, 比例)
|
||||
"""
|
||||
# 默认值
|
||||
n = 1
|
||||
aspect_ratio = default_ratio
|
||||
|
||||
|
||||
# 解析n参数
|
||||
n_match = re.search(r'{n:(\d+)}', prompt)
|
||||
n_match = re.search(r"{n:(\d+)}", prompt)
|
||||
if n_match:
|
||||
n = int(n_match.group(1))
|
||||
if n < 1 or n > 4:
|
||||
raise ValueError(f"Invalid n value: {n}. Must be between 1 and 4.")
|
||||
prompt = prompt.replace(n_match.group(0), '').strip()
|
||||
|
||||
# 解析ratio参数
|
||||
ratio_match = re.search(r'{ratio:(\d+:\d+)}', prompt)
|
||||
prompt = prompt.replace(n_match.group(0), "").strip()
|
||||
|
||||
# 解析ratio参数
|
||||
ratio_match = re.search(r"{ratio:(\d+:\d+)}", prompt)
|
||||
if ratio_match:
|
||||
aspect_ratio = ratio_match.group(1)
|
||||
if aspect_ratio not in VALID_IMAGE_RATIOS:
|
||||
raise ValueError(
|
||||
f"Invalid ratio: {aspect_ratio}. Must be one of: {', '.join(VALID_IMAGE_RATIOS)}"
|
||||
)
|
||||
prompt = prompt.replace(ratio_match.group(0), '').strip()
|
||||
|
||||
prompt = prompt.replace(ratio_match.group(0), "").strip()
|
||||
|
||||
return prompt, n, aspect_ratio
|
||||
|
||||
|
||||
def extract_image_urls_from_markdown(text: str) -> List[str]:
|
||||
"""
|
||||
从Markdown文本中提取图片URL
|
||||
|
||||
|
||||
Args:
|
||||
text: Markdown文本
|
||||
|
||||
|
||||
Returns:
|
||||
List[str]: 图片URL列表
|
||||
"""
|
||||
@@ -135,38 +142,90 @@ def extract_image_urls_from_markdown(text: str) -> List[str]:
|
||||
def is_valid_api_key(key: str) -> bool:
|
||||
"""
|
||||
检查API密钥格式是否有效
|
||||
|
||||
|
||||
Args:
|
||||
key: API密钥
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 如果密钥格式有效则返回True
|
||||
"""
|
||||
# 检查Gemini API密钥格式
|
||||
if key.startswith('AIza'):
|
||||
if key.startswith("AIza"):
|
||||
return len(key) >= 30
|
||||
|
||||
|
||||
# 检查OpenAI API密钥格式
|
||||
if key.startswith('sk-'):
|
||||
if key.startswith("sk-"):
|
||||
return len(key) >= 30
|
||||
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def redact_key_for_logging(key: str) -> str:
|
||||
"""
|
||||
Redacts API key for secure logging by showing only first and last 6 characters.
|
||||
|
||||
Args:
|
||||
key: API key to redact
|
||||
|
||||
Returns:
|
||||
str: Redacted key in format "first6...last6" or descriptive placeholder for edge cases
|
||||
"""
|
||||
if not key:
|
||||
return key
|
||||
|
||||
if len(key) <= 12:
|
||||
return f"{key[:3]}...{key[-3:]}"
|
||||
else:
|
||||
return f"{key[:6]}...{key[-6:]}"
|
||||
|
||||
|
||||
def get_current_version(default_version: str = "0.0.0") -> str:
|
||||
"""Reads the current version from the VERSION file."""
|
||||
version_file = VERSION_FILE_PATH
|
||||
try:
|
||||
with version_file.open('r', encoding='utf-8') as f:
|
||||
with version_file.open("r", encoding="utf-8") as f:
|
||||
version = f.read().strip()
|
||||
if not version:
|
||||
helper_logger.warning(f"VERSION file ('{version_file}') is empty. Using default version '{default_version}'.")
|
||||
helper_logger.warning(
|
||||
f"VERSION file ('{version_file}') is empty. Using default version '{default_version}'."
|
||||
)
|
||||
return default_version
|
||||
return version
|
||||
except FileNotFoundError:
|
||||
helper_logger.warning(f"VERSION file not found at '{version_file}'. Using default version '{default_version}'.")
|
||||
helper_logger.warning(
|
||||
f"VERSION file not found at '{version_file}'. Using default version '{default_version}'."
|
||||
)
|
||||
return default_version
|
||||
except IOError as e:
|
||||
helper_logger.error(f"Error reading VERSION file ('{version_file}'): {e}. Using default version '{default_version}'.")
|
||||
helper_logger.error(
|
||||
f"Error reading VERSION file ('{version_file}'): {e}. Using default version '{default_version}'."
|
||||
)
|
||||
return default_version
|
||||
|
||||
|
||||
def is_image_upload_configured(settings: Settings) -> bool:
|
||||
"""Return True only if a valid upload provider is selected and all required settings for that provider are present."""
|
||||
|
||||
provider = (getattr(settings, "UPLOAD_PROVIDER", "") or "").strip().lower()
|
||||
if provider == "smms":
|
||||
return bool(getattr(settings, "SMMS_SECRET_TOKEN", ""))
|
||||
if provider == "picgo":
|
||||
return bool(getattr(settings, "PICGO_API_KEY", ""))
|
||||
if provider == "aliyun_oss":
|
||||
return all(
|
||||
[
|
||||
getattr(settings, "OSS_ACCESS_KEY", ""),
|
||||
getattr(settings, "OSS_ACCESS_KEY_SECRET", ""),
|
||||
getattr(settings, "OSS_BUCKET_NAME", ""),
|
||||
getattr(settings, "OSS_ENDPOINT", ""),
|
||||
getattr(settings, "OSS_REGION", "")
|
||||
]
|
||||
)
|
||||
if provider == "cloudflare_imgbed":
|
||||
return all(
|
||||
[
|
||||
getattr(settings, "CLOUDFLARE_IMGBED_URL", ""),
|
||||
getattr(settings, "CLOUDFLARE_IMGBED_AUTH_CODE", ""),
|
||||
]
|
||||
)
|
||||
return False
|
||||
|
||||
127
app/utils/static_version.py
Normal file
127
app/utils/static_version.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""
|
||||
静态资源版本控制工具
|
||||
用于给CSS和JS文件添加版本参数,避免浏览器缓存问题
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import time
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
from app.utils.helpers import get_current_version
|
||||
|
||||
|
||||
class StaticVersionManager:
|
||||
"""静态资源版本管理器"""
|
||||
|
||||
def __init__(self, static_dir: str = "app/static"):
|
||||
self.static_dir = Path(static_dir)
|
||||
self._version_cache: Dict[str, str] = {}
|
||||
self._use_file_hash = True # 是否使用文件哈希作为版本号
|
||||
|
||||
def get_version_for_file(self, file_path: str) -> str:
|
||||
"""
|
||||
获取文件的版本号
|
||||
|
||||
Args:
|
||||
file_path: 相对于static目录的文件路径,如 'css/fonts.css'
|
||||
|
||||
Returns:
|
||||
版本号字符串
|
||||
"""
|
||||
if self._use_file_hash:
|
||||
return self._get_file_hash_version(file_path)
|
||||
else:
|
||||
return self._get_app_version()
|
||||
|
||||
def _get_file_hash_version(self, file_path: str) -> str:
|
||||
"""基于文件内容生成哈希版本号"""
|
||||
# 如果已经缓存过,直接返回
|
||||
if file_path in self._version_cache:
|
||||
return self._version_cache[file_path]
|
||||
|
||||
full_path = self.static_dir / file_path
|
||||
|
||||
if not full_path.exists():
|
||||
# 文件不存在,使用应用版本号作为fallback
|
||||
version = self._get_app_version()
|
||||
else:
|
||||
try:
|
||||
# 读取文件内容并计算MD5哈希
|
||||
with open(full_path, "rb") as f:
|
||||
content = f.read()
|
||||
hash_object = hashlib.md5(content)
|
||||
version = hash_object.hexdigest()[:8] # 取前8位
|
||||
except Exception:
|
||||
# 读取失败,使用应用版本号作为fallback
|
||||
version = self._get_app_version()
|
||||
|
||||
# 缓存结果
|
||||
self._version_cache[file_path] = version
|
||||
return version
|
||||
|
||||
def _get_app_version(self) -> str:
|
||||
"""获取应用程序版本号"""
|
||||
try:
|
||||
return get_current_version().replace(".", "")
|
||||
except Exception:
|
||||
# 如果获取版本失败,使用时间戳
|
||||
return str(int(time.time()))
|
||||
|
||||
def get_versioned_url(self, file_path: str) -> str:
|
||||
"""
|
||||
获取带版本参数的URL
|
||||
|
||||
Args:
|
||||
file_path: 相对于static目录的文件路径
|
||||
|
||||
Returns:
|
||||
带版本参数的URL
|
||||
"""
|
||||
version = self.get_version_for_file(file_path)
|
||||
return f"/static/{file_path}?v={version}"
|
||||
|
||||
def clear_cache(self):
|
||||
"""清空版本缓存"""
|
||||
self._version_cache.clear()
|
||||
|
||||
|
||||
# 全局实例
|
||||
_static_version_manager = StaticVersionManager()
|
||||
|
||||
|
||||
def get_static_url(file_path: str) -> str:
|
||||
"""
|
||||
获取静态资源的版本化URL
|
||||
|
||||
Args:
|
||||
file_path: 相对于static目录的文件路径
|
||||
|
||||
Returns:
|
||||
带版本参数的完整URL
|
||||
|
||||
Example:
|
||||
get_static_url('css/fonts.css') -> '/static/css/fonts.css?v=a1b2c3d4'
|
||||
get_static_url('js/config_editor.js') -> '/static/js/config_editor.js?v=e5f6g7h8'
|
||||
"""
|
||||
return _static_version_manager.get_versioned_url(file_path)
|
||||
|
||||
|
||||
def clear_static_cache():
|
||||
"""清空静态资源版本缓存"""
|
||||
_static_version_manager.clear_cache()
|
||||
|
||||
|
||||
@lru_cache(maxsize=128)
|
||||
def get_cached_static_url(file_path: str) -> str:
|
||||
"""
|
||||
获取缓存的静态资源URL(用于开发环境)
|
||||
|
||||
Args:
|
||||
file_path: 相对于static目录的文件路径
|
||||
|
||||
Returns:
|
||||
带版本参数的完整URL
|
||||
"""
|
||||
return get_static_url(file_path)
|
||||
@@ -2,6 +2,12 @@ import requests
|
||||
from app.domain.image_models import ImageMetadata, ImageUploader, UploadResponse
|
||||
from enum import Enum
|
||||
from typing import Optional, Any
|
||||
import hashlib
|
||||
import base64
|
||||
import hmac
|
||||
from datetime import datetime
|
||||
from urllib.parse import quote
|
||||
from app.log.logger import get_image_create_logger
|
||||
|
||||
class UploadErrorType(Enum):
|
||||
"""上传错误类型枚举"""
|
||||
@@ -179,9 +185,22 @@ class PicGoUploader(ImageUploader):
|
||||
"""
|
||||
try:
|
||||
# 准备请求头
|
||||
headers = {
|
||||
"X-API-Key": self.api_key
|
||||
}
|
||||
headers = {}
|
||||
|
||||
# 构建请求URL
|
||||
request_url = self.api_url
|
||||
|
||||
# 判断是否为默认PicGo URL,如果是则使用header认证,否则使用URL参数认证
|
||||
if self.api_url == "https://www.picgo.net/api/1/upload":
|
||||
headers["X-API-Key"] = self.api_key
|
||||
else:
|
||||
# 对于自定义URL,将API key作为查询参数添加到URL中
|
||||
from urllib.parse import urlparse, urlunparse, parse_qs, urlencode
|
||||
parsed_url = urlparse(request_url)
|
||||
query_params = parse_qs(parsed_url.query)
|
||||
query_params["key"] = self.api_key
|
||||
new_query = urlencode(query_params, doseq=True)
|
||||
request_url = urlunparse(parsed_url._replace(query=new_query))
|
||||
|
||||
# 准备文件数据
|
||||
files = {
|
||||
@@ -190,7 +209,7 @@ class PicGoUploader(ImageUploader):
|
||||
|
||||
# 发送请求
|
||||
response = requests.post(
|
||||
self.api_url,
|
||||
request_url,
|
||||
headers=headers,
|
||||
files=files
|
||||
)
|
||||
@@ -201,6 +220,34 @@ class PicGoUploader(ImageUploader):
|
||||
# 解析响应
|
||||
result = response.json()
|
||||
|
||||
# 处理自定义PicGo服务器的响应格式
|
||||
if "success" in result and "result" in result:
|
||||
# 自定义PicGo服务器格式: {"success": true, "result": ["url"]}
|
||||
if result["success"]:
|
||||
image_url = result["result"][0] if result["result"] and len(result["result"]) > 0 else ""
|
||||
image_metadata = ImageMetadata(
|
||||
width=0,
|
||||
height=0,
|
||||
filename=filename,
|
||||
size=0,
|
||||
url=image_url,
|
||||
delete_url=None
|
||||
)
|
||||
return UploadResponse(
|
||||
success=True,
|
||||
code="success",
|
||||
message="Upload success",
|
||||
data=image_metadata
|
||||
)
|
||||
else:
|
||||
raise UploadError(
|
||||
message="Upload failed",
|
||||
error_type=UploadErrorType.SERVER_ERROR,
|
||||
status_code=400,
|
||||
details=result
|
||||
)
|
||||
|
||||
# 处理官方PicGo服务器的响应格式
|
||||
# 验证上传是否成功
|
||||
if result.get("status_code") != 200:
|
||||
error_message = "Upload failed"
|
||||
@@ -259,20 +306,207 @@ class PicGoUploader(ImageUploader):
|
||||
)
|
||||
|
||||
|
||||
class AliyunOSSUploader(ImageUploader):
|
||||
"""阿里云OSS图片上传器"""
|
||||
|
||||
def __init__(self, access_key: str, access_key_secret: str, bucket_name: str,
|
||||
endpoint: str, region: str, use_internal: bool = False):
|
||||
"""
|
||||
初始化阿里云OSS上传器
|
||||
|
||||
Args:
|
||||
access_key: OSS访问密钥ID
|
||||
access_key_secret: OSS访问密钥
|
||||
bucket_name: OSS存储桶名称
|
||||
endpoint: OSS端点地址
|
||||
region: OSS区域
|
||||
use_internal: 是否使用内网端点
|
||||
"""
|
||||
self.access_key = access_key
|
||||
self.access_key_secret = access_key_secret
|
||||
self.bucket_name = bucket_name
|
||||
self.endpoint = endpoint
|
||||
self.region = region
|
||||
self.use_internal = use_internal
|
||||
self.logger = get_image_create_logger()
|
||||
|
||||
# 构建请求URL
|
||||
if not endpoint.startswith(('http://', 'https://')):
|
||||
self.base_url = f"https://{bucket_name}.{endpoint}"
|
||||
else:
|
||||
self.base_url = f"{endpoint}/{bucket_name}"
|
||||
|
||||
self.logger.info(f"Initialized AliyunOSSUploader for bucket: {bucket_name}, region: {region}")
|
||||
|
||||
def _sign_request(self, method: str, path: str, headers: dict, content: bytes = b'') -> dict:
|
||||
"""
|
||||
为OSS请求生成签名
|
||||
|
||||
Args:
|
||||
method: HTTP方法
|
||||
path: 请求路径
|
||||
headers: 请求头
|
||||
content: 请求内容
|
||||
|
||||
Returns:
|
||||
包含签名的请求头
|
||||
"""
|
||||
# 计算Content-MD5
|
||||
content_md5 = base64.b64encode(hashlib.md5(content).digest()).decode('utf-8') if content else ''
|
||||
|
||||
# 设置日期
|
||||
date = datetime.utcnow().strftime('%a, %d %b %Y %H:%M:%S GMT')
|
||||
|
||||
# 更新headers
|
||||
headers['Date'] = date
|
||||
if content_md5:
|
||||
headers['Content-MD5'] = content_md5
|
||||
headers['Content-Type'] = headers.get('Content-Type', 'image/png')
|
||||
|
||||
# 构建CanonicalizedOSSHeaders
|
||||
oss_headers = []
|
||||
for key, value in sorted(headers.items()):
|
||||
if key.lower().startswith('x-oss-'):
|
||||
oss_headers.append(f"{key.lower()}:{value}")
|
||||
canonicalized_oss_headers = '\n'.join(oss_headers)
|
||||
if canonicalized_oss_headers:
|
||||
canonicalized_oss_headers += '\n'
|
||||
|
||||
# 构建CanonicalizedResource
|
||||
canonicalized_resource = f"/{self.bucket_name}{path}"
|
||||
|
||||
# 构建StringToSign
|
||||
string_to_sign = f"{method}\n{content_md5}\n{headers.get('Content-Type', '')}\n{date}\n{canonicalized_oss_headers}{canonicalized_resource}"
|
||||
|
||||
# 计算签名
|
||||
signature = base64.b64encode(
|
||||
hmac.new(
|
||||
self.access_key_secret.encode('utf-8'),
|
||||
string_to_sign.encode('utf-8'),
|
||||
hashlib.sha1
|
||||
).digest()
|
||||
).decode('utf-8')
|
||||
|
||||
# 添加Authorization头
|
||||
headers['Authorization'] = f"OSS {self.access_key}:{signature}"
|
||||
|
||||
return headers
|
||||
|
||||
def upload(self, file: bytes, filename: str) -> UploadResponse:
|
||||
"""
|
||||
上传图片到阿里云OSS
|
||||
|
||||
Args:
|
||||
file: 图片文件二进制数据
|
||||
filename: 文件名(将作为OSS对象的key)
|
||||
|
||||
Returns:
|
||||
UploadResponse: 上传响应对象
|
||||
|
||||
Raises:
|
||||
UploadError: 上传失败时抛出异常
|
||||
"""
|
||||
# 记录开始上传的日志
|
||||
self.logger.info(f"Starting OSS upload for file: {filename}, size: {len(file)} bytes")
|
||||
|
||||
try:
|
||||
# 构建对象路径
|
||||
object_key = f"/{filename}"
|
||||
|
||||
# 准备请求头
|
||||
headers = {
|
||||
'Content-Type': 'image/png',
|
||||
'x-oss-object-acl': 'public-read' # 设置为公共读
|
||||
}
|
||||
|
||||
# 签名请求
|
||||
signed_headers = self._sign_request('PUT', object_key, headers, file)
|
||||
|
||||
# 构建完整URL
|
||||
upload_url = f"{self.base_url}{object_key}"
|
||||
self.logger.debug(f"OSS upload URL: {upload_url}")
|
||||
|
||||
# 发送请求
|
||||
response = requests.put(
|
||||
upload_url,
|
||||
data=file,
|
||||
headers=signed_headers
|
||||
)
|
||||
|
||||
# 检查响应状态
|
||||
if response.status_code != 200:
|
||||
error_msg = f"OSS upload failed with status {response.status_code}, response: {response.text}"
|
||||
self.logger.error(f"OSS upload failed for {filename}: {error_msg}")
|
||||
raise UploadError(
|
||||
message=f"OSS upload failed with status {response.status_code}",
|
||||
error_type=UploadErrorType.SERVER_ERROR,
|
||||
status_code=response.status_code,
|
||||
details={'response': response.text}
|
||||
)
|
||||
|
||||
# 构建访问URL
|
||||
if self.endpoint.startswith(('http://', 'https://')):
|
||||
access_url = f"{self.endpoint}/{self.bucket_name}{object_key}"
|
||||
else:
|
||||
access_url = f"https://{self.bucket_name}.{self.endpoint}{object_key}"
|
||||
|
||||
# 构建图片元数据
|
||||
image_metadata = ImageMetadata(
|
||||
width=0, # OSS PUT不返回图片尺寸
|
||||
height=0,
|
||||
filename=filename,
|
||||
size=len(file),
|
||||
url=access_url,
|
||||
delete_url=None # OSS需要单独的删除操作
|
||||
)
|
||||
|
||||
# 记录上传成功的日志
|
||||
self.logger.info(f"OSS upload successful for {filename}, URL: {access_url}")
|
||||
|
||||
return UploadResponse(
|
||||
success=True,
|
||||
code="success",
|
||||
message="Upload to Aliyun OSS success",
|
||||
data=image_metadata
|
||||
)
|
||||
|
||||
except requests.RequestException as e:
|
||||
error_msg = f"OSS upload request failed: {str(e)}"
|
||||
self.logger.error(f"OSS upload request failed for {filename}: {error_msg}")
|
||||
raise UploadError(
|
||||
message=error_msg,
|
||||
error_type=UploadErrorType.NETWORK_ERROR,
|
||||
original_error=e
|
||||
)
|
||||
except UploadError:
|
||||
# UploadError 已经被记录了,直接重新抛出
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = f"OSS upload failed: {str(e)}"
|
||||
self.logger.error(f"OSS upload unexpected error for {filename}: {error_msg}")
|
||||
raise UploadError(
|
||||
message=error_msg,
|
||||
error_type=UploadErrorType.UNKNOWN,
|
||||
original_error=e
|
||||
)
|
||||
|
||||
|
||||
class CloudFlareImgBedUploader(ImageUploader):
|
||||
"""CloudFlare图床上传器"""
|
||||
|
||||
def __init__(self, auth_code: str, api_url: str):
|
||||
|
||||
def __init__(self, auth_code: str, api_url: str, upload_folder: str = ""):
|
||||
"""
|
||||
初始化CloudFlare图床上传器
|
||||
|
||||
Args:
|
||||
auth_code: 认证码
|
||||
api_url: 上传API地址
|
||||
upload_folder: 上传文件夹路径(可选)
|
||||
"""
|
||||
self.auth_code = auth_code
|
||||
self.api_url = api_url
|
||||
|
||||
self.upload_folder = upload_folder
|
||||
|
||||
def upload(self, file: bytes, filename: str) -> UploadResponse:
|
||||
"""
|
||||
上传图片到CloudFlare图床
|
||||
@@ -288,12 +522,16 @@ class CloudFlareImgBedUploader(ImageUploader):
|
||||
UploadError: 上传失败时抛出异常
|
||||
"""
|
||||
try:
|
||||
# 准备请求URL(添加认证码参数,如果存在)
|
||||
# 准备请求URL参数
|
||||
params = []
|
||||
if self.upload_folder:
|
||||
params.append(f"uploadFolder={self.upload_folder}")
|
||||
if self.auth_code:
|
||||
request_url = f"{self.api_url}?authCode={self.auth_code}&uploadNameType=origin"
|
||||
else:
|
||||
request_url = f"{self.api_url}?uploadNameType=origin"
|
||||
|
||||
params.append(f"authCode={self.auth_code}")
|
||||
params.append("uploadNameType=origin")
|
||||
|
||||
request_url = f"{self.api_url}?{'&'.join(params)}"
|
||||
|
||||
# 准备文件数据
|
||||
files = {
|
||||
"file": (filename, file)
|
||||
@@ -383,11 +621,21 @@ class ImageUploaderFactory:
|
||||
credentials["secret_key"]
|
||||
)
|
||||
elif provider == "picgo":
|
||||
api_url = credentials.get("api_url", "https://www.picgo.net/api/1/upload")
|
||||
api_url = credentials.get("api_url") or "https://www.picgo.net/api/1/upload"
|
||||
return PicGoUploader(credentials["api_key"], api_url)
|
||||
elif provider == "cloudflare_imgbed":
|
||||
return CloudFlareImgBedUploader(
|
||||
credentials["auth_code"],
|
||||
credentials["base_url"]
|
||||
credentials["base_url"],
|
||||
credentials.get("upload_folder", ""),
|
||||
)
|
||||
elif provider == "aliyun_oss":
|
||||
return AliyunOSSUploader(
|
||||
credentials["access_key"],
|
||||
credentials["access_key_secret"],
|
||||
credentials["bucket_name"],
|
||||
credentials["endpoint"],
|
||||
credentials["region"],
|
||||
credentials.get("use_internal", False)
|
||||
)
|
||||
raise ValueError(f"Unknown provider: {provider}")
|
||||
|
||||
@@ -36,4 +36,13 @@ services:
|
||||
interval: 10s # 每隔10秒检查一次
|
||||
timeout: 5s # 每次检查的超时时间为5秒
|
||||
retries: 3 # 重试3次失败后标记为 unhealthy
|
||||
start_period: 30s # 容器启动后等待30秒再开始第一次健康检查
|
||||
start_period: 30s # 容器启动后等待30秒再开始第一次健康检查
|
||||
# adminer:
|
||||
# image: adminer:latest
|
||||
# container_name: gemini-balance-adminer
|
||||
# restart: unless-stopped
|
||||
# ports:
|
||||
# - "8080:8080"
|
||||
# depends_on:
|
||||
# mysql:
|
||||
# condition: service_healthy
|
||||
1
tests/__init__.py
Normal file
1
tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Tests package
|
||||
187
tests/test_key_redaction.py
Normal file
187
tests/test_key_redaction.py
Normal file
@@ -0,0 +1,187 @@
|
||||
"""
|
||||
Unit tests for API key redaction functionality
|
||||
"""
|
||||
|
||||
import unittest
|
||||
import logging
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from app.utils.helpers import redact_key_for_logging
|
||||
from app.log.logger import AccessLogFormatter
|
||||
|
||||
|
||||
class TestKeyRedaction(unittest.TestCase):
|
||||
"""Test cases for the redact_key_for_logging function"""
|
||||
|
||||
def test_valid_long_key_redaction(self):
|
||||
"""Test redaction of valid long API keys"""
|
||||
# Test Google/Gemini API key
|
||||
# This value is a random generated string for testing
|
||||
gemini_key = "AIzaSyDhKGfJ8xYzQwErTyUiOpLkMnBvCxDfGhI"
|
||||
result = redact_key_for_logging(gemini_key)
|
||||
expected = "AIzaSy...xDfGhI"
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
# Test OpenAI API key
|
||||
# This value is a random generated string for testing
|
||||
openai_key = "sk-1234567890abcdef1234567890abcdef1234567890abcdef"
|
||||
result = redact_key_for_logging(openai_key)
|
||||
expected = "sk-123...abcdef"
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_short_key_handling(self):
|
||||
"""Test handling of short keys"""
|
||||
short_key = "short"
|
||||
result = redact_key_for_logging(short_key)
|
||||
self.assertEqual(result, "[SHORT_KEY]")
|
||||
|
||||
# Test exactly 12 characters (boundary case)
|
||||
boundary_key = "123456789012"
|
||||
result = redact_key_for_logging(boundary_key)
|
||||
self.assertEqual(result, "[SHORT_KEY]")
|
||||
|
||||
def test_empty_and_none_keys(self):
|
||||
"""Test handling of empty and None keys"""
|
||||
# Test empty string
|
||||
result = redact_key_for_logging("")
|
||||
self.assertEqual(result, "[INVALID_KEY]")
|
||||
|
||||
# Test None
|
||||
result = redact_key_for_logging(None)
|
||||
self.assertEqual(result, "[INVALID_KEY]")
|
||||
|
||||
def test_invalid_input_types(self):
|
||||
"""Test handling of invalid input types"""
|
||||
# Test integer
|
||||
result = redact_key_for_logging(123)
|
||||
self.assertEqual(result, "[INVALID_KEY]")
|
||||
|
||||
# Test list
|
||||
result = redact_key_for_logging(["key"])
|
||||
self.assertEqual(result, "[INVALID_KEY]")
|
||||
|
||||
# Test dict
|
||||
result = redact_key_for_logging({"key": "value"})
|
||||
self.assertEqual(result, "[INVALID_KEY]")
|
||||
|
||||
def test_boundary_cases(self):
|
||||
"""Test boundary cases for key length"""
|
||||
# Test 13 characters (just above the threshold)
|
||||
key_13 = "1234567890123"
|
||||
result = redact_key_for_logging(key_13)
|
||||
expected = "123456...890123"
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
# Test very long key
|
||||
long_key = "a" * 100
|
||||
result = redact_key_for_logging(long_key)
|
||||
expected = "aaaaaa...aaaaaa"
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
|
||||
class TestAccessLogFormatter(unittest.TestCase):
|
||||
"""Test cases for the AccessLogFormatter class"""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures"""
|
||||
self.formatter = AccessLogFormatter()
|
||||
|
||||
def test_gemini_key_redaction_in_url(self):
|
||||
"""Test redaction of Gemini API keys in URLs"""
|
||||
log_message = (
|
||||
'POST /verify-key/AIzaSyDhKGfJ8xYzQwErTyUiOpLkMnBvCxDfGhI HTTP/1.1" 200'
|
||||
)
|
||||
result = self.formatter._redact_api_keys_in_message(log_message)
|
||||
self.assertIn("AIzaSy...xDfGhI", result)
|
||||
self.assertNotIn("AIzaSyDhKGfJ8xYzQwErTyUiOpLkMnBvCxDfGhI", result)
|
||||
|
||||
def test_openai_key_redaction_in_url(self):
|
||||
"""Test redaction of OpenAI API keys in URLs"""
|
||||
log_message = 'GET /api/models?key=sk-1234567890abcdef1234567890abcdef1234567890abcdef HTTP/1.1" 200'
|
||||
result = self.formatter._redact_api_keys_in_message(log_message)
|
||||
self.assertIn("sk-123...abcdef", result)
|
||||
self.assertNotIn("sk-1234567890abcdef1234567890abcdef1234567890abcdef", result)
|
||||
|
||||
def test_multiple_keys_in_message(self):
|
||||
"""Test redaction of multiple API keys in a single message"""
|
||||
log_message = "Request with keys: AIzaSyDhKGfJ8xYzQwErTyUiOpLkMnBvCxDfGhI and sk-1234567890abcdef1234567890abcdef1234567890abcdef"
|
||||
result = self.formatter._redact_api_keys_in_message(log_message)
|
||||
self.assertIn("AIzaSy...xDfGhI", result)
|
||||
self.assertIn("sk-123...abcdef", result)
|
||||
self.assertNotIn("AIzaSyDhKGfJ8xYzQwErTyUiOpLkMnBvCxDfGhI", result)
|
||||
self.assertNotIn("sk-1234567890abcdef1234567890abcdef1234567890abcdef", result)
|
||||
|
||||
def test_no_keys_in_message(self):
|
||||
"""Test that messages without API keys are unchanged"""
|
||||
log_message = 'GET /api/health HTTP/1.1" 200'
|
||||
result = self.formatter._redact_api_keys_in_message(log_message)
|
||||
self.assertEqual(result, log_message)
|
||||
|
||||
def test_partial_key_patterns_not_redacted(self):
|
||||
"""Test that partial key patterns are not redacted"""
|
||||
log_message = "Message with partial patterns: AIza sk- incomplete"
|
||||
result = self.formatter._redact_api_keys_in_message(log_message)
|
||||
self.assertEqual(result, log_message)
|
||||
|
||||
def test_error_handling_in_redaction(self):
|
||||
"""Test error handling in the redaction process"""
|
||||
# Test by directly calling _redact_api_keys_in_message with a broken pattern
|
||||
original_patterns = self.formatter.compiled_patterns
|
||||
# Create a mock pattern that will raise an exception
|
||||
mock_pattern = MagicMock()
|
||||
mock_pattern.sub.side_effect = Exception("Regex error")
|
||||
self.formatter.compiled_patterns = [mock_pattern]
|
||||
|
||||
try:
|
||||
log_message = (
|
||||
'POST /verify-key/AIzaSyDhKGfJ8xYzQwErTyUiOpLkMnBvCxDfGhI HTTP/1.1" 200'
|
||||
)
|
||||
result = self.formatter._redact_api_keys_in_message(log_message)
|
||||
self.assertEqual(result, "[LOG_REDACTION_ERROR]")
|
||||
finally:
|
||||
# Restore original patterns
|
||||
self.formatter.compiled_patterns = original_patterns
|
||||
|
||||
def test_format_method(self):
|
||||
"""Test the format method of AccessLogFormatter"""
|
||||
# Create a mock log record
|
||||
record = MagicMock()
|
||||
record.getMessage.return_value = (
|
||||
'POST /verify-key/AIzaSyDhKGfJ8xYzQwErTyUiOpLkMnBvCxDfGhI HTTP/1.1" 200'
|
||||
)
|
||||
|
||||
# Mock the parent format method
|
||||
with patch(
|
||||
"logging.Formatter.format",
|
||||
return_value='2025-01-01 12:00:00 | INFO | POST /verify-key/AIzaSyDhKGfJ8xYzQwErTyUiOpLkMnBvCxDfGhI HTTP/1.1" 200',
|
||||
):
|
||||
result = self.formatter.format(record)
|
||||
self.assertIn("AIzaSy...xDfGhI", result)
|
||||
self.assertNotIn("AIzaSyDhKGfJ8xYzQwErTyUiOpLkMnBvCxDfGhI", result)
|
||||
|
||||
def test_regex_patterns_compilation(self):
|
||||
"""Test that regex patterns are properly compiled"""
|
||||
formatter = AccessLogFormatter()
|
||||
self.assertEqual(len(formatter.compiled_patterns), 2)
|
||||
self.assertTrue(
|
||||
all(hasattr(pattern, "sub") for pattern in formatter.compiled_patterns)
|
||||
)
|
||||
|
||||
def test_flexible_openai_pattern(self):
|
||||
"""Test the flexible OpenAI pattern matches various formats"""
|
||||
test_cases = [
|
||||
"sk-1234567890abcdef1234567890abcdef1234567890abcdef", # Standard 48 chars
|
||||
"sk-proj-1234567890abcdef1234567890abcdef1234567890abcdef", # Project key
|
||||
"sk-1234567890abcdef_1234567890abcdef-1234567890abcdef", # With underscores/hyphens
|
||||
"sk-12345678901234567890", # Shorter key (20 chars)
|
||||
]
|
||||
|
||||
for test_key in test_cases:
|
||||
log_message = f"Request with key: {test_key}"
|
||||
result = self.formatter._redact_api_keys_in_message(log_message)
|
||||
self.assertNotIn(test_key, result)
|
||||
self.assertIn("sk-", result) # Should still contain the prefix
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user