Compare commits

...

84 Commits

Author SHA1 Message Date
snaily
ebfa1d247c chore: 更新版本号至2.2.1 2025-07-20 13:43:18 +08:00
snaily
cdb85ef9b7 fix: 更新思考预算配置,使用_get_real_model函数获取真实模型 2025-07-20 13:42:36 +08:00
snaily
7006522c13 feat: 修复gemini-cli调用工具失败的问题 2025-07-20 13:19:59 +08:00
snaily
530c958afc chore: 更新版本号至2.2.0 2025-07-20 01:47:35 +08:00
snaily
57d861b578 feat: 增加URL上下文理解功能
本次提交引入了一项新功能,允许模型在对话中理解所提供URL的上下文。

主要变更:

- **配置**:新增了 `URL_CONTEXT_ENABLED` 和 `URL_CONTEXT_MODELS` 两个配置项,用于控制此功能的开关和支持的模型列表。
- **后端服务**:在 `gemini_chat_service`、`openai_chat_service` 和 `vertex_express_chat_service` 中,为支持的模型动态添加 `urlContext` 工具。
- **前端界面**:在配置编辑器页面增加了相应的UI控件,方便用户启用/禁用该功能并管理支持的模型列表。
- **文档**:更新了 `.env.example`、`README.md` 和 `README_ZH.md`,包含了新配置项的说明。
2025-07-20 01:46:18 +08:00
snaily
99664298b9 fix: 更新思考配置,针对gemini-2.5-pro模型设置思考预算为128 2025-07-19 22:20:55 +08:00
snaily
a6fe5a7022 fix: 更新思考模型预算说明,使用-1表示自动预算 2025-07-19 22:11:36 +08:00
snaily
1918dad602 chore: 更新版本号至2.1.13 2025-07-19 15:24:36 +08:00
snaily
69399c291e fix: 在密钥验证成功时重置失败计数 2025-07-19 10:49:09 +08:00
snaily
9ec33ce320 fix: 为API_KEYS和ALLOWED_TOKENS添加默认值 2025-07-19 09:31:59 +08:00
snaily
c35d3aff7d chore: 更新版本号至2.1.12 2025-07-19 01:39:32 +08:00
snaily
2a5744d1c4 fix: 移除请求payload构建中的by_alias参数 2025-07-19 01:38:48 +08:00
snaily
825511506b chore: 更新版本号至2.1.11 2025-07-19 00:41:08 +08:00
snaily
5a98a701cb fix: 修复生成配置字段名称以符合API要求 2025-07-19 00:40:44 +08:00
snaily
dd1fa35c73 chore: 更新版本号至2.1.10 2025-07-18 22:34:46 +08:00
snaily
fb572fa849 chore: 移除不必要的json导入 2025-07-18 22:33:46 +08:00
snaily
c0a473ed19 Merge branch 'pr/hewenyu/220' 2025-07-18 16:47:22 +08:00
snaily
030641adc6 chore: 移除不必要的环境变量配置 2025-07-18 16:39:23 +08:00
hewenyu
445ef49dc8 fix # 219
修复token的问题
2025-07-18 10:50:52 +08:00
snaily
32d4c60541 fix: 修正Callirhoe拼写错误为Callirrhoe
refactor: 优化常量格式,提升可读性
2025-07-17 22:07:18 +08:00
snaily
23f865be07 Merge branch 'pr/cxyfer/200' 2025-07-17 21:30:35 +08:00
cxyfer
5d55325c12 refactor: Centralize API base URL and clean up
Replaces hardcoded Google API base URLs with `settings.BASE_URL` for improved configurability and maintainability across services.

Removes unused imports and variables from various modules to reduce code bloat and enhance readability.
2025-07-16 04:50:55 +08:00
zzh2632185
900330509a Delete .augment-guidelines 2025-07-16 01:26:38 +09:00
zzh
cfb682ae3c 修复parts的错误 2025-07-16 01:25:51 +09:00
zzh
abae90b16d 删除冗余代码 2025-07-16 00:13:30 +09:00
zzh
470fc37f26 普通文本生成 案例模型修改。 2025-07-15 18:47:50 +09:00
zzh
7a7caef1a6 修改README.md对openai兼容tts的案例支持 2025-07-15 18:39:24 +09:00
zzh
a6aecb5d89 添加对gemini原生格式TTS的支持 2025-07-15 18:04:16 +09:00
zzh
4a004f9aa1 删掉多余提交的内容。 2025-07-15 16:00:56 +09:00
zzh
1a6feae23b Update multi-speaker TTS README
- Reflect current smart detection implementation
- Remove outdated ENABLE_TTS environment variable references
- Add TTS systems comparison table
- Update usage examples with correct URLs
- Add intelligent routing flowchart
- Clarify zero-configuration approach
- Update feature list to match current implementation
2025-07-15 15:55:47 +09:00
zzh
af5b2fa2c9 Clean up TTS module dependencies
- Remove references to deleted tts_config.py
- Simplify tts_routes.py to directly return TTSGeminiChatService
- Update __init__.py imports
- Prepare for multi-speaker TTS testing
2025-07-15 15:44:55 +09:00
zzh
eeec45274b Implement smart multi-speaker TTS detection
- Only activate multi-speaker TTS when multiSpeakerVoiceConfig is present
- Preserve original TTS functionality for single-speaker requests
- Support dynamic model selection from user request
- Add fallback mechanism to standard service if multi-speaker TTS fails
- Maintain full backward compatibility with existing TTS systems
2025-07-15 15:43:12 +09:00
zzh
2b48c853fe Refactor: Use TTS service only for TTS models, keep original service for others
- Remove ENABLE_TTS environment variable dependency
- Detect TTS models dynamically by model name
- Use TTS-enhanced service only when needed
- Fallback to standard service if TTS processing fails
- Maintain full backward compatibility
2025-07-15 15:34:55 +09:00
zzh
c47f696691 Merge branch 'main' of https://github.com/zzh2632185/gemini-balance 2025-07-15 15:05:54 +09:00
zzh
9a8e4c8e15 Fix TTS payload - remove tools and safetySettings for TTS requests 2025-07-15 15:05:40 +09:00
zzh2632185
24aab9a658 Delete .augment-guidelines 2025-07-15 15:05:39 +09:00
zzh
afdaaffac5 Trigger Docker build - Add TTS functionality description 2025-07-15 14:46:15 +09:00
zzh
fe721116e2 添加对gemini多人语音功能的支持 2025-07-15 14:39:33 +09:00
cxyfer
8e0a834daa fix: Fix datetime.timezone AttributeError in file cleanup
- Change datetime.timezone.utc to timezone.utc in services.py
- Resolves error: 'type object datetime.datetime has no attribute timezone'
2025-07-12 08:48:46 +08:00
cxyfer
c9fca1561c Merge remote-tracking branch 'origin/main' into feature/upload-compatibility 2025-07-12 03:36:46 +08:00
cxyfer
5eb2dfd822 feat: Add Files API support with upload, list, get and delete operations
- Implement complete Files API compatible with Gemini API format
- Support resumable file uploads with chunked transfer (tested with 15MB video)
- Create file management service with database tracking
- Add file domain models and API request/response objects
- Implement file routes with proper authentication
- Use fixed API key for Files API requests (due to Google API restrictions)
- Support file state management (PROCESSING, ACTIVE, FAILED)
- Add scheduled task for automatic expired file cleanup
- Integrate seamlessly with existing key management and load balancing
2025-07-12 03:33:39 +08:00
snaily
0b837c3f80 chore: 更新版本号至 2.1.9 2025-07-10 21:33:54 +08:00
snaily
a6cfc12443 feat: 更新响应处理逻辑以支持推理内容
- 修改了 response_handler.py 中的 _handle_openai_stream_response 和 _handle_openai_normal_response 方法,增加了对推理内容 (reasoning_content) 的支持。
- 更新了 _extract_result 方法的返回值,确保能够提取推理内容。
- 在 gemini_chat_service.py 和 openai_chat_service.py 中,调整了生成配置以包含思考过程的选项。
- 在 vertex_express_chat_service.py 中,增强了对客户端思考配置的处理逻辑,确保优先使用客户端提供的配置。
2025-07-10 21:21:55 +08:00
snaily
f6d64dd850 feat: 添加 TTS 语音名称常量并更新 TTS 服务逻辑
- 在 constants.py 中新增 TTS_VOICE_NAMES 列表,包含多个语音名称。
- 更新 tts_service.py 中的语音配置逻辑,确保使用请求中的语音名称(如果有效),否则回退到默认配置。
2025-07-10 01:03:20 +08:00
snaily
eed62caa78 refactor: 移除 ApiClient 中的 count_tokens 抽象方法
- 从 ApiClient 类中删除了 count_tokens 方法的抽象定义,以简化接口。
2025-07-10 00:53:06 +08:00
ripper
204d41d6f3 feat: add JSON Schema cleaning function to remove unsupported fields in Gemini API 2025-07-09 10:29:42 +08:00
ripper
858df0548e fix: ensure generationConfig is not None in payload 2025-07-09 10:17:32 +08:00
snaily
b3da021803 refactor: 优化配置解析逻辑,增强对泛型类型的支持
- 在 config.py 中引入 get_args 和 get_origin 函数,以更好地处理 List 和 Dict 类型的解析。
- 更新了对 List[str] 和 List[Dict[str, str]] 的解析逻辑,增加了错误处理和日志记录。
- 在 keys_status.js 中将 filterValidKeys 函数替换为 filterAndSearchValidKeys,保留旧函数以避免破坏潜在的遗留调用。
- 在 keys_status.html 中新增选项以支持更多项目选择。
2025-07-08 16:35:56 +08:00
snaily
d234f826f4 chore: 更新 Vertex API 相关注释和正则表达式为 Vertex Express API,确保一致性和准确性。修改了多个文件中的相关描述和提示信息,以反映 API 名称的变化。 2025-07-08 15:27:16 +08:00
snaily
231b69ecf8 feat: 添加自定义 Headers 功能
- 在配置中添加 `CUSTOM_HEADERS` 选项,允许用户定义全局请求头。
- 更新 API 客户端,将自定义 `header` 应用于所有出站请求。
- 在配置页面上为 `CUSTOM_HEADERS` 添加了完整的前端编辑功能。
2025-07-08 13:58:05 +08:00
snaily
0a08913677 Merge pull request #183 from liucong2013/feature/count-tokens-compatibility 2025-07-07 17:24:45 +08:00
snaily
49d32813ea chore: 更新 GitHub Actions 工作流以生成发布说明
- 修改了版本标签的引号格式
- 添加了生成发布说明的步骤
- 更新了创建发布的步骤以包含发布说明
- 调整了步骤的顺序和注释
2025-07-07 14:45:07 +08:00
snaily
c5d57e97b1 chore: 更新版本号至2.1.8 2025-07-07 14:21:41 +08:00
lc631017672
da8f7539a1 Fix: Handle empty parts in CountTokensRequest and improve payload filtering 2025-07-07 14:13:16 +08:00
lc631017672
64a68f1176 refactor: Remove debug logging for security checks 2025-07-07 10:27:48 +08:00
lc631017672
1199d7cc3c feat: Add support for countTokens API and improve debug logging 2025-07-07 10:08:57 +08:00
ry
8a827d2acb feat: 支持CloudFlare图床自定义上传文件夹路径
- 新增CLOUDFLARE_IMGBED_UPLOAD_FOLDER环境变量配置
- 用户可通过该配置项指定图片在CloudFlare图床中的上传路径
2025-07-05 23:32:45 +08:00
snaily
0e8a943d7f chore:更新 README 和 README_ZH 文件,调整徽章的 HTML 结构,使其居中显示。 2025-07-05 16:49:57 +08:00
snaily
4f62658440 Update README.md 2025-07-05 16:39:18 +08:00
snaily
6e7c3d5f6a Update README.md 2025-07-05 16:38:35 +08:00
snaily
d5062db9b6 Update README_ZH.md 2025-07-05 16:27:20 +08:00
snaily
a6ad006a49 Update README.md 2025-07-05 16:26:59 +08:00
snaily
57d593fa17 chore: 更新版本号至2.1.7 2025-07-05 00:48:50 +08:00
snaily
f38b5ae870 feat: 添加TTS相关配置和功能
- 在.env.example中添加TTS模型、语音名称和语速的配置选项
- 更新README文件,增加TTS相关配置的说明
- 在配置类中添加TTS相关设置
- 新增TTS请求模型以支持文本转语音功能
- 更新智能路由中间件以支持音频请求
- 在路由中添加处理TTS请求的API接口
- 更新前端配置编辑器以支持TTS配置选项
2025-07-05 00:47:55 +08:00
snaily
418b3ca13c Merge branch 'pr/BigLiao/172' 2025-07-03 23:44:02 +08:00
jesonliao
09bfa85e69 fix: 修复api中对role的校验
官方给的demo是不传role的
2025-07-03 23:08:31 +08:00
jesonliao
62b132208b fix: 修复数据库密码中包含特殊字符串时的问题 2025-07-03 22:23:47 +08:00
snaily
fc28f4f74e Merge branch 'pr/chinrain/167' 2025-07-03 17:28:58 +08:00
snaily
f79a52f839 fix:优化智能路由中间件,增强URL处理逻辑
- 增加对新路径模式的支持,包括对`v1beta/models`的处理
- 统一日志记录格式,提升调试信息的可读性
- 规范化代码风格,确保一致性和可维护性
- 修复了请求体和查询参数的模型名称提取逻辑
2025-07-03 17:25:50 +08:00
chinrain
94d1041961 Merge branch 'feat/AutoRoute' of https://github.com/chinrain/gemini-balance into feat/AutoRoute 2025-07-03 03:05:39 +08:00
chinrain
ada32d526a refactor: 简化智能路由中间件,优化混合格式URL处理
- 重构智能路由逻辑,在保证聊天的同时尽量简化
- 只会修改常见错误,其余的透传(方便以后维护或者不用维护)
- 常见错误都能正常聊天
- 统一前端样式
2025-07-03 03:01:10 +08:00
snaily
ef1e38aba1 fix: 在智能路由中间件中添加对请求体的JSON解析异常处理,确保在提取模型时的稳定性 2025-07-03 00:56:57 +08:00
snaily
60b2d59e25 fix:修正Gemini路径模式,移除末尾的斜杠以确保路径匹配的一致性 2025-07-03 00:45:11 +08:00
chinrain
e18aa73456 添加gemini前缀模型列表 2025-07-02 23:52:03 +08:00
chinrain
24747a5f09 移除重复配置 2025-07-02 23:41:48 +08:00
chinrain
621dac22dc Merge remote-tracking branch 'origin/main' into feat/AutoRoute 2025-07-01 02:41:18 +08:00
chinrain
23d7004b60 - 增加vertex-express支持
- 移除了不必要的判断流式请求的方法
2025-07-01 02:25:32 +08:00
snaily
c3b3d34127 Merge branch 'pr/stevessr/160' 2025-06-30 23:54:42 +08:00
chchchchc1023
18a166afb0 feat: 添加智能路由中间件,支持API路径自动规范化
- 新增SmartRoutingMiddleware智能路由中间件
- 支持OpenAI/HF/Gemini/默认格式的自动检测和转换
- 修复错误URL路径格式,提升API兼容性
- 添加URL_NORMALIZATION_ENABLED配置开关,默认关闭
- 智能路由功能默认关闭,需手动启用
2025-06-30 22:58:58 +08:00
stevessr
a41447a96d fix: 更新 thinkingBudget 的最大值限制至32767 , 最小值为 -1 2025-06-30 20:43:27 +08:00
Wangnov
df8d543539 删除ruff导致的格式化换行 2025-06-30 17:52:10 +08:00
Wangnov
5ecce8e0fe fix: 使用Union替代类型注解中的管道符号,使python3.9版本不报错 2025-06-30 17:37:02 +08:00
snaily
00f423a622 Update README.md 2025-06-28 00:00:22 +08:00
snaily
05ce04de69 Update README.md 2025-06-27 23:49:05 +08:00
47 changed files with 4234 additions and 575 deletions

View File

@@ -20,6 +20,9 @@ THINKING_BUDGET_MAP={"gemini-2.5-flash-preview-04-17": 4000}
IMAGE_MODELS=["gemini-2.0-flash-exp"]
SEARCH_MODELS=["gemini-2.0-flash-exp","gemini-2.0-pro-exp"]
FILTERED_MODELS=["gemini-1.0-pro-vision-latest", "gemini-pro-vision", "chat-bison-001", "text-bison-001", "embedding-gecko-001"]
# 是否启用网址上下文,默认启用
URL_CONTEXT_ENABLED=true
URL_CONTEXT_MODELS=["gemini-2.5-pro","gemini-2.5-flash","gemini-2.5-flash-lite","gemini-2.0-flash","gemini-2.0-flash-live-001"]
TOOLS_CODE_EXECUTION_ENABLED=false
SHOW_SEARCH_LINK=true
SHOW_THINKING_PROCESS=true
@@ -43,6 +46,7 @@ SMMS_SECRET_TOKEN=XXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
PICGO_API_KEY=xxxx
CLOUDFLARE_IMGBED_URL=https://xxxxxxx.pages.dev/upload
CLOUDFLARE_IMGBED_AUTH_CODE=xxxxxxxxx
CLOUDFLARE_IMGBED_UPLOAD_FOLDER=
##########################################################################
#########################stream_optimizer 相关配置########################
STREAM_OPTIMIZER_ENABLED=false
@@ -74,3 +78,16 @@ FAKE_STREAM_EMPTY_DATA_INTERVAL_SECONDS=5
# 安全设置 (JSON 字符串格式)
# 注意:这里的示例值可能需要根据实际模型支持情况调整
SAFETY_SETTINGS=[{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"}, {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"}, {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"}, {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"}, {"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"}]
URL_NORMALIZATION_ENABLED=false
# tts配置
TTS_MODEL=gemini-2.5-flash-preview-tts
TTS_VOICE_NAME=Zephyr
TTS_SPEED=normal
#########################Files API 相关配置########################
# 是否启用文件过期自动清理
FILES_CLEANUP_ENABLED=true
# 文件过期清理间隔(小时)
FILES_CLEANUP_INTERVAL_HOURS=1
# 是否启用用户文件隔离(每个用户只能看到自己上传的文件)
FILES_USER_ISOLATION_ENABLED=true
##########################################################################

View File

@@ -3,7 +3,7 @@ name: Publish Release
on:
push:
tags:
- 'v*' # 当推送以 "v" 开头的标签时触发(如 v1.0.0, v2.1.0
- "v*" # 当推送以 "v" 开头的标签时触发(如 v1.0.0, v2.1.0
jobs:
update-release-draft:
@@ -15,8 +15,17 @@ jobs:
# Step 1: 检出代码库
- name: Checkout code
uses: actions/checkout@v3
with:
fetch-depth: 0
# Step 2: 自动生成 Release
# Step 2: 自动生成 Release Notes
- name: Generate release notes
id: changelog
uses: mikepenz/release-changelog-builder-action@v4
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
# Step 3: 自动生成 Release
- name: Create Release
id: create_release
uses: actions/create-release@v1
@@ -25,15 +34,16 @@ jobs:
with:
tag_name: ${{ github.ref_name }}
release_name: ${{ github.ref_name }}
body: ${{ steps.changelog.outputs.changelog }}
draft: false
prerelease: false
# Step 3: 可选构建zip文件
# Step 4: 可选构建zip文件
- name: Create ZIP file
run: |
zip -r gemini-balance.zip . -x "*.git*" "*.github*" "*.env*" "logs/*" "tests/*"
# Step 4: 可选,上传构建文件
# Step 5: 可选,上传构建文件
- name: Upload Release Asset
uses: actions/upload-release-asset@v1
env:
@@ -41,5 +51,5 @@ jobs:
with:
upload_url: ${{ steps.create_release.outputs.upload_url }}
asset_path: ./gemini-balance.zip # 替换为你的构建文件路径
asset_name: gemini-balance.zip # 替换为你的文件名
asset_name: gemini-balance.zip # 替换为你的文件名
asset_content_type: application/zip

View File

@@ -8,12 +8,6 @@ COPY ./VERSION /app
RUN pip install --no-cache-dir -r requirements.txt
COPY ./app /app/app
ENV API_KEYS='["your_api_key_1"]'
ENV ALLOWED_TOKENS='["your_token_1"]'
ENV BASE_URL=https://generativelanguage.googleapis.com/v1beta
ENV TOOLS_CODE_EXECUTION_ENABLED=false
ENV IMAGE_MODELS='["gemini-2.0-flash-exp"]'
ENV SEARCH_MODELS='["gemini-2.0-flash-exp","gemini-2.0-pro-exp"]'
# Expose port
EXPOSE 8000

136
README.md
View File

@@ -2,6 +2,12 @@
# Gemini Balance - Gemini API Proxy and Load Balancer
<p align="center">
<a href="https://trendshift.io/repositories/13692" target="_blank">
<img src="https://trendshift.io/api/badge/repositories/13692" alt="snailyp%2Fgemini-balance | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/>
</a>
</p>
> ⚠️ This project is licensed under the CC BY-NC 4.0 (Attribution-NonCommercial) license. Any form of commercial resale service is prohibited. See the LICENSE file for details.
> I have never sold this service on any platform. If you encounter someone selling this service, they are definitely a reseller. Please be careful not to be deceived.
@@ -11,7 +17,7 @@
[![Uvicorn](https://img.shields.io/badge/Uvicorn-running-purple.svg)](https://www.uvicorn.org/)
[![Telegram Group](https://img.shields.io/badge/Telegram-Group-blue.svg?logo=telegram)](https://t.me/+soaHax5lyI0wZDVl)
> Telegram Group: https://t.me/+soaHax5lyI0wZDVl
> Telegram Group: <https://t.me/+soaHax5lyI0wZDVl>
## Project Introduction
@@ -40,39 +46,39 @@ app/
## ✨ Feature Highlights
* **Multi-Key Load Balancing**: Supports configuring multiple Gemini API Keys (`API_KEYS`) for automatic sequential polling, improving availability and concurrency.
* **Visual Configuration Takes Effect Immediately**: Configurations modified through the admin backend take effect without restarting the service. Remember to click save for changes to apply.
* **Multi-Key Load Balancing**: Supports configuring multiple Gemini API Keys (`API_KEYS`) for automatic sequential polling, improving availability and concurrency.
* **Visual Configuration Takes Effect Immediately**: Configurations modified through the admin backend take effect without restarting the service. Remember to click save for changes to apply.
![Configuration Panel](files/image4.png)
* **Dual Protocol API Compatibility**: Supports forwarding CHAT API requests in both Gemini and OpenAI formats.
* **Dual Protocol API Compatibility**: Supports forwarding CHAT API requests in both Gemini and OpenAI formats.
```plaintext
openai baseurl `http://localhost:8000(/hf)/v1`
gemini baseurl `http://localhost:8000(/gemini)/v1beta`
```
* **Supports Image-Text Chat and Image Modification**: `IMAGE_MODELS` configures which models can perform image-text chat and image editing. When actually calling, use the `configured_model-image` model name to use this feature.
* **Supports Image-Text Chat and Image Modification**: `IMAGE_MODELS` configures which models can perform image-text chat and image editing. When actually calling, use the `configured_model-image` model name to use this feature.
![Chat with Image Generation](files/image6.png)
![Modify Image](files/image7.png)
* **Supports Web Search**: Supports web search. `SEARCH_MODELS` configures which models can perform web searches. When actually calling, use the `configured_model-search` model name to use this feature.
* **Supports Web Search**: Supports web search. `SEARCH_MODELS` configures which models can perform web searches. When actually calling, use the `configured_model-search` model name to use this feature.
![Web Search](files/image8.png)
* **Key Status Monitoring**: Provides a `/keys_status` page (requires authentication) to view the status and usage of each Key in real-time.
* **Key Status Monitoring**: Provides a `/keys_status` page (requires authentication) to view the status and usage of each Key in real-time.
![Monitoring Panel](files/image.png)
* **Detailed Logging**: Provides detailed error logs for easy troubleshooting.
* **Detailed Logging**: Provides detailed error logs for easy troubleshooting.
![Call Details](files/image1.png)
![Log List](files/image2.png)
![Log Details](files/image3.png)
* **Support for Custom Gemini Proxy**: Supports custom Gemini proxies, such as those built on Deno or Cloudflare.
* **OpenAI Image Generation API Compatibility**: Adapts the `imagen-3.0-generate-002` model interface to be compatible with the OpenAI image generation API, supporting client calls.
* **Flexible Key Addition**: Flexible way to add keys using regex matching for `gemini_key`, with key deduplication.
* **Support for Custom Gemini Proxy**: Supports custom Gemini proxies, such as those built on Deno or Cloudflare.
* **OpenAI Image Generation API Compatibility**: Adapts the `imagen-3.0-generate-002` model interface to be compatible with the OpenAI image generation API, supporting client calls.
* **Flexible Key Addition**: Flexible way to add keys using regex matching for `gemini_key`, with key deduplication.
![Add Key](files/image5.png)
* **OpenAI Format Embeddings API Compatibility**: Perfectly adapts to the OpenAI format `embeddings` interface, usable for local document vectorization.
* **Streamlined Response Optimization**: Optional stream output optimizer (`STREAM_OPTIMIZER_ENABLED`) to improve the experience of long-text stream responses.
* **Failure Retry and Key Management**: Automatically handles API request failures, retries (`MAX_RETRIES`), automatically disables Keys after too many failures (`MAX_FAILURES`), and periodically checks for recovery (`CHECK_INTERVAL_HOURS`).
* **Docker Support**: Supports AMD and ARM architecture Docker deployments. You can also build your own Docker image.
* **OpenAI Format Embeddings API Compatibility**: Perfectly adapts to the OpenAI format `embeddings` interface, usable for local document vectorization.
* **Streamlined Response Optimization**: Optional stream output optimizer (`STREAM_OPTIMIZER_ENABLED`) to improve the experience of long-text stream responses.
* **Failure Retry and Key Management**: Automatically handles API request failures, retries (`MAX_RETRIES`), automatically disables Keys after too many failures (`MAX_FAILURES`), and periodically checks for recovery (`CHECK_INTERVAL_HOURS`).
* **Docker Support**: Supports AMD and ARM architecture Docker deployments. You can also build your own Docker image.
> Image address: docker pull ghcr.io/snailyp/gemini-balance:latest
* **Automatic Model List Maintenance**: Supports fetching OpenAI and Gemini model lists, perfectly compatible with NewAPI's automatic model list fetching, no manual entry required.
* **Support for Removing Unused Models**: Too many default models are provided, many of which are not used. You can filter them out using `FILTERED_MODELS`.
* **Proxy Support**: Supports configuring HTTP/SOCKS5 proxy servers (`PROXIES`) for accessing the Gemini API, convenient for use in special network environments. Supports batch adding proxies.
* **Automatic Model List Maintenance**: Supports fetching OpenAI and Gemini model lists, perfectly compatible with NewAPI's automatic model list fetching, no manual entry required.
* **Support for Removing Unused Models**: Too many default models are provided, many of which are not used. You can filter them out using `FILTERED_MODELS`.
* **Proxy Support**: Supports configuring HTTP/SOCKS5 proxy servers (`PROXIES`) for accessing the Gemini API, convenient for use in special network environments. Supports batch adding proxies.
## 🚀 Quick Start
@@ -80,79 +86,83 @@ app/
#### a) Build with Dockerfile
1. **Build Image**:
1. **Build Image**:
```bash
docker build -t gemini-balance .
```
2. **Run Container**:
2. **Run Container**:
```bash
docker run -d -p 8000:8000 --env-file .env gemini-balance
```
* `-d`: Run in detached mode.
* `-p 8000:8000`: Map port 8000 of the container to port 8000 of the host.
* `--env-file .env`: Use the `.env` file to set environment variables.
* `-d`: Run in detached mode.
* `-p 8000:8000`: Map port 8000 of the container to port 8000 of the host.
* `--env-file .env`: Use the `.env` file to set environment variables.
> Note: If using an SQLite database, you need to mount a data volume to persist
> Note: If using an SQLite database, you need to mount a data volume to persist
>
> ```bash
> docker run -d -p 8000:8000 --env-file .env -v /path/to/data:/app/data gemini-balance
> ```
>
> Where `/path/to/data` is the data storage path on the host, and `/app/data` is the data directory inside the container.
#### b) Deploy with an Existing Docker Image
1. **Pull Image**:
1. **Pull Image**:
```bash
docker pull ghcr.io/snailyp/gemini-balance:latest
```
2. **Run Container**:
2. **Run Container**:
```bash
docker run -d -p 8000:8000 --env-file .env ghcr.io/snailyp/gemini-balance:latest
```
* `-d`: Run in detached mode.
* `-p 8000:8000`: Map port 8000 of the container to port 8000 of the host (adjust as needed).
* `--env-file .env`: Use the `.env` file to set environment variables (ensure the `.env` file exists in the directory where the command is executed).
* `-d`: Run in detached mode.
* `-p 8000:8000`: Map port 8000 of the container to port 8000 of the host (adjust as needed).
* `--env-file .env`: Use the `.env` file to set environment variables (ensure the `.env` file exists in the directory where the command is executed).
> Note: If using an SQLite database, you need to mount a data volume to persist
> Note: If using an SQLite database, you need to mount a data volume to persist
>
> ```bash
> docker run -d -p 8000:8000 --env-file .env -v /path/to/data:/app/data ghcr.io/snailyp/gemini-balance:latest
> ```
>
> Where `/path/to/data` is the data storage path on the host, and `/app/data` is the data directory inside the container.
### Run Locally (Suitable for Development and Testing)
If you want to run the source code directly locally for development or testing, follow these steps:
1. **Ensure Prerequisites are Met**:
* Clone the repository locally.
* Install Python 3.9 or higher.
* Create and configure the `.env` file in the project root directory (refer to the "Configure Environment Variables" section above).
* Install project dependencies:
1. **Ensure Prerequisites are Met**:
* Clone the repository locally.
* Install Python 3.9 or higher.
* Create and configure the `.env` file in the project root directory (refer to the "Configure Environment Variables" section above).
* Install project dependencies:
```bash
pip install -r requirements.txt
```
2. **Start Application**:
2. **Start Application**:
Run the following command in the project root directory:
```bash
uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload
```
* `app.main:app`: Specifies the location of the FastAPI application instance (the `app` object in the `main.py` file within the `app` module).
* `--host 0.0.0.0`: Makes the application accessible from any IP address on the local network.
* `--port 8000`: Specifies the port number the application listens on (you can change this as needed).
* `--reload`: Enables automatic reloading. When you modify the code, the service will automatically restart, which is very suitable for development environments (remove this option in production environments).
* `app.main:app`: Specifies the location of the FastAPI application instance (the `app` object in the `main.py` file within the `app` module).
* `--host 0.0.0.0`: Makes the application accessible from any IP address on the local network.
* `--port 8000`: Specifies the port number the application listens on (you can change this as needed).
* `--reload`: Enables automatic reloading. When you modify the code, the service will automatically restart, which is very suitable for development environments (remove this option in production environments).
3. **Access Application**:
3. **Access Application**:
After the application starts, you can access `http://localhost:8000` (or the host and port you specified) through a browser or API tool.
### Complete Configuration List
@@ -181,6 +191,9 @@ If you want to run the source code directly locally for development or testing,
| `SHOW_THINKING_PROCESS` | Optional, whether to display the model's thinking process | `true` |
| `THINKING_MODELS` | Optional, list of models that support thinking functions | `[]` |
| `THINKING_BUDGET_MAP` | Optional, thinking function budget mapping (model_name:budget_value) | `{}` |
| `URL_NORMALIZATION_ENABLED` | Optional, whether to enable intelligent URL routing mapping | `false` |
| `URL_CONTEXT_ENABLED` | Optional, whether to enable URL context understanding | `false` |
| `URL_CONTEXT_MODELS` | Optional, list of models that support URL context understanding | `[]` |
| `BASE_URL` | Optional, Gemini API base URL, no modification needed by default | `https://generativelanguage.googleapis.com/v1beta` |
| `MAX_FAILURES` | Optional, number of times a single key is allowed to fail | `3` |
| `MAX_RETRIES` | Optional, maximum number of retries for failed API requests | `3` |
@@ -194,6 +207,10 @@ If you want to run the source code directly locally for development or testing,
| `AUTO_DELETE_REQUEST_LOGS_ENABLED`| Optional, whether to enable automatic deletion of request logs | `false` |
| `AUTO_DELETE_REQUEST_LOGS_DAYS` | Optional, automatically delete request logs older than this many days (e.g., 1, 7, 30) | `30` |
| `SAFETY_SETTINGS` | Optional, safety settings (JSON string format), used to configure content safety thresholds. Example values may need adjustment based on actual model support. | `[{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"}, {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"}, {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"}, {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"}, {"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"}]` |
| **TTS Related** | | |
| `TTS_MODEL` | Optional, TTS model name | `gemini-2.5-flash-preview-tts` |
| `TTS_VOICE_NAME` | Optional, TTS voice name | `Zephyr` |
| `TTS_SPEED` | Optional, TTS speed | `normal` |
| **Image Generation Related** | | |
| `PAID_KEY` | Optional, paid API Key for advanced features like image generation | `your-paid-api-key` |
| `CREATE_IMAGE_MODEL` | Optional, image generation model | `imagen-3.0-generate-002` |
@@ -202,6 +219,7 @@ If you want to run the source code directly locally for development or testing,
| `PICGO_API_KEY` | Optional, API Key for [PicoGo](https://www.picgo.net/) image hosting | `your-picogo-apikey` |
| `CLOUDFLARE_IMGBED_URL` | Optional, [CloudFlare](https://github.com/MarSeventh/CloudFlare-ImgBed) image hosting upload address | `https://xxxxxxx.pages.dev/upload` |
| `CLOUDFLARE_IMGBED_AUTH_CODE` | Optional, authentication key for CloudFlare image hosting | `your-cloudflare-imgber-auth-code` |
| `CLOUDFLARE_IMGBED_UPLOAD_FOLDER` | Optional, upload folder path for CloudFlare image hosting | `""` |
| **Stream Optimizer Related** | | |
| `STREAM_OPTIMIZER_ENABLED` | Optional, whether to enable stream output optimization | `false` |
| `STREAM_MIN_DELAY` | Optional, minimum delay for stream output | `0.016` |
@@ -219,20 +237,20 @@ The following are the main API endpoints provided by the service:
### Gemini API Related (`(/gemini)/v1beta`)
* `GET /models`: List available Gemini models.
* `POST /models/{model_name}:generateContent`: Generate content using the specified Gemini model.
* `POST /models/{model_name}:streamGenerateContent`: Stream content generation using the specified Gemini model.
* `GET /models`: List available Gemini models.
* `POST /models/{model_name}:generateContent`: Generate content using the specified Gemini model.
* `POST /models/{model_name}:streamGenerateContent`: Stream content generation using the specified Gemini model.
### OpenAI API Related
* `GET (/hf)/v1/models`: List available models (uses Gemini format underneath).
* `POST (/hf)/v1/chat/completions`: Perform chat completion (uses Gemini format underneath, supports streaming).
* `POST (/hf)/v1/embeddings`: Create text embeddings (uses Gemini format underneath).
* `POST (/hf)/v1/images/generations`: Generate images (uses Gemini format underneath).
* `GET /openai/v1/models`: List available models (uses OpenAI format underneath).
* `POST /openai/v1/chat/completions`: Perform chat completion (uses OpenAI format underneath, supports streaming, can prevent truncation, and is faster).
* `POST /openai/v1/embeddings`: Create text embeddings (uses OpenAI format underneath).
* `POST /openai/v1/images/generations`: Generate images (uses OpenAI format underneath).
* `GET (/hf)/v1/models`: List available models (uses Gemini format underneath).
* `POST (/hf)/v1/chat/completions`: Perform chat completion (uses Gemini format underneath, supports streaming).
* `POST (/hf)/v1/embeddings`: Create text embeddings (uses Gemini format underneath).
* `POST (/hf)/v1/images/generations`: Generate images (uses Gemini format underneath).
* `GET /openai/v1/models`: List available models (uses OpenAI format underneath).
* `POST /openai/v1/chat/completions`: Perform chat completion (uses OpenAI format underneath, supports streaming, can prevent truncation, and is faster).
* `POST /openai/v1/embeddings`: Create text embeddings (uses OpenAI format underneath).
* `POST /openai/v1/images/generations`: Generate images (uses OpenAI format underneath).
## 🤝 Contributing
@@ -242,9 +260,9 @@ Pull Requests or Issues are welcome.
Special thanks to the following projects and platforms for providing image hosting services for this project:
* [PicGo](https://www.picgo.net/)
* [SM.MS](https://smms.app/)
* [CloudFlare-ImgBed](https://github.com/MarSeventh/CloudFlare-ImgBed) open source project
* [PicGo](https://www.picgo.net/)
* [SM.MS](https://smms.app/)
* [CloudFlare-ImgBed](https://github.com/MarSeventh/CloudFlare-ImgBed) open source project
## 🙏 Thanks to Contributors
@@ -254,11 +272,11 @@ Thanks to all developers who contributed to this project!
## Thanks to Our Supporters
We extend our heartfelt gratitude to the following supporters for their invaluable contributions to this project:
A special shout-out to DigitalOcean for providing the rock-solid and dependable cloud infrastructure that keeps this project humming!
[![DigitalOcean Logo](files/dataocean.svg)](https://m.do.co/c/b249dd7f3b4c)
A special shout-out to DigitalOcean for providing the rock-solid and dependable cloud infrastructure that keeps this project humming!
CDN acceleration and security protection for this project are sponsored by Tencent EdgeOne.
[![EdgeOne Logo](https://edgeone.ai/media/34fe3a45-492d-4ea4-ae5d-ea1087ca7b4b.png)](https://edgeone.ai/?from=github)
## ⭐ Star History
@@ -266,7 +284,7 @@ A special shout-out to DigitalOcean for providing the rock-solid and dependable
## 💖 Friendly Projects
* **[OneLine](https://github.com/chengtx809/OneLine)** by [chengtx809](https://github.com/chengtx809) - OneLine: AI-driven hot event timeline generation tool
* **[OneLine](https://github.com/chengtx809/OneLine)** by [chengtx809](https://github.com/chengtx809) - OneLine: AI-driven hot event timeline generation tool
## 🎁 Project Support

View File

@@ -1,5 +1,11 @@
# Gemini Balance - Gemini API 代理和负载均衡器
<p align="center">
<a href="https://trendshift.io/repositories/13692" target="_blank">
<img src="https://trendshift.io/api/badge/repositories/13692" alt="snailyp%2Fgemini-balance | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/>
</a>
</p>
> ⚠️ 本项目采用 CC BY-NC 4.0(署名-非商业性使用)协议,禁止任何形式的商业倒卖服务,详见 LICENSE 文件。
> 本人从未在各个平台售卖服务,如有遇到售卖此服务者,那一定是倒卖狗,大家切记不要上当受骗。
@@ -178,6 +184,9 @@ app/
| `SHOW_THINKING_PROCESS` | 可选,是否显示模型思考过程 | `true` |
| `THINKING_MODELS` | 可选,支持思考功能的模型列表 | `[]` |
| `THINKING_BUDGET_MAP` | 可选,思考功能预算映射 (模型名:预算值) | `{}` |
| `URL_NORMALIZATION_ENABLED` | 可选,是否启用智能路由映射功能 | `false` |
| `URL_CONTEXT_ENABLED` | 可选是否启用URL上下文理解功能 | `false` |
| `URL_CONTEXT_MODELS` | 可选支持URL上下文理解功能的模型列表 | `[]` |
| `BASE_URL` | 可选Gemini API 基础 URL默认无需修改 | `https://generativelanguage.googleapis.com/v1beta` |
| `MAX_FAILURES` | 可选允许单个key失败的次数 | `3` |
| `MAX_RETRIES` | 可选API 请求失败时的最大重试次数 | `3` |
@@ -191,6 +200,10 @@ app/
| `AUTO_DELETE_REQUEST_LOGS_ENABLED`| 可选,是否开启自动删除请求日志 | `false` |
| `AUTO_DELETE_REQUEST_LOGS_DAYS` | 可选,自动删除多少天前的请求日志 (例如 1, 7, 30) | `30` |
| `SAFETY_SETTINGS` | 可选,安全设置 (JSON 字符串格式),用于配置内容安全阈值。示例值可能需要根据实际模型支持情况调整。 | `[{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"}, {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"}, {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"}, {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"}, {"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"}]` |
| **TTS 相关** | | |
| `TTS_MODEL` | 可选TTS 模型名称 | `gemini-2.5-flash-preview-tts` |
| `TTS_VOICE_NAME` | 可选TTS 语音名称 | `Zephyr` |
| `TTS_SPEED` | 可选TTS 语速 | `normal` |
| **图像生成相关** | | |
| `PAID_KEY` | 可选付费版API Key用于图片生成等高级功能 | `your-paid-api-key` |
| `CREATE_IMAGE_MODEL` | 可选,图片生成模型 | `imagen-3.0-generate-002` |
@@ -199,6 +212,7 @@ app/
| `PICGO_API_KEY` | 可选,[PicoGo](https://www.picgo.net/)图床的API Key | `your-picogo-apikey` |
| `CLOUDFLARE_IMGBED_URL` | 可选,[CloudFlare](https://github.com/MarSeventh/CloudFlare-ImgBed) 图床上传地址 | `https://xxxxxxx.pages.dev/upload` |
| `CLOUDFLARE_IMGBED_AUTH_CODE`| 可选CloudFlare图床的鉴权key | `your-cloudflare-imgber-auth-code` |
| `CLOUDFLARE_IMGBED_UPLOAD_FOLDER`| 可选CloudFlare图床的上传文件夹路径 | `""` |
| **流式优化器相关** | | |
| `STREAM_OPTIMIZER_ENABLED` | 可选,是否启用流式输出优化 | `false` |
| `STREAM_MIN_DELAY` | 可选,流式输出最小延迟 | `0.016` |

View File

@@ -1 +1 @@
2.1.6
2.2.1

View File

@@ -4,7 +4,7 @@
import datetime
import json
from typing import Any, Dict, List, Type
from typing import Any, Dict, List, Type, get_args, get_origin
from pydantic import ValidationError, ValidationInfo, field_validator
from pydantic_settings import BaseSettings
@@ -51,8 +51,8 @@ class Settings(BaseSettings):
return v
# API相关配置
API_KEYS: List[str]
ALLOWED_TOKENS: List[str]
API_KEYS: List[str]=[]
ALLOWED_TOKENS: List[str]=[]
BASE_URL: str = f"https://generativelanguage.googleapis.com/{API_VERSION}"
AUTH_TOKEN: str = ""
MAX_FAILURES: int = 3
@@ -63,17 +63,31 @@ class Settings(BaseSettings):
PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY: bool = True # 是否使用一致性哈希来选择代理
VERTEX_API_KEYS: List[str] = []
VERTEX_EXPRESS_BASE_URL: str = "https://aiplatform.googleapis.com/v1beta1/publishers/google"
# 智能路由配置
URL_NORMALIZATION_ENABLED: bool = False # 是否启用智能路由映射功能
# 自定义 Headers
CUSTOM_HEADERS: Dict[str, str] = {}
# 模型相关配置
SEARCH_MODELS: List[str] = ["gemini-2.0-flash-exp"]
IMAGE_MODELS: List[str] = ["gemini-2.0-flash-exp"]
FILTERED_MODELS: List[str] = DEFAULT_FILTER_MODELS
TOOLS_CODE_EXECUTION_ENABLED: bool = False
# 是否启用网址上下文
URL_CONTEXT_ENABLED: bool = True
URL_CONTEXT_MODELS: List[str] = ["gemini-2.5-pro","gemini-2.5-flash","gemini-2.5-flash-lite","gemini-2.0-flash","gemini-2.0-flash-live-001"]
SHOW_SEARCH_LINK: bool = True
SHOW_THINKING_PROCESS: bool = True
THINKING_MODELS: List[str] = []
THINKING_BUDGET_MAP: Dict[str, float] = {}
# TTS相关配置
TTS_MODEL: str = "gemini-2.5-flash-preview-tts"
TTS_VOICE_NAME: str = "Zephyr"
TTS_SPEED: str = "normal"
# 图像生成相关配置
PAID_KEY: str = ""
CREATE_IMAGE_MODEL: str = DEFAULT_CREATE_IMAGE_MODEL
@@ -82,6 +96,7 @@ class Settings(BaseSettings):
PICGO_API_KEY: str = ""
CLOUDFLARE_IMGBED_URL: str = ""
CLOUDFLARE_IMGBED_AUTH_CODE: str = ""
CLOUDFLARE_IMGBED_UPLOAD_FOLDER: str = ""
# 流式输出优化器配置
STREAM_OPTIMIZER_ENABLED: bool = False
@@ -111,6 +126,11 @@ class Settings(BaseSettings):
AUTO_DELETE_REQUEST_LOGS_DAYS: int = 30
SAFETY_SETTINGS: List[Dict[str, str]] = DEFAULT_SAFETY_SETTINGS
# Files API
FILES_CLEANUP_ENABLED: bool = True
FILES_CLEANUP_INTERVAL_HOURS: int = 1
FILES_USER_ISOLATION_ENABLED: bool = True
def __init__(self, **kwargs):
super().__init__(**kwargs)
# 设置默认AUTH_TOKEN如果未提供
@@ -128,86 +148,106 @@ def _parse_db_value(key: str, db_value: str, target_type: Type) -> Any:
logger = get_config_logger()
try:
# 处理 List[str]
if target_type == List[str]:
try:
parsed = json.loads(db_value)
if isinstance(parsed, list):
return [str(item) for item in parsed]
except json.JSONDecodeError:
origin_type = get_origin(target_type)
args = get_args(target_type)
# 处理 List 类型
if origin_type is list:
# 处理 List[str]
if args and args[0] == str:
try:
parsed = json.loads(db_value)
if isinstance(parsed, list):
return [str(item) for item in parsed]
except json.JSONDecodeError:
return [item.strip() for item in db_value.split(",") if item.strip()]
logger.warning(
f"Could not parse '{db_value}' as List[str] for key '{key}', falling back to comma split or empty list."
)
return [item.strip() for item in db_value.split(",") if item.strip()]
logger.warning(
f"Could not parse '{db_value}' as List[str] for key '{key}', falling back to comma split or empty list."
)
return [item.strip() for item in db_value.split(",") if item.strip()]
# 处理 Dict[str, float]
elif target_type == Dict[str, float]:
parsed_dict = {}
try:
parsed = json.loads(db_value)
if isinstance(parsed, dict):
parsed_dict = {str(k): float(v) for k, v in parsed.items()}
else:
logger.warning(
f"Parsed DB value for key '{key}' is not a dictionary type. Value: {db_value}"
)
except (json.JSONDecodeError, ValueError, TypeError) as e1:
if isinstance(e1, json.JSONDecodeError) and "'" in db_value:
logger.warning(
f"Failed initial JSON parse for key '{key}'. Attempting to replace single quotes. Error: {e1}"
)
try:
corrected_db_value = db_value.replace("'", '"')
parsed = json.loads(corrected_db_value)
if isinstance(parsed, dict):
parsed_dict = {str(k): float(v) for k, v in parsed.items()}
# 处理 List[Dict[str, str]]
elif args and get_origin(args[0]) is dict:
try:
parsed = json.loads(db_value)
if isinstance(parsed, list):
valid = all(
isinstance(item, dict)
and all(isinstance(k, str) for k in item.keys())
and all(isinstance(v, str) for v in item.values())
for item in parsed
)
if valid:
return parsed
else:
logger.warning(
f"Parsed DB value (after quote replacement) for key '{key}' is not a dictionary type. Value: {corrected_db_value}"
f"Invalid structure in List[Dict[str, str]] for key '{key}'. Value: {db_value}"
)
except (json.JSONDecodeError, ValueError, TypeError) as e2:
logger.error(
f"Could not parse '{db_value}' as Dict[str, float] for key '{key}' even after replacing quotes: {e2}. Returning empty dict."
)
else:
logger.error(
f"Could not parse '{db_value}' as Dict[str, float] for key '{key}': {e1}. Returning empty dict."
)
return parsed_dict
# 处理 List[Dict[str, str]]
elif target_type == List[Dict[str, str]]:
try:
parsed = json.loads(db_value)
if isinstance(parsed, list):
# 验证列表中的每个元素是否为字典,并且键和值都是字符串
valid = all(
isinstance(item, dict)
and all(isinstance(k, str) for k in item.keys())
and all(isinstance(v, str) for v in item.values())
for item in parsed
)
if valid:
return parsed
return []
else:
logger.warning(
f"Invalid structure in List[Dict[str, str]] for key '{key}'. Value: {db_value}"
f"Parsed DB value for key '{key}' is not a list type. Value: {db_value}"
)
return []
else:
logger.warning(
f"Parsed DB value for key '{key}' is not a list type. Value: {db_value}"
except json.JSONDecodeError:
logger.error(
f"Could not parse '{db_value}' as JSON for List[Dict[str, str]] for key '{key}'. Returning empty list."
)
return []
except json.JSONDecodeError:
logger.error(
f"Could not parse '{db_value}' as JSON for List[Dict[str, str]] for key '{key}'. Returning empty list."
)
return []
except Exception as e:
logger.error(
f"Error parsing List[Dict[str, str]] for key '{key}': {e}. Value: {db_value}. Returning empty list."
)
return []
except Exception as e:
logger.error(
f"Error parsing List[Dict[str, str]] for key '{key}': {e}. Value: {db_value}. Returning empty list."
)
return []
# 处理 Dict 类型
elif origin_type is dict:
# 处理 Dict[str, str]
if args and args == (str, str):
parsed_dict = {}
try:
parsed = json.loads(db_value)
if isinstance(parsed, dict):
parsed_dict = {str(k): str(v) for k, v in parsed.items()}
else:
logger.warning(
f"Parsed DB value for key '{key}' is not a dictionary type. Value: {db_value}"
)
except json.JSONDecodeError:
logger.error(f"Could not parse '{db_value}' as Dict[str, str] for key '{key}'. Returning empty dict.")
return parsed_dict
# 处理 Dict[str, float]
elif args and args == (str, float):
parsed_dict = {}
try:
parsed = json.loads(db_value)
if isinstance(parsed, dict):
parsed_dict = {str(k): float(v) for k, v in parsed.items()}
else:
logger.warning(
f"Parsed DB value for key '{key}' is not a dictionary type. Value: {db_value}"
)
except (json.JSONDecodeError, ValueError, TypeError) as e1:
if isinstance(e1, json.JSONDecodeError) and "'" in db_value:
logger.warning(
f"Failed initial JSON parse for key '{key}'. Attempting to replace single quotes. Error: {e1}"
)
try:
corrected_db_value = db_value.replace("'", '"')
parsed = json.loads(corrected_db_value)
if isinstance(parsed, dict):
parsed_dict = {str(k): float(v) for k, v in parsed.items()}
else:
logger.warning(
f"Parsed DB value (after quote replacement) for key '{key}' is not a dictionary type. Value: {corrected_db_value}"
)
except (json.JSONDecodeError, ValueError, TypeError) as e2:
logger.error(
f"Could not parse '{db_value}' as Dict[str, float] for key '{key}' even after replacing quotes: {e2}. Returning empty dict."
)
else:
logger.error(
f"Could not parse '{db_value}' as Dict[str, float] for key '{key}': {e1}. Returning empty dict."
)
return parsed_dict
# 处理 bool
elif target_type == bool:
return db_value.lower() in ("true", "1", "yes", "on")
@@ -296,18 +336,12 @@ async def sync_initial_settings():
if parsed_db_value != memory_value:
# 检查类型是否匹配,以防解析函数返回了不兼容的类型
type_match = False
if target_type == List[str] and isinstance(
parsed_db_value, list
):
type_match = True
elif target_type == Dict[str, float] and isinstance(
parsed_db_value, dict
):
type_match = True
elif target_type not in (
List[str],
Dict[str, float],
) and isinstance(parsed_db_value, target_type):
origin_type = get_origin(target_type)
if origin_type: # It's a generic type
if isinstance(parsed_db_value, origin_type):
type_match = True
# It's a non-generic type, or a specific generic we want to handle
elif isinstance(parsed_db_value, target_type):
type_match = True
if type_match:

View File

@@ -15,12 +15,12 @@ DEFAULT_MAX_TOKENS = 8192
DEFAULT_TOP_P = 0.9
DEFAULT_TOP_K = 40
DEFAULT_FILTER_MODELS = [
"gemini-1.0-pro-vision-latest",
"gemini-pro-vision",
"chat-bison-001",
"text-bison-001",
"embedding-gecko-001"
]
"gemini-1.0-pro-vision-latest",
"gemini-pro-vision",
"chat-bison-001",
"text-bison-001",
"embedding-gecko-001",
]
DEFAULT_CREATE_IMAGE_MODEL = "imagen-3.0-generate-002"
# 图像生成相关常量
@@ -38,14 +38,14 @@ DEFAULT_STREAM_LONG_TEXT_THRESHOLD = 50
DEFAULT_STREAM_CHUNK_SIZE = 5
# 正则表达式模式
IMAGE_URL_PATTERN = r'!\[(.*?)\]\((.*?)\)'
DATA_URL_PATTERN = r'data:([^;]+);base64,(.+)'
IMAGE_URL_PATTERN = r"!\[(.*?)\]\((.*?)\)"
DATA_URL_PATTERN = r"data:([^;]+);base64,(.+)"
# Audio/Video Settings
SUPPORTED_AUDIO_FORMATS = ["wav", "mp3", "flac", "ogg"]
SUPPORTED_VIDEO_FORMATS = ["mp4", "mov", "avi", "webm"]
MAX_AUDIO_SIZE_BYTES = 50 * 1024 * 1024 # Example: 50MB limit for Base64 payload
MAX_VIDEO_SIZE_BYTES = 200 * 1024 * 1024 # Example: 200MB limit
MAX_VIDEO_SIZE_BYTES = 200 * 1024 * 1024 # Example: 200MB limit
# Optional: Define MIME type mappings if needed, or handle directly in converter
AUDIO_FORMAT_TO_MIMETYPE = {
@@ -63,17 +63,50 @@ VIDEO_FORMAT_TO_MIMETYPE = {
}
GEMINI_2_FLASH_EXP_SAFETY_SETTINGS = [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "OFF"},
]
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "OFF"},
]
DEFAULT_SAFETY_SETTINGS = [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"},
]
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"},
]
TTS_VOICE_NAMES = [
"Zephyr",
"Puck",
"Charon",
"Kore",
"Fenrir",
"Leda",
"Orus",
"Aoede",
"Callirrhoe",
"Autonoe",
"Enceladus",
"Iapetus",
"Umbriel",
"Algieba",
"Despina",
"Erinome",
"Algenib",
"Rasalgethi",
"Laomedeia",
"Achernar",
"Alnilam",
"Schedar",
"Gacrux",
"Pulcherrima",
"Achird",
"Zubenelgenubi",
"Vindemiatrix",
"Sadachbia",
"Sadaltager",
"Sulafat",
]

View File

@@ -2,6 +2,7 @@
数据库连接池模块
"""
from pathlib import Path
from urllib.parse import quote_plus
from databases import Database
from sqlalchemy import create_engine, MetaData
from sqlalchemy.ext.declarative import declarative_base
@@ -20,9 +21,9 @@ if settings.DATABASE_TYPE == "sqlite":
DATABASE_URL = f"sqlite:///{db_path}"
elif settings.DATABASE_TYPE == "mysql":
if settings.MYSQL_SOCKET:
DATABASE_URL = f"mysql+pymysql://{settings.MYSQL_USER}:{settings.MYSQL_PASSWORD}@/{settings.MYSQL_DATABASE}?unix_socket={settings.MYSQL_SOCKET}"
DATABASE_URL = f"mysql+pymysql://{settings.MYSQL_USER}:{quote_plus(settings.MYSQL_PASSWORD)}@/{settings.MYSQL_DATABASE}?unix_socket={settings.MYSQL_SOCKET}"
else:
DATABASE_URL = f"mysql+pymysql://{settings.MYSQL_USER}:{settings.MYSQL_PASSWORD}@{settings.MYSQL_HOST}:{settings.MYSQL_PORT}/{settings.MYSQL_DATABASE}"
DATABASE_URL = f"mysql+pymysql://{settings.MYSQL_USER}:{quote_plus(settings.MYSQL_PASSWORD)}@{settings.MYSQL_HOST}:{settings.MYSQL_PORT}/{settings.MYSQL_DATABASE}"
else:
raise ValueError("Unsupported database type. Please set DATABASE_TYPE to 'sqlite' or 'mysql'.")

View File

@@ -2,7 +2,8 @@
数据库模型模块
"""
import datetime
from sqlalchemy import Column, Integer, String, Text, DateTime, JSON, Boolean
from sqlalchemy import Column, Integer, String, Text, DateTime, JSON, Boolean, BigInteger, Enum
import enum
from app.database.connection import Base
@@ -60,3 +61,69 @@ class RequestLog(Base):
def __repr__(self):
return f"<RequestLog(id='{self.id}', key='{self.api_key[:4]}...', success='{self.is_success}')>"
class FileState(enum.Enum):
"""文件状态枚举"""
PROCESSING = "PROCESSING"
ACTIVE = "ACTIVE"
FAILED = "FAILED"
class FileRecord(Base):
"""
文件记录表,用于存储上传到 Gemini 的文件信息
"""
__tablename__ = "t_file_records"
id = Column(Integer, primary_key=True, autoincrement=True)
# 文件基本信息
name = Column(String(255), unique=True, nullable=False, comment="文件名称,格式: files/{file_id}")
display_name = Column(String(255), nullable=True, comment="用户上传时的原始文件名")
mime_type = Column(String(100), nullable=False, comment="MIME 类型")
size_bytes = Column(BigInteger, nullable=False, comment="文件大小(字节)")
sha256_hash = Column(String(255), nullable=True, comment="文件的 SHA256 哈希值")
# 状态信息
state = Column(Enum(FileState), nullable=False, default=FileState.PROCESSING, comment="文件状态")
# 时间戳
create_time = Column(DateTime, nullable=False, comment="创建时间")
update_time = Column(DateTime, nullable=False, comment="更新时间")
expiration_time = Column(DateTime, nullable=False, comment="过期时间")
# API 相关
uri = Column(String(500), nullable=False, comment="文件访问 URI")
api_key = Column(String(100), nullable=False, comment="上传时使用的 API Key")
upload_url = Column(Text, nullable=True, comment="临时上传 URL用于分块上传")
# 额外信息
user_token = Column(String(100), nullable=True, comment="上传用户的 token")
upload_completed = Column(DateTime, nullable=True, comment="上传完成时间")
def __repr__(self):
return f"<FileRecord(name='{self.name}', state='{self.state.value if self.state else 'None'}', api_key='{self.api_key[:8]}...')>"
def to_dict(self):
"""转换为字典格式,用于 API 响应"""
return {
"name": self.name,
"displayName": self.display_name,
"mimeType": self.mime_type,
"sizeBytes": str(self.size_bytes),
"createTime": self.create_time.isoformat() + "Z",
"updateTime": self.update_time.isoformat() + "Z",
"expirationTime": self.expiration_time.isoformat() + "Z",
"sha256Hash": self.sha256_hash,
"uri": self.uri,
"state": self.state.value if self.state else "PROCESSING"
}
def is_expired(self):
"""检查文件是否已过期"""
# 确保比较时都是 timezone-aware
expiration_time = self.expiration_time
if expiration_time.tzinfo is None:
expiration_time = expiration_time.replace(tzinfo=datetime.timezone.utc)
return datetime.datetime.now(datetime.timezone.utc) > expiration_time

View File

@@ -2,11 +2,11 @@
数据库服务模块
"""
from typing import List, Optional, Dict, Any, Union
from datetime import datetime
from datetime import datetime, timezone
from sqlalchemy import func, desc, asc, select, insert, update, delete
import json
from app.database.connection import database
from app.database.models import Settings, ErrorLog, RequestLog
from app.database.models import Settings, ErrorLog, RequestLog, FileRecord, FileState
from app.log.logger import get_database_logger
logger = get_database_logger()
@@ -427,3 +427,264 @@ async def add_request_log(
except Exception as e:
logger.error(f"Failed to add request log: {str(e)}")
return False
# ==================== 文件记录相关函数 ====================
async def create_file_record(
name: str,
mime_type: str,
size_bytes: int,
api_key: str,
uri: str,
create_time: datetime,
update_time: datetime,
expiration_time: datetime,
state: FileState = FileState.PROCESSING,
display_name: Optional[str] = None,
sha256_hash: Optional[str] = None,
upload_url: Optional[str] = None,
user_token: Optional[str] = None
) -> Dict[str, Any]:
"""
创建文件记录
Args:
name: 文件名称(格式: files/{file_id}
mime_type: MIME 类型
size_bytes: 文件大小(字节)
api_key: 上传时使用的 API Key
uri: 文件访问 URI
create_time: 创建时间
update_time: 更新时间
expiration_time: 过期时间
display_name: 显示名称
sha256_hash: SHA256 哈希值
upload_url: 临时上传 URL
user_token: 上传用户的 token
Returns:
Dict[str, Any]: 创建的文件记录
"""
try:
query = insert(FileRecord).values(
name=name,
display_name=display_name,
mime_type=mime_type,
size_bytes=size_bytes,
sha256_hash=sha256_hash,
state=state,
create_time=create_time,
update_time=update_time,
expiration_time=expiration_time,
uri=uri,
api_key=api_key,
upload_url=upload_url,
user_token=user_token
)
await database.execute(query)
# 返回创建的记录
return await get_file_record_by_name(name)
except Exception as e:
logger.error(f"Failed to create file record: {str(e)}")
raise
async def get_file_record_by_name(name: str) -> Optional[Dict[str, Any]]:
"""
根据文件名获取文件记录
Args:
name: 文件名称(格式: files/{file_id}
Returns:
Optional[Dict[str, Any]]: 文件记录,如果不存在则返回 None
"""
try:
query = select(FileRecord).where(FileRecord.name == name)
result = await database.fetch_one(query)
return dict(result) if result else None
except Exception as e:
logger.error(f"Failed to get file record by name {name}: {str(e)}")
raise
async def update_file_record_state(
file_name: str,
state: FileState,
update_time: Optional[datetime] = None,
upload_completed: Optional[datetime] = None,
sha256_hash: Optional[str] = None
) -> bool:
"""
更新文件记录状态
Args:
file_name: 文件名
state: 新状态
update_time: 更新时间
upload_completed: 上传完成时间
sha256_hash: SHA256 哈希值
Returns:
bool: 是否更新成功
"""
try:
values = {"state": state}
if update_time:
values["update_time"] = update_time
if upload_completed:
values["upload_completed"] = upload_completed
if sha256_hash:
values["sha256_hash"] = sha256_hash
query = update(FileRecord).where(FileRecord.name == file_name).values(**values)
result = await database.execute(query)
if result:
logger.info(f"Updated file record state for {file_name} to {state}")
return True
logger.warning(f"File record not found for update: {file_name}")
return False
except Exception as e:
logger.error(f"Failed to update file record state: {str(e)}")
return False
async def list_file_records(
user_token: Optional[str] = None,
api_key: Optional[str] = None,
page_size: int = 10,
page_token: Optional[str] = None
) -> tuple[List[Dict[str, Any]], Optional[str]]:
"""
列出文件记录
Args:
user_token: 用户 token如果提供只返回该用户的文件
api_key: API Key如果提供只返回使用该 key 的文件)
page_size: 每页大小
page_token: 分页标记(偏移量)
Returns:
tuple[List[Dict[str, Any]], Optional[str]]: (文件列表, 下一页标记)
"""
try:
logger.debug(f"list_file_records called with page_size={page_size}, page_token={page_token}")
query = select(FileRecord).where(
FileRecord.expiration_time > datetime.now(timezone.utc)
)
if user_token:
query = query.where(FileRecord.user_token == user_token)
if api_key:
query = query.where(FileRecord.api_key == api_key)
# 使用偏移量进行分页
offset = 0
if page_token:
try:
offset = int(page_token)
except ValueError:
logger.warning(f"Invalid page token: {page_token}")
offset = 0
# 按ID升序排列使用 OFFSET 和 LIMIT
query = query.order_by(FileRecord.id).offset(offset).limit(page_size + 1)
results = await database.fetch_all(query)
logger.debug(f"Query returned {len(results)} records")
if results:
logger.debug(f"First record ID: {results[0]['id']}, Last record ID: {results[-1]['id']}")
# 处理分页
has_next = len(results) > page_size
if has_next:
results = results[:page_size]
# 下一页的偏移量是当前偏移量加上本页返回的记录数
next_offset = offset + page_size
next_page_token = str(next_offset)
logger.debug(f"Has next page, offset={offset}, page_size={page_size}, next_page_token={next_page_token}")
else:
next_page_token = None
logger.debug(f"No next page, returning {len(results)} results")
return [dict(row) for row in results], next_page_token
except Exception as e:
logger.error(f"Failed to list file records: {str(e)}")
raise
async def delete_file_record(name: str) -> bool:
"""
删除文件记录
Args:
name: 文件名称
Returns:
bool: 是否删除成功
"""
try:
query = delete(FileRecord).where(FileRecord.name == name)
await database.execute(query)
return True
except Exception as e:
logger.error(f"Failed to delete file record: {str(e)}")
return False
async def delete_expired_file_records() -> List[Dict[str, Any]]:
"""
删除已过期的文件记录
Returns:
List[Dict[str, Any]]: 删除的记录列表
"""
try:
# 先获取要删除的记录
query = select(FileRecord).where(
FileRecord.expiration_time <= datetime.now(timezone.utc)
)
expired_records = await database.fetch_all(query)
if not expired_records:
return []
# 执行删除
delete_query = delete(FileRecord).where(
FileRecord.expiration_time <= datetime.now(timezone.utc)
)
await database.execute(delete_query)
logger.info(f"Deleted {len(expired_records)} expired file records")
return [dict(record) for record in expired_records]
except Exception as e:
logger.error(f"Failed to delete expired file records: {str(e)}")
raise
async def get_file_api_key(name: str) -> Optional[str]:
"""
获取文件对应的 API Key
Args:
name: 文件名称
Returns:
Optional[str]: API Key如果文件不存在或已过期则返回 None
"""
try:
query = select(FileRecord.api_key).where(
(FileRecord.name == name) &
(FileRecord.expiration_time > datetime.now(timezone.utc))
)
result = await database.fetch_one(query)
return result["api_key"] if result else None
except Exception as e:
logger.error(f"Failed to get file API key: {str(e)}")
raise

69
app/domain/file_models.py Normal file
View File

@@ -0,0 +1,69 @@
"""
Files API 相关的领域模型
"""
from typing import Optional, Dict, Any, List
from datetime import datetime
from pydantic import BaseModel, Field
class FileUploadConfig(BaseModel):
"""文件上传配置"""
mime_type: Optional[str] = Field(None, description="MIME 类型")
display_name: Optional[str] = Field(None, description="显示名称最多40个字符")
class CreateFileRequest(BaseModel):
"""创建文件请求(用于初始化上传)"""
file: Optional[Dict[str, Any]] = Field(None, description="文件元数据")
class FileMetadata(BaseModel):
"""文件元数据响应"""
name: str = Field(..., description="文件名称,格式: files/{file_id}")
displayName: Optional[str] = Field(None, description="显示名称")
mimeType: str = Field(..., description="MIME 类型")
sizeBytes: str = Field(..., description="文件大小(字节)")
createTime: str = Field(..., description="创建时间 (RFC3339)")
updateTime: str = Field(..., description="更新时间 (RFC3339)")
expirationTime: str = Field(..., description="过期时间 (RFC3339)")
sha256Hash: Optional[str] = Field(None, description="SHA256 哈希值")
uri: str = Field(..., description="文件访问 URI")
state: str = Field(..., description="文件状态")
class Config:
json_encoders = {
datetime: lambda v: v.isoformat() + "Z"
}
class ListFilesRequest(BaseModel):
"""列出文件请求参数"""
pageSize: Optional[int] = Field(10, ge=1, le=100, description="每页大小")
pageToken: Optional[str] = Field(None, description="分页标记")
class ListFilesResponse(BaseModel):
"""列出文件响应"""
files: List[FileMetadata] = Field(default_factory=list, description="文件列表")
nextPageToken: Optional[str] = Field(None, description="下一页标记")
class UploadInitResponse(BaseModel):
"""上传初始化响应(内部使用)"""
file_metadata: FileMetadata
upload_url: str
class FileKeyMapping(BaseModel):
"""文件与 API Key 的映射关系(内部使用)"""
file_name: str
api_key: str
user_token: str
created_at: datetime
expires_at: datetime
class DeleteFileResponse(BaseModel):
"""删除文件响应"""
success: bool = Field(..., description="是否删除成功")
message: Optional[str] = Field(None, description="消息")

View File

@@ -41,15 +41,18 @@ class GenerationConfig(BaseModel):
responseLogprobs: Optional[bool] = None
logprobs: Optional[int] = None
thinkingConfig: Optional[Dict[str, Any]] = None
# TTS相关字段
responseModalities: Optional[List[str]] = None
speechConfig: Optional[Dict[str, Any]] = None
class SystemInstruction(BaseModel):
role: str = "system"
parts: List[Dict[str, Any]] | Dict[str, Any]
role: Optional[str] = "system"
parts: Union[List[Dict[str, Any]], Dict[str, Any]]
class GeminiContent(BaseModel):
role: str
role: Optional[str] = None
parts: List[Dict[str, Any]]

View File

@@ -1,23 +1,20 @@
from typing import Union
class ImageMetadata:
def __init__(self, width: int, height: int, filename: str, size: int, url: str, delete_url: str | None = None):
def __init__(self, width: int, height: int, filename: str, size: int, url: str, delete_url: Union[str, None] = None):
self.width = width
self.height = height
self.filename = filename
self.size = size
self.url = url
self.delete_url = delete_url
class UploadResponse:
def __init__(self, success: bool, code: str, message: str, data: ImageMetadata):
self.success = success
self.code = code
self.message = message
self.data = data
class ImageUploader:
def upload(self, file: bytes, filename: str) -> UploadResponse:
raise NotImplementedError

View File

@@ -33,3 +33,10 @@ class ImageGenerationRequest(BaseModel):
quality: Optional[str] = None
style: Optional[str] = None
response_format: Optional[str] = "url"
class TTSRequest(BaseModel):
model: str = "gemini-2.5-flash-preview-tts"
input: str
voice: str = "Kore"
response_format: Optional[str] = "wav"

View File

@@ -9,6 +9,9 @@ from typing import Any, Dict, List, Optional
from app.config.config import settings
from app.utils.uploader import ImageUploaderFactory
from app.log.logger import get_openai_logger
logger = get_openai_logger()
class ResponseHandler(ABC):
@@ -39,13 +42,13 @@ class GeminiResponseHandler(ResponseHandler):
def _handle_openai_stream_response(
response: Dict[str, Any], model: str, finish_reason: str, usage_metadata: Optional[Dict[str, Any]]
) -> Dict[str, Any]:
text, tool_calls, _ = _extract_result(
text, reasoning_content, tool_calls, _ = _extract_result(
response, model, stream=True, gemini_format=False
)
if not text and not tool_calls:
if not text and not tool_calls and not reasoning_content:
delta = {}
else:
delta = {"content": text, "role": "assistant"}
delta = {"content": text, "reasoning_content": reasoning_content, "role": "assistant"}
if tool_calls:
delta["tool_calls"] = tool_calls
template_chunk = {
@@ -63,7 +66,7 @@ def _handle_openai_stream_response(
def _handle_openai_normal_response(
response: Dict[str, Any], model: str, finish_reason: str, usage_metadata: Optional[Dict[str, Any]]
) -> Dict[str, Any]:
text, tool_calls, _ = _extract_result(
text, reasoning_content, tool_calls, _ = _extract_result(
response, model, stream=False, gemini_format=False
)
return {
@@ -77,6 +80,7 @@ def _handle_openai_normal_response(
"message": {
"role": "assistant",
"content": text,
"reasoning_content": reasoning_content,
"tool_calls": tool_calls,
},
"finish_reason": finish_reason,
@@ -156,19 +160,24 @@ def _extract_result(
model: str,
stream: bool = False,
gemini_format: bool = False,
) -> tuple[str, List[Dict[str, Any]], Optional[bool]]:
text, tool_calls = "", []
thought = None
) -> tuple[str, Optional[str], List[Dict[str, Any]], Optional[bool]]:
text, reasoning_content, tool_calls, thought = "", "", [], None
if stream:
if response.get("candidates"):
candidate = response["candidates"][0]
content = candidate.get("content", {})
parts = content.get("parts", [])
if not parts:
return "", [], None
logger.warning("No parts found in stream response")
return "", None, [], None
if "text" in parts[0]:
text = parts[0].get("text")
if "thought" in parts[0]:
if not gemini_format and settings.SHOW_THINKING_PROCESS:
reasoning_content = text
text = ""
thought = parts[0].get("thought")
elif "executableCode" in parts[0]:
text = _format_code_block(parts[0]["executableCode"])
@@ -187,40 +196,40 @@ def _extract_result(
else:
if response.get("candidates"):
candidate = response["candidates"][0]
if "thinking" in model:
if settings.SHOW_THINKING_PROCESS:
if len(candidate["content"]["parts"]) == 2:
text = (
"> thinking\n\n"
+ candidate["content"]["parts"][0]["text"]
+ "\n\n---\n> output\n\n"
+ candidate["content"]["parts"][1]["text"]
)
else:
text = candidate["content"]["parts"][0]["text"]
else:
if len(candidate["content"]["parts"]) == 2:
text = candidate["content"]["parts"][1]["text"]
else:
text = candidate["content"]["parts"][0]["text"]
else:
text = ""
if "parts" in candidate["content"]:
for part in candidate["content"]["parts"]:
text, reasoning_content = "", ""
# 使用安全的访问方式
content = candidate.get("content", {})
if content and isinstance(content, dict):
parts = content.get("parts", [])
if parts:
for part in parts:
if "text" in part:
text += part["text"]
if "thought" in part and settings.SHOW_THINKING_PROCESS:
reasoning_content += part["text"]
else:
text += part["text"]
if "thought" in part and thought is None:
thought = part.get("thought")
elif "inlineData" in part:
text += _extract_image_data(part)
else:
logger.warning(f"No parts found in content for model: {model}")
else:
logger.error(f"Invalid content structure for model: {model}")
text = _add_search_link_text(model, candidate, text)
tool_calls = _extract_tool_calls(
candidate["content"]["parts"], gemini_format
)
# 安全地获取 parts 用于工具调用提取
parts = candidate.get("content", {}).get("parts", [])
tool_calls = _extract_tool_calls(parts, gemini_format)
else:
logger.warning(f"No candidates found in response for model: {model}")
text = "暂无返回"
return text, tool_calls, thought
return text, reasoning_content, tool_calls, thought
def _extract_image_data(part: dict) -> str:
@@ -238,6 +247,7 @@ def _extract_image_data(part: dict) -> str:
provider=settings.UPLOAD_PROVIDER,
base_url=settings.CLOUDFLARE_IMGBED_URL,
auth_code=settings.CLOUDFLARE_IMGBED_AUTH_CODE,
upload_folder=settings.CLOUDFLARE_IMGBED_UPLOAD_FOLDER,
)
current_date = time.strftime("%Y/%m/%d")
filename = f"{current_date}/{uuid.uuid4().hex[:8]}.png"
@@ -260,8 +270,8 @@ def _extract_tool_calls(
return []
letters = string.ascii_lowercase + string.digits
tool_calls = list()
for i in range(len(parts)):
part = parts[i]
if not part or not isinstance(part, dict):
@@ -270,7 +280,7 @@ def _extract_tool_calls(
item = part.get("functionCall", {})
if not item or not isinstance(item, dict):
continue
if gemini_format:
tool_calls.append(part)
else:
@@ -293,7 +303,7 @@ def _extract_tool_calls(
def _handle_gemini_stream_response(
response: Dict[str, Any], model: str, stream: bool
) -> Dict[str, Any]:
text, tool_calls, thought = _extract_result(
text, reasoning_content, tool_calls, thought = _extract_result(
response, model, stream=stream, gemini_format=True
)
if tool_calls:
@@ -310,16 +320,18 @@ def _handle_gemini_stream_response(
def _handle_gemini_normal_response(
response: Dict[str, Any], model: str, stream: bool
) -> Dict[str, Any]:
text, tool_calls, thought = _extract_result(
text, reasoning_content, tool_calls, thought = _extract_result(
response, model, stream=stream, gemini_format=True
)
parts = []
if tool_calls:
content = {"parts": tool_calls, "role": "model"}
parts = tool_calls
else:
part = {"text": text}
if thought is not None:
part["thought"] = thought
content = {"parts": [part], "role": "model"}
parts.append({"text": reasoning_content,"thought": thought})
part = {"text": text}
parts.append(part)
content = {"parts": parts, "role": "model"}
response["candidates"][0]["content"] = content
return response

View File

@@ -228,6 +228,10 @@ def get_request_log_logger():
return Logger.setup_logger("request_log")
def get_files_logger():
return Logger.setup_logger("files")
def get_vertex_express_logger():
return Logger.setup_logger("vertex_express")

View File

@@ -8,6 +8,7 @@ from fastapi.responses import RedirectResponse
from starlette.middleware.base import BaseHTTPMiddleware
# from app.middleware.request_logging_middleware import RequestLoggingMiddleware
from app.middleware.smart_routing_middleware import SmartRoutingMiddleware
from app.core.constants import API_VERSION
from app.core.security import verify_auth_token
from app.log.logger import get_middleware_logger
@@ -33,6 +34,7 @@ class AuthMiddleware(BaseHTTPMiddleware):
and not request.url.path.startswith("/openai")
and not request.url.path.startswith("/api/version/check")
and not request.url.path.startswith("/vertex-express")
and not request.url.path.startswith("/upload")
):
auth_token = request.cookies.get("auth_token")
@@ -52,6 +54,9 @@ def setup_middlewares(app: FastAPI) -> None:
Args:
app: FastAPI应用程序实例
"""
# 添加智能路由中间件(必须在认证中间件之前)
app.add_middleware(SmartRoutingMiddleware)
# 添加认证中间件
app.add_middleware(AuthMiddleware)

View File

@@ -0,0 +1,210 @@
from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
from app.config.config import settings
from app.log.logger import get_main_logger
import re
logger = get_main_logger()
class SmartRoutingMiddleware(BaseHTTPMiddleware):
def __init__(self, app):
super().__init__(app)
# 简化的路由规则 - 直接根据检测结果路由
pass
async def dispatch(self, request: Request, call_next):
if not settings.URL_NORMALIZATION_ENABLED:
return await call_next(request)
logger.debug(f"request: {request}")
original_path = str(request.url.path)
method = request.method
# 尝试修复URL
fixed_path, fix_info = self.fix_request_url(original_path, method, request)
if fixed_path != original_path:
logger.info(f"URL fixed: {method} {original_path}{fixed_path}")
if fix_info:
logger.debug(f"Fix details: {fix_info}")
# 重写请求路径
request.scope["path"] = fixed_path
request.scope["raw_path"] = fixed_path.encode()
return await call_next(request)
def fix_request_url(self, path: str, method: str, request: Request) -> tuple:
"""简化的URL修复逻辑"""
# 首先检查是否已经是正确的格式,如果是则不处理
if self.is_already_correct_format(path):
return path, None
# 1. 最高优先级包含generateContent → Gemini格式
if "generatecontent" in path.lower() or "v1beta/models" in path.lower():
return self.fix_gemini_by_operation(path, method, request)
# 2. 第二优先级:包含/openai/ → OpenAI格式
if "/openai/" in path.lower():
return self.fix_openai_by_operation(path, method)
# 3. 第三优先级:包含/v1/ → v1格式
if "/v1/" in path.lower():
return self.fix_v1_by_operation(path, method)
# 4. 第四优先级:包含/chat/completions → chat功能
if "/chat/completions" in path.lower():
return "/v1/chat/completions", {"type": "v1_chat"}
# 5. 默认:原样传递
return path, None
def is_already_correct_format(self, path: str) -> bool:
"""检查是否已经是正确的API格式"""
# 检查是否已经是正确的端点格式
correct_patterns = [
r"^/v1beta/models/[^/:]+:(generate|streamGenerate)Content$", # Gemini原生
r"^/gemini/v1beta/models/[^/:]+:(generate|streamGenerate)Content$", # Gemini带前缀
r"^/v1beta/models$", # Gemini模型列表
r"^/gemini/v1beta/models$", # Gemini带前缀的模型列表
r"^/v1/(chat/completions|models|embeddings|images/generations|audio/speech)$", # v1格式
r"^/openai/v1/(chat/completions|models|embeddings|images/generations|audio/speech)$", # OpenAI格式
r"^/hf/v1/(chat/completions|models|embeddings|images/generations|audio/speech)$", # HF格式
r"^/vertex-express/v1beta/models/[^/:]+:(generate|streamGenerate)Content$", # Vertex Express Gemini格式
r"^/vertex-express/v1beta/models$", # Vertex Express模型列表
r"^/vertex-express/v1/(chat/completions|models|embeddings|images/generations)$", # Vertex Express OpenAI格式
]
for pattern in correct_patterns:
if re.match(pattern, path):
return True
return False
def fix_gemini_by_operation(
self, path: str, method: str, request: Request
) -> tuple:
"""根据Gemini操作修复考虑端点偏好"""
if method == "GET":
return "/v1beta/models", {
"role": "gemini_models",
}
# 提取模型名称
try:
model_name = self.extract_model_name(path, request)
except ValueError:
# 无法提取模型名称,返回原路径不做处理
return path, None
# 检测是否为流式请求
is_stream = self.detect_stream_request(path, request)
# 检查是否有vertex-express偏好
if "/vertex-express/" in path.lower():
if is_stream:
target_url = (
f"/vertex-express/v1beta/models/{model_name}:streamGenerateContent"
)
else:
target_url = (
f"/vertex-express/v1beta/models/{model_name}:generateContent"
)
fix_info = {
"rule": (
"vertex_express_generate"
if not is_stream
else "vertex_express_stream"
),
"preference": "vertex_express_format",
"is_stream": is_stream,
"model": model_name,
}
else:
# 标准Gemini端点
if is_stream:
target_url = f"/v1beta/models/{model_name}:streamGenerateContent"
else:
target_url = f"/v1beta/models/{model_name}:generateContent"
fix_info = {
"rule": "gemini_generate" if not is_stream else "gemini_stream",
"preference": "gemini_format",
"is_stream": is_stream,
"model": model_name,
}
return target_url, fix_info
def fix_openai_by_operation(self, path: str, method: str) -> tuple:
"""根据操作类型修复OpenAI格式"""
if method == "POST":
if "chat" in path.lower() or "completion" in path.lower():
return "/openai/v1/chat/completions", {"type": "openai_chat"}
elif "embedding" in path.lower():
return "/openai/v1/embeddings", {"type": "openai_embeddings"}
elif "image" in path.lower():
return "/openai/v1/images/generations", {"type": "openai_images"}
elif "audio" in path.lower():
return "/openai/v1/audio/speech", {"type": "openai_audio"}
elif method == "GET":
if "model" in path.lower():
return "/openai/v1/models", {"type": "openai_models"}
return path, None
def fix_v1_by_operation(self, path: str, method: str) -> tuple:
"""根据操作类型修复v1格式"""
if method == "POST":
if "chat" in path.lower() or "completion" in path.lower():
return "/v1/chat/completions", {"type": "v1_chat"}
elif "embedding" in path.lower():
return "/v1/embeddings", {"type": "v1_embeddings"}
elif "image" in path.lower():
return "/v1/images/generations", {"type": "v1_images"}
elif "audio" in path.lower():
return "/v1/audio/speech", {"type": "v1_audio"}
elif method == "GET":
if "model" in path.lower():
return "/v1/models", {"type": "v1_models"}
return path, None
def detect_stream_request(self, path: str, request: Request) -> bool:
"""检测是否为流式请求"""
# 1. 路径中包含stream关键词
if "stream" in path.lower():
return True
# 2. 查询参数
if request.query_params.get("stream") == "true":
return True
return False
def extract_model_name(self, path: str, request: Request) -> str:
"""从请求中提取模型名称用于构建Gemini API URL"""
# 1. 从请求体中提取
try:
if hasattr(request, "_body") and request._body:
import json
body = json.loads(request._body.decode())
if "model" in body and body["model"]:
return body["model"]
except Exception:
pass
# 2. 从查询参数中提取
model_param = request.query_params.get("model")
if model_param:
return model_param
# 3. 从路径中提取(用于已包含模型名称的路径)
match = re.search(r"/models/([^/:]+)", path, re.IGNORECASE)
if match:
return match.group(1)
# 4. 如果无法提取模型名称,抛出异常
raise ValueError("Unable to extract model name from request")

295
app/router/files_routes.py Normal file
View File

@@ -0,0 +1,295 @@
"""
Files API 路由
"""
from typing import Optional
from fastapi import APIRouter, Request, Query, Depends, Header, HTTPException
from fastapi.responses import JSONResponse
from app.config.config import settings
from app.domain.file_models import (
FileMetadata,
ListFilesResponse,
DeleteFileResponse
)
from app.log.logger import get_files_logger
from app.core.security import SecurityService
from app.service.files.files_service import get_files_service
from app.service.files.file_upload_handler import get_upload_handler
logger = get_files_logger()
router = APIRouter()
security_service = SecurityService()
@router.post("/upload/v1beta/files")
async def upload_file_init(
request: Request,
auth_token: str = Depends(security_service.verify_key_or_goog_api_key),
x_goog_upload_protocol: Optional[str] = Header(None),
x_goog_upload_command: Optional[str] = Header(None),
x_goog_upload_header_content_length: Optional[str] = Header(None),
x_goog_upload_header_content_type: Optional[str] = Header(None),
):
"""初始化文件上传"""
logger.debug(f"Upload file request: {request.method=}, {request.url=}, {auth_token=}, {x_goog_upload_protocol=}, {x_goog_upload_command=}, {x_goog_upload_header_content_length=}, {x_goog_upload_header_content_type=}")
# 檢查是否是實際的上傳請求(有 upload_id
if request.query_params.get("upload_id") and x_goog_upload_command in ["upload", "upload, finalize"]:
logger.debug("This is an upload request, not initialization. Redirecting to handle_upload.")
return await handle_upload(
upload_path="v1beta/files",
request=request,
key=request.query_params.get("key"),
auth_token=auth_token
)
try:
# 使用认证 token 作为 user_token
user_token = auth_token
# 获取请求体
body = await request.body()
# 构建请求主机 URL
request_host = f"{request.url.scheme}://{request.url.netloc}"
logger.info(f"Request host: {request_host}")
# 准备请求头
headers = {
"x-goog-upload-protocol": x_goog_upload_protocol or "resumable",
"x-goog-upload-command": x_goog_upload_command or "start",
}
if x_goog_upload_header_content_length:
headers["x-goog-upload-header-content-length"] = x_goog_upload_header_content_length
if x_goog_upload_header_content_type:
headers["x-goog-upload-header-content-type"] = x_goog_upload_header_content_type
# 调用服务
files_service = await get_files_service()
response_data, response_headers = await files_service.initialize_upload(
headers=headers,
body=body,
user_token=user_token,
request_host=request_host # 傳遞請求主機
)
logger.info(f"Upload initialization response: {response_data}")
logger.info(f"Upload initialization response headers: {response_headers}")
logger.info(f"Upload initialization response headers: {response_data}")
# 返回响应
return JSONResponse(
content=response_data,
headers=response_headers
)
except HTTPException as e:
logger.error(f"Upload initialization failed: {e.detail}")
return JSONResponse(
content={"error": {"message": e.detail}},
status_code=e.status_code
)
except Exception as e:
logger.error(f"Unexpected error in upload initialization: {str(e)}")
return JSONResponse(
content={"error": {"message": "Internal server error"}},
status_code=500
)
@router.get("/v1beta/files")
async def list_files(
page_size: int = Query(10, ge=1, le=100, description="每页大小", alias="pageSize"),
page_token: Optional[str] = Query(None, description="分页标记", alias="pageToken"),
auth_token: str = Depends(security_service.verify_key_or_goog_api_key)
) -> ListFilesResponse:
"""列出文件"""
logger.debug(f"List files: {page_size=}, {page_token=}, {auth_token=}")
try:
# 使用认证 token 作为 user_token如果启用用户隔离
user_token = auth_token if settings.FILES_USER_ISOLATION_ENABLED else None
# 调用服务
files_service = await get_files_service()
return await files_service.list_files(
page_size=page_size,
page_token=page_token,
user_token=user_token
)
except HTTPException as e:
logger.error(f"List files failed: {e.detail}")
return JSONResponse(
content={"error": {"message": e.detail}},
status_code=e.status_code
)
except Exception as e:
logger.error(f"Unexpected error in list files: {str(e)}")
return JSONResponse(
content={"error": {"message": "Internal server error"}},
status_code=500
)
@router.get("/v1beta/files/{file_id:path}")
async def get_file(
file_id: str,
auth_token: str = Depends(security_service.verify_key_or_goog_api_key)
) -> FileMetadata:
"""获取文件信息"""
logger.debug(f"Get file request: {file_id=}, {auth_token=}")
try:
# 使用认证 token 作为 user_token
user_token = auth_token
# 调用服务
files_service = await get_files_service()
return await files_service.get_file(f"files/{file_id}", user_token)
except HTTPException as e:
logger.error(f"Get file failed: {e.detail}")
return JSONResponse(
content={"error": {"message": e.detail}},
status_code=e.status_code
)
except Exception as e:
logger.error(f"Unexpected error in get file: {str(e)}")
return JSONResponse(
content={"error": {"message": "Internal server error"}},
status_code=500
)
@router.delete("/v1beta/files/{file_id:path}")
async def delete_file(
file_id: str,
auth_token: str = Depends(security_service.verify_key_or_goog_api_key)
) -> DeleteFileResponse:
"""删除文件"""
logger.info(f"Delete file: {file_id=}, {auth_token=}")
try:
# 使用认证 token 作为 user_token
user_token = auth_token
# 调用服务
files_service = await get_files_service()
success = await files_service.delete_file(f"files/{file_id}", user_token)
return DeleteFileResponse(
success=success,
message="File deleted successfully" if success else "Failed to delete file"
)
except HTTPException as e:
logger.error(f"Delete file failed: {e.detail}")
return JSONResponse(
content={"error": {"message": e.detail}},
status_code=e.status_code
)
except Exception as e:
logger.error(f"Unexpected error in delete file: {str(e)}")
return JSONResponse(
content={"error": {"message": "Internal server error"}},
status_code=500
)
# 处理上传请求的通配符路由
@router.api_route("/upload/{upload_path:path}", methods=["GET", "POST", "PUT"])
async def handle_upload(
upload_path: str,
request: Request,
key: Optional[str] = Query(None), # 從查詢參數獲取 key
auth_token: str = Depends(security_service.verify_key_or_goog_api_key)
):
"""处理文件上传请求"""
try:
logger.info(f"Handling upload request: {request.method} {upload_path}, key={key}")
# 從查詢參數獲取 upload_id
upload_id = request.query_params.get("upload_id")
if not upload_id:
raise HTTPException(status_code=400, detail="Missing upload_id")
# 從 session 獲取真實的 API key
files_service = await get_files_service()
session_info = await files_service.get_upload_session(upload_id)
if not session_info:
logger.error(f"No session found for upload_id: {upload_id}")
raise HTTPException(status_code=404, detail="Upload session not found")
real_api_key = session_info["api_key"]
original_upload_url = session_info["upload_url"]
# 使用真實的 API key 構建完整的 Google 上傳 URL
# 保留原始 URL 的所有參數,但使用真實的 API key
upload_url = original_upload_url
logger.info(f"Using real API key for upload: {real_api_key[:8]}...{real_api_key[-4:]}")
# 代理上传请求
upload_handler = get_upload_handler()
return await upload_handler.proxy_upload_request(
request=request,
upload_url=upload_url,
files_service=files_service
)
except HTTPException as e:
logger.error(f"Upload handling failed: {e.detail}")
return JSONResponse(
content={"error": {"message": e.detail}},
status_code=e.status_code
)
except Exception as e:
logger.error(f"Unexpected error in upload handling: {str(e)}")
return JSONResponse(
content={"error": {"message": "Internal server error"}},
status_code=500
)
# 为兼容性添加 /gemini 前缀的路由
@router.post("/gemini/upload/v1beta/files")
async def gemini_upload_file_init(
request: Request,
auth_token: str = Depends(security_service.verify_key_or_goog_api_key),
x_goog_upload_protocol: Optional[str] = Header(None),
x_goog_upload_command: Optional[str] = Header(None),
x_goog_upload_header_content_length: Optional[str] = Header(None),
x_goog_upload_header_content_type: Optional[str] = Header(None),
):
"""初始化文件上传Gemini 前缀)"""
return await upload_file_init(
request,
auth_token,
x_goog_upload_protocol,
x_goog_upload_command,
x_goog_upload_header_content_length,
x_goog_upload_header_content_type
)
@router.get("/gemini/v1beta/files")
async def gemini_list_files(
page_size: int = Query(10, ge=1, le=100, alias="pageSize"),
page_token: Optional[str] = Query(None, alias="pageToken"),
auth_token: str = Depends(security_service.verify_key_or_goog_api_key)
) -> ListFilesResponse:
"""列出文件Gemini 前缀)"""
return await list_files(page_size, page_token, auth_token)
@router.get("/gemini/v1beta/files/{file_id:path}")
async def gemini_get_file(
file_id: str,
auth_token: str = Depends(security_service.verify_key_or_goog_api_key)
) -> FileMetadata:
"""获取文件信息Gemini 前缀)"""
return await get_file(file_id, auth_token)
@router.delete("/gemini/v1beta/files/{file_id:path}")
async def gemini_delete_file(
file_id: str,
auth_token: str = Depends(security_service.verify_key_or_goog_api_key)
) -> DeleteFileResponse:
"""删除文件Gemini 前缀)"""
return await delete_file(file_id, auth_token)

View File

@@ -8,6 +8,7 @@ from app.core.security import SecurityService
from app.domain.gemini_models import GeminiContent, GeminiRequest, ResetSelectedKeysRequest, VerifySelectedKeysRequest
from app.service.chat.gemini_chat_service import GeminiChatService
from app.service.key.key_manager import KeyManager, get_key_manager_instance
from app.service.tts.native.tts_routes import get_tts_chat_service
from app.service.model.model_service import ModelService
from app.handler.retry_handler import RetryHandler
from app.handler.error_handler import handle_route_errors
@@ -109,11 +110,41 @@ async def generate_content(
async with handle_route_errors(logger, operation_name, failure_message="Content generation failed"):
logger.info(f"Handling Gemini content generation request for model: {model_name}")
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
# 检测是否为原生Gemini TTS请求
is_native_tts = False
if "tts" in model_name.lower() and request.generationConfig:
# 直接从解析后的request对象获取TTS配置
response_modalities = request.generationConfig.responseModalities or []
speech_config = request.generationConfig.speechConfig or {}
# 如果包含AUDIO模态和语音配置则认为是原生TTS请求
if "AUDIO" in response_modalities and speech_config:
is_native_tts = True
logger.info("Detected native Gemini TTS request")
logger.info(f"TTS responseModalities: {response_modalities}")
logger.info(f"TTS speechConfig: {speech_config}")
logger.info(f"Using API key: {api_key}")
if not await model_service.check_model_support(model_name):
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
# 所有原生TTS请求都使用TTS增强服务
if is_native_tts:
try:
logger.info("Using native TTS enhanced service")
tts_service = await get_tts_chat_service(key_manager)
response = await tts_service.generate_content(
model=model_name,
request=request,
api_key=api_key
)
return response
except Exception as e:
logger.warning(f"Native TTS processing failed, falling back to standard service: {e}")
# 使用标准服务处理所有其他请求非TTS
response = await chat_service.generate_content(
model=model_name,
request=request,
@@ -151,6 +182,35 @@ async def stream_generate_content(
return StreamingResponse(response_stream, media_type="text/event-stream")
@router.post("/models/{model_name}:countTokens")
@router_v1beta.post("/models/{model_name}:countTokens")
@RetryHandler(key_arg="api_key")
async def count_tokens(
model_name: str,
request: GeminiRequest,
_=Depends(security_service.verify_key_or_goog_api_key),
api_key: str = Depends(get_next_working_key),
key_manager: KeyManager = Depends(get_key_manager),
chat_service: GeminiChatService = Depends(get_chat_service)
):
"""处理 Gemini token 计数请求。"""
operation_name = "gemini_count_tokens"
async with handle_route_errors(logger, operation_name, failure_message="Token counting failed"):
logger.info(f"Handling Gemini token count request for model: {model_name}")
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
logger.info(f"Using API key: {api_key}")
if not await model_service.check_model_support(model_name):
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
response = await chat_service.count_tokens(
model=model_name,
request=request,
api_key=api_key
)
return response
@router.post("/reset-all-fail-counts")
async def reset_all_key_fail_counts(key_type: str = None, key_manager: KeyManager = Depends(get_key_manager)):
"""批量重置Gemini API密钥的失败计数可选择性地仅重置有效或无效密钥"""
@@ -269,7 +329,7 @@ async def verify_key(api_key: str, chat_service: GeminiChatService = Depends(get
parts=[{"text": "hi"}],
)
],
generation_config={"temperature": 0.7, "top_p": 1.0, "max_output_tokens": 10}
generation_config={"temperature": 0.7, "topP": 1.0, "maxOutputTokens": 10}
)
response = await chat_service.generate_content(
@@ -279,7 +339,9 @@ async def verify_key(api_key: str, chat_service: GeminiChatService = Depends(get
)
if response:
return JSONResponse({"status": "valid"})
# 如果密钥验证成功,则重置其失败计数
await key_manager.reset_key_failure_count(api_key)
return JSONResponse({"status": "valid"})
except Exception as e:
logger.error(f"Key verification failed: {str(e)}")
@@ -314,7 +376,7 @@ async def verify_selected_keys(
try:
gemini_request = GeminiRequest(
contents=[GeminiContent(role="user", parts=[{"text": "hi"}])],
generation_config={"temperature": 0.7, "top_p": 1.0, "max_output_tokens": 10}
generation_config={"temperature": 0.7, "topP": 1.0, "maxOutputTokens": 10}
)
await chat_service.generate_content(
settings.TEST_MODEL,
@@ -322,6 +384,8 @@ async def verify_selected_keys(
api_key
)
successful_keys.append(api_key)
# 如果密钥验证成功,则重置其失败计数
await key_manager.reset_key_failure_count(api_key)
return api_key, "valid", None
except Exception as e:
error_message = str(e)

View File

@@ -1,4 +1,4 @@
from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, Depends, HTTPException, Response
from fastapi.responses import StreamingResponse
from app.config.config import settings
@@ -7,6 +7,7 @@ from app.domain.openai_models import (
ChatRequest,
EmbeddingRequest,
ImageGenerationRequest,
TTSRequest,
)
from app.handler.retry_handler import RetryHandler
from app.handler.error_handler import handle_route_errors
@@ -14,6 +15,7 @@ from app.log.logger import get_openai_logger
from app.service.chat.openai_chat_service import OpenAIChatService
from app.service.embedding.embedding_service import EmbeddingService
from app.service.image.image_create_service import ImageCreateService
from app.service.tts.tts_service import TTSService
from app.service.key.key_manager import KeyManager, get_key_manager_instance
from app.service.model.model_service import ModelService
@@ -24,6 +26,7 @@ security_service = SecurityService()
model_service = ModelService()
embedding_service = EmbeddingService()
image_create_service = ImageCreateService()
tts_service = TTSService()
async def get_key_manager():
@@ -41,6 +44,11 @@ async def get_openai_chat_service(key_manager: KeyManager = Depends(get_key_mana
return OpenAIChatService(settings.BASE_URL, key_manager)
async def get_tts_service():
"""获取TTS服务实例"""
return tts_service
@router.get("/v1/models")
@router.get("/hf/v1/models")
async def list_models(
@@ -147,3 +155,21 @@ async def get_keys_list(
},
"total": len(keys_status["valid_keys"]) + len(keys_status["invalid_keys"]),
}
@router.post("/v1/audio/speech")
@router.post("/hf/v1/audio/speech")
async def text_to_speech(
request: TTSRequest,
_=Depends(security_service.verify_authorization),
api_key: str = Depends(get_next_working_key_wrapper),
tts_service: TTSService = Depends(get_tts_service),
):
"""处理 OpenAI TTS 请求。"""
operation_name = "text_to_speech"
async with handle_route_errors(logger, operation_name):
logger.info(f"Handling TTS request for model: {request.model}")
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
logger.info(f"Using API key: {api_key}")
audio_data = await tts_service.create_tts(request, api_key)
return Response(content=audio_data, media_type="audio/wav")

View File

@@ -8,7 +8,7 @@ from fastapi.templating import Jinja2Templates
from app.core.security import verify_auth_token
from app.log.logger import get_routes_logger
from app.router import error_log_routes, gemini_routes, openai_routes, config_routes, scheduler_routes, stats_routes, version_routes, openai_compatiable_routes, vertex_express_routes
from app.router import error_log_routes, gemini_routes, openai_routes, config_routes, scheduler_routes, stats_routes, version_routes, openai_compatiable_routes, vertex_express_routes, files_routes
from app.service.key.key_manager import get_key_manager_instance
from app.service.stats.stats_service import StatsService
@@ -34,6 +34,7 @@ def setup_routers(app: FastAPI) -> None:
app.include_router(version_routes.router)
app.include_router(openai_compatiable_routes.router)
app.include_router(vertex_express_routes.router)
app.include_router(files_routes.router)
setup_page_routes(app)

View File

@@ -8,6 +8,7 @@ from app.service.chat.gemini_chat_service import GeminiChatService
from app.service.error_log.error_log_service import delete_old_error_logs
from app.service.key.key_manager import get_key_manager_instance
from app.service.request_log.request_log_service import delete_old_request_logs_task
from app.service.files.files_service import get_files_service
logger = Logger.setup_logger("scheduler")
@@ -96,6 +97,26 @@ async def check_failed_keys():
)
async def cleanup_expired_files():
"""
定时清理过期的文件记录
"""
logger.info("Starting scheduled cleanup for expired files...")
try:
files_service = await get_files_service()
deleted_count = await files_service.cleanup_expired_files()
if deleted_count > 0:
logger.info(f"Successfully cleaned up {deleted_count} expired files.")
else:
logger.info("No expired files to clean up.")
except Exception as e:
logger.error(
f"An error occurred during the scheduled file cleanup: {str(e)}", exc_info=True
)
def setup_scheduler():
"""设置并启动 APScheduler"""
scheduler = AsyncIOScheduler(timezone=str(settings.TIMEZONE)) # 从配置读取时区
@@ -134,6 +155,20 @@ def setup_scheduler():
logger.info(
f"Auto-delete request logs job scheduled to run daily at 3:05 AM, if enabled and AUTO_DELETE_REQUEST_LOGS_DAYS is set to {settings.AUTO_DELETE_REQUEST_LOGS_DAYS} days."
)
# 新增:添加文件过期清理的定时任务,每小时执行一次
if getattr(settings, 'FILES_CLEANUP_ENABLED', True):
cleanup_interval = getattr(settings, 'FILES_CLEANUP_INTERVAL_HOURS', 1)
scheduler.add_job(
cleanup_expired_files,
"interval",
hours=cleanup_interval,
id="cleanup_expired_files_job",
name="Cleanup Expired Files",
)
logger.info(
f"File cleanup job scheduled to run every {cleanup_interval} hour(s)."
)
scheduler.start()
logger.info("Scheduler started with all jobs.")

View File

@@ -13,7 +13,7 @@ from app.handler.stream_optimizer import gemini_optimizer
from app.log.logger import get_gemini_logger
from app.service.client.api_client import GeminiApiClient
from app.service.key.key_manager import KeyManager
from app.database.services import add_error_log, add_request_log
from app.database.services import add_error_log, add_request_log, get_file_api_key
logger = get_gemini_logger()
@@ -27,10 +27,74 @@ def _has_image_parts(contents: List[Dict[str, Any]]) -> bool:
return True
return False
def _extract_file_references(contents: List[Dict[str, Any]]) -> List[str]:
"""從內容中提取文件引用"""
file_names = []
for content in contents:
if "parts" in content:
for part in content["parts"]:
if not isinstance(part, dict) or "fileData" not in part:
continue
file_data = part["fileData"]
if "fileUri" not in file_data:
continue
file_uri = file_data["fileUri"]
# 從 URI 中提取文件名
# 1. https://generativelanguage.googleapis.com/v1beta/files/{file_id}
match = re.match(rf"{re.escape(settings.BASE_URL)}/(files/.*)", file_uri)
if not match:
logger.warning(f"Invalid file URI: {file_uri}")
continue
file_id = match.group(1)
file_names.append(file_id)
logger.info(f"Found file reference: {file_id}")
return file_names
def _clean_json_schema_properties(obj: Any) -> Any:
"""清理JSON Schema中Gemini API不支持的字段"""
if not isinstance(obj, dict):
return obj
# Gemini API不支持的JSON Schema字段
unsupported_fields = {
"exclusiveMaximum", "exclusiveMinimum", "const", "examples",
"contentEncoding", "contentMediaType", "if", "then", "else",
"allOf", "anyOf", "oneOf", "not", "definitions", "$schema",
"$id", "$ref", "$comment", "readOnly", "writeOnly"
}
cleaned = {}
for key, value in obj.items():
if key in unsupported_fields:
continue
if isinstance(value, dict):
cleaned[key] = _clean_json_schema_properties(value)
elif isinstance(value, list):
cleaned[key] = [_clean_json_schema_properties(item) for item in value]
else:
cleaned[key] = value
return cleaned
def _build_tools(model: str, payload: Dict[str, Any]) -> List[Dict[str, Any]]:
"""构建工具"""
def _has_function_call(contents: List[Dict[str, Any]]) -> bool:
"""检查内容中是否包含 functionCall"""
if not contents or not isinstance(contents, list):
return False
for content in contents:
if not content or not isinstance(content, dict) or "parts" not in content:
continue
parts = content.get("parts", [])
if not parts or not isinstance(parts, list):
continue
for part in parts:
if isinstance(part, dict) and "functionCall" in part:
return True
return False
def _merge_tools(tools: List[Dict[str, Any]]) -> Dict[str, Any]:
record = dict()
for item in tools:
@@ -40,7 +104,15 @@ def _build_tools(model: str, payload: Dict[str, Any]) -> List[Dict[str, Any]]:
for k, v in item.items():
if k == "functionDeclarations" and v and isinstance(v, list):
functions = record.get("functionDeclarations", [])
functions.extend(v)
# 清理每个函数声明中的不支持字段
cleaned_functions = []
for func in v:
if isinstance(func, dict):
cleaned_func = _clean_json_schema_properties(func)
cleaned_functions.append(cleaned_func)
else:
cleaned_functions.append(func)
functions.extend(cleaned_functions)
record["functionDeclarations"] = functions
else:
record[k] = v
@@ -62,15 +134,32 @@ def _build_tools(model: str, payload: Dict[str, Any]) -> List[Dict[str, Any]]:
tool["codeExecution"] = {}
if model.endswith("-search"):
tool["googleSearch"] = {}
real_model = _get_real_model(model)
if real_model in settings.URL_CONTEXT_MODELS and settings.URL_CONTEXT_ENABLED:
tool["urlContext"] = {}
# 解决 "Tool use with function calling is unsupported" 问题
if tool.get("functionDeclarations"):
if tool.get("functionDeclarations") or _has_function_call(payload.get("contents", [])):
tool.pop("googleSearch", None)
tool.pop("codeExecution", None)
tool.pop("urlContext", None)
return [tool] if tool else []
def _get_real_model(model: str) -> str:
if model.endswith("-search"):
model = model[:-7]
if model.endswith("-image"):
model = model[:-6]
if model.endswith("-non-thinking"):
model = model[:-13]
if "-search" in model and "-non-thinking" in model:
model = model[:-20]
return model
def _get_safety_settings(model: str) -> List[Dict[str, str]]:
"""获取安全设置"""
if model == "gemini-2.0-flash-exp":
@@ -78,21 +167,61 @@ def _get_safety_settings(model: str) -> List[Dict[str, str]]:
return settings.SAFETY_SETTINGS
def _filter_empty_parts(contents: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Filters out contents with empty or invalid parts."""
if not contents:
return []
filtered_contents = []
for content in contents:
if not content or "parts" not in content or not isinstance(content.get("parts"), list):
continue
valid_parts = [part for part in content["parts"] if isinstance(part, dict) and part]
if valid_parts:
new_content = content.copy()
new_content["parts"] = valid_parts
filtered_contents.append(new_content)
return filtered_contents
def _build_payload(model: str, request: GeminiRequest) -> Dict[str, Any]:
"""构建请求payload"""
request_dict = request.model_dump()
request_dict = request.model_dump(exclude_none=False)
if request.generationConfig:
if request.generationConfig.maxOutputTokens is None:
# 如果未指定最大输出长度,则不传递该字段,解决截断的问题
request_dict["generationConfig"].pop("maxOutputTokens")
payload = {
"contents": request_dict.get("contents", []),
"tools": _build_tools(model, request_dict),
"safetySettings": _get_safety_settings(model),
"generationConfig": request_dict.get("generationConfig"),
"systemInstruction": request_dict.get("systemInstruction"),
}
if "maxOutputTokens" in request_dict["generationConfig"]:
request_dict["generationConfig"].pop("maxOutputTokens")
# 检查是否为TTS模型
is_tts_model = "tts" in model.lower()
if is_tts_model:
# TTS模型使用简化的payload不包含tools和safetySettings
payload = {
"contents": _filter_empty_parts(request_dict.get("contents", [])),
"generationConfig": request_dict.get("generationConfig"),
}
# 只在有systemInstruction时才添加
if request_dict.get("systemInstruction"):
payload["systemInstruction"] = request_dict.get("systemInstruction")
else:
# 非TTS模型使用完整的payload
payload = {
"contents": _filter_empty_parts(request_dict.get("contents", [])),
"tools": _build_tools(model, request_dict),
"safetySettings": _get_safety_settings(model),
"generationConfig": request_dict.get("generationConfig"),
"systemInstruction": request_dict.get("systemInstruction"),
}
# 确保 generationConfig 不为 None
if payload["generationConfig"] is None:
payload["generationConfig"] = {}
if model.endswith("-image") or model.endswith("-image-generation"):
payload.pop("systemInstruction")
@@ -109,9 +238,18 @@ def _build_payload(model: str, request: GeminiRequest) -> Dict[str, Any]:
else:
# 客户端没有提供思考配置,使用默认配置
if model.endswith("-non-thinking"):
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": 0}
elif model in settings.THINKING_BUDGET_MAP:
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": settings.THINKING_BUDGET_MAP.get(model,1000)}
if "gemini-2.5-pro" in model:
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": 128}
else:
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": 0}
elif _get_real_model(model) in settings.THINKING_BUDGET_MAP:
if settings.SHOW_THINKING_PROCESS:
payload["generationConfig"]["thinkingConfig"] = {
"thinkingBudget": settings.THINKING_BUDGET_MAP.get(model,1000),
"includeThoughts": True
}
else:
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": settings.THINKING_BUDGET_MAP.get(model,1000)}
return payload
@@ -152,6 +290,17 @@ class GeminiChatService:
self, model: str, request: GeminiRequest, api_key: str
) -> Dict[str, Any]:
"""生成内容"""
# 檢查並獲取文件專用的 API key如果有文件
file_names = _extract_file_references(request.model_dump().get("contents", []))
if file_names:
logger.info(f"Request contains file references: {file_names}")
file_api_key = await get_file_api_key(file_names[0])
if file_api_key:
logger.info(f"Found API key for file {file_names[0]}: {file_api_key[:8]}...{file_api_key[-4:]}")
api_key = file_api_key # 使用文件的 API key
else:
logger.warning(f"No API key found for file {file_names[0]}, using default key: {api_key[:8]}...{api_key[-4:]}")
payload = _build_payload(model, request)
start_time = time.perf_counter()
request_datetime = datetime.datetime.now()
@@ -195,10 +344,69 @@ class GeminiChatService:
request_time=request_datetime
)
async def count_tokens(
self, model: str, request: GeminiRequest, api_key: str
) -> Dict[str, Any]:
"""计算token数量"""
# countTokens API只需要contents
payload = {"contents": _filter_empty_parts(request.model_dump().get("contents", []))}
start_time = time.perf_counter()
request_datetime = datetime.datetime.now()
is_success = False
status_code = None
response = None
try:
response = await self.api_client.count_tokens(payload, model, api_key)
is_success = True
status_code = 200
return response
except Exception as e:
is_success = False
error_log_msg = str(e)
logger.error(f"Count tokens API call failed with error: {error_log_msg}")
match = re.search(r"status code (\d+)", error_log_msg)
if match:
status_code = int(match.group(1))
else:
status_code = 500
await add_error_log(
gemini_key=api_key,
model_name=model,
error_type="gemini-count-tokens",
error_log=error_log_msg,
error_code=status_code,
request_msg=payload
)
raise e
finally:
end_time = time.perf_counter()
latency_ms = int((end_time - start_time) * 1000)
await add_request_log(
model_name=model,
api_key=api_key,
is_success=is_success,
status_code=status_code,
latency_ms=latency_ms,
request_time=request_datetime
)
async def stream_generate_content(
self, model: str, request: GeminiRequest, api_key: str
) -> AsyncGenerator[str, None]:
"""流式生成内容"""
# 檢查並獲取文件專用的 API key如果有文件
file_names = _extract_file_references(request.model_dump().get("contents", []))
if file_names:
logger.info(f"Request contains file references: {file_names}")
file_api_key = await get_file_api_key(file_names[0])
if file_api_key:
logger.info(f"Found API key for file {file_names[0]}: {file_api_key[:8]}...{file_api_key[-4:]}")
api_key = file_api_key # 使用文件的 API key
else:
logger.warning(f"No API key found for file {file_names[0]}, using default key: {api_key[:8]}...{api_key[-4:]}")
retries = 0
max_retries = settings.MAX_RETRIES
payload = _build_payload(model, request)

View File

@@ -26,16 +26,43 @@ from app.service.key.key_manager import KeyManager
logger = get_openai_logger()
def _has_media_parts(contents: List[Dict[str, Any]]) -> bool:
"""判断消息是否包含图片、音频或视频部分 (inline_data)"""
for content in contents:
if content and "parts" in content and isinstance(content["parts"], list):
for part in content["parts"]:
if isinstance(part, dict) and "inline_data" in part:
def _has_media_parts(messages: List[Dict[str, Any]]) -> bool:
"""判断消息是否包含多媒体部分"""
for message in messages:
if "parts" in message:
for part in message["parts"]:
if "image_url" in part or "inline_data" in part:
return True
return False
def _clean_json_schema_properties(obj: Any) -> Any:
"""清理JSON Schema中Gemini API不支持的字段"""
if not isinstance(obj, dict):
return obj
# Gemini API不支持的JSON Schema字段
unsupported_fields = {
"exclusiveMaximum", "exclusiveMinimum", "const", "examples",
"contentEncoding", "contentMediaType", "if", "then", "else",
"allOf", "anyOf", "oneOf", "not", "definitions", "$schema",
"$id", "$ref", "$comment", "readOnly", "writeOnly"
}
cleaned = {}
for key, value in obj.items():
if key in unsupported_fields:
continue
if isinstance(value, dict):
cleaned[key] = _clean_json_schema_properties(value)
elif isinstance(value, list):
cleaned[key] = [_clean_json_schema_properties(item) for item in value]
else:
cleaned[key] = value
return cleaned
def _build_tools(
request: ChatRequest, messages: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
@@ -60,6 +87,10 @@ def _build_tools(
if model.endswith("-search"):
tool["googleSearch"] = {}
real_model = _get_real_model(model)
if real_model in settings.URL_CONTEXT_MODELS and settings.URL_CONTEXT_ENABLED:
tool["urlContext"] = {}
# 将 request 中的 tools 合并到 tools 中
if request.tools:
@@ -76,6 +107,8 @@ def _build_tools(
):
function.pop("parameters", None)
# 清理函数中的不支持字段
function = _clean_json_schema_properties(function)
function_declarations.append(function)
if function_declarations:
@@ -97,10 +130,23 @@ def _build_tools(
if tool.get("functionDeclarations"):
tool.pop("googleSearch", None)
tool.pop("codeExecution", None)
tool.pop("urlContext",None)
return [tool] if tool else []
def _get_real_model(model: str) -> str:
if model.endswith("-search"):
model = model[:-7]
if model.endswith("-image"):
model = model[:-6]
if model.endswith("-non-thinking"):
model = model[:-13]
if "-search" in model and "-non-thinking" in model:
model = model[:-20]
return model
def _get_safety_settings(model: str) -> List[Dict[str, str]]:
"""获取安全设置"""
# if (
@@ -113,6 +159,23 @@ def _get_safety_settings(model: str) -> List[Dict[str, str]]:
return settings.SAFETY_SETTINGS
def _validate_and_set_max_tokens(
payload: Dict[str, Any],
max_tokens: Optional[int],
logger_instance
) -> None:
"""验证并设置 max_tokens 参数"""
if max_tokens is None:
return
# 参数验证和处理
if max_tokens <= 0:
logger_instance.warning(f"Invalid max_tokens value: {max_tokens}, will not set maxOutputTokens")
# 不设置 maxOutputTokens让 Gemini API 使用默认值
else:
payload["generationConfig"]["maxOutputTokens"] = max_tokens
def _build_payload(
request: ChatRequest,
messages: List[Dict[str, Any]],
@@ -130,16 +193,27 @@ def _build_payload(
"tools": _build_tools(request, messages),
"safetySettings": _get_safety_settings(request.model),
}
if request.max_tokens is not None:
payload["generationConfig"]["maxOutputTokens"] = request.max_tokens
# 处理 max_tokens 参数
_validate_and_set_max_tokens(payload, request.max_tokens, logger)
if request.model.endswith("-image") or request.model.endswith("-image-generation"):
payload["generationConfig"]["responseModalities"] = ["Text", "Image"]
if request.model.endswith("-non-thinking"):
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": 0}
if request.model in settings.THINKING_BUDGET_MAP:
payload["generationConfig"]["thinkingConfig"] = {
"thinkingBudget": settings.THINKING_BUDGET_MAP.get(request.model, 1000)
}
if "gemini-2.5-pro" in request.model:
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": 128}
else:
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": 0}
if _get_real_model(request.model) in settings.THINKING_BUDGET_MAP:
if settings.SHOW_THINKING_PROCESS:
payload["generationConfig"]["thinkingConfig"] = {
"thinkingBudget": settings.THINKING_BUDGET_MAP.get(request.model, 1000),
"includeThoughts": True
}
else:
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": settings.THINKING_BUDGET_MAP.get(request.model, 1000)}
if (
instruction
@@ -206,27 +280,53 @@ class OpenAIChatService:
is_success = False
status_code = None
response = None
try:
response = await self.api_client.generate_content(payload, model, api_key)
usage_metadata = response.get("usageMetadata", {})
is_success = True
status_code = 200
return self.response_handler.handle_response(
response,
model,
stream=False,
finish_reason="stop",
usage_metadata=usage_metadata,
)
# 尝试处理响应,捕获可能的响应处理异常
try:
result = self.response_handler.handle_response(
response,
model,
stream=False,
finish_reason="stop",
usage_metadata=usage_metadata,
)
return result
except Exception as response_error:
logger.error(f"Response processing failed for model {model}: {str(response_error)}")
# 记录详细的错误信息
if "parts" in str(response_error):
logger.error("Response structure issue - missing or invalid parts")
if response.get("candidates"):
candidate = response["candidates"][0]
content = candidate.get("content", {})
logger.error(f"Content structure: {content}")
# 重新抛出异常
raise response_error
except Exception as e:
is_success = False
error_log_msg = str(e)
logger.error(f"Normal API call failed with error: {error_log_msg}")
logger.error(f"API call failed for model {model}: {error_log_msg}")
# 特别记录 max_tokens 相关的错误
gen_config = payload.get('generationConfig', {})
if "maxOutputTokens" in gen_config:
logger.error(f"Request had maxOutputTokens: {gen_config['maxOutputTokens']}")
# 如果是响应处理错误,记录更多信息
if "parts" in error_log_msg:
logger.error("This is likely a response processing error")
match = re.search(r"status code (\d+)", error_log_msg)
if match:
status_code = int(match.group(1))
else:
status_code = 500
status_code = int(match.group(1)) if match else 500
await add_error_log(
gemini_key=api_key,
@@ -240,6 +340,8 @@ class OpenAIChatService:
finally:
end_time = time.perf_counter()
latency_ms = int((end_time - start_time) * 1000)
logger.info(f"Normal completion finished - Success: {is_success}, Latency: {latency_ms}ms")
await add_request_log(
model_name=model,
api_key=api_key,

View File

@@ -28,9 +28,51 @@ def _has_image_parts(contents: List[Dict[str, Any]]) -> bool:
return False
def _clean_json_schema_properties(obj: Any) -> Any:
"""清理JSON Schema中Gemini API不支持的字段"""
if not isinstance(obj, dict):
return obj
# Gemini API不支持的JSON Schema字段
unsupported_fields = {
"exclusiveMaximum", "exclusiveMinimum", "const", "examples",
"contentEncoding", "contentMediaType", "if", "then", "else",
"allOf", "anyOf", "oneOf", "not", "definitions", "$schema",
"$id", "$ref", "$comment", "readOnly", "writeOnly"
}
cleaned = {}
for key, value in obj.items():
if key in unsupported_fields:
continue
if isinstance(value, dict):
cleaned[key] = _clean_json_schema_properties(value)
elif isinstance(value, list):
cleaned[key] = [_clean_json_schema_properties(item) for item in value]
else:
cleaned[key] = value
return cleaned
def _build_tools(model: str, payload: Dict[str, Any]) -> List[Dict[str, Any]]:
"""构建工具"""
def _has_function_call(contents: List[Dict[str, Any]]) -> bool:
"""检查内容中是否包含 functionCall"""
if not contents or not isinstance(contents, list):
return False
for content in contents:
if not content or not isinstance(content, dict) or "parts" not in content:
continue
parts = content.get("parts", [])
if not parts or not isinstance(parts, list):
continue
for part in parts:
if isinstance(part, dict) and "functionCall" in part:
return True
return False
def _merge_tools(tools: List[Dict[str, Any]]) -> Dict[str, Any]:
record = dict()
for item in tools:
@@ -40,7 +82,15 @@ def _build_tools(model: str, payload: Dict[str, Any]) -> List[Dict[str, Any]]:
for k, v in item.items():
if k == "functionDeclarations" and v and isinstance(v, list):
functions = record.get("functionDeclarations", [])
functions.extend(v)
# 清理每个函数声明中的不支持字段
cleaned_functions = []
for func in v:
if isinstance(func, dict):
cleaned_func = _clean_json_schema_properties(func)
cleaned_functions.append(cleaned_func)
else:
cleaned_functions.append(func)
functions.extend(cleaned_functions)
record["functionDeclarations"] = functions
else:
record[k] = v
@@ -62,15 +112,32 @@ def _build_tools(model: str, payload: Dict[str, Any]) -> List[Dict[str, Any]]:
tool["codeExecution"] = {}
if model.endswith("-search"):
tool["googleSearch"] = {}
real_model = _get_real_model(model)
if real_model in settings.URL_CONTEXT_MODELS and settings.URL_CONTEXT_ENABLED:
tool["urlContext"] = {}
# 解决 "Tool use with function calling is unsupported" 问题
if tool.get("functionDeclarations"):
if tool.get("functionDeclarations") or _has_function_call(payload.get("contents", [])):
tool.pop("googleSearch", None)
tool.pop("codeExecution", None)
tool.pop("urlContext", None)
return [tool] if tool else []
def _get_real_model(model: str) -> str:
if model.endswith("-search"):
model = model[:-7]
if model.endswith("-image"):
model = model[:-6]
if model.endswith("-non-thinking"):
model = model[:-13]
if "-search" in model and "-non-thinking" in model:
model = model[:-20]
return model
def _get_safety_settings(model: str) -> List[Dict[str, str]]:
"""获取安全设置"""
if model == "gemini-2.0-flash-exp":
@@ -80,7 +147,7 @@ def _get_safety_settings(model: str) -> List[Dict[str, str]]:
def _build_payload(model: str, request: GeminiRequest) -> Dict[str, Any]:
"""构建请求payload"""
request_dict = request.model_dump()
request_dict = request.model_dump(exclude_none=False)
if request.generationConfig:
if request.generationConfig.maxOutputTokens is None:
# 如果未指定最大输出长度,则不传递该字段,解决截断的问题
@@ -98,10 +165,29 @@ def _build_payload(model: str, request: GeminiRequest) -> Dict[str, Any]:
payload.pop("systemInstruction")
payload["generationConfig"]["responseModalities"] = ["Text", "Image"]
if model.endswith("-non-thinking"):
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": 0}
if model in settings.THINKING_BUDGET_MAP:
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": settings.THINKING_BUDGET_MAP.get(model,1000)}
# 处理思考配置:优先使用客户端提供的配置,否则使用默认配置
client_thinking_config = None
if request.generationConfig and request.generationConfig.thinkingConfig:
client_thinking_config = request.generationConfig.thinkingConfig
if client_thinking_config is not None:
# 客户端提供了思考配置,直接使用
payload["generationConfig"]["thinkingConfig"] = client_thinking_config
else:
# 客户端没有提供思考配置,使用默认配置
if model.endswith("-non-thinking"):
if "gemini-2.5-pro" in model:
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": 128}
else:
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": 0}
elif _get_real_model(model) in settings.THINKING_BUDGET_MAP:
if settings.SHOW_THINKING_PROCESS:
payload["generationConfig"]["thinkingConfig"] = {
"thinkingBudget": settings.THINKING_BUDGET_MAP.get(model,1000),
"includeThoughts": True
}
else:
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": settings.THINKING_BUDGET_MAP.get(model,1000)}
return payload

View File

@@ -40,6 +40,13 @@ class GeminiApiClient(ApiClient):
model = model[:-20]
return model
def _prepare_headers(self) -> Dict[str, str]:
headers = {}
if settings.CUSTOM_HEADERS:
headers.update(settings.CUSTOM_HEADERS)
logger.info(f"Using custom headers: {settings.CUSTOM_HEADERS}")
return headers
async def get_models(self, api_key: str) -> Optional[Dict[str, Any]]:
"""获取可用的 Gemini 模型列表"""
timeout = httpx.Timeout(timeout=5)
@@ -52,10 +59,11 @@ class GeminiApiClient(ApiClient):
proxy_to_use = random.choice(settings.PROXIES)
logger.info(f"Using proxy for getting models: {proxy_to_use}")
headers = self._prepare_headers()
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
url = f"{self.base_url}/models?key={api_key}&pageSize=1000"
try:
response = await client.get(url)
response = await client.get(url, headers=headers)
response.raise_for_status()
return response.json()
except httpx.HTTPStatusError as e:
@@ -69,7 +77,7 @@ class GeminiApiClient(ApiClient):
async def generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]:
timeout = httpx.Timeout(self.timeout, read=self.timeout)
model = self._get_real_model(model)
proxy_to_use = None
if settings.PROXIES:
if settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY:
@@ -78,13 +86,36 @@ class GeminiApiClient(ApiClient):
proxy_to_use = random.choice(settings.PROXIES)
logger.info(f"Using proxy for getting models: {proxy_to_use}")
headers = self._prepare_headers()
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
url = f"{self.base_url}/models/{model}:generateContent?key={api_key}"
response = await client.post(url, json=payload)
if response.status_code != 200:
error_content = response.text
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
return response.json()
try:
response = await client.post(url, json=payload, headers=headers)
if response.status_code != 200:
error_content = response.text
logger.error(f"API call failed - Status: {response.status_code}, Content: {error_content}")
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
response_data = response.json()
# 检查响应结构的基本信息
if not response_data.get("candidates"):
logger.warning("No candidates found in API response")
return response_data
except httpx.TimeoutException as e:
logger.error(f"Request timeout: {e}")
raise Exception(f"Request timeout: {e}")
except httpx.RequestError as e:
logger.error(f"Request error: {e}")
raise Exception(f"Request error: {e}")
except Exception as e:
logger.error(f"Unexpected error: {e}")
raise
async def stream_generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> AsyncGenerator[str, None]:
timeout = httpx.Timeout(self.timeout, read=self.timeout)
@@ -98,9 +129,10 @@ class GeminiApiClient(ApiClient):
proxy_to_use = random.choice(settings.PROXIES)
logger.info(f"Using proxy for getting models: {proxy_to_use}")
headers = self._prepare_headers()
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
url = f"{self.base_url}/models/{model}:streamGenerateContent?alt=sse&key={api_key}"
async with client.stream(method="POST", url=url, json=payload) as response:
async with client.stream(method="POST", url=url, json=payload, headers=headers) as response:
if response.status_code != 200:
error_content = await response.aread()
error_msg = error_content.decode("utf-8")
@@ -108,6 +140,27 @@ class GeminiApiClient(ApiClient):
async for line in response.aiter_lines():
yield line
async def count_tokens(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]:
timeout = httpx.Timeout(self.timeout, read=self.timeout)
model = self._get_real_model(model)
proxy_to_use = None
if settings.PROXIES:
if settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY:
proxy_to_use = settings.PROXIES[hash(api_key) % len(settings.PROXIES)]
else:
proxy_to_use = random.choice(settings.PROXIES)
logger.info(f"Using proxy for counting tokens: {proxy_to_use}")
headers = self._prepare_headers()
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
url = f"{self.base_url}/models/{model}:countTokens?key={api_key}"
response = await client.post(url, json=payload, headers=headers)
if response.status_code != 200:
error_content = response.text
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
return response.json()
class OpenaiApiClient(ApiClient):
"""OpenAI API客户端"""
@@ -116,6 +169,13 @@ class OpenaiApiClient(ApiClient):
self.base_url = base_url
self.timeout = timeout
def _prepare_headers(self, api_key: str) -> Dict[str, str]:
headers = {"Authorization": f"Bearer {api_key}"}
if settings.CUSTOM_HEADERS:
headers.update(settings.CUSTOM_HEADERS)
logger.info(f"Using custom headers: {settings.CUSTOM_HEADERS}")
return headers
async def get_models(self, api_key: str) -> Dict[str, Any]:
timeout = httpx.Timeout(self.timeout, read=self.timeout)
@@ -127,9 +187,9 @@ class OpenaiApiClient(ApiClient):
proxy_to_use = random.choice(settings.PROXIES)
logger.info(f"Using proxy for getting models: {proxy_to_use}")
headers = self._prepare_headers(api_key)
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
url = f"{self.base_url}/openai/models"
headers = {"Authorization": f"Bearer {api_key}"}
response = await client.get(url, headers=headers)
if response.status_code != 200:
error_content = response.text
@@ -147,9 +207,9 @@ class OpenaiApiClient(ApiClient):
proxy_to_use = random.choice(settings.PROXIES)
logger.info(f"Using proxy for getting models: {proxy_to_use}")
headers = self._prepare_headers(api_key)
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
url = f"{self.base_url}/openai/chat/completions"
headers = {"Authorization": f"Bearer {api_key}"}
response = await client.post(url, json=payload, headers=headers)
if response.status_code != 200:
error_content = response.text
@@ -166,9 +226,9 @@ class OpenaiApiClient(ApiClient):
proxy_to_use = random.choice(settings.PROXIES)
logger.info(f"Using proxy for getting models: {proxy_to_use}")
headers = self._prepare_headers(api_key)
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
url = f"{self.base_url}/openai/chat/completions"
headers = {"Authorization": f"Bearer {api_key}"}
async with client.stream(method="POST", url=url, json=payload, headers=headers) as response:
if response.status_code != 200:
error_content = await response.aread()
@@ -188,9 +248,9 @@ class OpenaiApiClient(ApiClient):
proxy_to_use = random.choice(settings.PROXIES)
logger.info(f"Using proxy for getting models: {proxy_to_use}")
headers = self._prepare_headers(api_key)
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
url = f"{self.base_url}/openai/embeddings"
headers = {"Authorization": f"Bearer {api_key}"}
payload = {
"input": input,
"model": model,
@@ -212,9 +272,9 @@ class OpenaiApiClient(ApiClient):
proxy_to_use = random.choice(settings.PROXIES)
logger.info(f"Using proxy for getting models: {proxy_to_use}")
headers = self._prepare_headers(api_key)
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
url = f"{self.base_url}/openai/images/generations"
headers = {"Authorization": f"Bearer {api_key}"}
response = await client.post(url, json=payload, headers=headers)
if response.status_code != 200:
error_content = response.text

View File

@@ -0,0 +1 @@
# Intentionally empty __init__.py file

View File

@@ -0,0 +1,247 @@
"""
文件上传处理器
处理 Google 的可恢复上传协议
"""
from typing import Optional
from datetime import datetime, timezone, timedelta
from httpx import AsyncClient
from fastapi import Request, Response, HTTPException
from app.config.config import settings
from app.database import services as db_services
from app.database.models import FileState
from app.log.logger import get_files_logger
logger = get_files_logger()
class FileUploadHandler:
"""处理文件分块上传"""
def __init__(self):
self.chunk_size = 8 * 1024 * 1024 # 8MB
async def handle_upload_chunk(
self,
upload_url: str,
request: Request,
files_service=None # 添加 files_service 參數
) -> Response:
"""
处理上传分块
Args:
upload_url: 上传 URL
request: FastAPI 请求对象
files_service: 文件服務實例
Returns:
Response: 响应对象
"""
try:
# 获取请求头
headers = {}
# 复制必要的上传头
upload_headers = [
"x-goog-upload-command",
"x-goog-upload-offset",
"content-type",
"content-length"
]
for header in upload_headers:
if header in request.headers:
# 转换为正确的格式
key = "-".join(word.capitalize() for word in header.split("-"))
headers[key] = request.headers[header]
# 读取请求体
body = await request.body()
# 检查是否是最后一块
is_final = "finalize" in headers.get("X-Goog-Upload-Command", "")
logger.debug(f"Upload command: {headers.get('X-Goog-Upload-Command', '')}, is_final: {is_final}")
# 转发到真实的上传 URL
async with AsyncClient() as client:
response = await client.post(
upload_url,
headers=headers,
content=body,
timeout=300.0 # 5分钟超时
)
if response.status_code not in [200, 201, 308]:
logger.error(f"Upload chunk failed: {response.status_code} - {response.text}")
raise HTTPException(status_code=response.status_code, detail="Upload failed")
# 如果是最后一块,更新文件状态
if is_final and response.status_code in [200, 201]:
logger.debug(f"Upload finalized with status {response.status_code}")
try:
# 解析響應獲取文件信息
response_data = response.json()
logger.debug(f"Upload complete response data: {response_data}")
file_data = response_data.get("file", {})
# 獲取真實的文件名
real_file_name = file_data.get("name")
logger.debug(f"Upload response: {response_data}")
if real_file_name and files_service:
logger.info(f"Upload completed, file name: {real_file_name}")
# 從會話中獲取信息
session_info = await files_service.get_upload_session(upload_url)
logger.debug(f"Retrieved session info for {upload_url}: {session_info}")
if session_info:
# 創建文件記錄
now = datetime.now(timezone.utc)
expiration_time = now + timedelta(hours=48)
# 處理過期時間格式Google 可能返回納秒級精度)
expiration_time_str = file_data.get("expirationTime", expiration_time.isoformat() + "Z")
# 處理納秒格式2025-07-11T02:02:52.531916141Z -> 2025-07-11T02:02:52.531916Z
if expiration_time_str.endswith("Z"):
# 移除 Z
expiration_time_str = expiration_time_str[:-1]
# 如果有納秒超過6位小數截斷到微秒
if "." in expiration_time_str:
date_part, frac_part = expiration_time_str.rsplit(".", 1)
if len(frac_part) > 6:
frac_part = frac_part[:6]
expiration_time_str = f"{date_part}.{frac_part}"
# 添加時區
expiration_time_str += "+00:00"
# 獲取文件狀態Google 可能返回 PROCESSING
file_state = file_data.get("state", "PROCESSING")
logger.debug(f"File state from Google: {file_state}")
# 將字符串狀態轉換為枚舉
if file_state == "ACTIVE":
state_enum = FileState.ACTIVE
elif file_state == "PROCESSING":
state_enum = FileState.PROCESSING
elif file_state == "FAILED":
state_enum = FileState.FAILED
else:
logger.warning(f"Unknown file state: {file_state}, defaulting to PROCESSING")
state_enum = FileState.PROCESSING
await db_services.create_file_record(
name=real_file_name,
mime_type=file_data.get("mimeType", session_info["mime_type"]),
size_bytes=int(file_data.get("sizeBytes", session_info["size_bytes"])),
api_key=session_info["api_key"],
uri=file_data.get("uri", f"{settings.BASE_URL}/{real_file_name}"),
create_time=now,
update_time=now,
expiration_time=datetime.fromisoformat(expiration_time_str),
state=state_enum,
display_name=file_data.get("displayName", session_info.get("display_name", "")),
sha256_hash=file_data.get("sha256Hash"),
user_token=session_info["user_token"]
)
logger.info(f"Created file record: name={real_file_name}, api_key={session_info['api_key'][:8]}...{session_info['api_key'][-4:]}")
else:
logger.warning(f"No upload session found for URL: {upload_url}")
else:
logger.warning(f"Missing real_file_name or files_service: real_file_name={real_file_name}, files_service={files_service}")
# 返回完整的文件信息
return Response(
content=response.content,
status_code=response.status_code,
headers=dict(response.headers)
)
except Exception as e:
logger.error(f"Failed to create file record: {str(e)}", exc_info=True)
else:
logger.debug(f"Upload chunk processed: is_final={is_final}, status={response.status_code}")
# 返回响应
response_headers = dict(response.headers)
# 确保包含必要的头
if response.status_code == 308: # Resume Incomplete
if "x-goog-upload-status" not in response_headers:
response_headers["x-goog-upload-status"] = "active"
return Response(
content=response.content,
status_code=response.status_code,
headers=response_headers
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to handle upload chunk: {str(e)}")
raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}")
async def proxy_upload_request(
self,
request: Request,
upload_url: str,
files_service=None
) -> Response:
"""
代理上传请求
Args:
request: FastAPI 请求对象
upload_url: 目标上传 URL
files_service: 文件服務實例
Returns:
Response: 代理响应
"""
logger.debug(f"Proxy upload request: {request.method}, {upload_url}")
try:
# 如果是 GET 请求,返回上传状态
if request.method == "GET":
return await self._get_upload_status(upload_url)
# 处理 POST/PUT 请求
return await self.handle_upload_chunk(upload_url, request, files_service)
except Exception as e:
logger.error(f"Failed to proxy upload request: {str(e)}")
raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}")
async def _get_upload_status(self, upload_url: str) -> Response:
"""
获取上传状态
Args:
upload_url: 上传 URL
Returns:
Response: 状态响应
"""
try:
async with AsyncClient() as client:
response = await client.get(upload_url)
return Response(
content=response.content,
status_code=response.status_code,
headers=dict(response.headers)
)
except Exception as e:
logger.error(f"Failed to get upload status: {str(e)}")
raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}")
# 单例实例
_upload_handler_instance: Optional[FileUploadHandler] = None
def get_upload_handler() -> FileUploadHandler:
"""获取上传处理器单例实例"""
global _upload_handler_instance
if _upload_handler_instance is None:
_upload_handler_instance = FileUploadHandler()
return _upload_handler_instance

View File

@@ -0,0 +1,498 @@
"""
文件管理服务
"""
import json
from datetime import datetime, timedelta, timezone
from typing import Optional, Dict, Any, Tuple
from httpx import AsyncClient
import asyncio
from app.config.config import settings
from app.database import services as db_services
from app.database.models import FileState
from app.domain.file_models import FileMetadata, ListFilesResponse
from fastapi import HTTPException
from app.log.logger import get_files_logger
from app.service.client.api_client import GeminiApiClient
from app.service.key.key_manager import get_key_manager_instance
logger = get_files_logger()
# 全局上傳會話存儲
_upload_sessions: Dict[str, Dict[str, Any]] = {}
_upload_sessions_lock = asyncio.Lock()
class FilesService:
"""文件管理服务类"""
def __init__(self):
self.api_client = GeminiApiClient(base_url=settings.BASE_URL)
self.key_manager = None
async def _get_key_manager(self):
"""获取 KeyManager 实例"""
if not self.key_manager:
self.key_manager = await get_key_manager_instance(
settings.API_KEYS,
settings.VERTEX_API_KEYS
)
return self.key_manager
async def initialize_upload(
self,
headers: Dict[str, str],
body: Optional[bytes],
user_token: str,
request_host: str = None # 添加請求主機參數
) -> Tuple[Dict[str, Any], Dict[str, str]]:
"""
初始化文件上传
Args:
headers: 请求头
body: 请求体
user_token: 用户令牌
Returns:
Tuple[Dict[str, Any], Dict[str, str]]: (响应体, 响应头)
"""
try:
# 获取可用的 API key
key_manager = await self._get_key_manager()
api_key = await key_manager.get_next_key()
if not api_key:
raise HTTPException(status_code=503, detail="No available API keys")
# 转发请求到真实的 Gemini API
async with AsyncClient() as client:
# 准备请求头
forward_headers = {
"X-Goog-Upload-Protocol": headers.get("x-goog-upload-protocol", "resumable"),
"X-Goog-Upload-Command": headers.get("x-goog-upload-command", "start"),
"Content-Type": headers.get("content-type", "application/json"),
}
# 添加其他必要的头
if "x-goog-upload-header-content-length" in headers:
forward_headers["X-Goog-Upload-Header-Content-Length"] = headers["x-goog-upload-header-content-length"]
if "x-goog-upload-header-content-type" in headers:
forward_headers["X-Goog-Upload-Header-Content-Type"] = headers["x-goog-upload-header-content-type"]
# 发送请求
response = await client.post(
"https://generativelanguage.googleapis.com/upload/v1beta/files",
headers=forward_headers,
content=body,
params={"key": api_key}
)
if response.status_code != 200:
logger.error(f"Upload initialization failed: {response.status_code} - {response.text}")
raise HTTPException(status_code=response.status_code, detail="Upload initialization failed")
# 获取上传 URL
upload_url = response.headers.get("x-goog-upload-url")
if not upload_url:
raise HTTPException(status_code=500, detail="No upload URL in response")
logger.info(f"Original upload URL from Google: {upload_url}")
# 儲存上傳資訊到 headers 中,供後續使用
# 不在這裡創建數據庫記錄,等到上傳完成後再創建
logger.info(f"Upload initialized with API key: {api_key[:8]}...{api_key[-4:]}")
# 解析响应 - 初始化响应可能是空的
response_data = {}
# 從請求體中解析文件信息(如果有)
display_name = ""
if body:
try:
request_data = json.loads(body)
display_name = request_data.get("displayName", "")
except Exception:
pass
# 從 upload URL 中提取 upload_id
import urllib.parse
parsed_url = urllib.parse.urlparse(upload_url)
query_params = urllib.parse.parse_qs(parsed_url.query)
upload_id = query_params.get('upload_id', [None])[0]
if upload_id:
# 儲存上傳會話信息,使用 upload_id 作為 key
async with _upload_sessions_lock:
_upload_sessions[upload_id] = {
"api_key": api_key,
"user_token": user_token,
"display_name": display_name,
"mime_type": headers.get("x-goog-upload-header-content-type", "application/octet-stream"),
"size_bytes": int(headers.get("x-goog-upload-header-content-length", "0")),
"created_at": datetime.now(timezone.utc),
"upload_url": upload_url
}
logger.info(f"Stored upload session for upload_id={upload_id}: api_key={api_key[:8]}...{api_key[-4:]}")
logger.debug(f"Total active sessions: {len(_upload_sessions)}")
else:
logger.warning(f"No upload_id found in upload URL: {upload_url}")
# 定期清理過期的會話超過1小時
asyncio.create_task(self._cleanup_expired_sessions())
# 替換 Google 的 URL 為我們的代理 URL
proxy_upload_url = upload_url
if request_host:
# 原始: https://generativelanguage.googleapis.com/upload/v1beta/files?key=AIzaSyDc...&upload_id=xxx&upload_protocol=resumable
# 替換為: http://request-host/upload/v1beta/files?key=sk-123456&upload_id=xxx&upload_protocol=resumable
# 先替換域名
proxy_upload_url = upload_url.replace(
"https://generativelanguage.googleapis.com",
request_host.rstrip('/')
)
# 再替換 key 參數
import re
# 匹配 key=xxx 參數
key_pattern = r'(\?|&)key=([^&]+)'
match = re.search(key_pattern, proxy_upload_url)
if match:
# 替換為我們的 token
proxy_upload_url = proxy_upload_url.replace(
f"{match.group(1)}key={match.group(2)}",
f"{match.group(1)}key={user_token}"
)
logger.info(f"Replaced upload URL: {upload_url} -> {proxy_upload_url}")
return response_data, {
"X-Goog-Upload-URL": proxy_upload_url,
"X-Goog-Upload-Status": "active"
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to initialize upload: {str(e)}")
raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}")
async def _cleanup_expired_sessions(self):
"""清理過期的上傳會話"""
try:
async with _upload_sessions_lock:
now = datetime.now(timezone.utc)
expired_keys = []
for key, session in _upload_sessions.items():
if now - session["created_at"] > timedelta(hours=1):
expired_keys.append(key)
for key in expired_keys:
del _upload_sessions[key]
if expired_keys:
logger.info(f"Cleaned up {len(expired_keys)} expired upload sessions")
except Exception as e:
logger.error(f"Error cleaning up upload sessions: {str(e)}")
async def get_upload_session(self, key: str) -> Optional[Dict[str, Any]]:
"""獲取上傳會話信息(支持 upload_id 或完整 URL"""
async with _upload_sessions_lock:
# 先嘗試直接查找
session = _upload_sessions.get(key)
if session:
logger.debug(f"Found session by direct key {key}")
return session
# 如果是 URL嘗試提取 upload_id
if key.startswith("http"):
import urllib.parse
parsed_url = urllib.parse.urlparse(key)
query_params = urllib.parse.parse_qs(parsed_url.query)
upload_id = query_params.get('upload_id', [None])[0]
if upload_id:
session = _upload_sessions.get(upload_id)
if session:
logger.debug(f"Found session by upload_id {upload_id} from URL")
return session
logger.debug(f"No session found for key: {key}")
return None
async def get_file(self, file_name: str, user_token: str) -> FileMetadata:
"""
获取文件信息
Args:
file_name: 文件名称 (格式: files/{file_id})
user_token: 用户令牌
Returns:
FileMetadata: 文件元数据
"""
try:
# 查询文件记录
file_record = await db_services.get_file_record_by_name(file_name)
if not file_record:
raise HTTPException(status_code=404, detail="File not found")
# 检查是否过期
expiration_time = datetime.fromisoformat(str(file_record["expiration_time"]))
# 如果是 naive datetime假设为 UTC
if expiration_time.tzinfo is None:
expiration_time = expiration_time.replace(tzinfo=timezone.utc)
if expiration_time <= datetime.now(timezone.utc):
raise HTTPException(status_code=404, detail="File has expired")
# 使用原始 API key 获取文件信息
api_key = file_record["api_key"]
async with AsyncClient() as client:
response = await client.get(
f"{settings.BASE_URL}/{file_name}",
params={"key": api_key}
)
if response.status_code != 200:
logger.error(f"Failed to get file: {response.status_code} - {response.text}")
raise HTTPException(status_code=response.status_code, detail="Failed to get file")
file_data = response.json()
# 檢查並更新文件狀態
google_state = file_data.get("state", "PROCESSING")
if google_state != file_record.get("state", "").value if file_record.get("state") else None:
logger.info(f"File state changed from {file_record.get('state')} to {google_state}")
# 更新數據庫中的狀態
if google_state == "ACTIVE":
await db_services.update_file_record_state(
file_name=file_name,
state=FileState.ACTIVE,
update_time=datetime.now(timezone.utc)
)
elif google_state == "FAILED":
await db_services.update_file_record_state(
file_name=file_name,
state=FileState.FAILED,
update_time=datetime.now(timezone.utc)
)
# 构建响应
return FileMetadata(
name=file_data["name"],
displayName=file_data.get("displayName"),
mimeType=file_data["mimeType"],
sizeBytes=str(file_data["sizeBytes"]),
createTime=file_data["createTime"],
updateTime=file_data["updateTime"],
expirationTime=file_data["expirationTime"],
sha256Hash=file_data.get("sha256Hash"),
uri=file_data["uri"],
state=google_state
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to get file {file_name}: {str(e)}")
raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}")
async def list_files(
self,
page_size: int = 10,
page_token: Optional[str] = None,
user_token: Optional[str] = None
) -> ListFilesResponse:
"""
列出文件
Args:
page_size: 每页大小
page_token: 分页标记
user_token: 用户令牌(可选,如果提供则只返回该用户的文件)
Returns:
ListFilesResponse: 文件列表响应
"""
try:
logger.debug(f"list_files called with page_size={page_size}, page_token={page_token}")
# 从数据库获取文件列表
files, next_page_token = await db_services.list_file_records(
user_token=user_token,
page_size=page_size,
page_token=page_token
)
logger.debug(f"Database returned {len(files)} files, next_page_token={next_page_token}")
# 转换为响应格式
file_list = []
for file_record in files:
file_list.append(FileMetadata(
name=file_record["name"],
displayName=file_record.get("display_name"),
mimeType=file_record["mime_type"],
sizeBytes=str(file_record["size_bytes"]),
createTime=file_record["create_time"].isoformat() + "Z",
updateTime=file_record["update_time"].isoformat() + "Z",
expirationTime=file_record["expiration_time"].isoformat() + "Z",
sha256Hash=file_record.get("sha256_hash"),
uri=file_record["uri"],
state=file_record["state"].value if file_record.get("state") else "ACTIVE"
))
response = ListFilesResponse(
files=file_list,
nextPageToken=next_page_token
)
logger.debug(f"Returning response with {len(response.files)} files, nextPageToken={response.nextPageToken}")
return response
except Exception as e:
logger.error(f"Failed to list files: {str(e)}")
raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}")
async def delete_file(self, file_name: str, user_token: str) -> bool:
"""
删除文件
Args:
file_name: 文件名称
user_token: 用户令牌
Returns:
bool: 是否删除成功
"""
try:
# 查询文件记录
file_record = await db_services.get_file_record_by_name(file_name)
if not file_record:
raise HTTPException(status_code=404, detail="File not found")
# 使用原始 API key 删除文件
api_key = file_record["api_key"]
async with AsyncClient() as client:
response = await client.delete(
f"{settings.BASE_URL}/{file_name}",
params={"key": api_key}
)
if response.status_code not in [200, 204]:
logger.error(f"Failed to delete file: {response.status_code} - {response.text}")
# 如果 API 删除失败,但文件已过期,仍然删除数据库记录
expiration_time = datetime.fromisoformat(str(file_record["expiration_time"]))
if expiration_time.tzinfo is None:
expiration_time = expiration_time.replace(tzinfo=timezone.utc)
if expiration_time <= datetime.now(timezone.utc):
await db_services.delete_file_record(file_name)
return True
raise HTTPException(status_code=response.status_code, detail="Failed to delete file")
# 删除数据库记录
await db_services.delete_file_record(file_name)
return True
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to delete file {file_name}: {str(e)}")
raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}")
async def check_file_state(self, file_name: str, api_key: str) -> str:
"""
檢查並更新文件狀態
Args:
file_name: 文件名稱
api_key: API密鑰
Returns:
str: 當前狀態
"""
try:
async with AsyncClient() as client:
response = await client.get(
f"{settings.BASE_URL}/{file_name}",
params={"key": api_key}
)
if response.status_code != 200:
logger.error(f"Failed to check file state: {response.status_code}")
return "UNKNOWN"
file_data = response.json()
google_state = file_data.get("state", "PROCESSING")
# 更新數據庫狀態
if google_state == "ACTIVE":
await db_services.update_file_record_state(
file_name=file_name,
state=FileState.ACTIVE,
update_time=datetime.now(timezone.utc)
)
elif google_state == "FAILED":
await db_services.update_file_record_state(
file_name=file_name,
state=FileState.FAILED,
update_time=datetime.now(timezone.utc)
)
return google_state
except Exception as e:
logger.error(f"Failed to check file state: {str(e)}")
return "UNKNOWN"
async def cleanup_expired_files(self) -> int:
"""
清理过期文件
Returns:
int: 清理的文件数量
"""
try:
# 获取过期文件
expired_files = await db_services.delete_expired_file_records()
if not expired_files:
return 0
# 尝试从 Gemini API 删除文件
for file_record in expired_files:
try:
api_key = file_record["api_key"]
file_name = file_record["name"]
async with AsyncClient() as client:
await client.delete(
f"{settings.BASE_URL}/{file_name}",
params={"key": api_key}
)
except Exception as e:
# 记录错误但继续处理其他文件
logger.error(f"Failed to delete file {file_record['name']} from API: {str(e)}")
return len(expired_files)
except Exception as e:
logger.error(f"Failed to cleanup expired files: {str(e)}")
return 0
# 单例实例
_files_service_instance: Optional[FilesService] = None
async def get_files_service() -> FilesService:
"""获取文件服务单例实例"""
global _files_service_instance
if _files_service_instance is None:
_files_service_instance = FilesService()
return _files_service_instance

View File

@@ -121,6 +121,7 @@ class ImageCreateService:
provider=settings.UPLOAD_PROVIDER,
base_url=settings.CLOUDFLARE_IMGBED_URL,
auth_code=settings.CLOUDFLARE_IMGBED_AUTH_CODE,
upload_folder=settings.CLOUDFLARE_IMGBED_UPLOAD_FOLDER,
)
else:
raise ValueError(

View File

@@ -1,6 +1,6 @@
import asyncio
from itertools import cycle
from typing import Dict
from typing import Dict, Union
from app.config.config import settings
from app.log.logger import get_key_manager_logger
@@ -34,7 +34,7 @@ class KeyManager:
return next(self.key_cycle)
async def get_next_vertex_key(self) -> str:
"""获取下一个 Vertex API key"""
"""获取下一个 Vertex Express API key"""
async with self.vertex_key_cycle_lock:
return next(self.vertex_key_cycle)
@@ -98,7 +98,7 @@ class KeyManager:
return current_key
async def get_next_working_vertex_key(self) -> str:
"""获取下一可用的 Vertex API key"""
"""获取下一可用的 Vertex Express API key"""
initial_key = await self.get_next_vertex_key()
current_key = initial_key
@@ -124,12 +124,12 @@ class KeyManager:
return ""
async def handle_vertex_api_failure(self, api_key: str, retries: int) -> str:
"""处理 Vertex API 调用失败"""
"""处理 Vertex Express API 调用失败"""
async with self.vertex_failure_count_lock:
self.vertex_key_failure_counts[api_key] += 1
if self.vertex_key_failure_counts[api_key] >= self.MAX_FAILURES:
logger.warning(
f"Vertex API key {api_key} has failed {self.MAX_FAILURES} times"
f"Vertex Express API key {api_key} has failed {self.MAX_FAILURES} times"
)
def get_fail_count(self, key: str) -> int:
@@ -156,7 +156,7 @@ class KeyManager:
return {"valid_keys": valid_keys, "invalid_keys": invalid_keys}
async def get_vertex_keys_by_status(self) -> dict:
"""获取分类后的 Vertex API key 列表,包括失败次数"""
"""获取分类后的 Vertex Express API key 列表,包括失败次数"""
valid_keys = {}
invalid_keys = {}
@@ -185,12 +185,12 @@ class KeyManager:
_singleton_instance = None
_singleton_lock = asyncio.Lock()
_preserved_failure_counts: Dict[str, int] | None = None
_preserved_vertex_failure_counts: Dict[str, int] | None = None
_preserved_old_api_keys_for_reset: list | None = None
_preserved_vertex_old_api_keys_for_reset: list | None = None
_preserved_next_key_in_cycle: str | None = None
_preserved_vertex_next_key_in_cycle: str | None = None
_preserved_failure_counts: Union[Dict[str, int], None] = None
_preserved_vertex_failure_counts: Union[Dict[str, int], None] = None
_preserved_old_api_keys_for_reset: Union[list, None] = None
_preserved_vertex_old_api_keys_for_reset: Union[list, None] = None
_preserved_next_key_in_cycle: Union[str, None] = None
_preserved_vertex_next_key_in_cycle: Union[str, None] = None
async def get_key_manager_instance(
@@ -213,7 +213,7 @@ async def get_key_manager_instance(
)
if vertex_api_keys is None:
raise ValueError(
"Vertex API keys are required to initialize or re-initialize the KeyManager instance."
"Vertex Express API keys are required to initialize or re-initialize the KeyManager instance."
)
if not api_keys:
@@ -222,12 +222,12 @@ async def get_key_manager_instance(
)
if not vertex_api_keys:
logger.warning(
"Initializing KeyManager with an empty list of Vertex API keys."
"Initializing KeyManager with an empty list of Vertex Express API keys."
)
_singleton_instance = KeyManager(api_keys, vertex_api_keys)
logger.info(
f"KeyManager instance created/re-created with {len(api_keys)} API keys and {len(vertex_api_keys)} Vertex API keys."
f"KeyManager instance created/re-created with {len(api_keys)} API keys and {len(vertex_api_keys)} Vertex Express API keys."
)
# 1. 恢复失败计数
@@ -349,7 +349,7 @@ async def get_key_manager_instance(
break
except ValueError:
logger.warning(
f"Preserved next key '{_preserved_vertex_next_key_in_cycle}' not found in preserved old Vertex API keys. "
f"Preserved next key '{_preserved_vertex_next_key_in_cycle}' not found in preserved old Vertex Express API keys. "
"New cycle will start from the beginning of the new list."
)
except Exception as e:
@@ -357,7 +357,7 @@ async def get_key_manager_instance(
f"Error determining start key for new Vertex key cycle from preserved state: {e}. "
"New cycle will start from the beginning."
)
if start_key_for_new_vertex_cycle and _singleton_instance.vertex_api_keys:
try:
target_idx = _singleton_instance.vertex_api_keys.index(
@@ -370,25 +370,25 @@ async def get_key_manager_instance(
)
except ValueError:
logger.warning(
f"Determined start key '{start_key_for_new_vertex_cycle}' not found in new Vertex API keys during cycle advancement. "
f"Determined start key '{start_key_for_new_vertex_cycle}' not found in new Vertex Express API keys during cycle advancement. "
"New cycle will start from the beginning."
)
except StopIteration:
logger.error(
"StopIteration while advancing Vertex key cycle, implies empty new Vertex API key list previously missed."
"StopIteration while advancing Vertex key cycle, implies empty new Vertex Express API key list previously missed."
)
except Exception as e:
logger.error(
f"Error advancing new Vertex key cycle: {e}. Cycle will start from beginning."
)
)
else:
if _singleton_instance.vertex_api_keys:
logger.info(
"New Vertex key cycle will start from the beginning of the new Vertex API key list (no specific start key determined or needed)."
"New Vertex key cycle will start from the beginning of the new Vertex Express API key list (no specific start key determined or needed)."
)
else:
logger.info(
"New Vertex key cycle not applicable as the new Vertex API key list is empty."
"New Vertex key cycle not applicable as the new Vertex Express API key list is empty."
)
# 清理所有保存的状态
@@ -409,16 +409,20 @@ async def reset_key_manager_instance():
if _singleton_instance:
# 1. 保存失败计数
_preserved_failure_counts = _singleton_instance.key_failure_counts.copy()
_preserved_vertex_failure_counts = _singleton_instance.vertex_key_failure_counts.copy()
_preserved_vertex_failure_counts = (
_singleton_instance.vertex_key_failure_counts.copy()
)
# 2. 保存旧的 API keys 列表
_preserved_old_api_keys_for_reset = _singleton_instance.api_keys.copy()
_preserved_vertex_old_api_keys_for_reset = _singleton_instance.vertex_api_keys.copy()
_preserved_vertex_old_api_keys_for_reset = (
_singleton_instance.vertex_api_keys.copy()
)
# 3. 保存 key_cycle 的下一个 key 提示
try:
if _singleton_instance.api_keys:
_preserved_next_key_in_cycle = (
_preserved_next_key_in_cycle = (
await _singleton_instance.get_next_key()
)
else:
@@ -427,7 +431,7 @@ async def reset_key_manager_instance():
logger.warning(
"Could not preserve next key hint: key cycle was empty or exhausted in old instance."
)
_preserved_next_key_in_cycle = None
_preserved_next_key_in_cycle = None
except Exception as e:
logger.error(f"Error preserving next key hint during reset: {e}")
_preserved_next_key_in_cycle = None
@@ -443,12 +447,11 @@ async def reset_key_manager_instance():
except StopIteration:
logger.warning(
"Could not preserve next key hint: Vertex key cycle was empty or exhausted in old instance."
)
)
_preserved_vertex_next_key_in_cycle = None
except Exception as e:
logger.error(f"Error preserving next key hint during reset: {e}")
_preserved_vertex_next_key_in_cycle = None
_singleton_instance = None
logger.info(

View File

@@ -1,6 +1,7 @@
# app/service/stats_service.py
import datetime
from typing import Union
from sqlalchemy import and_, case, func, or_, select
@@ -195,10 +196,11 @@ class StatsService:
return details
except Exception as e:
logger.error(f"Failed to get API call details for period '{period}': {e}")
logger.error(
f"Failed to get API call details for period '{period}': {e}")
raise
async def get_key_usage_details_last_24h(self, key: str) -> dict | None:
async def get_key_usage_details_last_24h(self, key: str) -> Union[dict, None]:
"""
获取指定 API 密钥在过去 24 小时内按模型统计的调用次数。
@@ -218,7 +220,8 @@ class StatsService:
try:
query = (
select(
RequestLog.model_name, func.count(RequestLog.id).label("call_count")
RequestLog.model_name, func.count(
RequestLog.id).label("call_count")
)
.where(
RequestLog.api_key == key,
@@ -237,7 +240,8 @@ class StatsService:
)
return {}
usage_details = {row["model_name"]: row["call_count"] for row in results}
usage_details = {row["model_name"]: row["call_count"]
for row in results}
logger.info(
f"Successfully fetched usage details for key ending in ...{key[-4:]}: {usage_details}"
)

View File

@@ -0,0 +1,363 @@
# 原生Gemini TTS功能
这个模块为Gemini Balance项目添加了原生Gemini TTSText-to-Speech功能支持单人和多人语音合成采用智能检测和继承模式设计保持与原始代码的完全兼容性。
## 🎯 设计原则
- **智能检测**自动检测所有原生Gemini TTS格式的请求包含responseModalities和speechConfig
- **继承而非修改**:所有扩展都继承自原始类,不修改源码
- **完全兼容**原有TTS功能OpenAI兼容TTS完全不受影响
- **动态模型选择**支持用户在请求URL中指定不同的TTS模型
- **自动回退**原生TTS处理失败时自动回退到标准服务
- **完整日志记录**:包含请求日志、错误日志和性能监控
- **易于维护**:更新原始代码时不会产生冲突
## 📁 文件结构
```
app/service/tts/
├── tts_service.py # 原有的OpenAI兼容TTS服务
└── native/ # 原生Gemini TTS扩展
├── __init__.py # 模块初始化
├── README.md # 使用说明(本文件)
├── tts_models.py # TTS数据模型继承自原始模型
├── tts_response_handler.py # TTS响应处理器继承自原始处理器
├── tts_chat_service.py # TTS聊天服务继承自原始服务
└── tts_routes.py # TTS路由扩展和依赖注入
```
## 🚀 原生Gemini TTS功能
### 智能检测机制(当前实现)
原生Gemini TTS功能通过智能检测自动启用无需任何配置
1. **自动启用**
```bash
# 直接启动服务原生TTS功能自动可用
python -m uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload
```
2. **无需配置**
- 不需要环境变量
- 不需要修改配置文件
- 完全基于请求内容智能判断
### 工作原理
系统会智能检测请求内容:
- **原生TTS请求**:包含 `responseModalities: ["AUDIO"]``speechConfig` → 使用TTS增强服务
- **单人TTS**:包含 `voiceConfig.prebuiltVoiceConfig`
- **多人TTS**:包含 `multiSpeakerVoiceConfig`
- **普通请求**非TTS模型 → 使用原有Gemini聊天服务
```python
# app/router/gemini_routes.py 中的智能检测逻辑
if "tts" in model_name.lower() and request.generationConfig:
# 直接从解析后的request对象获取TTS配置
response_modalities = request.generationConfig.responseModalities or []
speech_config = request.generationConfig.speechConfig or {}
# 如果包含AUDIO模态和语音配置则认为是原生TTS请求
if "AUDIO" in response_modalities and speech_config:
# 使用TTS增强服务
tts_service = await get_tts_chat_service(key_manager)
return await tts_service.generate_content(...)
# 否则使用原有服务
```
## 📝 使用示例
### 1. 原生Gemini单人TTS请求使用TTS增强服务
包含 `voiceConfig.prebuiltVoiceConfig` 的原生Gemini格式请求会自动使用TTS增强服务
```bash
curl -X POST "https://your-domain.com/v1beta/models/gemini-2.5-flash-preview-tts:generateContent" \
-H "Content-Type: application/json" \
-H "x-goog-api-key: your-token" \
-d '{
"contents": [{
"parts": [{
"text": "Hello, this is a single speaker test."
}]
}],
"generationConfig": {
"responseModalities": ["AUDIO"],
"speechConfig": {
"voiceConfig": {
"prebuiltVoiceConfig": {
"voiceName": "Kore"
}
}
}
}
}'
```
### 2. 原生Gemini多人TTS请求使用TTS增强服务
包含 `multiSpeakerVoiceConfig` 的原生Gemini格式请求会自动使用TTS增强服务
```bash
curl -X POST "https://your-domain.com/v1beta/models/gemini-2.5-flash-preview-tts:generateContent" \
-H "Content-Type: application/json" \
-H "x-goog-api-key: your-token" \
-d '{
"contents": [{
"parts": [{
"text": "Alice: Hello everyone, welcome to our show today.\nBob: Hi Alice, and hello to all our listeners! Today we are talking about AI development."
}]
}],
"generationConfig": {
"responseModalities": ["AUDIO"],
"speechConfig": {
"multiSpeakerVoiceConfig": {
"speakerVoiceConfigs": [
{
"speaker": "Alice",
"voiceConfig": {
"prebuiltVoiceConfig": {
"voiceName": "Puck"
}
}
},
{
"speaker": "Bob",
"voiceConfig": {
"prebuiltVoiceConfig": {
"voiceName": "Kore"
}
}
}
]
}
}
}
}'
```
### 3. OpenAI兼容TTS请求使用原有服务
OpenAI兼容格式的TTS请求使用不同的API路径不受本模块影响
```bash
curl -X POST "https://your-domain.com/v1/audio/speech" \
-H "Content-Type: application/json" \
-H "Authorization: Bearer your-token" \
-d '{
"model": "tts-1",
"input": "这是一个OpenAI兼容格式的TTS测试。",
"voice": "alloy"
}' \
--output openai_tts_test.wav
```
**注意**OpenAI兼容TTS请求
- 使用路径:`/v1/audio/speech`
- 使用Authorization头而不是x-goog-api-key
- 返回音频文件而不是JSON响应
- 不受本模块的TTS增强服务影响
### 普通文本生成(使用原有服务)
非TTS模型的请求会使用原有的Gemini聊天服务完全不受影响
```bash
curl -X POST "https://your-domain.com/v1beta/models/gemini-2.5-flash:generateContent" \
-H "Content-Type: application/json" \
-H "x-goog-api-key: your-token" \
-d '{
"contents": [{
"parts": [{
"text": "请简单介绍一下人工智能的发展历程。"
}]
}],
"generationConfig": {
"maxOutputTokens": 200,
"temperature": 0.7
}
}'
```
## 🔧 技术实现
### 继承关系
```
GeminiChatService
↓ (继承)
TTSGeminiChatService
├── 重写 generate_content() 方法
├── 添加 _handle_tts_request() 方法
└── 集成完整的日志记录功能
GeminiResponseHandler
↓ (继承)
TTSResponseHandler
└── 重写 handle_response() 方法
GenerationConfig (Pydantic模型)
↓ (扩展)
TTSGenerationConfig
├── responseModalities: List[str]
└── speechConfig: Dict[str, Any]
```
### 工作流程
1. **请求接收**系统接收到API请求
2. **智能检测**
- 检查模型名称是否包含 "tts"
- 如果是TTS模型`request.generationConfig` 检查是否包含 `responseModalities: ["AUDIO"]``speechConfig`
3. **服务选择**
- **原生TTS请求**:使用 `TTSGeminiChatService` 增强服务
- **普通请求**:使用原有 `GeminiChatService`
4. **请求处理**
- **原生TTS**:使用 `_handle_tts_request()` 特殊处理
- **其他请求**:使用标准 `generate_content()` 方法
5. **字段处理**:从 `request.generationConfig` 直接获取TTS字段`responseModalities`, `speechConfig`
6. **API调用**构建优化的payload并调用Gemini API
7. **自动回退**如果原生TTS处理失败自动回退到标准服务
8. **响应处理**
- **TTS响应**:检测音频数据,直接返回原始响应
- **普通响应**:使用标准处理方法
9. **日志记录**:记录请求时间、成功状态、错误信息到数据库
## 📊 功能特性
### ✅ 已实现功能
- **智能原生TTS支持**:支持单人和多人语音合成
- **单人TTS**:支持 `voiceConfig.prebuiltVoiceConfig` 配置
- **多人TTS**:支持 `multiSpeakerVoiceConfig` 配置
- **智能检测机制**自动检测所有原生Gemini TTS格式的请求
- **动态模型选择**支持用户在URL中指定不同TTS模型
- **完全向后兼容**原有TTS功能OpenAI兼容TTS完全不受影响
- **自动回退机制**原生TTS处理失败时自动使用标准服务
- **完整日志记录**:请求日志、错误日志、性能监控
- **API配额管理**:自动重试和密钥轮换
- **零配置启用**:无需环境变量或配置文件修改
- **错误处理**:完整的异常捕获和错误记录
### 🎵 支持的语音配置
#### 单人语音配置
```json
{
"responseModalities": ["AUDIO"],
"speechConfig": {
"voiceConfig": {
"prebuiltVoiceConfig": {
"voiceName": "Kore|Puck|其他预设语音"
}
}
}
}
```
#### 多人语音配置
```json
{
"responseModalities": ["AUDIO"],
"speechConfig": {
"multiSpeakerVoiceConfig": {
"speakerVoiceConfigs": [
{
"speaker": "角色名称",
"voiceConfig": {
"prebuiltVoiceConfig": {
"voiceName": "Kore|Puck|其他预设语音"
}
}
}
]
}
}
}
```
## ⚠️ 注意事项
### API要求
- 确保API密钥有TTS权限
- TTS功能需要 `gemini-2.5-flash-preview-tts` 模型
- 注意API配额限制免费版每天15次
### 性能考虑
- TTS响应通常比文本响应更大音频数据
- 建议监控API调用频率和成功率
- 扩展功能不影响原始功能的性能和稳定性
### 部署建议
- 生产环境建议先测试普通功能
- 逐步启用TTS功能并监控日志
- 定期检查API配额使用情况
## 📈 监控和调试
### 日志查看
- **服务器日志**查看TTS请求处理过程
- **管理界面**:在"API 调用详情"中查看请求记录
- **错误日志**:查看失败请求的详细信息
### 调试技巧
```bash
# 启用详细日志
export LOG_LEVEL=DEBUG
# 查看实时日志
tail -f logs/app.log
# 多人TTS功能无需配置自动启用
# 可通过请求内容智能检测
```
## 🔄 TTS系统对比
项目中现在有三套TTS系统各自服务不同的用途
| TTS类型 | 路径 | 模型选择 | 语音配置 | 使用场景 | 我们的影响 |
|---------|------|----------|----------|----------|------------|
| **OpenAI兼容TTS** | `/v1/audio/speech` | 固定配置文件 | 单人语音 | OpenAI API兼容 | ✅ 无影响 |
| **Gemini单人TTS** | `/v1beta/models/{model}:generateContent` | 用户指定 | 单人语音 | 原生Gemini TTS | ✅ 我们的增强 |
| **Gemini多人TTS** | `/v1beta/models/{model}:generateContent` | 用户指定 | 多人语音 | 对话场景 | ✅ 我们的增强 |
### 智能路由机制
```mermaid
flowchart TD
A[API请求] --> B{路径检查}
B -->|/v1/audio/speech| C[OpenAI兼容TTS服务]
B -->|/v1beta/models/{model}:generateContent| D{模型名包含'tts'?}
D -->|否| E[标准Gemini聊天服务]
D -->|是| F{包含responseModalities和speechConfig?}
F -->|否| G[标准Gemini聊天服务]
F -->|是| H[原生TTS增强服务]
H --> I{处理成功?}
I -->|是| J[返回原生TTS响应]
I -->|否| K[自动回退到标准服务]
C --> L[完成]
E --> L
G --> L
J --> L
K --> L
```
## 🎉 成功案例
基于智能检测的原生Gemini TTS解决方案已经成功实现
-**零配置启用**:无需任何环境变量或配置修改
-**智能检测**自动检测所有原生Gemini TTS格式的请求
-**完全向后兼容**所有原有TTS功能零影响
-**动态模型选择**支持用户指定不同TTS模型
-**自动回退机制**:处理失败时自动使用标准服务
-**单人和多人语音合成**支持所有原生Gemini TTS场景
-**完整日志记录**:可在管理界面查看所有请求
-**错误处理完善**API配额和重试机制
-**易于维护**:更新原始代码无冲突
这个实现展示了如何在不修改原始代码的情况下,优雅地扩展复杂系统的功能,同时保持完美的向后兼容性。

View File

@@ -0,0 +1,19 @@
"""
原生Gemini TTS功能模块
Native Gemini TTS functionality for both single and multi-speaker scenarios
"""
from .tts_chat_service import TTSGeminiChatService
from .tts_models import TTSGenerationConfig, MultiSpeakerVoiceConfig, SpeechConfig, TTSRequest
from .tts_response_handler import TTSResponseHandler
from .tts_routes import get_tts_chat_service
__all__ = [
"TTSGeminiChatService",
"TTSGenerationConfig",
"MultiSpeakerVoiceConfig",
"SpeechConfig",
"TTSRequest",
"TTSResponseHandler",
"get_tts_chat_service"
]

View File

@@ -0,0 +1,151 @@
"""
原生Gemini TTS聊天服务扩展
继承自原始聊天服务添加原生Gemini TTS支持单人和多人保持向后兼容
"""
import time
import datetime
from typing import Any, Dict
from app.service.chat.gemini_chat_service import GeminiChatService
from app.service.tts.native.tts_response_handler import TTSResponseHandler
from app.domain.gemini_models import GeminiRequest
from app.log.logger import get_gemini_logger
from app.database.services import add_request_log, add_error_log
logger = get_gemini_logger()
class TTSGeminiChatService(GeminiChatService):
"""
支持TTS的Gemini聊天服务
继承自原始的GeminiChatService添加TTS功能
"""
def __init__(self, base_url: str, key_manager):
"""
初始化TTS聊天服务
"""
super().__init__(base_url, key_manager)
# 使用TTS响应处理器替换原始处理器
self.response_handler = TTSResponseHandler()
logger.info("TTS Gemini Chat Service initialized with multi-speaker TTS support")
async def generate_content(
self, model: str, request: GeminiRequest, api_key: str
) -> Dict[str, Any]:
"""
生成内容支持TTS
"""
try:
# 添加调试日志
logger.info(f"TTS request model: {model}")
logger.info(f"TTS request generationConfig: {request.generationConfig}")
# 检查是否是TTS模型如果是需要特殊处理
if "tts" in model.lower():
logger.info("Detected TTS model, applying TTS-specific processing")
# 对于TTS模型我们需要确保正确的字段被传递
response = await self._handle_tts_request(model, request, api_key)
return response
else:
# 对于非TTS模型使用父类的方法
response = await super().generate_content(model, request, api_key)
return response
except Exception as e:
logger.error(f"TTS API call failed with error: {e}")
raise
async def _handle_tts_request(self, model: str, request: GeminiRequest, api_key: str) -> Dict[str, Any]:
"""
处理TTS特定的请求包含完整的日志记录功能
"""
# 记录开始时间和请求时间
start_time = time.perf_counter()
request_datetime = datetime.datetime.now()
is_success = False
status_code = None
try:
# 构建TTS专用的payload - 不包含tools和safetySettings
from app.service.chat.gemini_chat_service import _filter_empty_parts
request_dict = request.model_dump(exclude_none=False)
# 构建TTS专用的简化payload
payload = {
"contents": _filter_empty_parts(request_dict.get("contents", [])),
"generationConfig": request_dict.get("generationConfig", {}),
}
# 只在有systemInstruction时才添加
if request_dict.get("systemInstruction"):
payload["systemInstruction"] = request_dict.get("systemInstruction")
# 确保 generationConfig 不为 None
if payload["generationConfig"] is None:
payload["generationConfig"] = {}
# 从request.generationConfig直接获取TTS相关字段
if request.generationConfig:
# 添加TTS特定字段
if request.generationConfig.responseModalities:
payload["generationConfig"]["responseModalities"] = request.generationConfig.responseModalities
logger.info(f"Added responseModalities: {request.generationConfig.responseModalities}")
if request.generationConfig.speechConfig:
payload["generationConfig"]["speechConfig"] = request.generationConfig.speechConfig
logger.info(f"Added speechConfig: {request.generationConfig.speechConfig}")
else:
logger.warning("No generationConfig found in request, TTS fields may be missing")
logger.info(f"TTS payload before API call: {payload}")
# 调用API
response = await self.api_client.generate_content(payload, model, api_key)
# 如果到达这里说明API调用成功
is_success = True
status_code = 200
# 使用TTS响应处理器处理响应
return self.response_handler.handle_response(response, model, False, None)
except Exception as e:
# 记录错误
is_success = False
error_msg = str(e)
# 尝试从错误消息中提取状态码
import re
match = re.search(r"status code (\d+)", error_msg)
if match:
status_code = int(match.group(1))
else:
status_code = 500
# 添加错误日志
await add_error_log(
gemini_key=api_key,
model_name=model,
error_type="tts-api-error",
error_log=error_msg,
error_code=status_code,
request_msg=request.model_dump(exclude_none=False)
)
logger.error(f"TTS API call failed: {error_msg}")
raise
finally:
# 记录请求日志
end_time = time.perf_counter()
latency_ms = int((end_time - start_time) * 1000)
await add_request_log(
model_name=model,
api_key=api_key,
is_success=is_success,
status_code=status_code,
latency_ms=latency_ms,
request_time=request_datetime
)

View File

@@ -0,0 +1,37 @@
"""
TTS扩展配置
控制是否启用TTS功能
"""
import os
from typing import Union
from app.service.chat.gemini_chat_service import GeminiChatService
from app.service.tts.native.tts_chat_service import TTSGeminiChatService
class TTSConfig:
"""TTS配置管理"""
@staticmethod
def is_tts_enabled() -> bool:
"""
检查是否启用TTS功能
通过环境变量 ENABLE_TTS 控制,默认为 False
"""
return os.getenv("ENABLE_TTS", "false").lower() in ("true", "1", "yes", "on")
@staticmethod
def get_chat_service(base_url: str, key_manager) -> Union[GeminiChatService, TTSGeminiChatService]:
"""
工厂方法:根据配置返回合适的聊天服务
"""
if TTSConfig.is_tts_enabled():
return TTSGeminiChatService(base_url, key_manager)
else:
return GeminiChatService(base_url, key_manager)
# 便捷函数
def create_chat_service(base_url: str, key_manager) -> Union[GeminiChatService, TTSGeminiChatService]:
"""创建聊天服务实例"""
return TTSConfig.get_chat_service(base_url, key_manager)

View File

@@ -0,0 +1,36 @@
"""
原生Gemini TTS扩展数据模型
继承自原始模型添加原生Gemini TTS相关字段保持向后兼容
"""
from typing import Any, Dict, List, Optional
from pydantic import BaseModel
from app.domain.gemini_models import GenerationConfig as BaseGenerationConfig
class TTSGenerationConfig(BaseGenerationConfig):
"""
支持TTS的生成配置类
继承自原始的GenerationConfig添加TTS相关字段
"""
# TTS 相关配置
responseModalities: Optional[List[str]] = None
speechConfig: Optional[Dict[str, Any]] = None
class MultiSpeakerVoiceConfig(BaseModel):
"""多人语音配置"""
speakerVoiceConfigs: List[Dict[str, Any]]
class SpeechConfig(BaseModel):
"""语音配置"""
multiSpeakerVoiceConfig: Optional[MultiSpeakerVoiceConfig] = None
voiceConfig: Optional[Dict[str, Any]] = None
class TTSRequest(BaseModel):
"""TTS请求模型"""
contents: List[Dict[str, Any]]
generationConfig: TTSGenerationConfig

View File

@@ -0,0 +1,53 @@
"""
原生Gemini TTS响应处理器扩展
继承自原始响应处理器添加原生Gemini TTS支持保持向后兼容
"""
from typing import Any, Dict, Optional
from app.handler.response_handler import GeminiResponseHandler
from app.log.logger import get_gemini_logger
logger = get_gemini_logger()
class TTSResponseHandler(GeminiResponseHandler):
"""
支持TTS的响应处理器
继承自原始的GeminiResponseHandler添加TTS响应处理
"""
def handle_response(
self, response: Dict[str, Any], model: str, stream: bool = False, usage_metadata: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""
处理响应支持TTS音频数据
"""
# 检查是否是TTS响应包含音频数据
if self._is_tts_response(response):
logger.info("Detected TTS response with audio data, returning original response")
return response
# 对于非TTS响应使用父类的处理方法
return super().handle_response(response, model, stream, usage_metadata)
def _is_tts_response(self, response: Dict[str, Any]) -> bool:
"""
检查是否是TTS响应
"""
try:
if (response.get("candidates") and
len(response["candidates"]) > 0 and
response["candidates"][0].get("content") and
response["candidates"][0]["content"].get("parts") and
len(response["candidates"][0]["content"]["parts"]) > 0):
parts = response["candidates"][0]["content"]["parts"]
for part in parts:
if "inlineData" in part:
mime_type = part["inlineData"].get("mimeType", "")
if mime_type.startswith("audio/"):
return True
return False
except Exception as e:
logger.warning(f"Error checking TTS response: {e}")
return False

View File

@@ -0,0 +1,24 @@
"""
TTS路由扩展
提供原生Gemini TTS增强服务支持单人和多人语音
"""
from fastapi import Depends
from app.config.config import settings
from app.service.key.key_manager import KeyManager, get_key_manager_instance
from app.service.tts.native.tts_chat_service import TTSGeminiChatService
async def get_key_manager():
"""获取密钥管理器实例"""
return get_key_manager_instance()
async def get_tts_chat_service(key_manager: KeyManager = Depends(get_key_manager)) -> TTSGeminiChatService:
"""
获取原生Gemini TTS增强聊天服务实例支持单人和多人语音
"""
return TTSGeminiChatService(settings.BASE_URL, key_manager)

View File

@@ -0,0 +1,95 @@
import datetime
import io
import re
import time
import wave
from typing import Optional
from google import genai
from app.config.config import settings
from app.core.constants import TTS_VOICE_NAMES
from app.database.services import add_error_log, add_request_log
from app.domain.openai_models import TTSRequest
from app.log.logger import get_openai_logger
logger = get_openai_logger()
def _create_wav_file(audio_data: bytes) -> bytes:
"""Creates a WAV file in memory from raw audio data."""
with io.BytesIO() as wav_file:
with wave.open(wav_file, "wb") as wf:
wf.setnchannels(1) # Mono
wf.setsampwidth(2) # 16-bit
wf.setframerate(24000) # 24kHz sample rate
wf.writeframes(audio_data)
return wav_file.getvalue()
class TTSService:
async def create_tts(self, request: TTSRequest, api_key: str) -> Optional[bytes]:
"""
使用 Google Gemini SDK 创建音频。
"""
start_time = time.perf_counter()
request_datetime = datetime.datetime.now()
is_success = False
status_code = None
response = None
error_log_msg = ""
try:
client = genai.Client(api_key=api_key)
response =await client.aio.models.generate_content(
model=settings.TTS_MODEL,
contents=f"Speak in a {settings.TTS_SPEED} speed voice: {request.input}",
config={
"response_modalities": ["Audio"],
"speech_config": {
"voice_config": {
"prebuilt_voice_config": {
"voice_name": request.voice if request.voice in TTS_VOICE_NAMES else settings.TTS_VOICE_NAME
}
}
},
},
)
if (
response.candidates
and response.candidates[0].content.parts
and response.candidates[0].content.parts[0].inline_data
):
raw_audio_data = response.candidates[0].content.parts[0].inline_data.data
is_success = True
status_code = 200
return _create_wav_file(raw_audio_data)
except Exception as e:
is_success = False
error_log_msg = f"Generic error: {e}"
logger.error(f"An error occurred in TTSService: {error_log_msg}")
match = re.search(r"status code (\d+)", str(e))
if match:
status_code = int(match.group(1))
else:
status_code = 500
raise
finally:
end_time = time.perf_counter()
latency_ms = int((end_time - start_time) * 1000)
if not is_success:
await add_error_log(
gemini_key=api_key,
model_name=settings.TTS_MODEL,
error_type="google-tts",
error_log=error_log_msg,
error_code=status_code,
request_msg=request.input
)
await add_request_log(
model_name=settings.TTS_MODEL,
api_key=api_key,
is_success=is_success,
status_code=status_code,
latency_ms=latency_ms,
request_time=request_datetime
)

View File

@@ -5,14 +5,17 @@ const ARRAY_INPUT_CLASS = "array-input";
const MAP_ITEM_CLASS = "map-item";
const MAP_KEY_INPUT_CLASS = "map-key-input";
const MAP_VALUE_INPUT_CLASS = "map-value-input";
const CUSTOM_HEADER_ITEM_CLASS = "custom-header-item";
const CUSTOM_HEADER_KEY_INPUT_CLASS = "custom-header-key-input";
const CUSTOM_HEADER_VALUE_INPUT_CLASS = "custom-header-value-input";
const SAFETY_SETTING_ITEM_CLASS = "safety-setting-item";
const SHOW_CLASS = "show"; // For modals
const API_KEY_REGEX = /AIzaSy\S{33}/g;
const PROXY_REGEX =
/(?:https?|socks5):\/\/(?:[^:@\/]+(?::[^@\/]+)?@)?(?:[^:\/\s]+)(?::\d+)?/g;
const VERTEX_API_KEY_REGEX = /AQ\.[a-zA-Z0-9_]{50}/g; // 新增 Vertex API Key 正则
const VERTEX_API_KEY_REGEX = /AQ\.[a-zA-Z0-9_\-]{50}/g; // 新增 Vertex Express API Key 正则
const MASKED_VALUE = "••••••••";
// DOM Elements - Global Scope for frequently accessed elements
const safetySettingsContainer = document.getElementById(
"SAFETY_SETTINGS_container"
@@ -31,8 +34,8 @@ const bulkDeleteProxyModal = document.getElementById("bulkDeleteProxyModal");
const bulkDeleteProxyInput = document.getElementById("bulkDeleteProxyInput");
const resetConfirmModal = document.getElementById("resetConfirmModal");
const configForm = document.getElementById("configForm"); // Added for frequent use
// Vertex API Key Modal Elements
// Vertex Express API Key Modal Elements
const vertexApiKeyModal = document.getElementById("vertexApiKeyModal");
const vertexApiKeyBulkInput = document.getElementById("vertexApiKeyBulkInput");
const bulkDeleteVertexApiKeyModal = document.getElementById(
@@ -41,7 +44,7 @@ const bulkDeleteVertexApiKeyModal = document.getElementById(
const bulkDeleteVertexApiKeyInput = document.getElementById(
"bulkDeleteVertexApiKeyInput"
);
// Model Helper Modal Elements
const modelHelperModal = document.getElementById("modelHelperModal");
const modelHelperTitleElement = document.getElementById("modelHelperTitle");
@@ -383,9 +386,15 @@ document.addEventListener("DOMContentLoaded", function () {
addSafetySettingBtn.addEventListener("click", () => addSafetySettingItem());
}
// Add Custom Header button
const addCustomHeaderBtn = document.getElementById("addCustomHeaderBtn");
if (addCustomHeaderBtn) {
addCustomHeaderBtn.addEventListener("click", () => addCustomHeaderItem());
}
initializeSensitiveFields(); // Initialize sensitive field handling
// Vertex API Key Modal Elements and Events
// Vertex Express API Key Modal Elements and Events
const addVertexApiKeyBtn = document.getElementById("addVertexApiKeyBtn");
const closeVertexApiKeyModalBtn = document.getElementById(
"closeVertexApiKeyModalBtn"
@@ -408,7 +417,7 @@ document.addEventListener("DOMContentLoaded", function () {
const confirmBulkDeleteVertexApiKeyBtn = document.getElementById(
"confirmBulkDeleteVertexApiKeyBtn"
);
if (addVertexApiKeyBtn) {
addVertexApiKeyBtn.addEventListener("click", () => {
openModal(vertexApiKeyModal);
@@ -428,7 +437,7 @@ document.addEventListener("DOMContentLoaded", function () {
"click",
handleBulkAddVertexApiKeys
);
if (bulkDeleteVertexApiKeyBtn) {
bulkDeleteVertexApiKeyBtn.addEventListener("click", () => {
openModal(bulkDeleteVertexApiKeyModal);
@@ -448,7 +457,7 @@ document.addEventListener("DOMContentLoaded", function () {
"click",
handleBulkDeleteVertexApiKeys
);
// Model Helper Modal Event Listeners
if (closeModelHelperModalBtn) {
closeModelHelperModalBtn.addEventListener("click", () =>
@@ -691,12 +700,26 @@ async function initConfig() {
) {
config.THINKING_BUDGET_MAP = {}; // 默认为空对象
}
// --- 新增:处理 CUSTOM_HEADERS 默认值 ---
if (
!config.CUSTOM_HEADERS ||
typeof config.CUSTOM_HEADERS !== "object" ||
config.CUSTOM_HEADERS === null
) {
config.CUSTOM_HEADERS = {}; // 默认为空对象
}
// --- 新增:处理 SAFETY_SETTINGS 默认值 ---
if (!config.SAFETY_SETTINGS || !Array.isArray(config.SAFETY_SETTINGS)) {
config.SAFETY_SETTINGS = []; // 默认为空数组
}
// --- 结束:处理 SAFETY_SETTINGS 默认值 ---
if (typeof config.URL_CONTEXT_ENABLED === "undefined") {
config.URL_CONTEXT_ENABLED = true;
}
if (!config.URL_CONTEXT_MODELS || !Array.isArray(config.URL_CONTEXT_MODELS)) {
config.URL_CONTEXT_MODELS = [];
}
// --- 新增:处理自动删除错误日志配置的默认值 ---
if (typeof config.AUTO_DELETE_ERROR_LOGS_ENABLED === "undefined") {
config.AUTO_DELETE_ERROR_LOGS_ENABLED = false;
@@ -756,6 +779,7 @@ async function initConfig() {
VERTEX_EXPRESS_BASE_URL: "", // 确保默认值存在
THINKING_MODELS: [],
THINKING_BUDGET_MAP: {},
CUSTOM_HEADERS: {},
AUTO_DELETE_ERROR_LOGS_ENABLED: false,
AUTO_DELETE_ERROR_LOGS_DAYS: 7, // 新增默认值
AUTO_DELETE_REQUEST_LOGS_ENABLED: false, // 新增默认值
@@ -765,7 +789,7 @@ async function initConfig() {
FAKE_STREAM_EMPTY_DATA_INTERVAL_SECONDS: 5,
// --- 结束:处理假流式配置的默认值 ---
};
populateForm(defaultConfig);
if (configForm) {
// Ensure form exists
@@ -854,6 +878,26 @@ function populateForm(config) {
'<div class="text-gray-500 text-sm italic">请在上方添加思考模型,预算将自动关联。</div>';
}
// Populate CUSTOM_HEADERS
const customHeadersContainer = document.getElementById(
"CUSTOM_HEADERS_container"
);
let customHeadersAdded = false;
if (
customHeadersContainer &&
config.CUSTOM_HEADERS &&
typeof config.CUSTOM_HEADERS === "object"
) {
for (const [key, value] of Object.entries(config.CUSTOM_HEADERS)) {
createAndAppendCustomHeaderItem(key, value);
customHeadersAdded = true;
}
}
if (!customHeadersAdded && customHeadersContainer) {
customHeadersContainer.innerHTML =
'<div class="text-gray-500 text-sm italic">添加自定义请求头,例如 X-Api-Key: your-key</div>';
}
// 4. Populate other array fields (excluding THINKING_MODELS)
for (const [key, value] of Object.entries(config)) {
if (Array.isArray(value) && key !== "THINKING_MODELS") {
@@ -1177,25 +1221,21 @@ function handleBulkDeleteProxies() {
}
bulkDeleteProxyInput.value = "";
}
/**
* Handles the bulk addition of Vertex API keys from the modal input.
* Handles the bulk addition of Vertex Express API keys from the modal input.
*/
function handleBulkAddVertexApiKeys() {
const vertexApiKeyContainer = document.getElementById(
"VERTEX_API_KEYS_container"
);
if (
!vertexApiKeyBulkInput ||
!vertexApiKeyContainer ||
!vertexApiKeyModal
) {
if (!vertexApiKeyBulkInput || !vertexApiKeyContainer || !vertexApiKeyModal) {
return;
}
const bulkText = vertexApiKeyBulkInput.value;
const extractedKeys = bulkText.match(VERTEX_API_KEY_REGEX) || [];
const currentKeyInputs = vertexApiKeyContainer.querySelectorAll(
`.${ARRAY_INPUT_CLASS}.${SENSITIVE_INPUT_CLASS}`
);
@@ -1206,16 +1246,16 @@ function handleBulkAddVertexApiKeys() {
: input.value;
})
.filter((key) => key && key.trim() !== "" && key !== MASKED_VALUE);
const combinedKeys = new Set([...currentKeys, ...extractedKeys]);
const uniqueKeys = Array.from(combinedKeys);
vertexApiKeyContainer.innerHTML = ""; // Clear existing items
uniqueKeys.forEach((key) => {
addArrayItemWithValue("VERTEX_API_KEYS", key); // VERTEX_API_KEYS are sensitive
});
// Ensure new sensitive inputs are masked
const newKeyInputs = vertexApiKeyContainer.querySelectorAll(
`.${ARRAY_INPUT_CLASS}.${SENSITIVE_INPUT_CLASS}`
@@ -1229,7 +1269,7 @@ function handleBulkAddVertexApiKeys() {
input.dispatchEvent(focusoutEvent);
}
});
closeModal(vertexApiKeyModal);
showNotification(
`添加/更新了 ${uniqueKeys.length} 个唯一 Vertex 密钥`,
@@ -1237,9 +1277,9 @@ function handleBulkAddVertexApiKeys() {
);
vertexApiKeyBulkInput.value = "";
}
/**
* Handles the bulk deletion of Vertex API keys based on input from the modal.
* Handles the bulk deletion of Vertex Express API keys based on input from the modal.
*/
function handleBulkDeleteVertexApiKeys() {
const vertexApiKeyContainer = document.getElementById(
@@ -1252,26 +1292,28 @@ function handleBulkDeleteVertexApiKeys() {
) {
return;
}
const bulkText = bulkDeleteVertexApiKeyInput.value;
if (!bulkText.trim()) {
showNotification("请粘贴需要删除的 Vertex API 密钥", "warning");
showNotification("请粘贴需要删除的 Vertex Express API 密钥", "warning");
return;
}
const keysToDelete = new Set(bulkText.match(VERTEX_API_KEY_REGEX) || []);
if (keysToDelete.size === 0) {
showNotification(
"未在输入内容中提取到有效的 Vertex API 密钥格式",
"未在输入内容中提取到有效的 Vertex Express API 密钥格式",
"warning"
);
return;
}
const keyItems = vertexApiKeyContainer.querySelectorAll(`.${ARRAY_ITEM_CLASS}`);
const keyItems = vertexApiKeyContainer.querySelectorAll(
`.${ARRAY_ITEM_CLASS}`
);
let deleteCount = 0;
keyItems.forEach((item) => {
const input = item.querySelector(
`.${ARRAY_INPUT_CLASS}.${SENSITIVE_INPUT_CLASS}`
@@ -1286,17 +1328,20 @@ function handleBulkDeleteVertexApiKeys() {
deleteCount++;
}
});
closeModal(bulkDeleteVertexApiKeyModal);
if (deleteCount > 0) {
showNotification(`成功删除了 ${deleteCount} 个匹配的 Vertex 密钥`, "success");
showNotification(
`成功删除了 ${deleteCount} 个匹配的 Vertex 密钥`,
"success"
);
} else {
showNotification("列表中未找到您输入的任何 Vertex 密钥进行删除", "info");
}
bulkDeleteVertexApiKeyInput.value = "";
}
/**
* Switches the active configuration tab.
* @param {string} tabId - The ID of the tab to switch to.
@@ -1305,8 +1350,10 @@ function switchTab(tabId) {
console.log(`Switching to tab: ${tabId}`);
// 定义选中态和未选中态的样式
const activeStyle = "background-color: #3b82f6 !important; color: #ffffff !important; border: 2px solid #2563eb !important; box-shadow: 0 4px 12px -2px rgba(59, 130, 246, 0.4), 0 2px 6px -1px rgba(59, 130, 246, 0.2) !important; transform: translateY(-2px) !important; font-weight: 600 !important;";
const inactiveStyle = "background-color: #f8fafc !important; color: #64748b !important; border: 2px solid #e2e8f0 !important; box-shadow: 0 1px 3px 0 rgba(0, 0, 0, 0.1) !important; font-weight: 500 !important; transform: none !important;";
const activeStyle =
"background-color: #3b82f6 !important; color: #ffffff !important; border: 2px solid #2563eb !important; box-shadow: 0 4px 12px -2px rgba(59, 130, 246, 0.4), 0 2px 6px -1px rgba(59, 130, 246, 0.2) !important; transform: translateY(-2px) !important; font-weight: 600 !important;";
const inactiveStyle =
"background-color: #f8fafc !important; color: #64748b !important; border: 2px solid #e2e8f0 !important; box-shadow: 0 1px 3px 0 rgba(0, 0, 0, 0.1) !important; font-weight: 500 !important; transform: none !important;";
// 更新标签按钮状态
const tabButtons = document.querySelectorAll(".tab-btn");
@@ -1421,7 +1468,7 @@ function addArrayItem(key) {
const modelId = addArrayItemWithValue(key, newItemValue); // This adds the DOM element
if (key === "THINKING_MODELS" && modelId) {
createAndAppendBudgetMapItem(newItemValue, 0, modelId); // Default budget 0
createAndAppendBudgetMapItem(newItemValue, -1, modelId); // Default budget -1
}
}
@@ -1439,10 +1486,9 @@ function addArrayItemWithValue(key, value) {
const isThinkingModel = key === "THINKING_MODELS";
const isAllowedToken = key === "ALLOWED_TOKENS";
const isVertexApiKey = key === "VERTEX_API_KEYS"; // 新增判断
const isSensitive =
key === "API_KEYS" || isAllowedToken || isVertexApiKey; // 更新敏感判断
const isSensitive = key === "API_KEYS" || isAllowedToken || isVertexApiKey; // 更新敏感判断
const modelId = isThinkingModel ? generateUUID() : null;
const arrayItem = document.createElement("div");
arrayItem.className = `${ARRAY_ITEM_CLASS} flex items-center mb-2 gap-2`;
if (isThinkingModel) {
@@ -1532,17 +1578,17 @@ function createAndAppendBudgetMapItem(mapKey, mapValue, modelId) {
const valueInput = document.createElement("input");
valueInput.type = "number";
const intValue = parseInt(mapValue, 10);
valueInput.value = isNaN(intValue) ? 0 : intValue;
valueInput.value = isNaN(intValue) ? -1 : intValue;
valueInput.placeholder = "预算 (整数)";
valueInput.className = `${MAP_VALUE_INPUT_CLASS} w-24 px-3 py-2 border border-gray-300 rounded-md focus:outline-none focus:border-primary-500 focus:ring focus:ring-primary-200 focus:ring-opacity-50`;
valueInput.min = 0;
valueInput.max = 24576;
valueInput.min = -1;
valueInput.max = 32767;
valueInput.addEventListener("input", function () {
let val = this.value.replace(/[^0-9]/g, "");
let val = this.value.replace(/[^0-9-]/g, "");
if (val !== "") {
val = parseInt(val, 10);
if (val < 0) val = 0;
if (val > 24576) val = 24576;
if (val < -1) val = -1;
if (val > 32767) val = 32767;
}
this.value = val; // Corrected variable name
});
@@ -1562,6 +1608,67 @@ function createAndAppendBudgetMapItem(mapKey, mapValue, modelId) {
container.appendChild(mapItem);
}
/**
* Adds a new custom header item to the DOM.
*/
function addCustomHeaderItem() {
createAndAppendCustomHeaderItem("", "");
}
/**
* Creates and appends a DOM element for a custom header.
* @param {string} key - The header key.
* @param {string} value - The header value.
*/
function createAndAppendCustomHeaderItem(key, value) {
const container = document.getElementById("CUSTOM_HEADERS_container");
if (!container) {
console.error(
"Cannot add custom header: CUSTOM_HEADERS_container not found!"
);
return;
}
const placeholder = container.querySelector(".text-gray-500.italic");
if (
placeholder &&
container.children.length === 1 &&
container.firstChild === placeholder
) {
container.innerHTML = "";
}
const headerItem = document.createElement("div");
headerItem.className = `${CUSTOM_HEADER_ITEM_CLASS} flex items-center mb-2 gap-2`;
const keyInput = document.createElement("input");
keyInput.type = "text";
keyInput.value = key;
keyInput.placeholder = "Header Name";
keyInput.className = `${CUSTOM_HEADER_KEY_INPUT_CLASS} flex-grow px-3 py-2 border border-gray-300 rounded-md focus:outline-none bg-gray-100 text-gray-500`;
const valueInput = document.createElement("input");
valueInput.type = "text";
valueInput.value = value;
valueInput.placeholder = "Header Value";
valueInput.className = `${CUSTOM_HEADER_VALUE_INPUT_CLASS} flex-grow px-3 py-2 border border-gray-300 rounded-md focus:outline-none focus:border-primary-500 focus:ring focus:ring-primary-200 focus:ring-opacity-50`;
const removeBtn = createRemoveButton();
removeBtn.addEventListener("click", () => {
headerItem.remove();
if (container.children.length === 0) {
container.innerHTML =
'<div class="text-gray-500 text-sm italic">添加自定义请求头,例如 X-Api-Key: your-key</div>';
}
});
headerItem.appendChild(keyInput);
headerItem.appendChild(valueInput);
headerItem.appendChild(removeBtn);
container.appendChild(headerItem);
}
/**
* Collects all data from the configuration form.
* @returns {object} An object containing all configuration data.
@@ -1632,12 +1739,32 @@ function collectFormData() {
formData["THINKING_BUDGET_MAP"][keyInput.value.trim()] = isNaN(
budgetValue
)
? 0
? -1
: budgetValue;
}
});
}
const customHeadersContainer = document.getElementById(
"CUSTOM_HEADERS_container"
);
if (customHeadersContainer) {
formData["CUSTOM_HEADERS"] = {};
const customHeaderItems = customHeadersContainer.querySelectorAll(
`.${CUSTOM_HEADER_ITEM_CLASS}`
);
customHeaderItems.forEach((item) => {
const keyInput = item.querySelector(`.${CUSTOM_HEADER_KEY_INPUT_CLASS}`);
const valueInput = item.querySelector(
`.${CUSTOM_HEADER_VALUE_INPUT_CLASS}`
);
if (keyInput && valueInput && keyInput.value.trim() !== "") {
formData["CUSTOM_HEADERS"][keyInput.value.trim()] =
valueInput.value.trim();
}
});
}
if (safetySettingsContainer) {
formData["SAFETY_SETTINGS"] = [];
const settingItems = safetySettingsContainer.querySelectorAll(
@@ -2163,7 +2290,7 @@ function handleModelSelection(selectedModelId) {
);
if (currentModelHelperTarget.targetKey === "THINKING_MODELS" && modelId) {
// Automatically add corresponding budget map item with default budget 0
createAndAppendBudgetMapItem(selectedModelId, 0, modelId);
createAndAppendBudgetMapItem(selectedModelId, -1, modelId);
}
}

View File

@@ -817,58 +817,11 @@ function toggleSection(header, sectionId) {
}
}
// 筛选有效密钥(根据失败次数阈值)并更新批量操作状态
// filterValidKeys 函数已被 filterAndSearchValidKeys 替代,此函数保留为空或可移除
function filterValidKeys() {
const thresholdInput = document.getElementById("failCountThreshold");
const validKeysList = document.getElementById("validKeys"); // Get the UL element
if (!validKeysList) return; // Exit if the list doesn't exist
const validKeyItems = validKeysList.querySelectorAll("li[data-key]"); // Select li elements within the list
// 读取阈值如果输入无效或为空则默认为0不过滤
const threshold = parseInt(thresholdInput.value, 10);
const filterThreshold = isNaN(threshold) || threshold < 0 ? 0 : threshold;
let hasVisibleItems = false;
validKeyItems.forEach((item) => {
// 确保只处理包含 data-fail-count 的 li 元素
if (item.dataset.failCount !== undefined) {
const failCount = parseInt(item.dataset.failCount, 10);
// 如果失败次数大于等于阈值,则显示,否则隐藏
if (failCount >= filterThreshold) {
item.style.display = "flex"; // 使用 flex 因为 li 现在是 flex 容器
hasVisibleItems = true;
} else {
item.style.display = "none"; // 隐藏
// 如果隐藏了一个项,取消其选中状态
const checkbox = item.querySelector(".key-checkbox");
if (checkbox && checkbox.checked) {
checkbox.checked = false;
}
}
}
});
// 更新有效密钥的批量操作状态和全选复选框
updateBatchActions("valid");
// 处理"暂无有效密钥"消息
const noMatchMsgId = "no-valid-keys-msg";
let noMatchMsg = validKeysList.querySelector(`#${noMatchMsgId}`);
const initialKeyCount = validKeysList.querySelectorAll("li[data-key]").length; // 获取初始密钥数量
if (!hasVisibleItems && initialKeyCount > 0) {
// 仅当初始有密钥但现在都不可见时显示
if (!noMatchMsg) {
noMatchMsg = document.createElement("li");
noMatchMsg.id = noMatchMsgId;
noMatchMsg.className = "text-center text-gray-500 py-4 col-span-full";
noMatchMsg.textContent = "没有符合条件的有效密钥";
validKeysList.appendChild(noMatchMsg);
}
noMatchMsg.style.display = "";
} else if (noMatchMsg) {
noMatchMsg.style.display = "none";
}
// This function is now handled by filterAndSearchValidKeys
// Kept for now to avoid breaking any potential legacy calls, but should be removed later.
filterAndSearchValidKeys();
}
// --- Initialization Helper Functions ---

File diff suppressed because it is too large Load Diff

View File

@@ -1245,6 +1245,7 @@ endblock %} {% block head_extra_styles %}
<option value="20">20</option>
<option value="50">50</option>
<option value="100">100</option>
<option value="500">500</option>
</select>
<span class="text-sm select-none font-semibold" style="color: #1f2937 !important;"></span>
</div>

View File

@@ -261,18 +261,20 @@ class PicGoUploader(ImageUploader):
class CloudFlareImgBedUploader(ImageUploader):
"""CloudFlare图床上传器"""
def __init__(self, auth_code: str, api_url: str):
def __init__(self, auth_code: str, api_url: str, upload_folder: str = ""):
"""
初始化CloudFlare图床上传器
Args:
auth_code: 认证码
api_url: 上传API地址
upload_folder: 上传文件夹路径(可选)
"""
self.auth_code = auth_code
self.api_url = api_url
self.upload_folder = upload_folder
def upload(self, file: bytes, filename: str) -> UploadResponse:
"""
上传图片到CloudFlare图床
@@ -288,12 +290,16 @@ class CloudFlareImgBedUploader(ImageUploader):
UploadError: 上传失败时抛出异常
"""
try:
# 准备请求URL(添加认证码参数,如果存在)
# 准备请求URL参数
params = []
if self.upload_folder:
params.append(f"uploadFolder={self.upload_folder}")
if self.auth_code:
request_url = f"{self.api_url}?authCode={self.auth_code}&uploadNameType=origin"
else:
request_url = f"{self.api_url}?uploadNameType=origin"
params.append(f"authCode={self.auth_code}")
params.append("uploadNameType=origin")
request_url = f"{self.api_url}?{'&'.join(params)}"
# 准备文件数据
files = {
"file": (filename, file)
@@ -388,6 +394,7 @@ class ImageUploaderFactory:
elif provider == "cloudflare_imgbed":
return CloudFlareImgBedUploader(
credentials["auth_code"],
credentials["base_url"]
credentials["base_url"],
credentials.get("upload_folder", ""),
)
raise ValueError(f"Unknown provider: {provider}")